26 lines
893 B
Mathematica
26 lines
893 B
Mathematica
|
% implements KNN, returns the test error for the k-nearest neighbors
|
||
|
% algorithms when using a specified number of neighbors (k) for
|
||
|
% classification using a majority rules with tie-breaking.
|
||
|
function [test_err] = KNN(k, training_data, test_data, training_labels, test_labels)
|
||
|
|
||
|
n = length(test_data(:,1)); % get number of rows in test data
|
||
|
preds = zeros(length(test_labels),1); % predict labels for each test point
|
||
|
|
||
|
% TODO: compute pairwise euclidean distance between the test data and the
|
||
|
% training data
|
||
|
|
||
|
% for each data point (row) in the test data
|
||
|
for t = 1:n
|
||
|
|
||
|
% TODO: compute k-nearest neighbors for data point
|
||
|
|
||
|
|
||
|
% TODO: classify test point using majority rule. Include tie-breaking
|
||
|
% using whichever class is closer by distance. Fill in preds with the
|
||
|
% predicted label.
|
||
|
|
||
|
end
|
||
|
|
||
|
test_err = sum(preds ~= test_labels)/n; % error rate
|
||
|
|
||
|
end % Function end
|