csci5521/assignments/hwk01/Bayes_Testing.m

41 lines
1.1 KiB
Mathematica
Raw Normal View History

2023-10-01 23:09:50 +00:00
% implements Bayes Testing, return the test error (p1: learned Bernoulli
% parameters of the first class, p2: learned Bernoulli parameters of the
% second class; pc1: best prior of the first class, pc2: best prior of the
% second class
function test_error = Bayes_Testing(test_data, p1, p2, pc1, pc2)
% (1) TODO: classify the test set using the learned parameters p1, p2, pc1, pc2
[test_row_size, column_size] = size(test_data); % dimension of test data
X = test_data(1:test_row_size, 1:column_size-1); % test data
y = test_data(:,column_size); % test labels
c = 0;
for i = 1:test_row_size
x = X(i, :);
correct_label = y(i);
postc1 = prod(p1 .^ x .* (1 - p1) .^ (1 - x)) * pc1;
postc2 = prod(p2 .^ x .* (1 - p2) .^ (1 - x)) * pc2;
if postc1 > postc2
lab = 1;
else
lab = 2;
end
if lab == correct_label
c = c + 1;
end
end
test_error = (test_row_size - c) / test_row_size;
% (2) TODO: compute error rate and print it
% (test_error = # of incorrectly classified / total number of test samples
fprintf('Error rate on the test dataset is: \n\n');
disp(test_error);
end