This commit is contained in:
Michael Zhang 2023-10-01 18:28:05 -05:00
parent 7f43b1d097
commit 8a019ca360

View file

@ -8,11 +8,14 @@ function [p1,p2,pc1,pc2] = Bayes_Learning(training_data, validation_data)
[train_row_size, column_size] = size (training_data); % dimension of training data [train_row_size, column_size] = size (training_data); % dimension of training data
[valid_row_size, ~] = size (validation_data); % dimension of validation data [valid_row_size, ~] = size (validation_data); % dimension of validation data
X = training_data(1:train_row_size, 1:column_size-1); %Training data X = training_data(1:train_row_size, 1:column_size-1); %Training data
y = training_data(:,column_size); % training labels
Xvalid = validation_data(1:valid_row_size, 1:column_size-1); %Training data
yvalid = validation_data(:,column_size); % training labels
% (1) TODO: find label counts of class 1 and class 2 % (1) TODO: find label counts of class 1 and class 2
% (2) TODO: get MLE p1, p2 % (2) TODO: get MLE p1, p2
[p1,p2] = MLE_Learning(training_data);
% Use different P(C_1) and P(C_2) on validation set % Use different P(C_1) and P(C_2) on validation set
% We compute g(x) = based on priors P(C_1), P(C_2), MLE estimator p1, p2, and x_{1*D} % We compute g(x) = based on priors P(C_1), P(C_2), MLE estimator p1, p2, and x_{1*D}
@ -25,9 +28,29 @@ for sigma = [0.00001,0.0001,0.001,0.01,0.1,1,2,3,4,5,6]
error_count = 0; % total number of errors to be count error_count = 0; % total number of errors to be count
% (3) TODO: compute likelihood for class1 and class2 , then compute the posterior % (3) TODO: compute likelihood for class1 and class2 , then compute the posterior
% probability for both classes (posterior = prior x likelihood). % probability for both classes (posterior = prior x likelihood).
% Classify each validation sample as whichever class has the higher posterior probability.
% If the sample is misclassified, increment the error count (error_count = error_count + 1); for i = 1:valid_row_size
x = Xvalid(i, :);
correct_label = yvalid(i);
postc1 = prod(p1 .^ x .* (1 - p1) .^ (1 - x)) * P_C1;
postc2 = prod(p2 .^ x .* (1 - p2) .^ (1 - x)) * P_C2;
% Classify each validation sample as whichever class has the higher posterior probability.
if postc1 > postc2
lab = 1;
else
lab = 2;
end
% If the sample is misclassified, increment the error count (error_count = error_count + 1);
if lab ~= correct_label
error_count = error_count + 1;
end
end
error_table(index,1) = sigma; error_table(index,1) = sigma;
error_table(index,2) = P_C1; error_table(index,2) = P_C1;