Source code for skgrf.tree.causal_regressor

import logging
import typing as t

from skgrf.tree.instrumental_regressor import GRFTreeInstrumentalRegressor

if t.TYPE_CHECKING:  # pragma: no cover
    from skgrf.ensemble import GRFForestCausalRegressor

logger = logging.getLogger(__name__)


[docs]class GRFTreeCausalRegressor(GRFTreeInstrumentalRegressor): r"""GRF Tree Causal regression implementation for sci-kit learn. Provides a sklearn tree causal regressor to the GRF C++ library using Cython. :param bool equalize_cluster_weights: Weight the samples such that clusters have equally weight. If ``False``, larger clusters will have more weight. If ``True``, the number of samples drawn from each cluster is equal to the size of the smallest cluster. If ``True``, sample weights should not be passed on fitting. :param float sample_fraction: Fraction of samples used in each tree. If ``ci_group_size`` > 1, the max allowed fraction is 0.5 :param int mtry: The number of features to split on each node. The default is ``sqrt(p) + 20`` where ``p`` is the number of features. :param int min_node_size: The minimum number of observations in each tree leaf. :param bool honesty: Use honest splitting (subsample splitting). :param float honesty_fraction: The fraction of data used for subsample splitting. :param bool honesty_prune_leaves: Prune estimation sample tree such that no leaves are empty. If ``False``, trees with empty leaves are skipped. :param float alpha: The maximum imbalance of a split. :param float imbalance_penalty: Penalty applied to imbalanced splits. :param bool orthogonal_boosting: When ``y_hat`` or ``w_hat`` are ``None``, they are estimated using boosted regression forests. (Not yet implemented) :param bool stabilize_splits: Whether or not the instrument should be taken into account when determining the imbalance of a split. :param int n_jobs: The number of threads. Default is number of CPU cores. Only used for target estimation. :param int seed: Random seed value. :ivar int n_features_in\_: The number of features (columns) from the fit input ``X``. :ivar dict grf_forest\_: The returned result object from calling C++ grf. :ivar int mtry\_: The ``mtry`` value determined by validation. :ivar int outcome_index\_: The index of the grf train matrix holding the outcomes. :ivar list samples_per_cluster\_: The number of samples to train per cluster. :ivar list clusters\_: The cluster labels determined from the fit input ``cluster``. :ivar int n_clusters\_: The number of unique cluster labels from the fit input ``cluster``. :ivar str criterion: The criterion used for splitting: ``mse`` """ def __init__( self, equalize_cluster_weights=False, sample_fraction=0.5, mtry=None, min_node_size=5, honesty=True, honesty_fraction=0.5, honesty_prune_leaves=True, alpha=0.05, imbalance_penalty=0, stabilize_splits=True, orthogonal_boosting=False, n_jobs=-1, seed=42, ): super().__init__( equalize_cluster_weights=equalize_cluster_weights, sample_fraction=sample_fraction, mtry=mtry, min_node_size=min_node_size, honesty=honesty, honesty_fraction=honesty_fraction, honesty_prune_leaves=honesty_prune_leaves, alpha=alpha, imbalance_penalty=imbalance_penalty, reduced_form_weight=0, stabilize_splits=stabilize_splits, n_jobs=n_jobs, seed=seed, ) self.orthogonal_boosting = orthogonal_boosting
[docs] @classmethod def from_forest(cls, forest: "GRFForestCausalRegressor", idx: int): """Extract a tree from a forest. :param GRFLocalLinearRegressor forest: A trained GRFLocalLinearRegressor instance :param int idx: The tree index from the forest to extract. """ # Even though we have a tree object, we keep the exact same dictionary structure # that exists in the forests, so that we can reuse the Cython entrypoints. # We also copy over some instance attributes from the trained forest. # params instance = cls( equalize_cluster_weights=forest.equalize_cluster_weights, sample_fraction=forest.sample_fraction, mtry=forest.mtry, min_node_size=forest.min_node_size, honesty=forest.honesty, honesty_fraction=forest.honesty_fraction, honesty_prune_leaves=forest.honesty_prune_leaves, alpha=forest.alpha, imbalance_penalty=forest.imbalance_penalty, stabilize_splits=forest.stabilize_splits, n_jobs=forest.n_jobs, seed=forest.seed, ) # forest grf_forest = {} for k, v in forest.grf_forest_.items(): if isinstance(v, list): grf_forest[k] = [forest.grf_forest_[k][idx]] else: grf_forest[k] = v grf_forest["num_trees"] = 1 instance.grf_forest_ = grf_forest instance._ensure_ptr() # vars instance.outcome_index_ = forest.outcome_index_ instance.treatment_index_ = forest.treatment_index_ instance.instrument_index_ = forest.instrument_index_ instance.n_features_in_ = forest.n_features_in_ instance.clusters_ = forest.clusters_ instance.n_clusters_ = forest.n_clusters_ instance.samples_per_cluster_ = forest.samples_per_cluster_ instance.mtry_ = forest.mtry_ instance.sample_weight_index_ = forest.sample_weight_index_ return instance
# noinspection PyMethodOverriding
[docs] def fit( self, X, y, w, # treatment y_hat=None, w_hat=None, sample_weight=None, cluster=None, ): """Fit the grf forest using training data. :param array2d X: training input features :param array1d y: training input targets :param array1d w: training input treatments :param array1d y_hat: estimated expected target responses :param array1d w_hat: estimated treatment propensities :param array1d sample_weight: optional weights for input samples :param array1d cluster: optional cluster assignments for input samples """ # avoiding circular import from skgrf.ensemble.boosted_regressor import GRFBoostedForestRegressor X, y = self._validate_data(X, y, force_all_finite="allow-nan") self._check_num_samples(X) boost_params = { "n_estimators": 500, "equalize_cluster_weights": self.equalize_cluster_weights, "sample_fraction": self.sample_fraction, "mtry": self.mtry, "min_node_size": 5, "honesty": True, "honesty_fraction": 0.5, "honesty_prune_leaves": self.honesty_prune_leaves, "alpha": self.alpha, "imbalance_penalty": self.imbalance_penalty, "ci_group_size": 1, "tune_params": None, # TODO ? "n_jobs": self.n_jobs, "seed": self.seed, } if y_hat is None and self.orthogonal_boosting: logger.debug("orthogonal boosting y_hat") br = GRFBoostedForestRegressor(**boost_params) br.fit(X, y, sample_weight=sample_weight, cluster=cluster) y_hat = br.boosted_forests_["predictions"] if w_hat is None and self.orthogonal_boosting: logger.debug("orthogonal boosting w_hat") br = GRFBoostedForestRegressor(**boost_params) br.fit(X, w, sample_weight=sample_weight, cluster=cluster) w_hat = br.boosted_forests_["predictions"] return super().fit( X=X, y=y, w=w, z=w, y_hat=y_hat, w_hat=w_hat, z_hat=w_hat, cluster=cluster, sample_weight=sample_weight, )