Warning: This repository is deprecated. All functionalities have been moved to (a package extension of) SoleModels.jl. Please refer to that repository for continued support and updates.
Ever wondered what to do with a trained decision tree? Start by inspecting its knowledge, and end up evaluating it in a dedicated framework! This package allows you to convert learned DecisionTree models to Sole decision tree models. With a Sole model in your hand, you can then treat the extracted knowledge in symbolic form, that is, as a set of logical formulas, which allows you to:
- Evaluate them in terms of
- accuracy (e.g., confidence, lift),
- relevance (e.g., support),
- interpretability (e.g., syntax height, number of atoms);
- Modify them;
- Merge them.
using MLJ
using MLJDecisionTreeInterface
using DataFrames
X, y = @load_iris
X = DataFrame(X)
train, test = partition(eachindex(y), 0.8, shuffle=true);
X_train, y_train = X[train, :], y[train];
X_test, y_test = X[test, :], y[test];
# Train a model
learned_dt_tree = begin
Tree = MLJ.@load DecisionTreeClassifier pkg=DecisionTree
model = Tree(max_depth=-1, )
mach = machine(model, X_train, y_train)
fit!(mach)
fitted_params(mach).tree
end
using SoleDecisionTreeInterface
# Convert to Sole model
sole_dt = solemodel(learned_dt_tree)
julia> using Sole;
julia> # Make test instances flow into the model, so that test metrics can, then, be computed.
apply!(sole_dt, X_test, y_test);
julia> # Print Sole model
printmodel(sole_dt; show_metrics = true);
▣ V4 < 0.8
├✔ setosa : (ninstances = 7, ncovered = 7, confidence = 1.0, lift = 1.0)
└✘ V3 < 4.95
├✔ V4 < 1.65
│├✔ versicolor : (ninstances = 10, ncovered = 10, confidence = 1.0, lift = 1.0)
│└✘ V2 < 3.1
│ ├✔ virginica : (ninstances = 2, ncovered = 2, confidence = 1.0, lift = 1.0)
│ └✘ versicolor : (ninstances = 0, ncovered = 0, confidence = NaN, lift = NaN)
└✘ V3 < 5.05
├✔ V1 < 6.5
│├✔ virginica : (ninstances = 0, ncovered = 0, confidence = NaN, lift = NaN)
│└✘ versicolor : (ninstances = 0, ncovered = 0, confidence = NaN, lift = NaN)
└✘ virginica : (ninstances = 11, ncovered = 11, confidence = 0.91, lift = 1.0)
julia> # Extract rules that are at least as good as a random baseline model
interesting_rules = listrules(sole_dt, min_lift = 1.0, min_ninstances = 0);
julia> printmodel.(interesting_rules; show_metrics = true);
▣ (V4 < 0.8) ∧ (⊤) ↣ setosa : (ninstances = 30, ncovered = 7, coverage = 0.23, confidence = 1.0, natoms = 1, lift = 4.29)
▣ (¬(V4 < 0.8)) ∧ (V3 < 4.95) ∧ (V4 < 1.65) ∧ (⊤) ↣ versicolor : (ninstances = 30, ncovered = 10, coverage = 0.33, confidence = 1.0, natoms = 3, lift = 2.73)
▣ (¬(V4 < 0.8)) ∧ (V3 < 4.95) ∧ (¬(V4 < 1.65)) ∧ (V2 < 3.1) ∧ (⊤) ↣ virginica : (ninstances = 30, ncovered = 2, coverage = 0.07, confidence = 1.0, natoms = 4, lift = 2.5)
▣ (¬(V4 < 0.8)) ∧ (¬(V3 < 4.95)) ∧ (¬(V3 < 5.05)) ∧ (⊤) ↣ virginica : (ninstances = 30, ncovered = 11, coverage = 0.37, confidence = 0.91, natoms = 3, lift = 2.27)
julia> # Simplify rules while extracting and prettify result
interesting_rules = listrules(sole_dt, min_lift = 1.0, min_ninstances = 0, normalize = true);
julia> printmodel.(interesting_rules; show_metrics = true, syntaxstring_kwargs = (; threshold_digits = 2));
▣ V4 < 0.8 ↣ setosa : (ninstances = 30, ncovered = 7, coverage = 0.23, confidence = 1.0, natoms = 1, lift = 4.29)
▣ (V4 ∈ [0.8,1.65)) ∧ (V3 < 4.95) ↣ versicolor : (ninstances = 30, ncovered = 10, coverage = 0.33, confidence = 1.0, natoms = 2, lift = 2.73)
▣ (V4 ≥ 1.65) ∧ (V3 < 4.95) ∧ (V2 < 3.1) ↣ virginica : (ninstances = 30, ncovered = 2, coverage = 0.07, confidence = 1.0, natoms = 3, lift = 2.5)
▣ (V4 ≥ 0.8) ∧ (V3 ≥ 5.05) ↣ virginica : (ninstances = 30, ncovered = 11, coverage = 0.37, confidence = 0.91, natoms = 2, lift = 2.27)
julia> # Directly access rule metrics
readmetrics.(listrules(sole_dt; min_lift=1.0, min_ninstances = 0))
4-element Vector{NamedTuple{(:ninstances, :ncovered, :coverage, :confidence, :natoms, :lift), Tuple{Int64, Int64, Float64, Float64, Int64, Float64}}}:
(ninstances = 30, ncovered = 7, coverage = 0.23333333333333334, confidence = 1.0, natoms = 1, lift = 4.285714285714286)
(ninstances = 30, ncovered = 10, coverage = 0.3333333333333333, confidence = 1.0, natoms = 3, lift = 2.7272727272727275)
(ninstances = 30, ncovered = 2, coverage = 0.06666666666666667, confidence = 1.0, natoms = 4, lift = 2.5)
(ninstances = 30, ncovered = 11, coverage = 0.36666666666666664, confidence = 0.9090909090909091, natoms = 3, lift = 2.2727272727272725)
julia> # Show rules with an additional metric (syntax height of the rule's antecedent)
printmodel.(sort(interesting_rules, by = readmetrics); show_metrics = (; round_digits = nothing, additional_metrics = (; height = r->SoleLogics.height(antecedent(r)))));
▣ (V4 ≥ 1.65) ∧ (V3 < 4.95) ∧ (V2 < 3.1) ↣ virginica : (ninstances = 30, ncovered = 2, coverage = 0.06666666666666667, confidence = 1.0, height = 2, lift = 2.5)
▣ V4 < 0.8 ↣ setosa : (ninstances = 30, ncovered = 7, coverage = 0.23333333333333334, confidence = 1.0, height = 0, lift = 4.285714285714286)
▣ (V4 ∈ [0.8,1.65)) ∧ (V3 < 4.95) ↣ versicolor : (ninstances = 30, ncovered = 10, coverage = 0.3333333333333333, confidence = 1.0, height = 1, lift = 2.7272727272727275)
▣ (V4 ≥ 0.8) ∧ (V3 ≥ 5.05) ↣ virginica : (ninstances = 30, ncovered = 11, coverage = 0.36666666666666664, confidence = 0.9090909090909091, height = 1, lift = 2.2727272727272725)
julia> # Pretty table of rules and their metrics
metricstable(interesting_rules; metrics_kwargs = (; round_digits = nothing, additional_metrics = (; height = r->SoleLogics.height(antecedent(r)))))
┌────────────────────────────────────────┬────────────┬────────────┬──────────┬───────────┬────────────┬────────┬─────────┐
│ Antecedent │ Consequent │ ninstances │ ncovered │ coverage │ confidence │ height │ lift │
├────────────────────────────────────────┼────────────┼────────────┼──────────┼───────────┼────────────┼────────┼─────────┤
│ V4 < 0.8 │ setosa │ 30 │ 7 │ 0.233333 │ 1.0 │ 0 │ 4.28571 │
│ (V4 ∈ [0.8,1.65)) ∧ (V3 < 4.95) │ versicolor │ 30 │ 10 │ 0.333333 │ 1.0 │ 1 │ 2.72727 │
│ (V4 ≥ 1.65) ∧ (V3 < 4.95) ∧ (V2 < 3.1) │ virginica │ 30 │ 2 │ 0.0666667 │ 1.0 │ 2 │ 2.5 │
│ (V4 ≥ 0.8) ∧ (V3 ≥ 5.05) │ virginica │ 30 │ 11 │ 0.366667 │ 0.909091 │ 1 │ 2.27273 │
└────────────────────────────────────────┴────────────┴────────────┴──────────┴───────────┴────────────┴────────┴─────────┘