K-近邻法

距离度量标准:欧式距离或更一般的 LpL_p 距离。

k 值的选择:k 值小时,k 近邻模型更复杂;k 值大时,模型更简单(当 k=N 时,最简单);用交叉验证法取得最合适的 k 值。

分类决策原则:多数表决。

构造kd树

kd树用于搜索与规定点空间距离最小的点。

kd树与线段树类似,线段树存某一区间,kd树存某一k维空间。

构造kd树的方法也是从一整个空间开始,递归往下分配空间。

每个节点存其所在维度,父子节点,以便之后搜索。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
#include<iostream>
#include<algorithm>
#include<vector>
using namespace std;
const int k = 2; //k维度
const int maxn = 100; //最大节点数
int cmp_dim; //排序时的维度
vector<vector<int> >dat; //存节点坐标
struct Node { //节点
int value[k]; //坐标
int dim; //节点所在维度
Node* parant=NULL;
Node* left_ch=NULL; //父节点,子节点
Node* right_ch=NULL;
}kdnode[maxn];
bool cmp(const Node& a,const Node& b) {
return a.value[cmp_dim] < b.value[cmp_dim];
}

Node* build(int l,int r,int d) {
if (l > r)return NULL;
int m;
cmp_dim = d;
if ((l + r) & 1)m = 1+(l + r) >> 1; //中位点
else m = (l + r) >> 1;
nth_element(kdnode + l, kdnode + m, kdnode + r + 1,cmp);
kdnode[m].dim = d;
Node* lch = build(l, m-1, (d + 1) % k); //递归建树
if(lch!=NULL)
lch->parant = kdnode + m;
Node* rch = build(m + 1, r, (d + 1) % k);
if(rch!=NULL)
rch->parant = kdnode + m;
return kdnode + m;

}

int main() {
int n = 6;
dat.push_back({ 2,3 });
dat.push_back({ 5,4 });
dat.push_back({ 9,6 });
dat.push_back({ 4,7 });
dat.push_back({ 8,1 });
dat.push_back({ 7,2 });
for (int i = 0; i < n; i++) {
for (int j = 0; j < k; j++) {
kdnode[i].value[j] = dat[i][j];
}
}
build(0, n - 1, 0);

return 0;
}

搜索kd树

  1. 找到包含目标点的叶节点

  2. 以此节点为当前最近节点

  3. 向上回退。同时检查:

    1. 若父节点更近,更新;

    2. 若兄弟节点空间可能存在更近的点(到轴距离小于当前最小距离),检查兄弟节点。

  4. 当回到根节点,搜索结束。最终最近节点即为结果。

所有代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
#include<iostream>
#include<algorithm>
#include<vector>
#include<cmath>
using namespace std;
const int k = 2; //k维度
const int maxn = 100; //最大节点数
int cmp_dim; //排序时的维度
vector<vector<int> >dat; //存节点坐标
struct Node { //节点
int value[k]; //坐标
int dim; //节点所在维度
Node* parent=NULL;
Node* left_ch=NULL; //父节点,子节点
Node* right_ch=NULL;
}kdnode[maxn];
bool cmp(const Node& a,const Node& b) {
return a.value[cmp_dim] < b.value[cmp_dim];
}

Node* build(int l,int r,int d) {
if (l > r)return NULL;
int m;
cmp_dim = d;
if ((l + r) & 1)m = 1+(l + r) >> 1; //中位点
else m = (l + r) >> 1;
nth_element(kdnode + l, kdnode + m, kdnode + r + 1,cmp);
kdnode[m].dim = d;
Node* lch = build(l, m-1, (d + 1) % k); //递归建树
if (lch != NULL)
{
lch->parent = kdnode + m;
kdnode[m].left_ch = lch;
}
Node* rch = build(m + 1, r, (d + 1) % k);
if (rch != NULL)
{
rch->parent = kdnode + m;
kdnode[m].right_ch = rch;
}
return kdnode + m;

}

//distance返回欧式距离的平方
int distance(int* a,int* b) {
int res = 0;
for (int i = 0; i < k; i++)
res += (a[i]-b[i])*(a[i]-b[i]);
return res;
}

Node* root; //kd树的根
int min_dis; //最近点与目标点距离
Node* nearest; //最近点

//找到包含目标点的叶节点
Node* the_one(int* arr) {
int d = 0;
Node* r=root;
Node* tmp=r;
while (r != NULL) {
tmp = r;
if (r->value[d] > arr[d]) {
r = r->left_ch;
}
else
r = r->right_ch;
d = (d + 1) % k;
}
min_dis = distance(arr, tmp->value);
nearest = tmp;
return tmp;
}
//在某一节点 从上往下 找最近点
void search_son(Node* son, int* tar) {
if (son == NULL)
return;
if (distance(son->value, tar) < min_dis) { //如果当前点更近,则更新
nearest = son;
min_dis = distance(son->value, tar);
}
//若轴与目标点距离小于当前最小距离,则两边都要找
if (pow(son->value[son->dim] - tar[son->dim], 2) < min_dis) {
search_son(son->left_ch, tar);
search_son(son->right_ch, tar);
}
//只要找一边
else {
//目标点在轴左边,找左边空间
if (son->value[son->dim] > tar[son->dim])
search_son(son->left_ch, tar);
else
search_son(son->right_ch, tar);
}
}

//在某一点 从下往上 找最近点
void search_all(Node* now,int tar[k]) {
Node* par = now->parent;
while (now->parent != NULL) {
now = now->parent;
//若父节点更近,则更新
if (distance(tar, now->value) < min_dis) {
nearest = now;
min_dis = distance(tar, now->value);
}
Node* son;
//若轴与目标点距离小于当前最小距离,还要在另一边找
if (pow(now->value[now->dim], 2) < min_dis) {
//目标点在轴右边,找右边空间
if (now->value[now->dim] > tar[now->dim])
son = now->right_ch;
else
son = now->left_ch;
//确定在哪边空间后,从上往下找
search_son(son,tar);
}
}
}
int main() {
int n = 6;
dat.push_back({ 2,3 });
dat.push_back({ 5,4 });
dat.push_back({ 9,6 });
dat.push_back({ 4,7 });
dat.push_back({ 8,1 });
dat.push_back({ 7,2 });
for (int i = 0; i < n; i++) {
for (int j = 0; j < k; j++) {
kdnode[i].value[j] = dat[i][j];
}
}
root=build(0, n - 1, 0);
//测试
int a[2] = { 8,2 };
Node* now = the_one(a);
search_all(now, a);
//distance返回的是最小距离的平方,结果还应开根号
cout << distance(nearest->value, a);
return 0;
}