diff --git a/R/fit_mvst.R b/R/fit_mvst.R index 0c8597d..437f099 100644 --- a/R/fit_mvst.R +++ b/R/fit_mvst.R @@ -159,7 +159,7 @@ fit_mvst <- function(X, # M-step ---------------------------------------- # nu if (optimize_nu) { - Q_nu <- function(nu) (nu/2)*sum(expect$E_logtau - log(alpha) - expect$E_tau/alpha) + T*((nu/2)*log(nu/2) - log(base::gamma(nu/2))) + Q_nu <- function(nu) (nu/2)*sum(expect$E_logtau - log(alpha) - expect$E_tau/alpha) + T*((nu/2)*log(nu/2) - lgamma(nu/2)) nu <- optimize(Q_nu, interval = c(getOption("nu_min"), getOption("nu_max")), maximum = TRUE)$maximum } @@ -351,9 +351,8 @@ besselK_ratio <- function(x, nu) { } - log_besselK <- function(x, nu) { - if (nu <= 10 && nu >= -1) { + if (nu <= 10 && nu >= -10) { return(log(besselK(x, nu))) } else if (nu >= 0){ nu_i <- nu - floor(nu) + 9 @@ -384,4 +383,3 @@ log_besselK <- function(x, nu) { return(log_values) } } - diff --git a/tests/testthat/test-fit_mvst.R b/tests/testthat/test-fit_mvst.R index 465c2d8..3b3c840 100644 --- a/tests/testthat/test-fit_mvst.R +++ b/tests/testthat/test-fit_mvst.R @@ -104,10 +104,38 @@ test_that("Bessel functions work", { res2 <- sum(dmvst(X = X, nu = 5.7, gamma = gamma, mu = mu, scatter = scatter)) expect_equal(res1, res2) + res1 <- sum(dmvst_orig(X = X, nu = 15.7, gamma = gamma, mu = mu, scatter = scatter)) + res2 <- sum(dmvst(X = X, nu = 15.7, gamma = gamma, mu = mu, scatter = scatter)) + expect_equal(res1, res2) + res1 <- sum(dmvst_orig(X = X, nu = 60.7, gamma = gamma, mu = mu, scatter = scatter)) res2 <- sum(dmvst(X = X, nu = 60.7, gamma = gamma, mu = mu, scatter = scatter)) expect_equal(res1, res2) + res1 <- besselK_ratio(x = 10, nu = 70) + res2 <- besselK(x=10, nu = 71)/besselK(x=10, nu = 70) + expect_equal(res1, res2) + res1 <- besselK_ratio(x = 0.01, nu = 5) + res2 <- besselK(x=0.01, nu = 6)/besselK(x=0.01, nu = 5) + expect_equal(res1, res2) + res1 <- besselK_ratio(x = 10, nu = 5) + res2 <- besselK(x=10, nu = 6)/besselK(x=10, nu = 5) + expect_equal(res1, res2) + res1 <- log_besselK(x = 0.01, nu = 5) + res2 <- log(besselK(0.01, nu = 5)) + expect_equal(res1, res2) + res1 <- log_besselK(x = 0.01, nu = -5) + res2 <- log(besselK(0.01, nu = -5)) + expect_equal(res1, res2) + res1 <- log_besselK(x = 10, nu = 5) + res2 <- log(besselK(10, nu = 5)) + expect_equal(res1, res2) + res1 <- log_besselK(x = 10, nu = -5) + res2 <- log(besselK(10, nu = -5)) + expect_equal(res1, res2) + res1 <- log_besselK(x = 10, nu = 70) + res2 <- log(besselK(10, nu = 70)) + expect_equal(res1, res2) + res1 <- log_besselK(x = 10, nu = -70) + res2 <- log(besselK(10, nu = -70)) + expect_equal(res1, res2) }) - - -