Tree Survival¶
-
class
skgrf.tree.GRFTreeSurvival(*, equalize_cluster_weights=False, sample_fraction=0.5, mtry=None, min_node_size=5, honesty=True, honesty_fraction=0.5, honesty_prune_leaves=True, alpha=0.05, seed=42)[source]¶ GRF Tree Survival implementation for sci-kit learn.
Provides a sklearn tree survival interface to the GRF C++ library using Cython.
Warning
Because the training dataset is required for prediction, the training dataset is recorded onto the estimator instance. This means that serializing this estimator will result in a file at least as large as the serialized training dataset.
- Parameters
equalize_cluster_weights (bool) – Weight the samples such that clusters have equally weight. If
False, larger clusters will have more weight. IfTrue, the number of samples drawn from each cluster is equal to the size of the smallest cluster. IfTrue, sample weights should not be passed on fitting.sample_fraction (float) – Fraction of samples used in each tree.
mtry (int) – The number of features to split on each node. The default is
sqrt(p) + 20wherepis the number of features.min_node_size (int) – The minimum number of observations in each tree leaf.
honesty (bool) – Use honest splitting (subsample splitting).
honesty_fraction (float) – The fraction of data used for subsample splitting.
honesty_prune_leaves (bool) – Prune estimation sample tree such that no leaves are empty. If
False, trees with empty leaves are skipped.alpha (float) – The maximum imbalance of a split.
seed (int) – Random seed value.
- Variables
n_features_in_ (int) – The number of features (columns) from the fit input
X.grf_forest_ (dict) – The returned result object from calling C++ grf.
mtry_ (int) – The
mtryvalue determined by validation.outcome_index_ (int) – The index of the grf train matrix holding the outcomes.
censor_index_ (int) – The index of the grf train matrix holding the censoring.
failure_times_ (array1d) – An array of unique failure times from the training set.
num_failures_ (int) – The length of the
failure_timesarray.clusters_ (list) – The cluster labels determined from the fit input
cluster.n_clusters_ (int) – The number of unique cluster labels from the fit input
cluster.criterion (str) – The criterion used for splitting:
logrank
-
apply(X)¶ Calculate the index of the leaf for each sample.
- Parameters
X (array2d) – training input features
-
decision_path(X)¶ Calculate the decision path through the tree for each sample.
- Parameters
X (array2d) – training input features
-
fit(X, y, sample_weight=None, cluster=None)[source]¶ Fit the grf tree using training data.
- Parameters
X (array2d) – training input features
y (array1d) – training input targets, rows of (bool, float) representing (survival, time)
sample_weight (array1d) – optional weights for input samples
cluster (array1d) – optional cluster assignments for input samples
-
classmethod
from_forest(forest: GRFForestSurvival, idx: int)[source]¶ Extract a tree from a forest.
- Parameters
forest (GRFForestSurvival) – A trained GRFSurvival instance
idx (int) – The tree index from the forest to extract.
-
get_depth()¶ Calculate the maximum depth of the tree.
-
get_n_leaves()¶ Calculate the number of leaves of the tree.
-
get_params(deep=True)¶ Get parameters for this estimator.
- Parameters
deep (bool, default=True) – If True, will return the parameters for this estimator and contained subobjects that are estimators.
- Returns
params – Parameter names mapped to their values.
- Return type
dict
-
predict_cumulative_hazard_function(X)[source]¶ Predict cumulative hazard function.
- Parameters
X (array2d) – prediction input features
-
predict_survival_function(X)[source]¶ Predict survival function.
- Parameters
X (array2d) – prediction input features
-
set_params(**params)¶ Set the parameters of this estimator.
The method works on simple estimators as well as on nested objects (such as
Pipeline). The latter have parameters of the form<component>__<parameter>so that it’s possible to update each component of a nested object.- Parameters
**params (dict) – Estimator parameters.
- Returns
self – Estimator instance.
- Return type
estimator instance