我的第一个svm程序:手写字识别

之前学过svm相关知识,基本原理不算复杂,今天做了一个手写字识别程序,总算验证了svm的效果。

因为只是验证效果,实现上原则是简单,使用python + libsvm + PIL(python image library)。这部分工作花了一些时间:

PIL:

下载源码包,解压之后运行:python setup.py install即可。

max下python libsvm安装使用:libsvm python接口介绍:

说是手写字,其实只是一到十这十个汉字,这样比较简单,而且收集的样本不太多。这十个汉字,在mac上用paintbrush前前后后画了259个80*80的png图片。图片缩放为16*16,二值化之后用一个256维的向量表示,简单粗暴。准备训练数据文件:inittraindata.py#! /usr/bin/env pythonimport Imageimport osf = []for i in range(1,11): f.append(open('ocr_' + str(i), 'wb'))for i in range(1,11): for item in os.listdir(str(i)):path = os.path.join(str(i), item)if os.path.isfile(path) and path.endswith(".png"):img_org = Image.open(path)img = img_org.resize((16,16), Image.NEAREST)pixdata = img.load()# -1for j in range(1,i):line = "-1 "for k in range(0, 256): line += str(k + 1) if pixdata[k / 16,k % 16][0] == 255:line += ":0 " else:line += ":1 "f[j – 1].write(line + "\n")# -1for j in range(i + 1, 11):line = "-1 "for k in range(0, 256): line += str(k + 1) if pixdata[k / 16, k % 16][0] == 255:line += ":0 " else:line += ":1 "f[j – 1].write(line + "\n")# 1line = "1 "for k in range(0, 256):line += str(k + 1)if pixdata[k / 16, k % 16][0] == 255: line += ":0 "else: line += ":1 "f[i – 1].write(line + "\n")for o in f: o.close

训练数据并保存模型save.py:

#! /usr/bin/env pythonimport sysfrom svmutil import *import Imageimport randomfor i in range(1, 11): y, x = svm_read_problem('./ocr_' + str(i))# if i == 4 or i == 3:# m = svm_train(y, x, '-c 10000')# else: m = svm_train(y, x, '-c 3 -g 0.015625') svm_save_model('./model_' + str(i), m)预测predict.py:

#! /usr/bin/env pythonimport sysfrom svmutil import *import Image# loadm = []for i in range(1, 11): m.append(svm_load_model('./model_' + str(i)))# predictpath = sys.argv[1]img_org = Image.open(path)img = img_org.resize((16,16), Image.NEAREST)pixdata = img.load()line = "-1 "tmpfile = open("tmpfile", "wb")for i in range(0, 256): line += str(i + 1) if pixdata[i / 16, i % 16][0] == 255:line += ":0 " else:line += ":1 "tmpfile.write(line + "\n")tmpfile.close()max = 100.0maxidx = -1for i in range(1, 11): y, x = svm_read_problem("tmpfile") label, acc, val = svm_predict(y, x, m[i – 1]) print val[0][0] if abs(val[0][0] – 1.0) < max:max = abs(val[0][0] – 1.0)maxidx = iprint "probably is: ", maxidx

使用c-svm,核函数使用RBF,参数c=3,gama=1.0/64,参数怎么选的,用的是简单粗暴的grid search,gridsearch.py:

#! /usr/bin/env pythonfrom svmutil import *import randomdef test(y, x, c, g): count = len(y[0]) correct_rate = 0.0 # n-fold cross-validation for i in range(0, 10):marr = []tarr = []answers = []for k in range(count*i/10, count*(i+1)*10):answers.append(0)for k in range(1, 11):# training setsyy = []xx = []for j in range(0, count*i/10):yy.append(y[k – 1][j])xx.append(x[k – 1][j])for j in range(count*(i + 1)/10, count):yy.append(y[k – 1][j])xx.append(x[k – 1][j])m = svm_train(yy, xx, '-c ' + str(c) + ' -g ' + str(g))marr.append(m)yyy = []xxx = []for j in range(count*i/10, count*(i+1)/10):yyy.append(y[k – 1][j])if y[k – 1][j] == 1: answers[j – count*i/10] = kxxx.append(x[k – 1][j])# test setstarr.append((yyy, xxx))print answers# predictingcorrect_count = 0for j in range(0, len(tarr[0][0])):max = 10000.0maxidx = -1for k in range(1, 11):label, acc, val = svm_predict(tarr[k – 1][0][j:j+1], tarr[k – 1][1][j:j+1], marr[k – 1])if abs(val[0][0] – 1.0) < max: max = abs(val[0][0] – 1.0) maxid = kprint "probably is", maxid, " answer is", answers[j]if answers[j] == maxid:correct_count += 1correct_rate += float(correct_count) / len(tarr[0][0]) correct_rate /= 10 print 'c=',c,'g=',g,'avg_correct_rate=',correct_rate return correct_ratedef main(): yarr = [] xarr = [] for i in range(1, 11):y, x = svm_read_problem('./ocr_' + str(i))yarr.append(y)xarr.append(x) #shuffle arr = [] for i in range(0, len(yarr[0])):arr.append(i) random.shuffle(arr) print "RANDOM ARR:",arr count = len(yarr[0]) for i in range(1, 11):yy = []xx = []y = yarr[i – 1]x = xarr[i – 1]for j in range(0, count):yy.append(y[arr[j]])xx.append(x[arr[j]])yarr[i – 1] = yyxarr[i – 1] = xx # grid search maxcorrect = -1 cpos = 0 gpos = 0 for c in range(1, 16, 1):for gg in range(0, 256, 1):g = gg * 1.0 / 256ret = test(yarr, xarr, c, g)if ret > maxcorrect:maxcorrect = retcpos = cgpos = gprint "current c=",cpos,"g=",gpos,"maxcorrect=",maxcorrect print "c=",cpos,"g=",gpos,"maxcorrect=",maxcorrect #test(yarr, xarr, 3, 1.0 / 64)if __name__ == '__main__': main()

也许不是自己该去发挥的地方,还是让自己到最适合自己战斗的方面去吧!勇敢的接受自己的失败,

我的第一个svm程序:手写字识别

相关文章:

你感兴趣的文章:

标签云: