posts: kd-tree: add nearest neighbour

This commit is contained in:
Bruno BELANYI 2024-08-10 16:48:32 +01:00
parent 74d4aa87e6
commit 8acb675b16

View file

@ -229,3 +229,244 @@ class KdSplitNode[T]:
# Recurse into the child which contains the point
return self.children[self._index(point)].lookup(point)
```
### Closest points
Now to look at the most interesting operation one can do on a _k-d Tree_:
querying for the objects which are closest to a given point (i.e: the [Nearest
neighbour search][nns].
This is a more complicated algorithm, which will also need some modifications to
current _k-d Tree_ implementation in order to track just a bit more information
about the points it contains.
[nns]: https://en.wikipedia.org/wiki/Nearest_neighbor_search
#### A notion of distance
To search for the closest points to a given origin, we first need to define
which [distance](https://en.wikipedia.org/wiki/Distance) we are using in our
space.
For this example, we'll simply be using the usual definition of [(Euclidean)
distance][euclidean-distance].
[euclidean-distance]: https://en.wikipedia.org/wiki/Euclidean_distance
```python
def dist(point: Point, other: Point) -> float:
return sqrt(sum((a - b) ** 2 for a, b in zip(self, other)))
```
#### Tracking the tree's boundaries
To make the query efficient, we'll need to track the tree's boundaries: the
bounding box of all points contained therein. This will allow us to stop the
search early once we've found enough points and can be sure that the rest of the
tree is too far away to qualify.
For this, let's define the `AABB` (Axis-Aligned Bounding Box) class.
```python
class Point(NamedTuple):
# Convenience function to replace the coordinate along a given dimension
def replace(self, axis: Axis, new_coord: float) -> Point:
coords = list(self)
coords[axis] = new_coord
return Point(coords)
class AABB(NamedTuple):
# Lowest coordinates in the box
low: Point
# Highest coordinates in the box
high: Point
# An empty box
@classmethod
def empty(cls) -> AABB:
return cls(
Point(*(float("inf"),) * 3),
Point(*(float("-inf"),) * 3),
)
# Split the box into two along a given axis for a given mid-point
def split(axis: Axis, mid: float) -> tuple[AABB, AABB]:
assert self.low[axis] <= mid <= self.high[axis]
return (
AABB(self.low, self.high.replace(axis, mid)),
AABB(self.low.replace(axis, mid), self.high),
)
# Extend a box to contain a given point
def extend(self, point: Point) -> None:
low = NamedTuple(*(map(min, zip(self.low, point))))
high = NamedTuple(*(map(max, zip(self.high, point))))
return AABB(low, high)
# Return the shortest between a given point and the box
def dist_to_point(self, point: Point) -> float:
deltas = (
max(self.low[axis] - point[axis], 0, point[axis] - self.high[axis])
for axis in Axis
)
return dist(Point(0, 0, 0), Point(*deltas))
```
And do the necessary modifications to the `KdTree` to store the bounding box and
update it as we add new points.
```python
class KdTree[T]:
_root: KdNode[T]
# New field: to keep track of the tree's boundaries
_aabb: AABB
def __init__(self):
self._root = KdNode()
# Initialize the empty tree with an empty bounding box
self._aabb = AABB.empty()
def insert(self, point: Point, val: T) -> bool:
# Extend the AABB for our k-d Tree when adding a point to it
self._aabb = self._aabb.extend(point)
return self._root.insert(point, val, Axis.X)
```
#### `MaxHeap`
Python's builtin [`heapq`][heapq] module provides the necessary functions to
create and interact with a [_Priority Queue_][priority-queue], in the form of a
[_Binary Heap_][binary-heap].
Unfortunately, Python's library maintains a _min-heap_, which keeps the minimum
element at the root. For this algorithm, we're interested in having a
_max-heap_, with the maximum at the root.
Thankfully, one can just reverse the comparison function for each element to
convert between the two. Let's write a `MaxHeap` class making use of this
library, with a `Reverse` wrapper class to reverse the order of elements
contained within it (similar to [Rust's `Reverse`][reverse]).
[binary-heap]: https://en.wikipedia.org/wiki/Binary_heap
[heapq]: https://docs.python.org/3/library/heapq.html
[priority-queue]: https://en.wikipedia.org/wiki/Priority_queue
[reverse]: https://doc.rust-lang.org/std/cmp/struct.Reverse.html
```python
# Reverses the wrapped value's ordering
@functools.total_ordering
class Reverse[T]:
value: T
def __init__(self, value: T):
self.value = value
def __lt__(self, other: Reverse[T]) -> bool:
return self.value > other.value
def __eq__(self, other: Reverse[T]) -> bool:
return self.value == other.value
class MaxHeap[T]:
_heap: list[Reverse[T]]
def __init__(self):
self._heap = []
def __len__(self) -> int:
return len(self._heap)
def __iter__(self) -> Iterator[T]:
yield from (item.value for item in self._heap)
# Push a value on the heap
def push(self, value: T) -> None:
heapq.heappush(self._heap, Reverse(value))
# Peek at the current maximum value
def peek(self) -> T:
return self._heap[0].value
# Pop and return the highest value
def pop(self) -> T:
return heapq.heappop(self._heap).value
# Pushes a value onto the heap, pops and returns the highest value
def pushpop(self, value: T) -> None:
return heapq.heappushpop(self._heap, Reverse(value)).value
```
#### The actual Implementation
Now that we have written the necessary building blocks, let's tackle the
Implementation of `closest` for our _k-d Tree_.
```python
# Wrapper type for closest points, ordered by `distance`
@dataclasses.dataclass(order=True)
class ClosestPoint[T](NamedTuple):
point: Point = field(compare=False)
value: T = field(compare=False)
distance: float
class KdTree[T]:
def closest(self, point: Point, n: int = 1) -> list[ClosestPoint[T]]:
assert n > 0
# Create the output heap
res = MaxHeap()
# Recurse onto the root node
self._root.closest(point, res, n, self._aabb)
# Return the resulting list, from closest to farthest
return sorted(res)
class KdNode[T]:
def closest(
self,
point: Point,
out: MaxHeap[ClosestPoint[T]],
n: int,
bounds: AABB,
) -> None:
# Forward to the wrapped node
self.inner.closest(point, out, n, bounds)
class KdLeafNode[T]:
def closest(
self,
point: Point,
out: MaxHeap[ClosestPoint[T]],
n: int,
bounds: AABB,
) -> None:
# At the leaf, simply iterate over all points and add them to the heap
for p, val in self.points.items():
item = ClosestPoint(p, val, dist(p, point))
if len(out) < n:
# If the heap isn't full, just push
out.push(item)
elif out.peek().distance > item.distance:
# Otherwise, push and pop to keep the heap at `n` elements
out.pushpop(item)
class KdSplitNode[T]:
def closest(
self,
point: Point,
out: list[ClosestPoint[T]],
n: int,
bounds: AABB,
) -> None:
index = self._index(point)
children_bounds = bounds.split(self.axis, self.mid)
# Iterate over the child which contains the point, then its neighbour
for i in (index, 1 - index):
child, bounds = self.children[i], children_bounds[i]
# `min_dist` is 0 for the first child, and the minimum distance of
# all points contained in the second child
min_dist = bounds.dist_to_point(point)
# If the heap is at capacity and the child to inspect too far, stop
if len(out) == n and min_dist > out.peek().distance:
return
# Otherwise, recurse
child.closest(point, out, n, bounds)
```