对词向量进行Kmeans距离

使用过word2vec的人都知道,使用word2vec训练得到的结果是每个词对应一个向量。虽然word2vec提供了kmeans的聚类方法,但是它是对所有原始的词进行聚类,如果我们只需要对其中一部分词按照向量进行kmeans聚类,那只好自己写方法。

参考网上一个开源的JAVA版 word2vec,可以得到JAVA版的kmeans聚类,输入为一个csv文件,每一行为词和其向量,输出为词的类别,词,该词到中心词的距离。

代码如下:

import java.io.BufferedReader;import java.io.File;import java.io.FileInputStream;import java.io.IOException;import java.io.InputStreamReader;import java.util.HashMap;public class Word2VEC {private HashMap<String, float[]> wordMap = new HashMap<String, float[]>();public void loadVectorFile(String path) throws IOException {BufferedReader br = null;double len = 0;float vector = 0;int size=0;try {File f = new File(path);br = new BufferedReader(new InputStreamReader(new FileInputStream(f), "UTF-8"));String word;String line="";String[] outline=new String[210];float[] vectors = null;while((line=br.readLine())!=null){outline=line.split(",");size=outline.length-1;word = outline[0];vectors = new float[size];len = 0;for (int j = 0; j < size; j++) {vector = Float.parseFloat(outline[j+1]);len += vector * vector;vectors[j] = (float) vector;}len = Math.sqrt(len);for (int j = 0; j < size; j++) {vectors[j] /= len;}wordMap.put(word, vectors);}}finally {System.out.println("total word: "+wordMap.size()+" vector dimensions: "+size);br.close();}}public HashMap<String, float[]> getWordMap() {return wordMap;}}import java.io.BufferedWriter;import java.io.File;import java.io.FileOutputStream;import java.io.IOException;import java.io.OutputStreamWriter;import java.util.ArrayList;import java.util.Collections;import java.util.Comparator;import java.util.HashMap;import java.util.Iterator;import java.util.List;import java.util.Map;import java.util.Map.Entry;public class WordKmeans {private HashMap<String, float[]> wordMap = null;private int iter;private Classes[] cArray = null;//total 659624 words each is a 200 vector//args[0] is the word vectors csv file//args[1] is the output file //args[2] is the cluster number//args[3] is the iterator numberpublic static void main(String[] args) throws IOException {Word2VEC vec = new Word2VEC();vec.loadVectorFile(args[0]);System.out.println("load data ok!");//input cluster number and iterator numberWordKmeans wordKmeans = new WordKmeans(vec.getWordMap(), Integer.parseInt(args[2]),Integer.parseInt(args[3]));Classes[] explain = wordKmeans.explain();File fw = new File(args[1]);BufferedWriter bw = new BufferedWriter(new OutputStreamWriter(new FileOutputStream(fw), "UTF-8"));//explain.length is the classes numberfor (int i = 0; i < explain.length; i++) {List<Entry<String, Double>> result=explain[i].getMember();StringBuffer buf = new StringBuffer();for (int j = 0; j < result.size(); j++) {buf.append(i+"\t"+result.get(j).getKey()+"\t"+result.get(j).getValue().toString()+"\n");}bw.write(buf.toString());bw.flush();}bw.close();}public WordKmeans(HashMap<String, float[]> wordMap, int clcn, int iter) {this.wordMap = wordMap;this.iter = iter;cArray = new Classes[clcn];}public Classes[] explain() {Iterator<Entry<String, float[]>> iterator = wordMap.entrySet().iterator();for (int i = 0; i < cArray.length; i++) {Entry<String, float[]> next = iterator.next();cArray[i] = new Classes(i, next.getValue());}for (int i = 0; i < iter; i++) {for (Classes classes : cArray) {classes.clean();}iterator = wordMap.entrySet().iterator();int cnt = 0;while (iterator.hasNext()) {if(cnt % 10000 ==0){System.out.println("Iter:"+i+"\tword:"+(cnt));}cnt++;Entry<String, float[]> next = iterator.next();double miniScore = Double.MAX_VALUE;double tempScore;int classesId = 0;for (Classes classes : cArray) {tempScore = classes.distance(next.getValue());if (miniScore > tempScore) {miniScore = tempScore;classesId = classes.id;}}cArray[classesId].putValue(next.getKey(), miniScore);}for (Classes classes : cArray) {classes.updateCenter(wordMap);}System.out.println("iter " + i + " ok!");}return cArray;}public static class Classes {private int id;private float[] center;public Classes(int id, float[] center) {this.id = id;this.center = center.clone();}Map<String, Double> values = new HashMap<>();public double distance(float[] value) {double sum = 0;for (int i = 0; i < value.length; i++) {sum += (center[i] – value[i])*(center[i] – value[i]) ;}return sum ;}public void putValue(String word, double score) {values.put(word, score);}public void updateCenter(HashMap<String, float[]> wordMap) {for (int i = 0; i < center.length; i++) {center[i] = 0;}float[] value = null;for (String keyWord : values.keySet()) {value = wordMap.get(keyWord);for (int i = 0; i < value.length; i++) {center[i] += value[i];}}for (int i = 0; i < center.length; i++) {center[i] = center[i] / values.size();}}public void clean() {values.clear();}public List<Entry<String, Double>> getTop(int n) {List<Map.Entry<String, Double>> arrayList = new ArrayList<Map.Entry<String, Double>>(values.entrySet());Collections.sort(arrayList, new Comparator<Map.Entry<String, Double>>() {@Overridepublic int compare(Entry<String, Double> o1, Entry<String, Double> o2) {return o1.getValue() > o2.getValue() ? 1 : -1;}});int min = Math.min(n, arrayList.size() – 1);if(min<=1){return Collections.emptyList() ;}return arrayList.subList(0, min);}public List<Entry<String, Double>> getMember() {List<Map.Entry<String, Double>> arrayList = new ArrayList<Map.Entry<String, Double>>(values.entrySet());Collections.sort(arrayList, new Comparator<Map.Entry<String, Double>>() {@Overridepublic int compare(Entry<String, Double> o1, Entry<String, Double> o2) {return o1.getValue() > o2.getValue() ? 1 : -1;}});int count=arrayList.size() – 1;if(count<=1){return Collections.emptyList() ;}return arrayList.subList(0, count);}}}进行聚类时需要指定输入文件,输出文件,类别数目和迭代次数。目前经过试验觉得JAVA速度不是很快,不知道C版本的速度如何,估计肯定要比JAVA快很多。

,年岁有加,并非垂老,理想丢弃,方堕暮年。

对词向量进行Kmeans距离

相关文章:

你感兴趣的文章:

标签云: