41 lines
1.1 KiB
Mathematica
41 lines
1.1 KiB
Mathematica
|
% implements Bayes Testing, return the test error (p1: learned Bernoulli
|
||
|
% parameters of the first class, p2: learned Bernoulli parameters of the
|
||
|
% second class; pc1: best prior of the first class, pc2: best prior of the
|
||
|
% second class
|
||
|
|
||
|
function test_error = Bayes_Testing(test_data, p1, p2, pc1, pc2)
|
||
|
|
||
|
% (1) TODO: classify the test set using the learned parameters p1, p2, pc1, pc2
|
||
|
|
||
|
[test_row_size, column_size] = size(test_data); % dimension of test data
|
||
|
X = test_data(1:test_row_size, 1:column_size-1); % test data
|
||
|
y = test_data(:,column_size); % test labels
|
||
|
|
||
|
|
||
|
c = 0;
|
||
|
for i = 1:test_row_size
|
||
|
x = X(i, :);
|
||
|
correct_label = y(i);
|
||
|
|
||
|
postc1 = prod(p1 .^ x .* (1 - p1) .^ (1 - x)) * pc1;
|
||
|
postc2 = prod(p2 .^ x .* (1 - p2) .^ (1 - x)) * pc2;
|
||
|
|
||
|
if postc1 > postc2
|
||
|
lab = 1;
|
||
|
else
|
||
|
lab = 2;
|
||
|
end
|
||
|
|
||
|
if lab == correct_label
|
||
|
c = c + 1;
|
||
|
end
|
||
|
end
|
||
|
|
||
|
test_error = (test_row_size - c) / test_row_size;
|
||
|
|
||
|
% (2) TODO: compute error rate and print it
|
||
|
% (test_error = # of incorrectly classified / total number of test samples
|
||
|
fprintf('Error rate on the test dataset is: \n\n');
|
||
|
disp(test_error);
|
||
|
|
||
|
end
|