[알고리즘] KDTree로 가장 가까운 포인트 찾기

2024. 12. 28. 05:45·연구 Research/데이터과학 Data Science

 

K-neighbor 알고리즘에서 근접한 포인트를 찾을 때 다양한 방법으로 구할 수 있다.

그 중 KD tree 방식을 이용한 효율적인 근접 포인트 찾는 알고리즘을 기록한다.

 

import jax.numpy as jnp
from jax import vmap

class KDTreeJAX:
    def __init__(self, data, depth=0):
        self.axis = depth % data.shape[1]
        self.median_idx = len(data) // 2

        # Sort data along the current axis
        sorted_data = data[jnp.argsort(data[:, self.axis])]
        self.median_point = sorted_data[self.median_idx]

        # Recursively build left and right subtrees
        self.left = None
        self.right = None
        if self.median_idx > 0:
            self.left = KDTreeJAX(sorted_data[:self.median_idx], depth + 1)
        if self.median_idx + 1 < len(sorted_data):
            self.right = KDTreeJAX(sorted_data[self.median_idx + 1:], depth + 1)

    def nearest_neighbor(self, query_point, best=None, depth=0):
        if best is None:
            best = [None, float('inf')]  # [best_point, best_distance]

        axis = depth % query_point.shape[0]
        dist = jnp.linalg.norm(query_point - self.median_point)

        # Update best point if the current point is closer and not the excluded index
        if dist < best[1]:
            # best = [self.median_point, dist]
            best = self.median_point

        # Determine which subtree to search first
        if query_point[axis] < self.median_point[axis]:
            closer, farther = self.left, self.right
        else:
            closer, farther = self.right, self.left

        # Search the closer branch
        if closer is not None:
            best = closer.nearest_neighbor(query_point, best, depth + 1)

        # Check the farther branch if needed
        if farther is not None and abs(query_point[axis] - self.median_point[axis]) < best[1]:
            best = farther.nearest_neighbor(query_point, best, depth + 1)

        return best


# Find nearest neighbors for all points in the dataset
def find_all_nearest_neighbors(data):
    # data = (n_data, n_dim)
    indices = jnp.arange(data.shape[0])
    def nearest_neighbor_fn(idx):
        # Exclude the current point from the search by passing its index
        indices_del = jnp.delete( indices, idx )
        tree = KDTreeJAX(data[indices_del])
        return tree.nearest_neighbor(data[idx])
    
    close_points = []
    for i in range(data.shape[0]):
        close_points.append(nearest_neighbor_fn(indices[i]))

    # return vmap(nearest_neighbor_fn)(indices, data)
    return jnp.array( close_points )


# test
points = jnp.array([[0.1, 0.2], [0.4, 0.7], [0.6, 0.8], [0.9, 0.5], [0.3, 0.9]])

# Find nearest neighbors for all points
neighbors = find_all_nearest_neighbors(points)

# Display results
for i, (neighbor, dist) in enumerate(neighbors):
    print(f"Point {i}: {points[i]}")
    print(f"Nearest Neighbor: {neighbor}, Distance: {dist}")

 

저작자표시 비영리 변경금지 (새창열림)

'연구 Research > 데이터과학 Data Science' 카테고리의 다른 글

[데이터과학] scipy interpolation 종류 정리  (0) 2023.08.25
[matplotlib] x,y축 format 지정하는 방법  (0) 2023.06.08
[Matplotlib] 3D scatter plot 그리는 코드  (0) 2023.04.28
[데이터과학] Pandas에서 dataframe 생성 및 export  (0) 2023.04.27
[데이터과학] Unbalancing data 처리  (0) 2021.05.26
'연구 Research/데이터과학 Data Science' 카테고리의 다른 글
  • [데이터과학] scipy interpolation 종류 정리
  • [matplotlib] x,y축 format 지정하는 방법
  • [Matplotlib] 3D scatter plot 그리는 코드
  • [데이터과학] Pandas에서 dataframe 생성 및 export
보통의공대생
보통의공대생
수학,프로그래밍,기계항공우주 등 공부하는 기록들을 남깁니다.
  • 보통의공대생
    뛰는 놈 위에 나는 공대생
    보통의공대생
  • 전체
    오늘
    어제
    • 분류 전체보기 (468)
      • 공지 (1)
      • 영어 공부 English Study (40)
        • 텝스 TEPS (7)
        • 글 Article (21)
        • 영상 Video (10)
      • 연구 Research (99)
        • 최적화 Optimization (3)
        • 데이터과학 Data Science (7)
        • 인공지능 Artificial Intelligent (40)
        • 제어 Control (45)
      • 프로그래밍 Programming (103)
        • 매트랩 MATLAB (25)
        • 파이썬 Python (33)
        • 줄리아 Julia (2)
        • C++ (3)
        • 리눅스 우분투 Ubuntu (6)
      • 항공우주 Aeronautical engineeri.. (21)
        • 항법 Navigation (0)
        • 유도 Guidance (0)
      • 기계공학 Mechanical engineering (13)
        • 열역학 Thermodynamics (0)
        • 고체역학 Statics & Solid mechan.. (10)
        • 동역학 Dynamics (1)
        • 유체역학 Fluid Dynamics (0)
      • 수학 Mathematics (34)
        • 선형대수학 Linear Algebra (18)
        • 미분방정식 Differential Equation (3)
        • 확률및통계 Probability &amp; Sta.. (2)
        • 미적분학 Calculus (1)
        • 복소해석학 Complex Analysis (5)
        • 실해석학 Real Analysis (0)
      • 수치해석 Numerical Analysis (27)
      • 확률 및 랜덤프로세스 Random process (2)
      • 추론 & 추정 이론 Estimation (3)
      • 기타 (26)
        • 설계 프로젝트 System Design (8)
        • 논문작성 Writing (55)
        • 세미나 Seminar (2)
        • 생산성 Productivity (3)
      • 실험 Experiment (1)
      • 유학 생활 Daily (8)
  • 블로그 메뉴

    • 홈
    • 태그
    • 방명록
  • 링크

  • 공지사항

  • 인기 글

  • 태그

    텝스공부
    IEEE
    인공지능
    서버
    teps
    논문작성법
    Numerical Analysis
    Zotero
    matplotlib
    수치해석
    WOX
    Dear abby
    딥러닝
    Python
    옵시디언
    Statics
    JAX
    ChatGPT
    고체역학
    MATLAB
    LaTeX
    논문작성
    에러기록
    Julia
    pytorch
    obsidian
    텝스
    우분투
    Linear algebra
    생산성
  • 최근 댓글

  • 최근 글

  • hELLO· Designed By정상우.v4.10.3
보통의공대생
[알고리즘] KDTree로 가장 가까운 포인트 찾기
상단으로

티스토리툴바