K-neighbor 알고리즘에서 근접한 포인트를 찾을 때 다양한 방법으로 구할 수 있다.
그 중 KD tree 방식을 이용한 효율적인 근접 포인트 찾는 알고리즘을 기록한다.
import jax.numpy as jnp
from jax import vmap
class KDTreeJAX:
def __init__(self, data, depth=0):
self.axis = depth % data.shape[1]
self.median_idx = len(data) // 2
# Sort data along the current axis
sorted_data = data[jnp.argsort(data[:, self.axis])]
self.median_point = sorted_data[self.median_idx]
# Recursively build left and right subtrees
self.left = None
self.right = None
if self.median_idx > 0:
self.left = KDTreeJAX(sorted_data[:self.median_idx], depth + 1)
if self.median_idx + 1 < len(sorted_data):
self.right = KDTreeJAX(sorted_data[self.median_idx + 1:], depth + 1)
def nearest_neighbor(self, query_point, best=None, depth=0):
if best is None:
best = [None, float('inf')] # [best_point, best_distance]
axis = depth % query_point.shape[0]
dist = jnp.linalg.norm(query_point - self.median_point)
# Update best point if the current point is closer and not the excluded index
if dist < best[1]:
# best = [self.median_point, dist]
best = self.median_point
# Determine which subtree to search first
if query_point[axis] < self.median_point[axis]:
closer, farther = self.left, self.right
else:
closer, farther = self.right, self.left
# Search the closer branch
if closer is not None:
best = closer.nearest_neighbor(query_point, best, depth + 1)
# Check the farther branch if needed
if farther is not None and abs(query_point[axis] - self.median_point[axis]) < best[1]:
best = farther.nearest_neighbor(query_point, best, depth + 1)
return best
# Find nearest neighbors for all points in the dataset
def find_all_nearest_neighbors(data):
# data = (n_data, n_dim)
indices = jnp.arange(data.shape[0])
def nearest_neighbor_fn(idx):
# Exclude the current point from the search by passing its index
indices_del = jnp.delete( indices, idx )
tree = KDTreeJAX(data[indices_del])
return tree.nearest_neighbor(data[idx])
close_points = []
for i in range(data.shape[0]):
close_points.append(nearest_neighbor_fn(indices[i]))
# return vmap(nearest_neighbor_fn)(indices, data)
return jnp.array( close_points )
# test
points = jnp.array([[0.1, 0.2], [0.4, 0.7], [0.6, 0.8], [0.9, 0.5], [0.3, 0.9]])
# Find nearest neighbors for all points
neighbors = find_all_nearest_neighbors(points)
# Display results
for i, (neighbor, dist) in enumerate(neighbors):
print(f"Point {i}: {points[i]}")
print(f"Nearest Neighbor: {neighbor}, Distance: {dist}")
'연구 Research > 데이터과학 Data Science' 카테고리의 다른 글
[데이터과학] scipy interpolation 종류 정리 (0) | 2023.08.25 |
---|---|
[matplotlib] x,y축 format 지정하는 방법 (0) | 2023.06.08 |
[Matplotlib] 3D scatter plot 그리는 코드 (0) | 2023.04.28 |
[데이터과학] Pandas에서 dataframe 생성 및 export (0) | 2023.04.27 |
[데이터과학] Unbalancing data 처리 (0) | 2021.05.26 |