2023-10-12 23:51:06 +00:00
|
|
|
% implements KNN, returns the test error for the k-nearest neighbors
|
|
|
|
% algorithms when using a specified number of neighbors (k) for
|
|
|
|
% classification using a majority rules with tie-breaking.
|
|
|
|
function [test_err] = KNN(k, training_data, test_data, training_labels, test_labels)
|
|
|
|
|
|
|
|
n = length(test_data(:,1)); % get number of rows in test data
|
|
|
|
preds = zeros(length(test_labels),1); % predict labels for each test point
|
|
|
|
|
2023-10-22 04:32:00 +00:00
|
|
|
% compute pairwise euclidean distance between the test data and the training data
|
|
|
|
pairwise_distance = pdist2(training_data, test_data);
|
|
|
|
|
|
|
|
unique_classes = unique(training_labels);
|
2023-10-12 23:51:06 +00:00
|
|
|
|
|
|
|
% for each data point (row) in the test data
|
|
|
|
for t = 1:n
|
2023-10-22 06:04:22 +00:00
|
|
|
% compute k-nearest neighbors for data point
|
2023-10-22 04:32:00 +00:00
|
|
|
distances = pairwise_distance(:,t);
|
|
|
|
[~, smallest_indexes] = sort(distances, 'ascend');
|
|
|
|
smallest_k_indexes = smallest_indexes(1:k);
|
2023-10-12 23:51:06 +00:00
|
|
|
|
2023-10-22 04:32:00 +00:00
|
|
|
distances_by_class = zeros(max(unique_classes), 2);
|
|
|
|
for i = 1:length(unique_classes)
|
|
|
|
class = unique_classes(i);
|
|
|
|
this_class_distances = distances(training_labels == class,:);
|
|
|
|
distances_by_class(i,1) = class;
|
|
|
|
distances_by_class(i,2) = mean(this_class_distances);
|
|
|
|
end
|
2023-10-12 23:51:06 +00:00
|
|
|
|
2023-10-22 06:04:22 +00:00
|
|
|
% classify test point using majority rule. Include tie-breaking
|
2023-10-12 23:51:06 +00:00
|
|
|
% using whichever class is closer by distance. Fill in preds with the
|
|
|
|
% predicted label.
|
2023-10-22 04:32:00 +00:00
|
|
|
smallest_k_labels = training_labels(smallest_k_indexes);
|
|
|
|
|
2023-10-22 06:04:22 +00:00
|
|
|
% Try to resolve ties
|
2023-10-22 04:32:00 +00:00
|
|
|
labels_by_count = tabulate(smallest_k_labels);
|
|
|
|
labels_by_count_sorted = sortrows(labels_by_count, 2);
|
|
|
|
most_frequent_label = labels_by_count_sorted(1,:);
|
|
|
|
most_frequent_label_count = most_frequent_label(2);
|
|
|
|
labels_that_have_most_frequent_count = labels_by_count_sorted(labels_by_count_sorted(:,2) == most_frequent_label_count,1);
|
|
|
|
if length(labels_that_have_most_frequent_count) > 1
|
2023-10-22 06:04:22 +00:00
|
|
|
common_indexes = ismember(distances_by_class, labels_that_have_most_frequent_count);
|
2023-10-22 04:32:00 +00:00
|
|
|
common_distances = distances_by_class(common_indexes,:);
|
|
|
|
sorted_distances = sortrows(common_distances,2);
|
|
|
|
preds(t) = sorted_distances(1,1);
|
|
|
|
else
|
|
|
|
winning_label = mode(smallest_k_labels);
|
|
|
|
preds(t) = winning_label;
|
|
|
|
end
|
2023-10-12 23:51:06 +00:00
|
|
|
|
|
|
|
end
|
|
|
|
|
|
|
|
test_err = sum(preds ~= test_labels)/n; % error rate
|
|
|
|
|
|
|
|
end % Function end
|