cal_dirichlet() is the multiclass generalization of beta calibration. It
fits a linear map on the log of the predicted probabilities followed by a
softmax, which is equivalent to a multinomial logistic regression with the
log-probabilities as features. An off-diagonal and intercept regularization
(ODIR) penalty shrinks the off-diagonal weights and the intercepts toward
zero, which reduces overfitting risk when the number of classes is large.
Arguments
- p
Numeric matrix of uncalibrated probabilities with one row per observation and one column per class. Rows must sum to one within absolute tolerance
1e-6.- y
A factor or a vector of integer class codes in
1:K, whereKis the number of columns ofp.- lambda
Non-negative ODIR regularization strength. When
NULLit is chosen by cross-validation.- eps
Clipping constant satisfying
0 < eps < 0.5. Probabilities must first be valid values in[0, 1]; values belowepsand above1 - epsare clipped before taking logarithms.
Value
A cal_dirichlet object that also inherits from cal_multiclass.
Use predict() with new probabilities to obtain calibrated probabilities.
Details
The calibrated probabilities are computed row-wise as
softmax(log(p) %*% t(W) + b), where W is a K by K weight matrix and
b is a length K intercept vector. Probabilities are clipped to
to have lower bound eps and upper bound 1 - eps before taking logarithms.
When lambda is NULL, it is selected from a small deterministic grid by
cross-validated log-likelihood.
Let \(p_{ik}\) be the uncalibrated probability assigned to class \(k\)
for observation \(i\). Each row of p must sum to one within absolute
tolerance 1e-6. Column \(k\) corresponds to integer class code \(k\);
if y is a factor, column \(k\) corresponds to levels(y)[k]. The
entries are clipped elementwise by
$$p_{ik}^* = \min\{\max(p_{ik}, \epsilon), 1 - \epsilon\},$$
and transformed to \(u_{ik} = \log(p_{ik}^*)\). The clipped feature matrix is not renormalized; normalization occurs only after the linear map, through the final softmax. Dirichlet calibration fits a multinomial logistic regression on these log-probability features,
$$\eta_{ik} = b_k + \sum_{\ell = 1}^K W_{k\ell} u_{i\ell},$$
followed by
$$q_{ik} = \frac{\exp(\eta_{ik})}{\sum_{m = 1}^K \exp(\eta_{im})}.$$
With fixed \(\lambda\), the fitted parameters minimize
$$-\frac{1}{n}\sum_i \log q_{i y_i} + \lambda\left(\sum_{k \ne \ell} W_{k\ell}^2 + \sum_k b_k^2\right).$$
This is the off-diagonal and intercept regularization penalty. Diagonal
weights are not penalized. For fixed lambda, optimization uses BFGS with
analytic gradients, initial weight matrix \(W = I_K\), initial
bias \(b = 0\), and maxit = 500. True-class probabilities
entering logarithms are clipped to [1e-15, 1 - 1e-15]. The returned
weight is a \(K \times K\) matrix whose row \(k\) produces
the logit for class \(k\); bias is a length-\(K\) vector of
intercepts. The object also stores lambda, value, and the optimizer
convergence code.
If lambda = NULL, the implementation evaluates the grid
c(0, 1e-4, 1e-3, 1e-2, 1e-1) with at most three deterministic stratified
folds. Class indices are assigned to folds in their existing order. The
selected value minimizes the unweighted average of the fold mean held-out
negative log-likelihoods; ties choose the first grid value. If fewer than two
observations are available in the smallest class during selection, the
fallback value is 1e-3. With lambda = 0, the multinomial softmax
parameterization is not unique: adding the same linear function of the
features to every class logit leaves all probabilities unchanged. The
calibrated probabilities are the identified output.
References
Kull, M., Perello-Nieto, M., Kängsepp, M., Silva Filho, T., Song, H., & Flach, P. (2019). Beyond temperature scaling: Obtaining well-calibrated multi-class probabilities with Dirichlet calibration. Advances in Neural Information Processing Systems 32.
Examples
set.seed(23)
prob <- matrix(stats::runif(200 * 3), ncol = 3)
prob <- prob / rowSums(prob)
labels <- max.col(prob)
fit <- cal_dirichlet(prob, labels)
head(predict(fit, prob))
#> 1 2 3
#> [1,] 2.749610e-39 1.678165e-17 1.000000e+00
#> [2,] 2.481962e-177 1.000000e+00 0.000000e+00
#> [3,] 7.895177e-175 1.000000e+00 0.000000e+00
#> [4,] 4.089151e-61 1.000000e+00 1.825486e-49
#> [5,] 1.000000e+00 0.000000e+00 1.662300e-56
#> [6,] 2.490598e-103 1.000000e+00 2.702539e-321
