% implements KNN, returns the test error for the k-nearest neighbors % algorithms when using a specified number of neighbors (k) for % classification using a majority rules with tie-breaking. function [test_err] = KNN(k, training_data, test_data, training_labels, test_labels) n = length(test_data(:,1)); % get number of rows in test data preds = zeros(length(test_labels),1); % predict labels for each test point % compute pairwise euclidean distance between the test data and the training data pairwise_distance = pdist2(training_data, test_data); unique_classes = unique(training_labels); % for each data point (row) in the test data for t = 1:n % compute k-nearest neighbors for data point distances = pairwise_distance(:,t); [~, smallest_indexes] = sort(distances, 'ascend'); smallest_k_indexes = smallest_indexes(1:k); distances_by_class = zeros(max(unique_classes), 2); for i = 1:length(unique_classes) class = unique_classes(i); this_class_distances = distances(training_labels == class,:); distances_by_class(i,1) = class; distances_by_class(i,2) = mean(this_class_distances); end % classify test point using majority rule. Include tie-breaking % using whichever class is closer by distance. Fill in preds with the % predicted label. smallest_k_labels = training_labels(smallest_k_indexes); % Try to resolve ties labels_by_count = tabulate(smallest_k_labels); labels_by_count_sorted = sortrows(labels_by_count, 2); most_frequent_label = labels_by_count_sorted(1,:); most_frequent_label_count = most_frequent_label(2); labels_that_have_most_frequent_count = labels_by_count_sorted(labels_by_count_sorted(:,2) == most_frequent_label_count,1); if length(labels_that_have_most_frequent_count) > 1 common_indexes = ismember(distances_by_class, labels_that_have_most_frequent_count); common_distances = distances_by_class(common_indexes,:); sorted_distances = sortrows(common_distances,2); preds(t) = sorted_distances(1,1); else winning_label = mode(smallest_k_labels); preds(t) = winning_label; end end test_err = sum(preds ~= test_labels)/n; % error rate end % Function end