-
Notifications
You must be signed in to change notification settings - Fork 1
/
mahalanobis.cpp
executable file
·53 lines (45 loc) · 1.57 KB
/
mahalanobis.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
#include <RcppArmadillo.h>
const double log2pi = std::log(2.0 * M_PI);
// [[Rcpp::depends("RcppArmadillo")]]
// [[Rcpp::export]]
arma::vec Mahalanobis(arma::mat x, arma::rowvec center, arma::mat cov) {
int n = x.n_rows;
arma::mat x_cen;
x_cen.copy_size(x);
for (int i=0; i < n; i++) {
x_cen.row(i) = x.row(i) - center;
}
return sum((x_cen * cov.i()) % x_cen, 1);
}
// [[Rcpp::export]]
arma::vec dmvnorm_arma(arma::mat x, arma::rowvec mean, arma::mat sigma, bool log = false) {
arma::vec distval = Mahalanobis(x, mean, sigma);
double logdet = sum(arma::log(arma::eig_sym(sigma)));
arma::vec logretval = -( (x.n_cols * log2pi + logdet + distval)/2 ) ;
if (log) {
return(logretval);
} else {
return(exp(logretval));
}
}
// [[Rcpp::depends("RcppArmadillo")]]
// [[Rcpp::export]]
arma::vec dmvnrm_arma(arma::mat x,
arma::rowvec mean,
arma::mat sigma,
bool logd = false) {
int n = x.n_rows;
int xdim = x.n_cols;
arma::vec out(n);
arma::mat rooti = arma::trans(arma::inv(trimatu(arma::chol(sigma))));
double rootisum = arma::sum(log(rooti.diag()));
double constants = -(static_cast<double>(xdim)/2.0) * log2pi;
for (int i=0; i < n; i++) {
arma::vec z = rooti * arma::trans( x.row(i) - mean) ;
out(i) = constants - 0.5 * arma::sum(z%z) + rootisum;
}
if (logd == false) {
out = exp(out);
}
return(out);
}