1 import numpy as np
2 arr = np.array([[2, 3], [5, 4], [9, 6], [4, 7], [8, 1], [7, 2]])
3 arr.shape
4
5 class KDTree():
6 def __init__(self):
7 self.value = None
8 self.left = None
9 self.right = None
10 self.axis = None
11
12 def create(arr, k, h=0):
13 if arr.shape[0] == 0:
14 return None
15 tree = KDTree()
16 axis = h % k
17
18 if arr.shape[0] == 1:
19 tree.value = arr[0]
20 tree.left = None
21 tree.right = None
22 tree.axis = axis
23 else:
24 arr = sorted(arr, key = lambda x:x[axis])
25 arr = np.array(arr)
26 i = arr.shape[0]//2
27
28 tree.value = arr[i]
29 tree.left = create(arr[0:i], k, h+1)
30 tree.right = create(arr[i+1:], k, h+1)
31 tree.axis = axis
32 return tree
33
34 k = KDTree()
35
36 k = create(arr, arr.shape[1])
37
38 def preOrder(k):
39 print('当前节点:' + str(k.value))
40
41 if k.left:
42 preOrder(k.left)
43 if k.right:
44 preOrder(k.right)
45
46 preOrder(k)
47
48 def dis(a, b):
49 return np.linalg.norm(a-b)
50 def search(kd, goal, k, h=0):
51 '''输入:kd树,目标点、特征维度k以及当前深度h'''
52 '''输出:在kd树上的与目标点距离(欧氏距离)最近的距离'''
53 if kd.left == None and kd.right == None:
54 return dis(goal, kd.value)
55 if kd.left == None:
56 return min(search(kd.right, goal, k, h+1), dis(kd.value, goal))
57 if kd.right == None:
58 return min(search(kd.left, goal, k, h+1), dis(kd.value, goal))
59 axis = h%k
60
61 if goal[axis] < kd.value[axis]:
62 cur_dis = search(kd.left, goal, k, h+1)
63 else:
64 cur_dis = search(kd.right, goal, k, h+1)
65
66
67 if cur_dis < kd.value[axis]-goal[axis]:////cut 取绝对值
68 return cur_dis;
69 else:
70 if goal[axis] < kd.value[axis]:
71 cur_dis = min(search(kd.right, goal, k, h+1), cur_dis, dis(kd.value, goal))
72 else:
73 cur_dis = min(search(kd.left, goal, k, h+1), cur_dis, dis(kd.value, goal))
74 return cur_dis
75
76 search(k, np.array([9, 6]), 2)