csci5521/assignments/hwk03/M_step.m

76 lines
2.5 KiB
Mathematica
Raw Normal View History

2023-11-10 03:29:17 +00:00
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
2023-11-15 15:53:18 +00:00
% Name: M_step.m
2023-11-10 03:29:17 +00:00
% Input: x - a nxd matrix (nx3 if using RGB)
% Q - vector of values from the complete data log-likelihood function
% h - a nxk matrix, the expectation of the hidden variable z given the data set and distribution params
% S - cluster covariance matrices
% k - the number of clusters
2023-11-15 15:53:18 +00:00
% flag - flag to use improved EM to avoid singular covariance matrix
% Output: S - cluster covariance matrices
2023-11-10 03:29:17 +00:00
% m - cluster means
2023-11-15 15:53:18 +00:00
% pi - mixing coefficients
2023-11-10 03:29:17 +00:00
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
2023-11-15 15:53:18 +00:00
function [S, m, pi] = M_step(x, h, S, k, flag)
2023-11-12 17:42:19 +00:00
% get size of data
[num_data, dim] = size(x);
eps = 1e-15;
2023-11-15 15:53:18 +00:00
lambda = 1e-3; % value for improved version of EM
2023-11-10 03:29:17 +00:00
2023-11-12 17:42:19 +00:00
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
2023-11-16 04:35:25 +00:00
% update mixing coefficients
2023-11-12 17:42:19 +00:00
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
2023-11-16 08:46:02 +00:00
N_i = zeros(k, 1);
2023-11-17 02:23:38 +00:00
m = zeros(k, dim);
2023-11-16 08:46:02 +00:00
for i = 1:k
N_i(i) = sum(h(:, i));
2023-11-17 02:23:38 +00:00
for j = 1:num_data
m(i, :) = m(i, :) + h(j, i) * x(j, :);
end
2023-11-16 04:35:25 +00:00
end
2023-11-16 08:46:02 +00:00
pi = N_i / num_data;
2023-11-15 15:53:18 +00:00
2023-11-17 02:23:38 +00:00
for i = 1:k
m(i, :) = m(i, :) ./ N_i(i);
end
2023-11-12 17:42:19 +00:00
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
2023-11-16 08:46:02 +00:00
% update cluster means
2023-11-12 17:42:19 +00:00
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
2023-11-17 02:23:38 +00:00
% m = zeros(k, dim);
% m = h' * x ./ N_i;
2023-11-16 08:46:02 +00:00
% for i = 1:k
% m(i, :) = sum(h(:, i) .* x(i, :)) / N_i(i);
% end
2023-11-15 15:53:18 +00:00
2023-11-12 17:42:19 +00:00
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
2023-11-16 08:46:02 +00:00
% Calculate the covariance matrix estimate
% further modifications will need to be made when doing 2(d)
2023-11-12 17:42:19 +00:00
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
2023-11-18 08:40:46 +00:00
S = zeros(dim, dim, k);
2023-11-16 08:46:02 +00:00
for i = 1:k
2023-11-18 08:40:46 +00:00
s = zeros(dim, dim) + eye(dim) * eps;
2023-11-17 02:23:38 +00:00
for j = 1:num_data
s = s + h(j, i) * (x(j, :) - m(i, :))' * (x(j, :) - m(i, :));
end
s = s / N_i(i);
% s = (x - m(i, :))' * ((x - m(i, :)) .* h(:, i)) / N_i(i);
2023-11-16 08:46:02 +00:00
% % MAKE IT SYMMETRIC https://stackoverflow.com/a/38730499
% S(:, :, i) = (s + s') / 2;
% https://www.mathworks.com/matlabcentral/answers/366140-eig-gives-a-negative-eigenvalue-for-a-positive-semi-definite-matrix#answer_290270
2023-11-18 08:40:46 +00:00
% s = (s + s') / 2;
2023-11-17 02:23:38 +00:00
% https://www.mathworks.com/matlabcentral/answers/57411-matlab-sometimes-produce-a-covariance-matrix-error-with-non-postive-semidefinite#answer_69524
2023-11-18 08:40:46 +00:00
% [V, D] = eig(s);
% s = V * max(D, eps) / V;
2023-11-17 02:23:38 +00:00
S(:, :, i) = s;
2023-11-16 08:46:02 +00:00
end
2023-11-15 15:53:18 +00:00
2023-11-18 08:40:46 +00:00
if flag
for i = 1:k
S(:, :, i) = S(:, :, i) + lambda * eye(dim) / 2;
end
end
2023-11-10 03:29:17 +00:00
end