csci5521/assignments/hwk02/Param_Est.m

40 lines
1,002 B
Mathematica
Raw Normal View History

2023-10-12 23:51:06 +00:00
% implements Param_Est, returns the parameters for each Multivariate Gaussian
% (m1: learned mean of features for class 1, m2: learned mean of features
% for class 2, S1: learned covariance matrix for features of class 1,
% S2: learned covariance matrix for features of class 2)
2023-10-13 22:11:18 +00:00
function [m1, m2, S1, S2] = Param_Est(training_data, training_labels, part)
[num_rows, num_cols] = size(training_data);
class1_data = training_data(training_labels==1,:);
class2_data = training_data(training_labels==2,:);
m1 = mean(class1_data);
m2 = mean(class2_data);
S1 = cov(class1_data);
S2 = cov(class2_data);
2023-10-12 23:51:06 +00:00
% Parameter estimation for 3 different models described in homework
if(strcmp(part, '3'))
2023-10-13 22:11:18 +00:00
S1 = diag(diag(S1));
S2 = diag(diag(S2));
2023-10-12 23:51:06 +00:00
elseif(strcmp(part, '2'))
2023-10-13 22:11:18 +00:00
P_C1 = length(class1_data) / num_rows;
P_C2 = length(class2_data) / num_rows;
S = P_C1 * S1 + P_C2 + S2;
S1 = S;
S2 = S;
2023-10-12 23:51:06 +00:00
elseif(strcmp(part, '1'))
end
end % Function end