Skip to content

Commit

Permalink
Merge pull request #1 from neil-gallagher/stochFreq
Browse files Browse the repository at this point in the history
Stochastic optimization & more
  • Loading branch information
neil-gallagher authored Oct 6, 2020
2 parents 35c2145 + 1674d04 commit 79eac7c
Show file tree
Hide file tree
Showing 17 changed files with 1,152 additions and 485 deletions.
162 changes: 117 additions & 45 deletions +GP/CSFA.m
Original file line number Diff line number Diff line change
Expand Up @@ -6,21 +6,25 @@
maxW % maximum number of time windows in a partition
C % number of channels
Q % number of components in SM kernel
L % rank of FA
R % rank of factor coregionalization matrices
L % number of factors
scores % factor scores
freqBounds % [low, high] frequency boundary for spectral gaussian means

LMCkernels % cell array of L LMC kernels

eta % additive Gaussian noise
regB % L1 regularization strength for coregionalization weights

updateKernels % biniary indicator for kernel updates
updateScores % biniary indicator for score updates
updateNoise % biniary indicator for noise floor updates
end

methods
function self = CSFA(modelOpts, s, xFft)
GMM_MAX_ITER = 1000;
GMM_REG = 0.01;
DIST_PRECISION = 1000;
DIST_SAMPLES = 1e5;

Expand All @@ -32,7 +36,9 @@
self.L = modelOpts.L;
self.C = modelOpts.C;
self.Q = modelOpts.Q;
self.R = modelOpts.R;
self.eta = modelOpts.eta;
self.regB = modelOpts.regB;

% divide windows into partitions for computations
self.setPartitions(modelOpts.W, modelOpts.maxW);
Expand Down Expand Up @@ -74,8 +80,8 @@

% normalize components and scores by max power in a channel
% for that component
compWeights = reshape(compInit', [], self.C, self.L);
compNorm = max(sum(compWeights, 1), [], 2);
compWeights = reshape(compInit', sum(modelFreqs), self.C, self.L);
compNorm = max(max(compWeights, [], 1), [], 2);
compWeights = bsxfun(@rdivide, compWeights, compNorm);
compNorm = reshape(compNorm, 1, self.L);
scoreInit = bsxfun(@times, scoreInit, compNorm);
Expand All @@ -90,7 +96,7 @@
% model component frequency usage via GMM
gmOpts = statset('MaxIter',GMM_MAX_ITER);
gmmodel = fitgmdist(simDist', self.Q,'Options',gmOpts,...
'RegularizationValue', 0.01);
'RegularizationValue', GMM_REG);
means = num2cell( reshape(gmmodel.mu, 1, self.Q));
vars = num2cell( reshape(gmmodel.Sigma, 1, self.Q));
bWeight = gmmodel.ComponentProportion;
Expand Down Expand Up @@ -120,7 +126,8 @@
end

self.updateKernels = true;
self.updateNoise = false;
self.updateScores = true;
self.updateNoise = modelOpts.learnNoise;
end
end

Expand Down Expand Up @@ -149,7 +156,7 @@ function setPartitions(self, W, maxW)
log likelihood is to be calculated. If unset, the log likelihood is
calculated for all windows in the dataset.
%}
function LL = evaluate(self,s,data,windowIdx)
function [LL, rLoss] = evaluate(self,s,data,windowIdx)
nWindows = sum(self.W);
if nargin < 4
windowIdx = true(1,nWindows);
Expand Down Expand Up @@ -205,6 +212,24 @@ function setPartitions(self, W, maxW)

% remove machine precision complex values
LL = real(LL);

% calculate regularization loss (only if both kernels and scores are
% getting updated)
if self.updateKernels && self.updateScores
params = self.getParams;
Bvals = exp(params(self.getParamIdx.coregWeights));
Bsum = sum(abs(Bvals(:)));
scoreSum = sum(abs(self.scores(:)));
weightRatio = (self.Q*self.R*self.C)/sum(self.W);
rLoss = (Bsum + scoreSum*weightRatio) * self.regB;
elseif self.regB && (self.updateKernels || self.updateScores)
warning(['Regularization is applied, but either scores or kernels are ',...
'not being updated. This can lead to instability and convergence ',...
'issues!'])
rLoss = NaN;
else
rLoss = 0;
end
end

%{
Expand All @@ -214,7 +239,7 @@ function setPartitions(self, W, maxW)
vector of all factor scores. K is a vector of parameters defining
the kernels of each factor and can be broken down into the
parameter vectors for each kernel [K1 ... KL]. Each Kl can further
be broken down into the vectors of parameters corresponding to the
be broken down into the vectors of pazrameters corresponding to the
coregionalization matricies and the spectral gaussian components,
(ie Kl = [Bl1 ... BlQ,kl1 ... klQ]). Blq corresponds to the qth
coregionalization matrix of the lth factor. Each coregionalization
Expand All @@ -226,11 +251,14 @@ matrix is rank R and complex (defined by R complex vectors b1...bR)
variance of the spectral gaussian
%}
function res = getParams(self)
if self.updateScores
res = self.scores(:);
else
res = [];
end
if self.updateKernels
params = cellfun(@(x)x.getParams,self.LMCkernels,'un',0);
res = vertcat(params{:},self.scores(:));
else
res = self.scores(:);
res = vertcat(params{:},res);
end
if self.updateNoise
res = [self.eta; res];
Expand All @@ -254,12 +282,18 @@ function setParams(self,params)
indB = indE + 1;
end
end
self.scores = reshape(params(indB:end),[self.L,sum(self.W)]);
if self.updateScores
self.scores = reshape(params(indB:end),[self.L,sum(self.W)]);
end
end

function [lb,ub] = getBounds(self)
lb = zeros(self.L*sum(self.W),1);
ub = 100*ones(self.L*sum(self.W),1);
if self.updateScores
lb = zeros(self.L*sum(self.W),1);
ub = 100*ones(self.L*sum(self.W),1);
else
lb = []; ub = [];
end
if self.updateKernels
[lball,uball] = cellfun(@(x)x.getBounds,self.LMCkernels,'un',0);
lb = vertcat(lball{:},lb);
Expand All @@ -281,18 +315,20 @@ function setParams(self,params)
% coregShifts (if updateKernels)
% noise (if updateNoise)
function pIdx = getParamIdx(self)
% initialize all parameter masks as boolean vectors
nParams = numel(getParams(self));
idxVec = false(1,nParams);
pIdx.noise = idxVec; pIdx.scores = idxVec;
pIdx.sgMeans = idxVec; pIdx.sgVars = idxVec;
pIdx.coregWeights = idxVec; pIdx.coregShifts = idxVec;

if self.updateNoise
pIdx.noise = idxVec; pIdx.noise(1) = 1;
pIdx.noise(1) = 1;
un = true;
else
un = false;
end
if self.updateKernels
pIdx.sgMeans = idxVec; pIdx.sgVars = idxVec;
pIdx.coregWeights = idxVec; pIdx.coregShifts = idxVec;

nFactorParams = self.LMCkernels{1}.nParams;
nKernelParams = self.LMCkernels{1}.coregs.B{1}.nParams;
R = self.LMCkernels{1}.coregs.B{1}.R;
Expand All @@ -314,10 +350,10 @@ function setParams(self,params)
pIdx.sgVars(vStart:2:vEnd) = true;
end
end
pIdx.scores = idxVec;
nScores = numel(self.scores(:));
pIdx.scores(end-nScores+1:end) = true;

if self.updateScores
nScores = numel(self.scores(:));
pIdx.scores(end-nScores+1:end) = true;
end
end

% gradient: returns gradients of the log-likelihood of the current
Expand All @@ -337,11 +373,17 @@ function setParams(self,params)
% to all parameters
% maxCondNum: largest condition number of UKU for all windows
% evaluated
function [grad, maxCondNum] = gradient(self,s,data,inds)
function [grad, maxCondNum] = gradient(self,s,data,inds,fInds)
%global gradCheck
modelFreqs = self.freqBand(s);
s = s(modelFreqs);
Ns = numel(s); % number of frequency bins

% handle whether or not learning is stochastic by frequency
if nargin < 5
fInds = true(size(s));
end

Ns = sum(fInds); % number of frequency bins
Nc = self.C; % number of channels
Nl = self.L; % number of latent factors

Expand All @@ -356,6 +398,7 @@ function setParams(self,params)
opts.smallFlag = false;
[~,UKUl] = self.LMCkernels{l}.UKU(s,opts);
vals = self.LMCkernels{l}.extractBlocks(UKUl);
vals = vals(:,:,fInds);
UKUlStore(:,l) = vals(:);

% compute derivatives of parameters for each factor
Expand All @@ -373,13 +416,13 @@ function setParams(self,params)
ngrad = 0;
sgrad = zeros(self.L,sum(self.W));

% loop through all memory partitions (unless using sgd)
if nargin == 3
Parts = self.P;
stochastic = false;
elseif nargin == 4
% loop through all memory partitions (unless using sgd)
if nargin > 3 && ~isempty(inds)
Parts = 1;
stochastic = true;
else
Parts = self.P;
stochastic = false;
end
for p = 1:Parts

Expand All @@ -390,6 +433,8 @@ function setParams(self,params)
inds = ((p-1)*self.maxW+1):((p-1)*self.maxW+self.W(p));
y = data(modelFreqs,:,inds);
end
% if stochastic by freq
y = y(fInds,:,:);
y = conj(y);
theseScores = self.scores(:,inds);

Expand Down Expand Up @@ -422,10 +467,10 @@ function setParams(self,params)

if self.updateKernels
% gradient for window
gradW = (Agt*BdAll).*kdAll;
gradW = (Agt*BdAll).*kdAll(fInds,:);
gradW = sum(real(gradW),1).'; % i.e. trace
thisKgrad = (s(2)-s(1)) * Ns * bsxfun(@times, theseScores(:,w)', ...
reshape(gradW,[nParams,Nl]) ); % check this!
thisKgrad = (s(2)-s(1)) * numel(s) * bsxfun(@times, ...
theseScores(:,w)', reshape(gradW,[nParams,Nl]) );

% add window w contribution to gradient
LMCgrad = LMCgrad + thisKgrad;
Expand All @@ -444,12 +489,13 @@ function setParams(self,params)
% util.gradientCheckNoise
% end
end

% gradient for window w factor scores
sgrad(:,inds(w)) = real(Ag(:)'*UKUlStore)';
%if strcmp(gradCheck,'scores')
% util.gradientCheckScores
%end
if self.updateScores
% gradient for window w factor scores
sgrad(:,inds(w)) = real(Ag(:)'*UKUlStore)';
%if strcmp(gradCheck,'scores')
% util.gradientCheckScores
%end
end
end
end

Expand All @@ -458,14 +504,36 @@ function setParams(self,params)
%end

% concatenate gradients if necessary
if self.updateKernels
grad = vertcat(LMCgrad(:),sgrad(:));
else
if self.updateScores
grad = sgrad(:);
else
grad = [];
end
if self.updateKernels
grad = vertcat(LMCgrad(:),grad);
end
if self.updateNoise
grad = vertcat(ngrad, grad);
end

% add regularization terms (L1 reg on weights, applied to log weights)
% (and L1 penalty on scores)
if self.updateKernels
params = self.getParams;
Bidx = self.getParamIdx.coregWeights;
Bvals = params(Bidx);
grad(Bidx) = grad(Bidx) - self.regB*exp(Bvals);
end
if self.updateScores
Sidx = self.getParamIdx.scores;
weightRatio = (self.Q*self.R*self.C)/sum(self.W);
grad(Sidx) = grad(Sidx) - self.regB*weightRatio;
end
end

function setUpdateState(self, updateKernels, updateScores)
self.updateKernels = updateKernels;
self.updateScores = updateScores;
end

% normalize factors for identifiability
Expand All @@ -477,18 +545,22 @@ function makeIdentifiable(self)
end

function res = UKU(self,s,n,UKUlstore)
% if n is a vector, returns UKU for windows corresponding to all elements
% with 4th dimension iterating over windows
s = s(self.freqBand(s));
Ns = numel(s);
Nw = numel(n);

res = bsxfun(@times,1/self.eta*eye(self.C),ones([1,1,Ns]));
res = bsxfun(@times,1/self.eta*eye(self.C),ones([1,1,Ns,Nw]));
if ~exist('UKUlstore','var')
for l = 1:self.L
res = res + self.scores(l,n) * ...
self.LMCkernels{l}.extractBlocks(self.UKUl(s,l));
theseScores = permute(self.scores(l,n),[1,3,4,2]);
res = res + bsxfun(@times, theseScores, ...
self.LMCkernels{l}.extractBlocks(self.UKUl(s,l)));
end
else
res = res + sum(bsxfun(@times,UKUlstore, ...
permute(self.scores(:,n),[2,3,4,1])),4);
permute(self.scores(:,n),[3,4,5,2,1])),5);
end
end

Expand All @@ -501,7 +573,7 @@ function makeIdentifiable(self)
res = spalloc(Ns*Nc,Ns*Nc,Nc^2*Ns);
for q = 1:self.Q
B = self.LMCkernels{l}.coregs.getMat(q);
res = res + 1/d/2*kron(spdiags(SD(:,q),0,Ns,Ns),B);
res = res + 1/d*kron(spdiags(SD(:,q),0,Ns,Ns),B);
end
end

Expand Down
Loading

0 comments on commit 79eac7c

Please sign in to comment.