R/random_machines.R
random_machines.RdFits an ensemble of models using bootstrap aggregation (bagging).
random_machines(
data,
newdata,
time_col = "t",
delta_col = "delta",
kernels,
B = 100L,
mtry = NULL,
crop = NULL,
beta_kernel = 1,
beta_bag = 1,
cores = 1L,
seed = NULL,
prop_holdout = 0.2,
.progress = TRUE
)A data.frame containing training data.
A data.frame containing test data for prediction.
Name of the column with survival times.
Name of the column with the event indicator (1 = event, 0 = censored).
A named list of kernel specifications.
Integer. Number of bootstrap samples.
Integer or Numeric. Number of variables to randomly sample at each split.
Numeric or NULL. Threshold for kernel selection probabilities.
Numeric. Temperature for kernel selection probabilities.
Numeric. Temperature for ensemble weighting.
Integer. Number of parallel workers (via mirai).
Optional integer passed to mirai::daemons.
Numeric in (0, 1). Proportion for internal holdout.
Logical. Show progress bar?
An object of class "random_machines".
if (FALSE) { # \dontrun{
if (reticulate::py_module_available("sksurv") && requireNamespace("mirai")) {
library(FastSurvivalSVM)
# 1. Data Generation and Split
set.seed(42)
df <- data_generation(n = 250, prop_cen = 0.25)
train_idx <- sample(nrow(df), 200)
train_df <- df[train_idx, ]
test_df <- df[-train_idx, ]
# 2. Define Custom Kernel Functions (Math Only)
# Wavelet Kernel
my_wavelet <- function(x, z, A) {
u <- (as.numeric(x) - as.numeric(z)) / A
prod(cos(1.75 * u) * exp(-0.5 * u^2))
}
# Polynomial Kernel
my_poly <- function(x, z, degree, coef0) {
(sum(as.numeric(x) * as.numeric(z)) + coef0)^degree
}
# 3. Tuning Workflow
# Before training the ensemble, we optimize the hyperparameters for each
# kernel family using 'tune_random_machines'.
# A. Define Kernel Mix (Fixed Structure)
# We set rank_ratio = 0 because we want to solve a Regression problem.
kernel_mix <- list(
linear_std = list(kernel = "linear", rank_ratio = 0),
rbf_std = list(kernel = "rbf", rank_ratio = 0),
wavelet_ok = list(rank_ratio = 0),
poly_ok = list(rank_ratio = 0)
)
# B. Define Parameter Grids (Search Space)
# We define 4 values for each hyperparameter to be tuned.
param_grids <- list(
# Linear: Tune regularization (alpha)
linear_std = list(
alpha = c(0.01, 0.1, 1.0, 10.0)
),
# RBF: Tune alpha and kernel width (gamma)
rbf_std = list(
alpha = c(0.01, 0.1, 1.0, 10.0),
gamma = c(0.001, 0.01, 0.1, 1.0)
),
# Custom Wavelet: Tune alpha and the kernel parameter 'A'
# 'grid_kernel' generates the variants for the custom function.
wavelet_ok = list(
kernel = grid_kernel(my_wavelet, A = c(0.5, 1.0, 1.5, 2.0)),
alpha = c(0.01, 0.1, 1.0, 10.0)
),
# Custom Poly: Tune alpha and the degree
# (coef0 kept fixed at 1 for this example)
poly_ok = list(
kernel = grid_kernel(my_poly, degree = c(2, 3, 4, 5), coef0 = 1),
alpha = c(0.01, 0.1, 1.0, 10.0)
)
)
# C. Execute Hybrid Tuning
# This uses Python threads for Native kernels and R processes for Custom ones.
cat("Starting hyperparameter tuning...\n")
tune_res <- tune_random_machines(
data = train_df,
time_col = "tempo",
delta_col = "cens",
kernel_mix = kernel_mix,
param_grids = param_grids,
cv = 3,
cores = parallel::detectCores(),
verbose = 1
)
# D. Bridge: Extract Best Hyperparameters
# This creates the final configuration list ready for the ensemble.
final_kernels <- as_kernels(tune_res, kernel_mix)
print("Best configurations found:")
print(final_kernels)
# 4. Train Random Machines (Bagging)
# Now we use the optimized 'final_kernels' to train the ensemble.
cat("Training Random Machines ensemble...\n")
rm_model <- random_machines(
data = train_df,
newdata = test_df,
time_col = "tempo",
delta_col = "cens",
kernels = final_kernels, # Use tuned kernels
B = 50, # Number of bootstrap samples
mtry = NULL, # Use all features (Random Forest style)
crop = 0.10, # Prune kernels with weight < 10%
prop_holdout = 0.20, # 20% internal holdout for weighting
cores = parallel::detectCores(),
seed = 42,
.progress = TRUE
)
# 5. Evaluate and Print
print(rm_model)
cidx <- score(rm_model, test_df)
cat(sprintf("Final Test C-Index: %.4f\n", cidx))
}
} # }