%% 初始化 clear; clc; k = 7; %设置找近邻所需对象的数量 load labels load trainData [m,n] = size(trainData); %% testData a = 4; %a为第a类样本 b = 56; eval(sprintf('load C:/Users/HP/Documents/%d_%d.txt',a,b)); eval(sprintf('testData = X%d_%d;',a,b)); testData = testData'; |
D = max(trainData,[],2) - min(trainData,[],2); trainData = (trainData-repmat(min(trainData,[],2),1,n))./repmat(D,1,n); testData = (testData - min(testData))./(max(testData)-min(testData)); |
relustLabel = knn(testData,trainData,labels,k); fprintf('the real number in the picture is %d\n',a); fprintf('predict number in the picture is %d\n',relustLabel); |
function relustLabel = knn(inx,data,labels,k) % inx为测试集,data为训练集,labels为训练集标签,k为判断时选取的点的个数 % relustLabel为测试集的标签 [datarow,datacol] = size(data); [inxrow,inxcol] = size(inx); data = repmat(data,1,1,inxrow); inx = repmat(reshape(inx',1,inxcol,inxrow),datarow,1,1); dis = reshape(sum((data-inx).^2,2),datarow,inxrow); [kdis,ind] = sort(dis,1); %¶Ô¾àÀëÅÅÐò ind = ind(1:k,:); for i = 1:inxrow, aa = tabulate(labels(ind(:,i))); [bb,inds] = max(aa(:,2)); relustLabel(i,:) = aa(inds) ; end |
clear; clc; %% 读取数据 load ex3data1.mat; for i = 1:10 % 每个数字读取前400个样本 X_train([400*i-399:400*i],:) = X([500*i-499:500*i-100],:); y_train([400*i-399:400*i],:) = y([500*i-499:500*i-100],:); end m = size(X_train,1); class_y = zeros(m,10); X_train = [ones(m,1),X_train]; %将X扩展 n = size(X_train,2); initial_theta = zeros(n,1); |
%% 调用优化工具箱 options = optimset('GradObj', 'on', 'MaxIter', 400); lamda = 0.1; for i = 1:10; class_y(find(y_train==i),i) = 1; [theta(:,i),cost(i)] = fminunc(@(t)(costfun(t,X_train,class_y(:,i),lamda)),initial_theta,options); end |
function [J,grad] = costfun(theta,X,y,lamda) m = size(X,1); Z = X*theta; H = sigmoid(Z); J = -1*sum(y.*log(H)+(1-y).*log(1-H))/m+lamda/(2*m)*sum(theta(2:end).^2); thetaj = theta; thetaj(1) = 0; grad = (X'*(H-y)+lamda*thetaj)/m; end |
load thetazhenze; load ex3data1.mat; m = size(X,1); X = [ones(m,1),X]; i = 9;n = 50; %n作为偏移量在1:100中取,i作为数字类在0:9中取 H = X(500*(i+1)-100+n,:)*theta; [h,ind] =max(H); %取H中最大的标签 if ind == 10, ind = 0; end fprintf('the number is %d\n',ind) %输出结果 |
欢迎光临 (http://www.51hei.com/bbs/) | Powered by Discuz! X3.1 |