vignettes/FastSurvivalSVM-wavelet-kernel.Rmd
FastSurvivalSVM-wavelet-kernel.RmdThe FastSurvivalSVM package provides an efficient R interface to the Python class FastKernelSurvivalSVM from the scikit-survival library.
A powerful feature of this package is the ability to define fully
custom kernel functions in R. Previously, this required complex
programming patterns (function factories and closures). Now, with the
helper function grid_kernel(), defining and using a custom
kernel is straightforward.
This vignette demonstrates:
grid_kernel().We begin by generating a synthetic survival dataset with nonlinear relationships to test the kernel’s capability.
library(FastSurvivalSVM)
# Function to generate synthetic survival data
data_generation <- function(n, prop_cen) {
x1 <- rnorm(n, 1, 1)
x2 <- rnorm(n, 2, 2)
x3 <- rexp(n)
xbeta <- x1 * log(abs(x2)) + 2 * sin(x3 - x2)^2
shape <- 2
scale <- 5
u <- runif(n, 0, 1)
time_t <- ((-log(1 - u))^(1 / shape)) * scale * exp(-xbeta)
ind_cens <- sample(1:n, n * prop_cen, replace = FALSE)
time_gerado <- time_t
time_gerado[ind_cens] <- runif(
length(ind_cens),
min(time_t),
time_t[ind_cens]
)
cens <- rep(1, n)
cens[ind_cens] <- 0
data.frame(
tempo = time_gerado,
cens = cens,
x1 = x1,
x2 = x2,
x3 = x3
)
}
set.seed(123)
df <- data_generation(n = 300L, prop_cen = 0.1)
head(df)
#> tempo cens x1 x2 x3
#> 1 1.4801101 1 0.4395244 0.5695156 0.33901239
#> 2 2.3390461 1 0.7698225 0.4946221 0.43618362
#> 3 788.5807128 1 2.5587083 0.1229226 0.73802924
#> 4 17.1329416 1 1.0705084 -0.1050266 0.57069724
#> 5 1.6826516 1 1.1292877 1.1256809 0.08437658
#> 6 0.0492852 1 2.7150650 2.6623583 1.35933656The wavelet kernel described in Ará et al. (2016) uses the following mother function:
And the kernel is defined as the product over covariates:
wavelet_kernel_fn <- function(x, z, A = 1) {
x <- as.numeric(x)
z <- as.numeric(z)
u <- (x - z) / A
prod(cos(1.75 * u) * exp(-0.5 * u^2))
}
my_wavelet <- grid_kernel(wavelet_kernel_fn, A = 0.5)
class(my_wavelet)
#> [1] "fastsvm_custom_kernel" "function"
fit_wavelet <- fastsvm(
data = df,
time_col = "tempo",
delta_col = "cens",
kernel = my_wavelet,
alpha = 1,
rank_ratio = 0,
fit_intercept = FALSE
)
summary(fit_wavelet)
#> Summary of FastKernelSurvivalSVM model (kernel survival SVM)
#> ======================================================================
#>
#> == Data ==
#> - n (observations) : 300
#> - p (covariates) : 3
#> - Covariates : x1, x2, x3
#>
#> == Hyperparameters ==
#> - kernel : custom callable function
#> - alpha : 1
#> - rank_ratio : 0 (0 = pure regression)
#> - fit_intercept : FALSE
#>
#> == Estimated parameters (coef_ = sample-wise weights alpha_i) ==
#> - Number of support-like vectors (|alpha_i| > 1e-8): 282
#> - Summary of alpha_i (coef_):
#> Min. 1st Qu. Median Mean 3rd Qu. Max.
#> -2.15467 -0.59048 0.00000 -0.04655 0.42036 4.91534
#>
#> - Number of optimization iterations: 14
#> ======================================================================
head(coef(fit_wavelet))
#> [1] -0.0006233144 0.3609977816 2.7753938725 1.2151206510 0.4199461135
#> [6] -1.2277373095
get_params_fastsvm(fit_wavelet)
#> $alpha
#> [1] 1
#>
#> $coef0
#> [1] 1
#>
#> $degree
#> [1] 3
#>
#> $fit_intercept
#> [1] FALSE
#>
#> $gamma
#> NULL
#>
#> $kernel
#> <function make_python_function.<locals>.python_function at 0x11d657d80>
#> signature: (*args, **kwargs)
#>
#> $kernel_params
#> NULL
#>
#> $max_iter
#> [1] 20
#>
#> $optimizer
#> [1] "rbtree"
#>
#> $random_state
#> NULL
#>
#> $rank_ratio
#> [1] 0
#>
#> $timeit
#> [1] FALSE
#>
#> $tol
#> NULL
#>
#> $verbose
#> [1] FALSE