csci5521/assignments/hwk02/Classify.m

40 lines
1.4 KiB
Mathematica
Raw Permalink Normal View History

2023-10-12 23:51:06 +00:00
% implements Classify, return the predicted class for each row (we'll call each row x) in data
% by computing the posterior probability that x is in class 1 vs. class 2 then
% these posterior probabilities are compared using the log odds.
function [predictions] = Classify(data, m1, m2, S1, S2, pc1, pc2)
2023-10-25 13:32:53 +00:00
[num_rows, d] = size(data);
2023-10-22 04:32:00 +00:00
% calculate P(x|C) * P(C) for both classes
2023-10-13 22:11:18 +00:00
2023-10-25 13:32:53 +00:00
% pxC1 = 1/(power(2*pi, d/2) * power(det(S1), 1/2)) * exp(-1/2 * (data-m1) * inv(S1) * (data-m1)');
% pxC2 = 1/(power(2*pi, d/2) * power(det(S2), 1/2)) * exp(-1/2 * (data-m2) * inv(S2) * (data-m2)');
pxC1 = zeros(num_rows,1);
pxC2 = zeros(num_rows,1);
for i = 1:num_rows
x = data(i,:);
pxC1(i) = 1/(power(2*pi, d/2) * power(det(S1), 1/2)) * exp(-1/2 * (x-m1) * inv(S1) * (x-m1)');
pxC2(i) = 1/(power(2*pi, d/2) * power(det(S2), 1/2)) * exp(-1/2 * (x-m2) * inv(S2) * (x-m2)');
end
% pxC1 = mvnpdf(data, m1, S1);
% pxC2 = mvnpdf(data, m2, S2);
2023-10-13 22:11:18 +00:00
2023-10-25 13:32:53 +00:00
% P(C|x) = (P(x|C) * P(C)) / common factor
2023-10-25 10:20:44 +00:00
pC1x = pxC1 * pc1;
pC2x = pxC2 * pc2;
2023-10-12 23:51:06 +00:00
2023-10-25 13:32:53 +00:00
% calculate log odds, if > 0 then data(i) belongs to class c1, else, c2
log_odds = log(pC1x) - log(pC2x);
2023-10-12 23:51:06 +00:00
2023-10-25 13:32:53 +00:00
% get predictions from log odds calculation
2023-10-22 04:32:00 +00:00
predictions = zeros(num_rows,1);
for i = 1:num_rows
2023-10-25 10:20:44 +00:00
if log_odds(i) > 0
predictions(i) = 1;
else
predictions(i) = 2;
end
2023-10-22 04:32:00 +00:00
end
2023-10-12 23:51:06 +00:00
end % Function end