csci5521/gauss_class/gauss_class_2D.m

93 lines
2.2 KiB
Mathematica
Raw Normal View History

2023-10-08 03:42:35 +00:00
% CSCI 5521 Introduction to Machine Learning
% Rui Kuang
% Demonstration of Classification by 2-D Gaussians
2023-10-08 05:38:10 +00:00
clf;
prior1 = 0.3;
prior2 = 0.7;
2023-10-08 06:50:20 +00:00
% mvndis = @(X, mu, Sigma, prior) ( ...
% -1/2 * log(2*pi) ...
% - log(Sigma) ...
% - power(X - mu, 2) / (2 * power(Sigma, 2)) ...
% + log(prior) ...
% );
2023-10-08 03:42:35 +00:00
mu1 = [-1 -1];
mu2 = [1 1];
% Equal diagnoal covariance matrix
2023-10-08 17:48:13 +00:00
% Sigma1 = [1 0; 0 1];
% Sigma2 = [1 0; 0 1];
2023-10-08 03:42:35 +00:00
% Diagnoal covariance matrix
% Sigma1 = [1 0; 0 0.5];
% Sigma2 = [1 0; 0 0.5];
% Shared covariance matrix
% Sigma1 = [1 0.3; 0.3 0.5];
% Sigma2 = [1 0.3; 0.3 0.5];
x1 = -10:.1:10; x2 = -10:.1:10;
% covariance matrix (increase the range for visualization)
2023-10-08 17:48:13 +00:00
Sigma1 = [1 0.1; 0.1 0.5];
Sigma2 = [0.5 0.3; 0.3 1];
x1 = -40:.1:40; x2 = -40:.1:40;
2023-10-08 03:42:35 +00:00
[X1,X2] = meshgrid(x1,x2);
%pdf1
2023-10-08 06:50:20 +00:00
% F1 = mvnpdf([X1(:) X2(:)],mu1,Sigma1);
F1 = mvndis([X1(:) X2(:)], mu1, Sigma1, prior1);
2023-10-08 17:48:13 +00:00
F1 = reshape(F1,length(x2),length(x1));
2023-10-08 03:42:35 +00:00
subplot(1,2,1);
surf(x1,x2,F1); hold on;
%pdf2
2023-10-08 06:50:20 +00:00
% F2 = mvnpdf([X1(:) X2(:)],mu2,Sigma2);
F2 = mvndis([X1(:) X2(:)], mu2, Sigma2, prior2);
2023-10-08 17:48:13 +00:00
F2 = reshape(F2,length(x2),length(x1));
2023-10-08 03:42:35 +00:00
surf(x1,x2,F2);
caxis([min(F2(:))-.5*range(F2(:)),max(F2(:))]);
axis([-4 4 -4 4 0 .4])
xlabel('x1'); ylabel('x2'); zlabel('Probability Density');
%decosopm boundary
%F1 = mvnpdf([X1(:) X2(:)],mu1,Sigma1);
%F1 = reshape(F1,length(x2),length(x1));
%F2 = mvnpdf([X1(:) X2(:)],mu2,Sigma2);
%F2 = reshape(F2,length(x2),length(x1));
cmp = F1 > F2;
subplot(1,2,2);
imagesc(X1(:),X2(:),cmp);
xlabel('x1'); ylabel('x2');
2023-10-08 06:50:20 +00:00
function res = mvndis(X, mu, Sigma, prior)
2023-10-08 17:48:13 +00:00
[len, d] = size(X);
2023-10-08 06:50:20 +00:00
res = zeros(len, 1);
for i = 1:len
x = X(i,:);
mdist = (x - mu) * inv(Sigma) * (x - mu).';
2023-10-08 17:48:13 +00:00
res(i) = -d/2*log(2*pi) - 1/2*log(det(Sigma)) - 1/2*mdist + log(prior);
2023-10-08 06:50:20 +00:00
end
% 1 x 2
% (1 x 2) x ((2 x 2) x (2 x 1))
% X - mu = 40401 x 2
% (40401 x 2) x (2 x 2) x (2 x 40401)
% mdist = (X - mu) * inv(Sigma) * (X - mu).';
% res = -log(2*pi) - 1/2*log(det(Sigma)) - 1/2*mdist + log(prior)
% res = zeros(size(X));
% [l1, l2] = size(X);
%
% for i1 = 1:l1
% for i2 = 1:l2
% cell = -log(2*pi) - 1/2*log(det(Sigma))
% res(i1, i2) = cell
% end
% end
end