Tree Survival¶
-
class
skgrf.tree.
GRFTreeSurvival
(equalize_cluster_weights=False, sample_fraction=0.5, mtry=None, min_node_size=5, honesty=True, honesty_fraction=0.5, honesty_prune_leaves=True, alpha=0.05, seed=42)[source]¶ GRF Tree Survival implementation for sci-kit learn.
Provides a sklearn tree survival interface to the GRF C++ library using Cython.
Warning
Because the training dataset is required for prediction, the training dataset is recorded onto the estimator instance. This means that serializing this estimator will result in a file at least as large as the serialized training dataset.
- Parameters
equalize_cluster_weights (bool) – Weight the samples such that clusters have equally weight. If
False
, larger clusters will have more weight. IfTrue
, the number of samples drawn from each cluster is equal to the size of the smallest cluster. IfTrue
, sample weights should not be passed on fitting.sample_fraction (float) – Fraction of samples used in each tree.
mtry (int) – The number of features to split on each node. The default is
sqrt(p) + 20
wherep
is the number of features.min_node_size (int) – The minimum number of observations in each tree leaf.
honesty (bool) – Use honest splitting (subsample splitting).
honesty_fraction (float) – The fraction of data used for subsample splitting.
honesty_prune_leaves (bool) – Prune estimation sample tree such that no leaves are empty. If
False
, trees with empty leaves are skipped.alpha (float) – The maximum imbalance of a split.
seed (int) – Random seed value.
- Variables
n_features_in_ (int) – The number of features (columns) from the fit input
X
.grf_forest_ (dict) – The returned result object from calling C++ grf.
mtry_ (int) – The
mtry
value determined by validation.outcome_index_ (int) – The index of the grf train matrix holding the outcomes.
censor_index_ (int) – The index of the grf train matrix holding the censoring.
failure_times_ (array1d) – An array of unique failure times from the training set.
num_failures_ (int) – The length of the
failure_times
array.clusters_ (list) – The cluster labels determined from the fit input
cluster
.n_clusters_ (int) – The number of unique cluster labels from the fit input
cluster
.criterion (str) – The criterion used for splitting:
logrank
-
apply
(X)¶ Calculate the index of the leaf for each sample.
- Parameters
X (array2d) – training input features
-
decision_path
(X)¶ Calculate the decision path through the tree for each sample.
- Parameters
X (array2d) – training input features
-
fit
(X, y, sample_weight=None, cluster=None)[source]¶ Fit the grf tree using training data.
- Parameters
X (array2d) – training input features
y (array1d) – training input targets, rows of (bool, float) representing (survival, time)
sample_weight (array1d) – optional weights for input samples
cluster (array1d) – optional cluster assignments for input samples
-
classmethod
from_forest
(forest: GRFForestSurvival, idx: int)[source]¶ Extract a tree from a forest.
- Parameters
forest (GRFForestSurvival) – A trained GRFSurvival instance
idx (int) – The tree index from the forest to extract.
-
get_depth
()¶ Calculate the maximum depth of the tree.
-
get_n_leaves
()¶ Calculate the number of leaves of the tree.
-
get_params
(deep=True)¶ Get parameters for this estimator.
- Parameters
deep (bool, default=True) – If True, will return the parameters for this estimator and contained subobjects that are estimators.
- Returns
params – Parameter names mapped to their values.
- Return type
dict
-
predict_cumulative_hazard_function
(X)[source]¶ Predict cumulative hazard function.
- Parameters
X (array2d) – prediction input features
-
predict_survival_function
(X)[source]¶ Predict survival function.
- Parameters
X (array2d) – prediction input features
-
set_params
(**params)¶ Set the parameters of this estimator.
The method works on simple estimators as well as on nested objects (such as
Pipeline
). The latter have parameters of the form<component>__<parameter>
so that it’s possible to update each component of a nested object.- Parameters
**params (dict) – Estimator parameters.
- Returns
self – Estimator instance.
- Return type
estimator instance