用Python來實(shí)現(xiàn)K近鄰分類算法(KNN)已經(jīng)是一個(gè)老生常談的問題,網(wǎng)上也已經(jīng)有諸多資料,不過這里我還是決定記錄一下自己的學(xué)習(xí)心得,
Python實(shí)戰(zhàn)之KNN實(shí)現(xiàn)
。1、配置numpy庫(kù)
numpy庫(kù)是Python用于矩陣運(yùn)算的第三方庫(kù),大多數(shù)數(shù)學(xué)運(yùn)算都會(huì)依賴這個(gè)庫(kù)來進(jìn)行,關(guān)于numpy庫(kù)的配置參見:Python配置第三方庫(kù)Numpy和matplotlib的曲折之路,配置完成后將numpy庫(kù)整體導(dǎo)入到當(dāng)前工程中。
2、準(zhǔn)備訓(xùn)練樣本
這里簡(jiǎn)單的構(gòu)造四個(gè)點(diǎn)并配以對(duì)應(yīng)標(biāo)簽作為KNN的訓(xùn)練樣本:
# ====================創(chuàng)建訓(xùn)練樣本====================def createdataset(): group = array([[1.0, 1.1], [1.0, 1.0], [0, 0], [0, 0.1]]) labels = ['A', 'B', 'C', 'D'] return group, labels
這里有一個(gè)小細(xì)節(jié),就是通過array()函數(shù)老構(gòu)造并初始化numpy的矩陣對(duì)象時(shí),要保證只有一個(gè)參數(shù),因此在代碼中需要將參數(shù)用中括號(hào)括起來,像下面這種調(diào)用方式是不合法的:
group = array([1.0, 1.1], [1.0, 1.0], [0, 0], [0, 0.1])
3、創(chuàng)建分類函數(shù)
K近鄰算法在分類時(shí)一般是根據(jù)歐氏距離進(jìn)行分類的,因此需要將輸入的數(shù)據(jù)與訓(xùn)練數(shù)據(jù)在各個(gè)維度上相減再平方求和,再開方,如下:
# ====================歐氏距離分類====================def classify(Inx, Dataset, labels, k): DataSetSize = Dataset.shape[0] # 獲取數(shù)據(jù)的行數(shù),shape[1]位列數(shù) diffmat = tile(Inx, (DataSetSize, 1)) - Dataset SqDiffMat = diffmat**2 SqDistances = SqDiffMat.sum(axis=1) Distance = SqDistances**0.5 SortedDistanceIndicies = Distance.argsort() ClassCount = {}
這里tile()函數(shù)是numpy的矩陣擴(kuò)展函數(shù),比如說這個(gè)例子中訓(xùn)練樣本有四個(gè)二維坐標(biāo)點(diǎn),對(duì)于輸入樣本(一個(gè)二維坐標(biāo)點(diǎn)),需要將其先擴(kuò)展為一個(gè)4行1列的矩陣,然后在進(jìn)行矩陣減法,在平法求和,再開平方算距離。計(jì)算完距離之后,調(diào)用矩陣對(duì)象的排序成員函數(shù)argsort()對(duì)距離進(jìn)行升序排序。在這里介紹一個(gè)Pycharm查看源碼生命的小技巧:加入在編寫這段程序的時(shí)候我們并不確定argsort()是否為array對(duì)象的成員函數(shù),我們選中這個(gè)函數(shù)然后 右鍵 -> Go to -> Declaration,這樣就會(huì)跳轉(zhuǎn)到argsort()函數(shù)的聲明代碼片中,通過查看代碼的從屬關(guān)系能夠確認(rèn)array類中確實(shí)包含這個(gè)成員函數(shù),調(diào)用沒有問題:
對(duì)距離排序之后,接下來就根據(jù)前K個(gè)最小距離值所對(duì)應(yīng)的標(biāo)簽來判斷當(dāng)前樣本屬于哪一類:
for i in range(k): VoteiLabel = labels[SortedDistanceIndicies[i]] ClassCount[VoteiLabel] = ClassCount.get(VoteiLabel, 0) + 1 SortedClassCount = sorted(ClassCount.items(), key = operator.itemgetter(1), reverse = True)
這里有一個(gè)小問題就是在Python2中獲取字典元素使用的是dict.iteritems()成員函數(shù),而在Python3中改為dict.items()函數(shù),
電腦資料
《Python實(shí)戰(zhàn)之KNN實(shí)現(xiàn)》(http://m.stanzs.com)。“key = operator.itemgetter(1)”的意思是指定函數(shù)針對(duì)字典中第二維元素進(jìn)行排序,注意這里需要在之前導(dǎo)入符號(hào)庫(kù)operator。這里是通過記錄前K個(gè)距離最下值中每類標(biāo)簽出現(xiàn)的次數(shù)來判決測(cè)試樣本的歸屬。4、測(cè)試
這里給出完整的KNN測(cè)試代碼:
# coding: utf-8from numpy import *import operator# ====================創(chuàng)建訓(xùn)練樣本====================def createdataset(): group = array([[1.0, 1.1], [1.0, 1.0], [0, 0], [0, 0.1]]) labels = ['A', 'B', 'C', 'D'] return group, labels# ====================歐氏距離分類====================def classify(Inx, Dataset, labels, k): DataSetSize = Dataset.shape[0] # 獲取數(shù)據(jù)的行數(shù),shape[1]位列數(shù) diffmat = tile(Inx, (DataSetSize, 1)) - Dataset SqDiffMat = diffmat**2 SqDistances = SqDiffMat.sum(axis=1) Distance = SqDistances**0.5 SortedDistanceIndicies = Distance.argsort() ClassCount = {} for i in range(k): VoteiLabel = labels[SortedDistanceIndicies[i]] ClassCount[VoteiLabel] = ClassCount.get(VoteiLabel, 0) + 1 SortedClassCount = sorted(ClassCount.items(), key = operator.itemgetter(1), reverse = True) return SortedClassCount[0][0]Groups, Labels = createdataset()Result = classify([0, 0], Groups, Labels, 1)print(Result)
運(yùn)行代碼,程序答應(yīng)結(jié)果“C”。這里需要提一點(diǎn)的就是對(duì)于單訓(xùn)練樣本(每類只有一個(gè)訓(xùn)練樣本)的分類問題,KNN的K值應(yīng)該設(shè)定為1。