2023-10-12 23:51:06 +00:00
|
|
|
|
% implements Param_Est, returns the parameters for each Multivariate Gaussian
|
|
|
|
|
% (m1: learned mean of features for class 1, m2: learned mean of features
|
|
|
|
|
% for class 2, S1: learned covariance matrix for features of class 1,
|
|
|
|
|
% S2: learned covariance matrix for features of class 2)
|
2023-10-13 22:11:18 +00:00
|
|
|
|
function [m1, m2, S1, S2] = Param_Est(training_data, training_labels, part)
|
|
|
|
|
|
2023-10-22 04:32:00 +00:00
|
|
|
|
[num_rows, ~] = size(training_data);
|
2023-10-13 22:11:18 +00:00
|
|
|
|
class1_data = training_data(training_labels==1,:);
|
|
|
|
|
class2_data = training_data(training_labels==2,:);
|
|
|
|
|
|
|
|
|
|
m1 = mean(class1_data);
|
|
|
|
|
m2 = mean(class2_data);
|
|
|
|
|
|
2023-10-25 13:32:53 +00:00
|
|
|
|
S1 = cov(class1_data);
|
|
|
|
|
S2 = cov(class2_data);
|
2023-10-12 23:51:06 +00:00
|
|
|
|
|
2023-10-25 10:20:44 +00:00
|
|
|
|
% Model 1.
|
|
|
|
|
% Assume independent 𝑆1 and 𝑆2 (the discriminant function is as equation (5.17) in the textbook).
|
|
|
|
|
if (strcmp(part, '1'))
|
2023-10-25 13:32:53 +00:00
|
|
|
|
% Already calculated above so nothing to be done here
|
2023-10-12 23:51:06 +00:00
|
|
|
|
|
2023-10-22 04:32:00 +00:00
|
|
|
|
% Model 2.
|
2023-10-25 10:20:44 +00:00
|
|
|
|
% Assume 𝑆1 = 𝑆2. In other words, shared S between two classes
|
|
|
|
|
% (the discriminant function is as equation (5.21) and (5.22) in the textbook).
|
|
|
|
|
elseif (strcmp(part, '2'))
|
2023-10-13 22:11:18 +00:00
|
|
|
|
P_C1 = length(class1_data) / num_rows;
|
|
|
|
|
P_C2 = length(class2_data) / num_rows;
|
|
|
|
|
|
2023-10-25 13:32:53 +00:00
|
|
|
|
S = P_C1 * S1 + P_C2 * S2;
|
2023-10-13 22:11:18 +00:00
|
|
|
|
S1 = S;
|
|
|
|
|
S2 = S;
|
2023-10-25 13:32:53 +00:00
|
|
|
|
|
|
|
|
|
% Model 3.
|
|
|
|
|
% Assume 𝑆1 and 𝑆2 are diagonal (the Naive Bayes model in equation (5.24)).
|
|
|
|
|
elseif (strcmp(part, '3'))
|
|
|
|
|
% pull diagonals into vector -> turn vector into diagonal matrix
|
|
|
|
|
S1 = diag(diag(S1));
|
|
|
|
|
S2 = diag(diag(S2));
|
2023-10-12 23:51:06 +00:00
|
|
|
|
end
|
|
|
|
|
|
|
|
|
|
end % Function end
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|