学无常师,负笈不远险阻

转载请声明出处:

KNN即K-Nearest Neighbor,是数据挖掘中一种最简单的分类方法,即要判断某一个样本属于已知样本种类中的哪一类时,通过计算找出所有样本中与测试样本最近或者最相似的K个样本,统计这K个样本中哪一种类最多则把测试样本归位该类。如何衡量两个样本的相似度?可以用向量的p-范数来定义。

假设有两个样本X=(x1, x2, …, xn),Y=(y1, y2, …, yn),则他们之间的相似度可以用以下向量p-范数定义:

当p=2时即为计算X、Y的欧几里得距离。

本文将介绍用Java实现KNN分类器对Iris数据进行分类。Iris数据如下:

前面四个item是属性,最后一个是类别名,总共有三类。完整的数据集可点击这里下载。

拿到原始数据后为了测试KNN分类效果,需要在原始数据中随机抽取一部分作为测试集,另一部分作为训练集。随机抽取的方法可以用下面代码实现:

/** * 将数据集划分为训练集和测试集,随机划分 * * @param filePath *数据集文件路径 * @param testCount *测试集个数 * @param outputPath *输出路径 * @throws Exception */public static void splitDataSet(String filePath, int testCount,String outputPath) throws Exception{BufferedWriter trainFile = new BufferedWriter(new FileWriter(new File(outputPath + "/train.txt")));BufferedWriter testFile = new BufferedWriter(new FileWriter(new File(outputPath + "/test.txt")));BufferedReader input = new BufferedReader(new FileReader(new File(filePath)));List<String> lines = new ArrayList<String>();String line = null;//将所有数据读取到一个List里while ((line = input.readLine()) != null)lines.add(line);//遍历一次List,每次产生一个随机序号,将该随机序号和当前序号内容进行交换for (int i = 0; i < lines.size(); i++){int ran = (int) (Math.random() * lines.size());String temp = lines.get(i);lines.set(i, lines.get(ran));lines.set(ran, temp);}int i = 0;//将指定数目的测试集写进test.txt中for (; i < testCount; i++){testFile.write(lines.get(i) + "\n");testFile.flush();}//剩余的写进train.txt中for (; i < lines.size(); i++){trainFile.write(lines.get(i) + "\n");trainFile.flush();}testFile.close();trainFile.close();} 调用这个方法后就可以得到train.txt和test.txt两份数据了。

接下来将数据读入:

/** * 根据文件生成训练集,注意:程序将以第一个出现的非数字的属性作为类别名称 * * @param fileName *文件名 * @param sep *分隔符 * @return * @throws Exception */public List<DataNode> getDataList(String fileName, String sep)throws Exception{List<DataNode> list = new ArrayList<DataNode>();BufferedReader br = new BufferedReader(new FileReader(new File(fileName)));String line = null;while ((line = br.readLine()) != null){String splits[] = line.split(sep);//DataNode类用于保存数据属性和数据类别DataNode node = new DataNode();int i = 0;for (; i < splits.length; i++){try{node.addAttrib(Float.valueOf(splits[i]));} catch (NumberFormatException e){// 非数字,则为类别名称,将类别映射为数字if (!mTypes.containsKey(splits[i])){mTypes.put(splits[i], mTypeCount);mTypeCount++;}node.setType(mTypes.get(splits[i]));list.add(node);}}}return list;}对于testList中的每一个样本,均与所有trainList中的样本进行计算,取出最接近的K个样本并返回:

KnnClassifier.java

package com.jingchen.knn;import java.util.List;/** * @author chenjing * */public class KnnClassifier{//k个近邻节点private int k;private KNode[] mNearestK;private List<DataNode> mTrainData;public KnnClassifier(int k, List<DataNode> trainList){mTrainData = trainList;this.k = k;mNearestK = new KNode[k];for (int i = 0; i < k; i++)mNearestK[i] = new KNode();}public void setK(int k){this.k = k;mNearestK = new KNode[k];for (int i = 0; i < k; i++)mNearestK[i] = new KNode();}private void train(DataNode test, float p){for (int i = 0; i < mTrainData.size(); i++){putNode(getSim(test, mTrainData.get(i), p));}}/** * 将新计算出来的节点与k个近邻节点比较,,如果比其中之一小则插入 * @param node */private void putNode(KNode node){for (int i = 0; i < k; i++){if (node.getD() < mNearestK[i].getD()){for (int j = k – 1; j > i; j–)mNearestK[j] = mNearestK[j – 1];mNearestK[i] = node;break;}}}/** * 获取相似度并封装成一个KNode类型返回 * @param test * @param trainNode * @param p * @return */private KNode getSim(DataNode test, DataNode trainNode, float p){List<Float> list1 = test.getAttribs();List<Float> list2 = trainNode.getAttribs();float d = 0;for (int i = 0; i < list1.size(); i++)d += Math.pow(Math.abs(list1.get(i).floatValue() – list2.get(i).floatValue()), p);d = (float) Math.pow(d, 1/p);KNode node = new KNode(d, trainNode.getType());return node;}private void reset(){for (int i = 0; i < k; i++)mNearestK[i].reset();}/** * 返回K个近邻节点 * @param test * @param p * @return */public KNode[] getKNN(DataNode test, float p){reset();train(test, p);return mNearestK;}} main方法:public static void main(String[] args) throws Exception{DataUtil util = DataUtil.getInstance();//获得训练集和测试集List<DataNode> trainList = util.getDataList("E:/train.txt", ",");List<DataNode> testList = util.getDataList("E:/test.txt", ",");int K = BASE_K;KnnClassifier classifier = new KnnClassifier(K, trainList);BufferedWriter output = new BufferedWriter(new FileWriter(new File("E:/output.txt")));int typeCount = util.getTypeCount();int[] count = new int[typeCount];for (int i = 0; i < testList.size();){for (int m = 0; m < typeCount; m++)count[m] = 0;DataNode test = testList.get(i);classifier.setK(K);KNode[] nodes = classifier.getKNN(test, 2);for (int j = 0; j < nodes.length; j++)count[nodes[j].getType()]++;int type = -1;int max = -1;for (int j = 0; j < typeCount; j++){if (count[j] > max){max = count[j];type = j;} else if (count[j] == max){// 存在两个类型分个数相同,无法判断属于哪个类型,增加K的值继续从该节点开始type = -1;K++;break;}}if (type == -1)continue;else{i++;K = BASE_K;}//将分类结果写入文件List<Float> attribs = test.getAttribs();for (int n = 0; n < attribs.size(); n++){output.write(attribs.get(n) + ",");output.flush();}output.write(util.getTypeName(type) + "\n");output.flush();}output.close();}

坐在外婆的沙滩,看最白的帆影。

学无常师,负笈不远险阻

相关文章:

你感兴趣的文章:

标签云: