Low-level tree interface

The tree interface mimics that of sklearn.tree._tree.Tree. In order to get this level of detail, the argument enable_tree_details must be set to True on ensemble estimators prior to fitting. The single Tree estimators will always perform these calculations.

Since grf does not track this level of detail on its trees, we perform extra operations in Python to provide them. These operations are currently quite slow and not well-optimized.

These operations include

  • determining which leaf node corresponds to each training sample for every tree

  • determining the sum of the weights of the training samples at each leaf node

  • determining the weighted average prediction value of each leaf node

These extra calculations deliver n_node_values, weighted_n_node_values, and value attributes of the Tree class.

Note

Since we don’t have direct access to the sub-sampled training sets used in building each of the trees, we determine the above values using the full training set.

class skgrf.tree._tree.Tree(grf_forest)[source]

The low-level tree interface.

Tree objects can be accessed using the tree_ attribute on fitted GRF decision tree estimators. Instances of Tree provide methods and properties describing the underlying structure and attributes of the tree.

apply(X)[source]

Calculate the leaf index for each sample.

Parameters

X (array2d) – training input features

property capacity

The total nodes in the tree, including pruned nodes.

property children_default

Children nodes for missing data.

property children_left

Left children nodes.

property children_right

Right children nodes.

decision_path(X)[source]

Calculate the decision path through the tree for each sample.

Parameters

X (array2d) – training input features

property feature

Variables on which nodes are split.

get_depth()[source]

Calculate the maximum depth of the tree.

get_n_leaves()[source]

Calculate the number of leaves of the tree.

property max_depth

Max depth of the tree.

property n_classes

The quantity of classes.

property n_node_samples

The number of samples reaching each node.

property n_outputs

The quantity of outputs of the tree.

property node_count

The quantity of (unpruned) nodes in the tree.

property threshold

Threshold values on which nodes are split.

property value

The constant prediction value of each node.

property weighted_n_node_samples

The sum of the weights of the samples reaching each node.

SHAP

Regressors and classifiers can be used with shap. A context manager is provided which can patch skgrf objects so that they work with shap.

from shap import TreeExplainer
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from skgrf.ensemble import GRFForestClassifier
from skgrf.utils.shap import shap_patch

X, y = load_iris(return_X_y=True)
X_train, X_test, y_train, y_test = train_test_split(X, y)

forest = GRFForestClassifier(enable_tree_details=True).fit(X_train, y_train)

with shap_patch():
    explainer = TreeExplainer(model=forest, data=X_train)

shap_values = explainer.shap_values(X_test, check_additivity=False)