Low-level tree interface¶
The tree interface
mimics that of sklearn.tree._tree.Tree
. In order to get this
level of detail, the argument enable_tree_details
must be set to True
on
ensemble estimators prior to fitting. The single Tree estimators will always perform
these calculations.
Since grf
does not track this level of detail on its trees, we perform
extra operations in Python to provide them. These operations are currently quite slow
and not well-optimized.
These operations include
determining which leaf node corresponds to each training sample for every tree
determining the sum of the weights of the training samples at each leaf node
determining the weighted average prediction value of each leaf node
These extra calculations deliver n_node_values
, weighted_n_node_values
,
and value
attributes of the Tree
class.
Note
Since we don’t have direct access to the sub-sampled training sets used in building each of the trees, we determine the above values using the full training set.
-
class
skgrf.tree._tree.
Tree
(grf_forest)[source]¶ The low-level tree interface.
Tree objects can be accessed using the
tree_
attribute on fitted GRF decision tree estimators. Instances ofTree
provide methods and properties describing the underlying structure and attributes of the tree.-
apply
(X)[source]¶ Calculate the leaf index for each sample.
- Parameters
X (array2d) – training input features
-
property
capacity
¶ The total nodes in the tree, including pruned nodes.
-
property
children_default
¶ Children nodes for missing data.
-
property
children_left
¶ Left children nodes.
-
property
children_right
¶ Right children nodes.
-
decision_path
(X)[source]¶ Calculate the decision path through the tree for each sample.
- Parameters
X (array2d) – training input features
-
property
feature
¶ Variables on which nodes are split.
-
property
max_depth
¶ Max depth of the tree.
-
property
n_classes
¶ The quantity of classes.
-
property
n_node_samples
¶ The number of samples reaching each node.
-
property
n_outputs
¶ The quantity of outputs of the tree.
-
property
node_count
¶ The quantity of (unpruned) nodes in the tree.
-
property
threshold
¶ Threshold values on which nodes are split.
-
property
value
¶ The constant prediction value of each node.
-
property
weighted_n_node_samples
¶ The sum of the weights of the samples reaching each node.
-
SHAP¶
Regressors and classifiers can be used with shap. A context manager is provided which can patch skgrf objects so that they work with shap.
from shap import TreeExplainer
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from skgrf.ensemble import GRFForestClassifier
from skgrf.utils.shap import shap_patch
X, y = load_iris(return_X_y=True)
X_train, X_test, y_train, y_test = train_test_split(X, y)
forest = GRFForestClassifier(enable_tree_details=True).fit(X_train, y_train)
with shap_patch():
explainer = TreeExplainer(model=forest, data=X_train)
shap_values = explainer.shap_values(X_test, check_additivity=False)