本文实例讲述了java实现的knn算法。分享给大家供大家参考,具体如下:
提起knn算法大家应该都不会陌生,对于数据挖掘来说算是十大经典算法之一。
算法的思想是:对于训练数据集中已经归类的分组,来对于未知的数据进行分组归类。其中是根据该未知点与其训练数据中的点计算距离,求出距离最短的点,并将其归入该点的那一类。
看看算法的工程吧:
1. 准备数据,对数据进行预处理
2. 选用合适的数据结构存储训练数据和测试元组
3. 设定参数,如k
4.维护一个大小为k的的按距离由大到小的优先级队列,用于存储最近邻训练元组。随机从训练元组中选取k个元组作为初始的最近邻元组,分别计算测试元组到这k个元组的距离,将训练元组标号和距离存入优先级队列
5. 遍历训练元组集,计算当前训练元组与测试元组的距离,将所得距离l 与优先级队列中的最大距离lmax
6. 进行比较。若l>=lmax,则舍弃该元组,遍历下一个元组。若l < lmax,删除优先级队列中最大距离的元组,将当前训练元组存入优先级队 列。
7. 遍历完毕,计算优先级队列中k 个元组的多数类,并将其作为测试元组的类别。
8. 测试元组集测试完毕后计算误差率,继续设定不同的k值重新进行训练,最后取误差率最小的k 值。
根据算法的过程我们进行java语言实现:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
|
package knn; /** * 点的坐标 x 、y * @author administrator * */ public class pointbean { int x; int y; public int getx() { return x; } public void setx( int x) { this .x = x; } public int gety() { return y; } public void sety( int y) { this .y = y; } public pointbean( int x, int y) { super (); this .x = x; this .y = y; } public pointbean() { super (); } @override public string tostring() { return "pointbean [x=" + x + ", y=" + y + "]" ; } } |
knn算法
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
|
package knn; import java.util.arraylist; /** * knn实现的方法 * @author administrator * */ public class knnmain { public double getpointlength(arraylist<pointbean> list,pointbean bb){ int b_x=bb.getx(); int b_y=bb.gety(); double temp=(b_x -list.get( 0 ).getx())*(b_x -list.get( 0 ).getx())+ (b_y -list.get( 0 ).gety())*(b_y -list.get( 0 ).gety()); // 找出最小的距离 for ( int i= 1 ;i<list.size();i++){ if (temp<((b_x -list.get(i).getx())*(b_x -list.get(i).getx())+ (b_y -list.get(i).gety())*(b_y -list.get(i).gety()))){ temp=(b_x -list.get(i).getx())*(b_x -list.get(i).getx())+ (b_y -list.get(i).gety())*(b_y -list.get(i).gety()); } } return math.sqrt(temp); } /** * 获取长度,找出最小的一个进行归类 * @param list1 * @param list2 * @param list3 * @param bb */ public void getcontent(arraylist<pointbean> list1,arraylist<pointbean> list2, arraylist<pointbean> list3,pointbean bb){ double a=getpointlength(list1,bb); double b=getpointlength(list2,bb); double c=getpointlength(list3,bb); //做出比较 if (a>b){ if (b>c){ system.out.println( "这个点:" +bb.getx()+ " , " +bb.gety()+ " " + "属于c" ); } else { system.out.println( "这个点:" +bb.getx()+ " , " +bb.gety()+ " " + "属于b" ); } } else { if (a>c){ system.out.println( "这个点:" +bb.getx()+ " , " +bb.gety()+ " " + "属于c" ); } else { system.out.println( "这个点:" +bb.getx()+ " , " +bb.gety()+ " " + "属于a" ); } } } } |
主函数
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
|
package knn; import java.util.arraylist; /* * 主函数 knn */ public class testjava { static arraylist< pointbean> lista; static arraylist< pointbean> listb; static arraylist< pointbean> listc; static arraylist< pointbean> listd; public static void main(string[] args) { //创佳arraylist lista=new arraylist<pointbean>(); listb=new arraylist<pointbean>(); listc=new arraylist<pointbean>(); listd=new arraylist<pointbean>(); //写入数据 setdate(); gettestresult(); } /** * 得到结果 */ private static void gettestresult() { //创建对象 knnmain km=new knnmain(); for(int i=0;i<listd.size();i++){ km.getcontent(lista, listb, listc, listd.get(i)); } } /** * 写入数据 */ private static void setdate() { //a的坐标点 int a_x[]={ 1 , 1 , 2 , 2 , 1 }; int a_y[]={ 0 , 1 , 1 , 0 , 2 }; //b的坐标点 int b_x[]={ 2 , 3 , 3 , 3 , 4 }; int b_y[]={ 4 , 4 , 3 , 2 , 3 }; //c的坐标点 int c_x[]={ 4 , 5 , 5 , 6 , 6 }; int c_y[]={ 1 , 2 , 0 , 2 , 1 }; // 测试数据 //b的坐标点 int d_x[]={ 3 , 3 , 3 , 0 , 5 }; int d_y[]={ 0 , 1 , 5 , 0 , 1 }; // pointbean ba; for ( int i= 0 ;i< 5 ;i++){ ba= new pointbean(a_x[i], a_y[i]); lista.add(ba); } // pointbean bb ; for ( int i= 0 ;i< 5 ;i++){ bb= new pointbean(b_x[i], b_y[i]); listb.add(bb); } // pointbean bc ; for ( int i= 0 ;i< 5 ;i++){ bc= new pointbean(c_x[i], c_y[i]); listc.add(bc); } // pointbean bd ; for ( int i= 0 ;i< 5 ;i++){ bd= new pointbean(d_x[i], d_y[i]); listd.add(bd); } } } |
测试的结果:
这个点:3 , 1 属于a
这个点:3 , 5 属于b
这个点:0 , 0 属于a
这个点:5 , 1 属于c
到此简单的knn算法已经实现对于未知点的划分,有助于大家对于knn算法的理解。对于改进knn的一些算法java实现会在后面进行贴出。共同学习共同进步!
希望本文所述对大家java程序设计有所帮助。
原文链接:https://blog.csdn.net/u011015260/article/details/53392194