Skip to content
This repository has been archived by the owner on Feb 18, 2025. It is now read-only.

Sole interface for trees trained via JuliaAI/DecisionTree.jl.

License

Notifications You must be signed in to change notification settings

aclai-lab/SoleDecisionTreeInterface.jl

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

39 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Warning: This repository is deprecated. All functionalities have been moved to (a package extension of) SoleModels.jl. Please refer to that repository for continued support and updates.

SoleDecisionTreeInterface.jl

Stable Dev Build Status Coverage

Ever wondered what to do with a trained decision tree? Start by inspecting its knowledge, and end up evaluating it in a dedicated framework! This package allows you to convert learned DecisionTree models to Sole decision tree models. With a Sole model in your hand, you can then treat the extracted knowledge in symbolic form, that is, as a set of logical formulas, which allows you to:

  • Evaluate them in terms of
    • accuracy (e.g., confidence, lift),
    • relevance (e.g., support),
    • interpretability (e.g., syntax height, number of atoms);
  • Modify them;
  • Merge them.

Usage

Converting to a Sole model

using MLJ
using MLJDecisionTreeInterface
using DataFrames

X, y = @load_iris
X = DataFrame(X)

train, test = partition(eachindex(y), 0.8, shuffle=true);
X_train, y_train = X[train, :], y[train];
X_test, y_test = X[test, :], y[test];

# Train a model
learned_dt_tree = begin
  Tree = MLJ.@load DecisionTreeClassifier pkg=DecisionTree
  model = Tree(max_depth=-1, )
  mach = machine(model, X_train, y_train)
  fit!(mach)
  fitted_params(mach).tree
end

using SoleDecisionTreeInterface

# Convert to Sole model
sole_dt = solemodel(learned_dt_tree)

Model inspection & rule study

julia> using Sole;

julia> # Make test instances flow into the model, so that test metrics can, then, be computed.
       apply!(sole_dt, X_test, y_test);

julia> # Print Sole model
       printmodel(sole_dt; show_metrics = true);
▣ V4 < 0.8
├✔ setosa : (ninstances = 7, ncovered = 7, confidence = 1.0, lift = 1.0)
└✘ V3 < 4.95
 ├✔ V4 < 1.65
 │├✔ versicolor : (ninstances = 10, ncovered = 10, confidence = 1.0, lift = 1.0)
 │└✘ V2 < 3.1
 │ ├✔ virginica : (ninstances = 2, ncovered = 2, confidence = 1.0, lift = 1.0)
 │ └✘ versicolor : (ninstances = 0, ncovered = 0, confidence = NaN, lift = NaN)
 └✘ V3 < 5.05
  ├✔ V1 < 6.5
  │├✔ virginica : (ninstances = 0, ncovered = 0, confidence = NaN, lift = NaN)
  │└✘ versicolor : (ninstances = 0, ncovered = 0, confidence = NaN, lift = NaN)
  └✘ virginica : (ninstances = 11, ncovered = 11, confidence = 0.91, lift = 1.0)

julia> # Extract rules that are at least as good as a random baseline model
       interesting_rules = listrules(sole_dt, min_lift = 1.0, min_ninstances = 0);

julia> printmodel.(interesting_rules; show_metrics = true);
▣ (V4 < 0.8) ∧ (⊤)  ↣  setosa : (ninstances = 30, ncovered = 7, coverage = 0.23, confidence = 1.0, natoms = 1, lift = 4.29)
▣ (¬(V4 < 0.8)) ∧ (V3 < 4.95) ∧ (V4 < 1.65) ∧ (⊤)  ↣  versicolor : (ninstances = 30, ncovered = 10, coverage = 0.33, confidence = 1.0, natoms = 3, lift = 2.73)
▣ (¬(V4 < 0.8)) ∧ (V3 < 4.95) ∧ (¬(V4 < 1.65)) ∧ (V2 < 3.1) ∧ (⊤)  ↣  virginica : (ninstances = 30, ncovered = 2, coverage = 0.07, confidence = 1.0, natoms = 4, lift = 2.5)
▣ (¬(V4 < 0.8)) ∧ (¬(V3 < 4.95)) ∧ (¬(V3 < 5.05)) ∧ (⊤)  ↣  virginica : (ninstances = 30, ncovered = 11, coverage = 0.37, confidence = 0.91, natoms = 3, lift = 2.27)

julia> # Simplify rules while extracting and prettify result
       interesting_rules = listrules(sole_dt, min_lift = 1.0, min_ninstances = 0, normalize = true);

julia> printmodel.(interesting_rules; show_metrics = true, syntaxstring_kwargs = (; threshold_digits = 2));
▣ V4 < 0.8  ↣  setosa : (ninstances = 30, ncovered = 7, coverage = 0.23, confidence = 1.0, natoms = 1, lift = 4.29)
▣ (V4 ∈ [0.8,1.65)) ∧ (V3 < 4.95)  ↣  versicolor : (ninstances = 30, ncovered = 10, coverage = 0.33, confidence = 1.0, natoms = 2, lift = 2.73)
▣ (V4 ≥ 1.65) ∧ (V3 < 4.95) ∧ (V2 < 3.1)  ↣  virginica : (ninstances = 30, ncovered = 2, coverage = 0.07, confidence = 1.0, natoms = 3, lift = 2.5)
▣ (V4 ≥ 0.8) ∧ (V3 ≥ 5.05)  ↣  virginica : (ninstances = 30, ncovered = 11, coverage = 0.37, confidence = 0.91, natoms = 2, lift = 2.27)

julia> # Directly access rule metrics
       readmetrics.(listrules(sole_dt; min_lift=1.0, min_ninstances = 0))
4-element Vector{NamedTuple{(:ninstances, :ncovered, :coverage, :confidence, :natoms, :lift), Tuple{Int64, Int64, Float64, Float64, Int64, Float64}}}:
 (ninstances = 30, ncovered = 7, coverage = 0.23333333333333334, confidence = 1.0, natoms = 1, lift = 4.285714285714286)
 (ninstances = 30, ncovered = 10, coverage = 0.3333333333333333, confidence = 1.0, natoms = 3, lift = 2.7272727272727275)
 (ninstances = 30, ncovered = 2, coverage = 0.06666666666666667, confidence = 1.0, natoms = 4, lift = 2.5)
 (ninstances = 30, ncovered = 11, coverage = 0.36666666666666664, confidence = 0.9090909090909091, natoms = 3, lift = 2.2727272727272725)

julia> # Show rules with an additional metric (syntax height of the rule's antecedent)
       printmodel.(sort(interesting_rules, by = readmetrics); show_metrics = (; round_digits = nothing, additional_metrics = (; height = r->SoleLogics.height(antecedent(r)))));

▣ (V4 ≥ 1.65) ∧ (V3 < 4.95) ∧ (V2 < 3.1)  ↣  virginica : (ninstances = 30, ncovered = 2, coverage = 0.06666666666666667, confidence = 1.0, height = 2, lift = 2.5)
▣ V4 < 0.8  ↣  setosa : (ninstances = 30, ncovered = 7, coverage = 0.23333333333333334, confidence = 1.0, height = 0, lift = 4.285714285714286)
▣ (V4 ∈ [0.8,1.65)) ∧ (V3 < 4.95)  ↣  versicolor : (ninstances = 30, ncovered = 10, coverage = 0.3333333333333333, confidence = 1.0, height = 1, lift = 2.7272727272727275)
▣ (V4 ≥ 0.8) ∧ (V3 ≥ 5.05)  ↣  virginica : (ninstances = 30, ncovered = 11, coverage = 0.36666666666666664, confidence = 0.9090909090909091, height = 1, lift = 2.2727272727272725)

julia> # Pretty table of rules and their metrics
       metricstable(interesting_rules; metrics_kwargs = (; round_digits = nothing, additional_metrics = (; height = r->SoleLogics.height(antecedent(r)))))
┌────────────────────────────────────────┬────────────┬────────────┬──────────┬───────────┬────────────┬────────┬─────────┐
│                             Antecedent │ Consequent │ ninstances │ ncovered │  coverage │ confidence │ height │    lift │
├────────────────────────────────────────┼────────────┼────────────┼──────────┼───────────┼────────────┼────────┼─────────┤
│                               V4 < 0.8 │     setosa │         30 │        7 │  0.233333 │        1.0 │      0 │ 4.28571 │
│        (V4 ∈ [0.8,1.65)) ∧ (V3 < 4.95) │ versicolor │         30 │       10 │  0.333333 │        1.0 │      1 │ 2.72727 │
│ (V4 ≥ 1.65) ∧ (V3 < 4.95) ∧ (V2 < 3.1) │  virginica │         30 │        2 │ 0.0666667 │        1.0 │      2 │     2.5 │
│               (V4 ≥ 0.8) ∧ (V3 ≥ 5.05) │  virginica │         30 │       11 │  0.366667 │   0.909091 │      1 │ 2.27273 │
└────────────────────────────────────────┴────────────┴────────────┴──────────┴───────────┴────────────┴────────┴─────────┘

About

Sole interface for trees trained via JuliaAI/DecisionTree.jl.

Resources

License

Stars

Watchers

Forks

Packages

No packages published

Languages