Source code for skgrf.tree.base

from sklearn.base import BaseEstimator
from sklearn.utils.validation import check_is_fitted

from skgrf.base import GRFMixin
from skgrf.tree._tree import Tree


class BaseGRFTree(GRFMixin, BaseEstimator):
    @property
    def tree_(self):
        check_is_fitted(self)
        return Tree(grf_forest=self.grf_forest_)

    def get_depth(self):
        """Calculate the maximum depth of the tree."""
        return self.tree_.get_depth()

    def get_n_leaves(self):
        """Calculate the number of leaves of the tree."""
        return self.tree_.get_n_leaves()

    def apply(self, X):
        """Calculate the index of the leaf for each sample.

        :param array2d X: training input features
        """
        return self.tree_.apply(X)

    def decision_path(self, X):
        """Calculate the decision path through the tree for each sample.

        :param array2d X: training input features
        """
        return self.tree_.decision_path(X)