csci5521/assignments/hwk02/Param_Est.m
2023-10-25 08:32:53 -05:00

47 lines
1.4 KiB
Matlab
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

% implements Param_Est, returns the parameters for each Multivariate Gaussian
% (m1: learned mean of features for class 1, m2: learned mean of features
% for class 2, S1: learned covariance matrix for features of class 1,
% S2: learned covariance matrix for features of class 2)
function [m1, m2, S1, S2] = Param_Est(training_data, training_labels, part)
[num_rows, ~] = size(training_data);
class1_data = training_data(training_labels==1,:);
class2_data = training_data(training_labels==2,:);
m1 = mean(class1_data);
m2 = mean(class2_data);
S1 = cov(class1_data);
S2 = cov(class2_data);
% Model 1.
% Assume independent 𝑆1 and 𝑆2 (the discriminant function is as equation (5.17) in the textbook).
if (strcmp(part, '1'))
% Already calculated above so nothing to be done here
% Model 2.
% Assume 𝑆1 = 𝑆2. In other words, shared S between two classes
% (the discriminant function is as equation (5.21) and (5.22) in the textbook).
elseif (strcmp(part, '2'))
P_C1 = length(class1_data) / num_rows;
P_C2 = length(class2_data) / num_rows;
S = P_C1 * S1 + P_C2 * S2;
S1 = S;
S2 = S;
% Model 3.
% Assume 𝑆1 and 𝑆2 are diagonal (the Naive Bayes model in equation (5.24)).
elseif (strcmp(part, '3'))
% pull diagonals into vector -> turn vector into diagonal matrix
S1 = diag(diag(S1));
S2 = diag(diag(S2));
end
end % Function end