diff --git a/assignments/hwk02/Problem1/Classify.m b/assignments/hwk02/Problem1/Classify.m index c601a8a..d001b68 100644 --- a/assignments/hwk02/Problem1/Classify.m +++ b/assignments/hwk02/Problem1/Classify.m @@ -3,9 +3,20 @@ % these posterior probabilities are compared using the log odds. function [predictions] = Classify(data, m1, m2, S1, S2, pc1, pc2) + d = 8; + % TODO: calculate P(x|C) * P(C) for both classes + + pxC1 = exp(-1/2*(data-m1)./S1*(data-m1)') / (power(2*pi,d/2) * sqrt(det(S1))); + pxC2 = exp(-1/2*(data-m2)*(S2\(data-m2).')); + + g1 = pxC1 * pc1; + g2 = pxC2 * pc2; % TODO: calculate log odds, if > 0 then data(i) belongs to class c1, else, c2 + for i = 1:length(data) + data(i) + end % TODO: get predictions from log odds calculation diff --git a/assignments/hwk02/Problem1/Param_Est.m b/assignments/hwk02/Problem1/Param_Est.m index 6c77329..4ac24af 100644 --- a/assignments/hwk02/Problem1/Param_Est.m +++ b/assignments/hwk02/Problem1/Param_Est.m @@ -2,17 +2,32 @@ % (m1: learned mean of features for class 1, m2: learned mean of features % for class 2, S1: learned covariance matrix for features of class 1, % S2: learned covariance matrix for features of class 2) -function [m1 m2 S1 S2] = Param_Est(training_data, training_labels, part) +function [m1, m2, S1, S2] = Param_Est(training_data, training_labels, part) + + [num_rows, num_cols] = size(training_data); + class1_data = training_data(training_labels==1,:); + class2_data = training_data(training_labels==2,:); + + m1 = mean(class1_data); + m2 = mean(class2_data); + + S1 = cov(class1_data); + S2 = cov(class2_data); % Parameter estimation for 3 different models described in homework if(strcmp(part, '3')) - % TODO: compute parameters for model 3 + S1 = diag(diag(S1)); + S2 = diag(diag(S2)); elseif(strcmp(part, '2')) - % TODO: compute parameters for model 2 + P_C1 = length(class1_data) / num_rows; + P_C2 = length(class2_data) / num_rows; + + S = P_C1 * S1 + P_C2 + S2; + S1 = S; + S2 = S; elseif(strcmp(part, '1')) - % TODO: compute parameters for model 1 end end % Function end