Title: | Bagged k-Nearest Neighbors Survival Prediction |
---|---|
Description: | Implements a bootstrap aggregated (bagged) version of the k-nearest neighbors survival probability prediction method (Lowsky et al. 2013). In addition to the bootstrapping of training samples, the features can be subsampled in each baselearner to break the correlation between them. The Rcpp package is used to speed up the computation. |
Authors: | Marvin N. Wright |
Maintainer: | Marvin N. Wright <[email protected]> |
License: | GPL-3 |
Version: | 0.1.5 |
Built: | 2025-01-23 04:10:43 UTC |
Source: | https://github.com/mnwright/bnnsurvival |
Bootstrap aggregated (bagged) version of the k-nearest neighbors survival probability prediction method (Lowsky et al. 2013). In addition to the bootstrapping of training samples, the features can be subsampled in each base learner.
bnnSurvival(formula, data, k = max(1, nrow(data)/10), num_base_learners = 50, num_features_per_base_learner = NULL, metric = "mahalanobis", weighting_function = function(x) { x * 0 + 1 }, replace = TRUE, sample_fraction = NULL)
bnnSurvival(formula, data, k = max(1, nrow(data)/10), num_base_learners = 50, num_features_per_base_learner = NULL, metric = "mahalanobis", weighting_function = function(x) { x * 0 + 1 }, replace = TRUE, sample_fraction = NULL)
formula |
Object of class formula or character describing the model to fit. |
data |
Training data of class data.frame. |
k |
Number nearest neighbors to use. If a vector is given, the optimal k of these values is found using 5-fold cross validation. |
num_base_learners |
Number of base learners to use for bootstrapping. |
num_features_per_base_learner |
Number of features randomly selected in each base learner. Default: all. |
metric |
Metric d(x,y) used to measure the distance between observations. Currently only "mahalanobis". |
weighting_function |
Weighting function w(d(,x,y)) used to weight the observations based on their distance. |
replace |
Sample with or without replacement. |
sample_fraction |
Fraction of observations to sample in [0,1]. Default is 1 for |
For a description of the k-nearest neighbors survival probability prediction method see (Lowsky et al. 2013). Please note, that parallel processing, as currently implemented, does not work on Microsoft Windows platforms.
The weighting function needs to be defined for all distances >= 0. The default function is constant 1, a possible alternative is w(x) = 1/(1+x).
To use the non-bagged version as in Lowsky et al. 2013, use num_base_learners=1
, replace=FALSE
and sample_fraction=1
.
bnnSurvivalEnsemble object. Use predict() with a new data set to predict survival probabilites.
Marvin N. Wright
Lowsky, D.J. et al. (2013). A K-nearest neighbors survival probability prediction method. Stat Med, 32(12), 2062-2069.
require(bnnSurvival) ## Use only 1 core options(mc.cores = 1) ## Load a dataset and split in training and test data require(survival) n <- nrow(veteran) idx <- sample(n, 2/3*n) train_data <- veteran[idx, ] test_data <- veteran[-idx, ] ## Create model with training data and predict for test data model <- bnnSurvival(Surv(time, status) ~ trt + karno + diagtime + age + prior, train_data, k = 20, num_base_learners = 10, num_features_per_base_learner = 3) result <- predict(model, test_data) ## Plot survival curve for the first observations plot(timepoints(result), predictions(result)[1, ])
require(bnnSurvival) ## Use only 1 core options(mc.cores = 1) ## Load a dataset and split in training and test data require(survival) n <- nrow(veteran) idx <- sample(n, 2/3*n) train_data <- veteran[idx, ] test_data <- veteran[-idx, ] ## Create model with training data and predict for test data model <- bnnSurvival(Surv(time, status) ~ trt + karno + diagtime + age + prior, train_data, k = 20, num_base_learners = 10, num_features_per_base_learner = 3) result <- predict(model, test_data) ## Plot survival curve for the first observations plot(timepoints(result), predictions(result)[1, ])
Get optimal number of neighbors for bnnSurvival by cross validation
get_best_k(formula, data, k, ...)
get_best_k(formula, data, k, ...)
formula |
Formula |
data |
Data |
k |
Number of neighbors |
... |
Further arguments passed to bnnSurvival |
Optimal k
Compute prediction for all samples.
## S4 method for signature 'bnnSurvivalBaseLearner' predict(object, train_data, test_data, timepoints, metric, weighting_function, k)
## S4 method for signature 'bnnSurvivalBaseLearner' predict(object, train_data, test_data, timepoints, metric, weighting_function, k)
object |
bnnSurvivalBaseLearner object |
train_data |
Training data (with response) |
test_data |
Test data (without response) |
timepoints |
Timepoint to predict at |
metric |
Metric used |
weighting_function |
Weighting function used |
k |
Number of nearest neighbors |
Predict survival probabilities with bagged k-nearest neighbors survival prediction.
## S4 method for signature 'bnnSurvivalEnsemble' predict(object, test_data)
## S4 method for signature 'bnnSurvivalEnsemble' predict(object, test_data)
object |
Object of class bnnSurvivalEnsemble, created with bnnSurvival(). |
test_data |
Data set containing data to predict survival. |
Get Predictions
predictions(object, ...)
predictions(object, ...)
object |
Object to extract predictions from |
... |
further arguments passed to or from other methods. |
Get Predictions
## S4 method for signature 'bnnSurvivalResult' predictions(object)
## S4 method for signature 'bnnSurvivalResult' predictions(object)
object |
bnnSurvivalResult object to extract predictions from |
pec
package.Function to extract survival probability predictions from bnnSurvivalEnsemble. Use with pec
package.
## S3 method for class 'bnnSurvivalEnsemble' predictSurvProb(object, newdata, times, ...)
## S3 method for class 'bnnSurvivalEnsemble' predictSurvProb(object, newdata, times, ...)
object |
bnnSurvivalEnsemble object. |
newdata |
Data used for prediction. |
times |
Not used. |
... |
Not used. |
survival probability predictions
Generic print method for bnnSurvivalEnsemble
## S4 method for signature 'bnnSurvivalEnsemble' print(x)
## S4 method for signature 'bnnSurvivalEnsemble' print(x)
x |
bnnSurvivalEnsemble object to print |
Generic print method for bnnSurvivalResult
## S4 method for signature 'bnnSurvivalResult' print(x)
## S4 method for signature 'bnnSurvivalResult' print(x)
x |
bnnSurvivalResult object to print |
Generic show method for bnnSurvivalEnsemble
## S4 method for signature 'bnnSurvivalEnsemble' show(object)
## S4 method for signature 'bnnSurvivalEnsemble' show(object)
object |
bnnSurvivalEnsemble object to show |
Generic show method for bnnSurvivalResult
## S4 method for signature 'bnnSurvivalResult' show(object)
## S4 method for signature 'bnnSurvivalResult' show(object)
object |
bnnSurvivalResult object to show |
Get Timepoints
timepoints(object, ...)
timepoints(object, ...)
object |
Object to extract timepoints from |
... |
further arguments passed to or from other methods. |
Get timepoints
## S4 method for signature 'bnnSurvivalResult' timepoints(object)
## S4 method for signature 'bnnSurvivalResult' timepoints(object)
object |
bnnSurvivalResult object to extract timepoints from |