使用python实现knn算法

本文实例为大家分享了python实现knn算法的具体代码,供大家参考,具体内容如下

knn算法描述

对需要分类的点依次执行以下操作:

1.计算已知类别数据集中每个点与该点之间的距离

2.按照距离递增顺序排序

3.选取与该点距离最近的k个点

4.确定前k个点所在类别出现的频率

5.返回前k个点出现频率最高的类别作为该点的预测分类

knn算法实现

数据处理

#从文件中读取数据,返回的数据和分类均为二维数组

def loadDataSet(filename):

dataSet = []

labels = []

fr = open(filename)

for line in fr.readlines():

lineArr = line.strip().split(",")

dataSet.append([float(lineArr[0]),float(lineArr[1])])

labels.append([float(lineArr[2])])

return dataSet , labels

knn算法

#计算两个向量之间的欧氏距离

def calDist(X1 , X2):

sum = 0

for x1 , x2 in zip(X1 , X2):

sum += (x1 - x2) ** 2

return sum ** 0.5

def knn(data , dataSet , labels , k):

n = shape(dataSet)[0]

for i in range(n):

dist = calDist(data , dataSet[i])

#只记录两点之间的距离和已知点的类别

labels[i].append(dist)

#按照距离递增排序

labels.sort(key=lambda x:x[1])

count = {}

#统计每个类别出现的频率

for i in range(k):

key = labels[i][0]

if count.has_key(key):

count[key] += 1

else : count[key] = 1

#按频率递减排序

sortCount = sorted(count.items(),key=lambda item:item[1],reverse=True)

return sortCount[0][0]#返回频率最高的key,即label

结果测试

已知类别数据(来源于西瓜书+虚构)

0.697,0.460,1

0.774,0.376,1

0.720,0.330,1

0.634,0.264,1

0.608,0.318,1

0.556,0.215,1

0.403,0.237,1

0.481,0.149,1

0.437,0.211,1

0.525,0.186,1

0.666,0.091,0

0.639,0.161,0

0.657,0.198,0

0.593,0.042,0

0.719,0.103,0

0.671,0.196,0

0.703,0.121,0

0.614,0.116,0

绘图方法

def drawPoints(data , dataSet, labels):

xcord1 = [];

ycord1 = [];

xcord2 = [];

ycord2 = [];

for i in range(shape(dataSet)[0]):

if labels[i][0] == 0:

xcord1.append(dataSet[i][0])

ycord1.append(dataSet[i][1])

if labels[i][0] == 1:

xcord2.append(dataSet[i][0])

ycord2.append(dataSet[i][1])

fig = plt.figure()

ax = fig.add_subplot(111)

ax.scatter(xcord1, ycord1, s=30, c='blue', marker='s',label=0)

ax.scatter(xcord2, ycord2, s=30, c='green',label=1)

ax.scatter(data[0], data[1], s=30, c='red',label="testdata")

plt.legend(loc='upper right')

plt.show()

测试代码

dataSet , labels = loadDataSet('dataSet.txt')

data = [0.6767,0.2122]

drawPoints(data , dataSet, labels)

newlabels = knn(data, dataSet , labels , 5)

print newlabels

运行结果

以上是 使用python实现knn算法 的全部内容, 来源链接: utcz.com/z/342431.html

回到顶部