LDA Gibbs Sampling 的JAVA实现

本系列博文介绍常见概率语言模型及其变形模型,主要总结PLSA、LDA及LDA的变形模型及参数Inference方法。初步计划内容如下

第一篇:PLSA及EM算法

第二篇:LDA及Gibbs Samping

第三篇:LDA变形模型-Twitter LDA,TimeUserLDA,ATM,Labeled-LDA,MaxEnt-LDA等

第四篇:基于变形LDA的paper分类总结(bibliography)

第五篇:LDA Gibbs Sampling 的JAVA实现

第五篇 LDA Gibbs Sampling的JAVA 实现

在本系列博文的前两篇,我们系统介绍了PLSA, LDA以及它们的参数Inference 方法,重点分析了模型表示和公式推导部分。曾有位学者说,“做研究要顶天立地”,意思是说做研究空有模型和理论还不够,我们还得有扎实的程序code和真实数据的实验结果来作为支撑。本文就重点分析LDA Gibbs Sampling的JAVA 实现,并给出apply到newsgroup18828新闻文档集上得出的Topic建模结果。

本项目Github地址https://github.com/yangliuy/LDAGibbsSampling

1、文档集预处理

要用LDA对文本进行topic建模,首先要对文本进行预处理,包括token,去停用词,stem,去noise词,去掉低频词等等。当语料库比较大时,我们也可以不进行stem。然后将文本转换成term的index表示形式,因为后面实现LDA的过程中经常需要在term和index之间进行映射。Documents类的实现如下,里面定义了Document内部类,用于描述文本集合中的文档。

package liuyang.nlp.lda.main;import java.io.File;import java.util.ArrayList;import java.util.HashMap;import java.util.Map;import java.util.regex.Matcher;import java.util.regex.Pattern;import liuyang.nlp.lda.com.FileUtil;import liuyang.nlp.lda.com.Stopwords;/**Class for corpus which consists of M documents * @author yangliu * @blog * @mail yangliuyx@gmail.com */public class Documents {ArrayList<Document> docs; Map<String, Integer> termToIndexMap;ArrayList<String> indexToTermMap;Map<String,Integer> termCountMap;public Documents(){docs = new ArrayList<Document>();termToIndexMap = new HashMap<String, Integer>();indexToTermMap = new ArrayList<String>();termCountMap = new HashMap<String, Integer>();}public void readDocs(String docsPath){for(File docFile : new File(docsPath).listFiles()){Document doc = new Document(docFile.getAbsolutePath(), termToIndexMap, indexToTermMap, termCountMap);docs.add(doc);}}public static class Document {private String docName;int[] docWords;public Document(String docName, Map<String, Integer> termToIndexMap, ArrayList<String> indexToTermMap, Map<String, Integer> termCountMap){this.docName = docName;//Read file and initialize word index arrayArrayList<String> docLines = new ArrayList<String>();ArrayList<String> words = new ArrayList<String>();FileUtil.readLines(docName, docLines);for(String line : docLines){FileUtil.tokenizeAndLowerCase(line, words);}//Remove stop words and noise wordsfor(int i = 0; i < words.size(); i++){if(Stopwords.isStopword(words.get(i)) || isNoiseWord(words.get(i))){words.remove(i);i–;}}//Transfer word to indexthis.docWords = new int[words.size()];for(int i = 0; i < words.size(); i++){String word = words.get(i);if(!termToIndexMap.containsKey(word)){int newIndex = termToIndexMap.size();termToIndexMap.put(word, newIndex);indexToTermMap.add(word);termCountMap.put(word, new Integer(1));docWords[i] = newIndex;} else {docWords[i] = termToIndexMap.get(word);termCountMap.put(word, termCountMap.get(word) + 1);}}words.clear();}public boolean isNoiseWord(String string) {// TODO Auto-generated method stubstring = string.toLowerCase().trim();Pattern MY_PATTERN = Pattern.compile(".*[a-zA-Z]+.*");Matcher m = MY_PATTERN.matcher(string);// filter @xxx and URLif(string.matches(".*www\\..*") || string.matches(".*\\.com.*") ||string.matches(".*http:.*") )return true;if (!m.matches()) {return true;} elsereturn false;}}}2 LDA Gibbs Sampling

文本预处理完毕后我们就可以实现LDA Gibbs Sampling。 首先我们要定义需要的参数,我的实现中在程序中给出了参数默认值,同时也支持配置文件覆盖,程序默认优先选用配置文件的参数设置。整个算法流程包括模型初始化,迭代Inference,,不断更新主题和待估计参数,最后输出收敛时的参数估计结果。

包含主函数的配置参数解析类如下:

package liuyang.nlp.lda.main;import java.io.File;import java.io.IOException;import java.util.ArrayList;import liuyang.nlp.lda.com.FileUtil;import liuyang.nlp.lda.conf.ConstantConfig;import liuyang.nlp.lda.conf.PathConfig;/**Liu Yang’s implementation of Gibbs Sampling of LDA * @author yangliu * @blog * @mail yangliuyx@gmail.com */public class LdaGibbsSampling {public static class modelparameters {float alpha = 0.5f; //usual value is 50 / Kfloat beta = 0.1f;//usual value is 0.1int topicNum = 100;int iteration = 100;int saveStep = 10;int beginSaveIters = 50;}/**Get parameters from configuring file. If the * configuring file has value in it, use the value. * Else the default value in program will be used * @param ldaparameters * @param parameterFile * @return void */private static void getParametersFromFile(modelparameters ldaparameters,String parameterFile) {// TODO Auto-generated method stubArrayList<String> paramLines = new ArrayList<String>();FileUtil.readLines(parameterFile, paramLines);for(String line : paramLines){String[] lineParts = line.split("\t");switch(parameters.valueOf(lineParts[0])){case alpha:ldaparameters.alpha = Float.valueOf(lineParts[1]);break;case beta:ldaparameters.beta = Float.valueOf(lineParts[1]);break;case topicNum:ldaparameters.topicNum = Integer.valueOf(lineParts[1]);break;case iteration:ldaparameters.iteration = Integer.valueOf(lineParts[1]);break;case saveStep:ldaparameters.saveStep = Integer.valueOf(lineParts[1]);break;case beginSaveIters:ldaparameters.beginSaveIters = Integer.valueOf(lineParts[1]);break;}}}public enum parameters{alpha, beta, topicNum, iteration, saveStep, beginSaveIters;}/** * @param args * @throws IOException */public static void main(String[] args) throws IOException {// TODO Auto-generated method stubString originalDocsPath = PathConfig.ldaDocsPath;String resultPath = PathConfig.LdaResultsPath;String parameterFile= ConstantConfig.LDAPARAMETERFILE;modelparameters ldaparameters = new modelparameters();getParametersFromFile(ldaparameters, parameterFile);Documents docSet = new Documents();docSet.readDocs(originalDocsPath);System.out.println("wordMap size " + docSet.termToIndexMap.size());FileUtil.mkdir(new File(resultPath));LdaModel model = new LdaModel(ldaparameters);System.out.println("1 Initialize the model …");model.initializeModel(docSet);System.out.println("2 Learning and Saving the model …");model.inferenceModel(docSet);System.out.println("3 Output the final model …");model.saveIteratedModel(ldaparameters.iteration, docSet);System.out.println("Done!");}}LDA 模型实现类如下package liuyang.nlp.lda.main;/**Class for Lda model * @author yangliu * @blog * @mail yangliuyx@gmail.com */import java.io.BufferedWriter;import java.io.FileWriter;import java.io.IOException;import java.util.ArrayList;import java.util.Collections;import java.util.Comparator;import java.util.List;import liuyang.nlp.lda.com.FileUtil;import liuyang.nlp.lda.conf.PathConfig;public class LdaModel {int [][] doc;//word index arrayint V, K, M;//vocabulary size, topic number, document numberint [][] z;//topic label arrayfloat alpha; //doc-topic dirichlet prior parameter float beta; //topic-word dirichlet prior parameterint [][] nmk;//given document m, count times of topic k. M*Kint [][] nkt;//given topic k, count times of term t. K*Vint [] nmkSum;//Sum for each row in nmkint [] nktSum;//Sum for each row in nktdouble [][] phi;//Parameters for topic-word distribution K*Vdouble [][] theta;//Parameters for doc-topic distribution M*Kint iterations;//Times of iterationsint saveStep;//The number of iterations between two savingint beginSaveIters;//Begin save model at this iterationpublic LdaModel(LdaGibbsSampling.modelparameters modelparam) {// TODO Auto-generated constructor stubalpha = modelparam.alpha;beta = modelparam.beta;iterations = modelparam.iteration;K = modelparam.topicNum;saveStep = modelparam.saveStep;beginSaveIters = modelparam.beginSaveIters;}public void initializeModel(Documents docSet) {// TODO Auto-generated method stubM = docSet.docs.size();V = docSet.termToIndexMap.size();nmk = new int [M][K];nkt = new int[K][V];nmkSum = new int[M];nktSum = new int[K];phi = new double[K][V];theta = new double[M][K];//initialize documents index arraydoc = new int[M][];for(int m = 0; m < M; m++){//Notice the limit of memoryint N = docSet.docs.get(m).docWords.length;doc[m] = new int[N];for(int n = 0; n < N; n++){doc[m][n] = docSet.docs.get(m).docWords[n];}}//initialize topic lable z for each wordz = new int[M][];for(int m = 0; m < M; m++){int N = docSet.docs.get(m).docWords.length;z[m] = new int[N];for(int n = 0; n < N; n++){int initTopic = (int)(Math.random() * K);// From 0 to K – 1z[m][n] = initTopic;//number of words in doc m assigned to topic initTopic add 1nmk[m][initTopic]++;//number of terms doc[m][n] assigned to topic initTopic add 1nkt[initTopic][doc[m][n]]++;// total number of words assigned to topic initTopic add 1nktSum[initTopic]++;}// total number of words in document m is NnmkSum[m] = N;}}public void inferenceModel(Documents docSet) throws IOException {// TODO Auto-generated method stubif(iterations < saveStep + beginSaveIters){System.err.println("Error: the number of iterations should be larger than " + (saveStep + beginSaveIters));System.exit(0);}for(int i = 0; i < iterations; i++){System.out.println("Iteration " + i);if((i >= beginSaveIters) && (((i – beginSaveIters) % saveStep) == 0)){//Saving the modelSystem.out.println("Saving model at iteration " + i +" … ");//Firstly update parametersupdateEstimatedParameters();//Secondly print model variablessaveIteratedModel(i, docSet);}//Use Gibbs Sampling to update z[][]for(int m = 0; m < M; m++){int N = docSet.docs.get(m).docWords.length;for(int n = 0; n < N; n++){// Sample from p(z_i|z_-i, w)int newTopic = sampleTopicZ(m, n);z[m][n] = newTopic;}}}}private void updateEstimatedParameters() {// TODO Auto-generated method stubfor(int k = 0; k < K; k++){for(int t = 0; t < V; t++){phi[k][t] = (nkt[k][t] + beta) / (nktSum[k] + V * beta);}}for(int m = 0; m < M; m++){for(int k = 0; k < K; k++){theta[m][k] = (nmk[m][k] + alpha) / (nmkSum[m] + K * alpha);}}}private int sampleTopicZ(int m, int n) {// TODO Auto-generated method stub// Sample from p(z_i|z_-i, w) using Gibbs upde rule//Remove topic label for w_{m,n}int oldTopic = z[m][n];nmk[m][oldTopic]–;nkt[oldTopic][doc[m][n]]–;nmkSum[m]–;nktSum[oldTopic]–;//Compute p(z_i = k|z_-i, w)double [] p = new double[K];for(int k = 0; k < K; k++){p[k] = (nkt[k][doc[m][n]] + beta) / (nktSum[k] + V * beta) * (nmk[m][k] + alpha) / (nmkSum[m] + K * alpha);}//Sample a new topic label for w_{m, n} like roulette//Compute cumulated probability for pfor(int k = 1; k < K; k++){p[k] += p[k – 1];}double u = Math.random() * p[K – 1]; //p[] is unnormalisedint newTopic;for(newTopic = 0; newTopic < K; newTopic++){if(u < p[newTopic]){break;}}//Add new topic label for w_{m, n}nmk[m][newTopic]++;nkt[newTopic][doc[m][n]]++;nmkSum[m]++;nktSum[newTopic]++;return newTopic;}public void saveIteratedModel(int iters, Documents docSet) throws IOException {// TODO Auto-generated method stub//lda.params lda.phi lda.theta lda.tassign lda.twords//lda.paramsString resPath = PathConfig.LdaResultsPath;String modelName = "lda_" + iters;ArrayList<String> lines = new ArrayList<String>();lines.add("alpha = " + alpha);lines.add("beta = " + beta);lines.add("topicNum = " + K);lines.add("docNum = " + M);lines.add("termNum = " + V);lines.add("iterations = " + iterations);lines.add("saveStep = " + saveStep);lines.add("beginSaveIters = " + beginSaveIters);FileUtil.writeLines(resPath + modelName + ".params", lines);//lda.phi K*VBufferedWriter writer = new BufferedWriter(new FileWriter(resPath + modelName + ".phi"));for (int i = 0; i < K; i++){for (int j = 0; j < V; j++){writer.write(phi[i][j] + "\t");}writer.write("\n");}writer.close();//lda.theta M*Kwriter = new BufferedWriter(new FileWriter(resPath + modelName + ".theta"));for(int i = 0; i < M; i++){for(int j = 0; j < K; j++){writer.write(theta[i][j] + "\t");}writer.write("\n");}writer.close();//lda.tassignwriter = new BufferedWriter(new FileWriter(resPath + modelName + ".tassign"));for(int m = 0; m < M; m++){for(int n = 0; n < doc[m].length; n++){writer.write(doc[m][n] + ":" + z[m][n] + "\t");}writer.write("\n");}writer.close();//lda.twords phi[][] K*Vwriter = new BufferedWriter(new FileWriter(resPath + modelName + ".twords"));int topNum = 20; //Find the top 20 topic words in each topicfor(int i = 0; i < K; i++){List<Integer> tWordsIndexArray = new ArrayList<Integer>();for(int j = 0; j < V; j++){tWordsIndexArray.add(new Integer(j));}Collections.sort(tWordsIndexArray, new LdaModel.TwordsComparable(phi[i]));writer.write("topic " + i + "\t:\t");for(int t = 0; t < topNum; t++){writer.write(docSet.indexToTermMap.get(tWordsIndexArray.get(t)) + " " + phi[i][tWordsIndexArray.get(t)] + "\t");}writer.write("\n");}writer.close();}public class TwordsComparable implements Comparator<Integer> {public double [] sortProb; // Store probability of each word in topic kpublic TwordsComparable (double[] sortProb){this.sortProb = sortProb;}@Overridepublic int compare(Integer o1, Integer o2) {// TODO Auto-generated method stub//Sort topic word index according to the probability of each word in topic kif(sortProb[o1] > sortProb[o2]) return -1;else if(sortProb[o1] < sortProb[o2]) return 1;else return 0;}}}程序的实现细节可以参考我在程序中给出的注释,如果理解LDA Gibbs Sampling的算法流程,上面的代码很好理解。其实排除输入输出和参数解析的代码,标准LDA 的Gibbs sampling只需要不到200行程序就可以搞定。当然,里面有很多可以考虑优化和变形的地方。

3 用LDA Gibbs Sampling对Newsgroup 18828文档集进行主题分析

而是他们在同伴们都睡着的时候,一步步艰辛地向上攀爬的。

LDA Gibbs Sampling 的JAVA实现

相关文章:

你感兴趣的文章:

标签云: