Forest Survival¶
-
class
skgrf.ensemble.
GRFForestSurvival
(n_estimators=100, equalize_cluster_weights=False, sample_fraction=0.5, mtry=None, min_node_size=5, honesty=True, honesty_fraction=0.5, honesty_prune_leaves=True, alpha=0.05, n_jobs=- 1, seed=42, enable_tree_details=False)[source]¶ GRF Survival implementation for sci-kit learn.
Provides a sklearn survival interface to the GRF C++ library using Cython.
Warning
Because the training dataset is required for prediction, the training dataset is recorded onto the estimator instance. This means that serializing this estimator will result in a file at least as large as the serialized training dataset.
- Parameters
n_estimators (int) – The number of survival trees to train
equalize_cluster_weights (bool) – Weight the samples such that clusters have equally weight. If
False
, larger clusters will have more weight. IfTrue
, the number of samples drawn from each cluster is equal to the size of the smallest cluster. IfTrue
, sample weights should not be passed on fitting.sample_fraction (float) – Fraction of samples used in each tree.
mtry (int) – The number of features to split on each node. The default is
sqrt(p) + 20
wherep
is the number of features.min_node_size (int) – The minimum number of observations in each tree leaf.
honesty (bool) – Use honest splitting (subsample splitting).
honesty_fraction (float) – The fraction of data used for subsample splitting.
honesty_prune_leaves (bool) – Prune estimation sample tree such that no leaves are empty. If
False
, trees with empty leaves are skipped.alpha (float) – The maximum imbalance of a split.
n_jobs (int) – The number of threads. Default is number of CPU cores.
seed (int) – Random seed value.
enable_tree_details (bool) – When
True
, perform additional calculations for detailing the underlying decision trees. Must be enabled forestimators_
andget_estimator
to work. Very slow.
- Variables
estimators_ (list) – A list of tree objects from the forest.
n_features_in_ (int) – The number of features (columns) from the fit input
X
.grf_forest_ (dict) – The returned result object from calling C++ grf.
mtry_ (int) – The
mtry
value determined by validation.outcome_index_ (int) – The index of the grf train matrix holding the outcomes.
censor_index_ (int) – The index of the grf train matrix holding the censoring.
failure_times_ (array1d) – An array of unique failure times from the training set.
num_failures_ (int) – The length of the
failure_times
array.clusters_ (list) – The cluster labels determined from the fit input
cluster
.n_clusters_ (int) – The number of unique cluster labels from the fit input
cluster
.criterion (str) – The criterion used for splitting:
logrank
-
fit
(X, y, sample_weight=None, cluster=None)[source]¶ Fit the grf forest using training data.
- Parameters
X (array2d) – training input features
y (array1d) – training input targets, rows of (bool, float) representing (survival, time)
sample_weight (array1d) – optional weights for input samples
cluster (array1d) – optional cluster assignments for input samples
-
get_estimator
(idx)[source]¶ Extract a single estimator tree from the forest.
- Parameters
idx (int) – The index of the tree to extract.
-
get_feature_importances
(decay_exponent=2, max_depth=4)¶ Get the feature importances.
- Parameters
decay_exponent (int) – Exponential decay of importance by split depth
max_depth (int) – The maximum depth of splits to consider
-
get_kernel_weights
(X, oob_prediction=False)¶ Get training sample weights for test data.
Given a trained forest and test data, compute the kernel weights for each test point.
Creates a sparse matrix in which the value at (i, j) gives the weight of training sample j for test sample i. Use
oob_prediction=True
if using training set.- Parameters
X (array2d) – input features
oob_prediction (bool) – whether to calculate weights out of bag
-
get_params
(deep=True)¶ Get parameters for this estimator.
- Parameters
deep (bool, default=True) – If True, will return the parameters for this estimator and contained subobjects that are estimators.
- Returns
params – Parameter names mapped to their values.
- Return type
dict
-
get_split_frequencies
(max_depth=4)¶ Get the split frequencies of feature indexes at various depths.
- Parameters
max_depth (int) – The maximum depth of splits to consider
-
predict_cumulative_hazard_function
(X)[source]¶ Predict cumulative hazard function.
- Parameters
X (array2d) – prediction input features
-
predict_survival_function
(X)[source]¶ Predict survival function.
- Parameters
X (array2d) – prediction input features
-
set_params
(**params)¶ Set the parameters of this estimator.
The method works on simple estimators as well as on nested objects (such as
Pipeline
). The latter have parameters of the form<component>__<parameter>
so that it’s possible to update each component of a nested object.- Parameters
**params (dict) – Estimator parameters.
- Returns
self – Estimator instance.
- Return type
estimator instance