Skip to content

Commit

Permalink
Making besselK more robust...
Browse files Browse the repository at this point in the history
  • Loading branch information
dppalomar committed Apr 9, 2022
1 parent 74c789a commit e588dd2
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 34 deletions.
84 changes: 50 additions & 34 deletions R/fit_mvst.R
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,7 @@ dmvst <- function(X, nu = 3, gamma = 1, mu = 0, scatter = 1) {
kappa <- sqrt(nu + rowSums(Xc * (Xc %*% scatter_inv)))
first_term <-Xc %*% solve(scatter) %*% gamma - (N/2) * log(2*pi) - 0.5 * sum(log(eigen((scatter))$values) )
second_term <- log(2) + (nu/2) * log(nu/2) - lgamma(nu/2)
third_term <- -((nu+N)/2) * log(kappa/delta) + log_besselK(delta*kappa, -((nu+N)/2))
third_term <- -((nu+N)/2) * log(kappa/delta) + log_besselK(delta*kappa, -((nu + N)/2))
return(first_term + second_term + third_term)
}
} else # Gaussian case
Expand Down Expand Up @@ -300,29 +300,32 @@ Estep_mvst <- function(X, nu, gamma_unscaled, mu, scatter_unscaled, alpha) {
# kappa <- sqrt(nu + stats::mahalanobis(x = X, center = mu, cov = scatter_inv, inverted = TRUE))

if (delta == 0) {
E_tau <- ((kappa**2)/2)*lambda
E_invtau <- (1/((kappa**2)/2)) * (1/lambda)
E_logtau <- digamma(lambda) - log((kappa**2)/4)
E_tau <- (nu + N)/(kappa**2)
E_invtau <- (kappa**2)/(nu + N -2)
E_logtau <- digamma(lambda) - log((kappa**2)/2)
} else {
tmp <- besselK_ratio(delta * kappa, lmd = lambda)
tmp <- besselK_ratio(delta * kappa, lambda)
E_tau <- (delta / kappa) * tmp
#E_invtau <- (kappa / delta) * 1/besselK_ratio(delta * kappa, lmd = lambda - 1)
E_invtau <- (kappa / delta) * (tmp - (2*lambda)/delta/kappa) # this saves computing bessel functions again

if (lambda < 150) {
dev_cal <- function(val) numDeriv::grad(func = function(lmd) log(besselK(x = val, nu = lmd, expon.scaled = FALSE)), x = lambda, method = "simple", method.args = list(eps = 1e-10))
dev_cal <- function(val) numDeriv::grad(func = function(lmd) log(besselK(x = val, nu = lmd, expon.scaled = TRUE)),
x = lambda, method = "simple", method.args = list(eps = 1e-10))
E_logtau <- log(delta / kappa) + sapply(delta * kappa, dev_cal)
} else
E_logtau <- log(delta / kappa) + log(besselK_ratio(delta * kappa, lmd = lambda))
E_logtau <- log(delta / kappa) + log(besselK_ratio(delta * kappa, lambda))
}

# return
list_to_return <- list("E_tau" = E_tau * alpha,
"E_invtau" = E_invtau / alpha,
"E_logtau" = E_logtau + log(alpha))

if (any(is.infinite(list_to_return$E_invtau)) || is.infinite(sum(list_to_return$E_invtau)) ||
any(is.nan(list_to_return$E_invtau))) {
if (any(is.infinite(list_to_return$E_invtau)) ||
any(is.nan(list_to_return$E_invtau)) ||
any(is.nan(list_to_return$E_logtau))
) {
message("Problem with the computation of E[tau], probably because of very small numbers in the evaluation of the bessel function.")
browser()
}
Expand All @@ -332,40 +335,53 @@ Estep_mvst <- function(X, nu, gamma_unscaled, mu, scatter_unscaled, alpha) {




# https://www.researchgate.net/journal/Journal-of-Inequalities-and-Applications-1029-242X[…]mating-the-modified-Bessel-function-of-the-second-kind.pdf
besselK_ratio <- function(x, lmd) {
if (lmd < 100)
return(besselK(x = x, nu = lmd + 1, expon.scaled = TRUE) / besselK(x = x, nu = lmd, expon.scaled = TRUE))
besselK_ratio <- function(x, nu) {
if (nu < 51)
return(besselK(x = x, nu = nu + 1, expon.scaled = TRUE) / besselK(x = x, nu = nu, expon.scaled = TRUE))
else {
lmd_i <- lmd - floor(lmd) + 10
R_i <- besselK(x = x, nu = lmd_i + 1, expon.scaled = TRUE) / besselK(x = x, nu = lmd_i, expon.scaled = TRUE)
while (lmd_i != lmd) {
R_i <- 1/R_i + (2 * lmd_i + 2)/x
lmd_i <- lmd_i + 1
nu_i <- nu - floor(nu) + 50
R_i <- besselK(x = x, nu = nu_i + 1, expon.scaled = TRUE) / besselK(x = x, nu = nu_i, expon.scaled = TRUE)
while (nu_i != nu) {
R_i <- 1/R_i + (2 * nu_i + 2)/x
nu_i <- nu_i + 1
}
return(R_i)
}
}



log_besselK <- function(x, lmd) {
if (lmd >= 0)
stop("lmd should be negative in this function.")
lmd_i <- lmd - floor(lmd) - 1
K_lmd_i <- besselK(x, nu = lmd_i) # K_{lmd_i}(x)
log_values <- log(K_lmd_i)
R_lmd_i <- K_lmd_i/besselK(x, nu = lmd_i - 1) # R_{lmd_i -1}(x)
log_values <- log_values - log(R_lmd_i)
lmd_i <- lmd_i - 1
while(lmd_i != lmd) {
R_lmd_i_inv <- R_lmd_i - 2*lmd_i/x
R_lmd_i <- 1/R_lmd_i_inv # R_{i -1}(x)
log_values <- log_values - log(R_lmd_i)
lmd_i <- lmd_i - 1
log_besselK <- function(x, nu) {
if (nu <= 10 && nu >= -1) {
return(log(besselK(x, nu)))
} else if (nu >= 0){
nu_i <- nu - floor(nu) + 9
K_nu_i <- besselK(x, nu = nu_i) # K_{nu_i}(x)
log_values <- log(K_nu_i)
R_nu_i <- besselK(x, nu = nu_i + 1)/K_nu_i # R_{nu_i}(x)
log_values <- log_values + log(R_nu_i)
nu_i <- nu_i + 1
while (nu_i != nu) {
R_nu_i <- 1/R_nu_i + 2 * nu_i /x
log_values <- log_values + log(R_nu_i)
nu_i <- nu_i + 1
}
return(log_values)
} else {
nu_i <- nu - floor(nu) - 9
K_nu_i <- besselK(x, nu = nu_i) # K_{nu_i}(x)
log_values <- log(K_nu_i)
R_nu_i <- K_nu_i/besselK(x, nu = nu_i - 1) # R_{nu_i -1}(x)
log_values <- log_values - log(R_nu_i)
nu_i <- nu_i - 1
while (nu_i != nu) {
R_nu_i_inv <- R_nu_i - 2*nu_i/x
R_nu_i <- 1/R_nu_i_inv # R_{i -1}(x)
log_values <- log_values - log(R_nu_i)
nu_i <- nu_i - 1
}
return(log_values)
}
return(log_values)
}


Binary file modified tests/testthat/fitted_mvst_check.RData
Binary file not shown.

0 comments on commit e588dd2

Please sign in to comment.