This commit is contained in:
Michael Zhang 2023-11-15 22:35:25 -06:00
parent fa84edb84c
commit b444304a89
Signed by: michael
GPG key ID: BDA47A31A3C8EE6B
2 changed files with 17 additions and 10 deletions

View file

@ -22,12 +22,12 @@ function [h, m, Q] = EMG(x, k, epochs, flag)
Q = zeros(epochs*2,1); % vector that can hold complete data log-likelihood after each E and M step Q = zeros(epochs*2,1); % vector that can hold complete data log-likelihood after each E and M step
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% TODO: Initialise cluster means using k-means % Initialise cluster means using k-means
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
[~, ~, ~, D] = kmeans(x, k); [~, ~, ~, D] = kmeans(x, k);
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% TODO: Determine the b values for all data points % Determine the b values for all data points
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
for i = 1:num_data for i = 1:num_data
row = D(i,:); row = D(i,:);
@ -36,7 +36,7 @@ function [h, m, Q] = EMG(x, k, epochs, flag)
end end
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% TODO: Initialize pi's (mixing coefficients) % Initialize pi's (mixing coefficients)
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
pi = zeros(k, 1); pi = zeros(k, 1);
for i = 1:k for i = 1:k
@ -44,8 +44,8 @@ function [h, m, Q] = EMG(x, k, epochs, flag)
end end
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% TODO: Initialize the covariance matrix estimate % Initialize the covariance matrix estimate
% further modifications will need to be made when doing 2(d) % further modifications will need to be made when doing 2(d)
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
m = zeros(k, dim); m = zeros(k, dim);
for i = 1:k for i = 1:k
@ -63,13 +63,13 @@ function [h, m, Q] = EMG(x, k, epochs, flag)
[h] = E_step(x, h, pi, m, S, k); [h] = E_step(x, h, pi, m, S, k);
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% TODO: Store the value of the complete log-likelihood function % Store the value of the complete log-likelihood function
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
L = 0; L = 0;
for i = 1:num_data for i = 1:num_data
for j = 1:k for j = 1:k
prior = mvnpdf(x, m(j, :), S(:, :, j)); prior = mvnpdf(x, m(j, :), S(:, :, j));
L = L + h(i, j) * (log(pi(i)) + log(prior(i))); L = L + h(i, j) * (log(pi(j)) + log(prior(j)));
end end
end end
@ -77,7 +77,7 @@ function [h, m, Q] = EMG(x, k, epochs, flag)
% M-step % M-step
%%%%%%%%%%%%%%%% %%%%%%%%%%%%%%%%
fprintf('M-step, epoch #%d\n', n); fprintf('M-step, epoch #%d\n', n);
[Q, S, m] = M_step(x, Q, h, S, k); [Q, S, m] = M_step(x, h, S, k, flag);
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% TODO: Store the value of the complete log-likelihood function % TODO: Store the value of the complete log-likelihood function

View file

@ -18,9 +18,16 @@ function [S, m, pi] = M_step(x, h, S, k, flag)
lambda = 1e-3; % value for improved version of EM lambda = 1e-3; % value for improved version of EM
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% TODO: update mixing coefficients % update mixing coefficients
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
pi = zeros(k, 1);
for i = 1:num_data
row = h(i, :);
maxValue = max(row);
maxIdx = find(row == maxValue);
pi(maxIdx) = pi(maxIdx) + 1;
end
pi = pi ./ num_data;
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%