csci5521/assignments/hwk02/Classify.m

31 lines
922 B
Mathematica
Raw Normal View History

2023-10-12 23:51:06 +00:00
% implements Classify, return the predicted class for each row (we'll call each row x) in data
% by computing the posterior probability that x is in class 1 vs. class 2 then
% these posterior probabilities are compared using the log odds.
function [predictions] = Classify(data, m1, m2, S1, S2, pc1, pc2)
2023-10-22 04:32:00 +00:00
% calculate P(x|C) * P(C) for both classes
2023-10-13 22:11:18 +00:00
2023-10-22 04:32:00 +00:00
pxC1 = mvnpdf(data, m1, S1);
pxC2 = mvnpdf(data, m2, S2);
2023-10-13 22:11:18 +00:00
2023-10-22 04:32:00 +00:00
g1 = log(pxC1 * pc1);
g2 = log(pxC2 * pc2);
2023-10-12 23:51:06 +00:00
% TODO: calculate log odds, if > 0 then data(i) belongs to class c1, else, c2
2023-10-22 04:32:00 +00:00
log_odds = g1 - g2;
% for i = 1:length(data)
% if g1 > g2
% predictions(i) = 1;
% else
% predictions(i) = 2;
% end
% end
2023-10-12 23:51:06 +00:00
% TODO: get predictions from log odds calculation
2023-10-22 04:32:00 +00:00
[num_rows, ~] = size(data);
predictions = zeros(num_rows,1);
for i = 1:num_rows
predictions(i) = log_odds(i) > 0;
end
2023-10-12 23:51:06 +00:00
end % Function end