posts: kd-tree: add nearest neighbour
This commit is contained in:
parent
8ae274d5b2
commit
4796157b65
|
@ -229,3 +229,244 @@ class KdSplitNode[T]:
|
||||||
# Recurse into the child which contains the point
|
# Recurse into the child which contains the point
|
||||||
return self.children[self._index(point)].lookup(point)
|
return self.children[self._index(point)].lookup(point)
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### Closest points
|
||||||
|
|
||||||
|
Now to look at the most interesting operation one can do on a _k-d Tree_:
|
||||||
|
querying for the objects which are closest to a given point (i.e: the [Nearest
|
||||||
|
neighbour search][nns].
|
||||||
|
|
||||||
|
This is a more complicated algorithm, which will also need some modifications to
|
||||||
|
current _k-d Tree_ implementation in order to track just a bit more information
|
||||||
|
about the points it contains.
|
||||||
|
|
||||||
|
[nns]: https://en.wikipedia.org/wiki/Nearest_neighbor_search
|
||||||
|
|
||||||
|
#### A notion of distance
|
||||||
|
|
||||||
|
To search for the closest points to a given origin, we first need to define
|
||||||
|
which [distance](https://en.wikipedia.org/wiki/Distance) we are using in our
|
||||||
|
space.
|
||||||
|
|
||||||
|
For this example, we'll simply be using the usual definition of [(Euclidean)
|
||||||
|
distance][euclidean-distance].
|
||||||
|
|
||||||
|
[euclidean-distance]: https://en.wikipedia.org/wiki/Euclidean_distance
|
||||||
|
|
||||||
|
```python
|
||||||
|
def dist(point: Point, other: Point) -> float:
|
||||||
|
return sqrt(sum((a - b) ** 2 for a, b in zip(self, other)))
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Tracking the tree's boundaries
|
||||||
|
|
||||||
|
To make the query efficient, we'll need to track the tree's boundaries: the
|
||||||
|
bounding box of all points contained therein. This will allow us to stop the
|
||||||
|
search early once we've found enough points and can be sure that the rest of the
|
||||||
|
tree is too far away to qualify.
|
||||||
|
|
||||||
|
For this, let's define the `AABB` (Axis-Aligned Bounding Box) class.
|
||||||
|
|
||||||
|
```python
|
||||||
|
class Point(NamedTuple):
|
||||||
|
# Convenience function to replace the coordinate along a given dimension
|
||||||
|
def replace(self, axis: Axis, new_coord: float) -> Point:
|
||||||
|
coords = list(self)
|
||||||
|
coords[axis] = new_coord
|
||||||
|
return Point(coords)
|
||||||
|
|
||||||
|
class AABB(NamedTuple):
|
||||||
|
# Lowest coordinates in the box
|
||||||
|
low: Point
|
||||||
|
# Highest coordinates in the box
|
||||||
|
high: Point
|
||||||
|
|
||||||
|
# An empty box
|
||||||
|
@classmethod
|
||||||
|
def empty(cls) -> AABB:
|
||||||
|
return cls(
|
||||||
|
Point(*(float("inf"),) * 3),
|
||||||
|
Point(*(float("-inf"),) * 3),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Split the box into two along a given axis for a given mid-point
|
||||||
|
def split(axis: Axis, mid: float) -> tuple[AABB, AABB]:
|
||||||
|
assert self.low[axis] <= mid <= self.high[axis]
|
||||||
|
return (
|
||||||
|
AABB(self.low, self.high.replace(axis, mid)),
|
||||||
|
AABB(self.low.replace(axis, mid), self.high),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Extend a box to contain a given point
|
||||||
|
def extend(self, point: Point) -> None:
|
||||||
|
low = NamedTuple(*(map(min, zip(self.low, point))))
|
||||||
|
high = NamedTuple(*(map(max, zip(self.high, point))))
|
||||||
|
return AABB(low, high)
|
||||||
|
|
||||||
|
# Return the shortest between a given point and the box
|
||||||
|
def dist_to_point(self, point: Point) -> float:
|
||||||
|
deltas = (
|
||||||
|
max(self.low[axis] - point[axis], 0, point[axis] - self.high[axis])
|
||||||
|
for axis in Axis
|
||||||
|
)
|
||||||
|
return dist(Point(0, 0, 0), Point(*deltas))
|
||||||
|
```
|
||||||
|
|
||||||
|
And do the necessary modifications to the `KdTree` to store the bounding box and
|
||||||
|
update it as we add new points.
|
||||||
|
|
||||||
|
```python
|
||||||
|
class KdTree[T]:
|
||||||
|
_root: KdNode[T]
|
||||||
|
# New field: to keep track of the tree's boundaries
|
||||||
|
_aabb: AABB
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self._root = KdNode()
|
||||||
|
# Initialize the empty tree with an empty bounding box
|
||||||
|
self._aabb = AABB.empty()
|
||||||
|
|
||||||
|
def insert(self, point: Point, val: T) -> bool:
|
||||||
|
# Extend the AABB for our k-d Tree when adding a point to it
|
||||||
|
self._aabb = self._aabb.extend(point)
|
||||||
|
return self._root.insert(point, val, Axis.X)
|
||||||
|
```
|
||||||
|
|
||||||
|
#### `MaxHeap`
|
||||||
|
|
||||||
|
Python's builtin [`heapq`][heapq] module provides the necessary functions to
|
||||||
|
create and interact with a [_Priority Queue_][priority-queue], in the form of a
|
||||||
|
[_Binary Heap_][binary-heap].
|
||||||
|
|
||||||
|
Unfortunately, Python's library maintains a _min-heap_, which keeps the minimum
|
||||||
|
element at the root. For this algorithm, we're interested in having a
|
||||||
|
_max-heap_, with the maximum at the root.
|
||||||
|
|
||||||
|
Thankfully, one can just reverse the comparison function for each element to
|
||||||
|
convert between the two. Let's write a `MaxHeap` class making use of this
|
||||||
|
library, with a `Reverse` wrapper class to reverse the order of elements
|
||||||
|
contained within it (similar to [Rust's `Reverse`][reverse]).
|
||||||
|
|
||||||
|
[binary-heap]: https://en.wikipedia.org/wiki/Binary_heap
|
||||||
|
[heapq]: https://docs.python.org/3/library/heapq.html
|
||||||
|
[priority-queue]: https://en.wikipedia.org/wiki/Priority_queue
|
||||||
|
[reverse]: https://doc.rust-lang.org/std/cmp/struct.Reverse.html
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Reverses the wrapped value's ordering
|
||||||
|
@functools.total_ordering
|
||||||
|
class Reverse[T]:
|
||||||
|
value: T
|
||||||
|
|
||||||
|
def __init__(self, value: T):
|
||||||
|
self.value = value
|
||||||
|
|
||||||
|
def __lt__(self, other: Reverse[T]) -> bool:
|
||||||
|
return self.value > other.value
|
||||||
|
|
||||||
|
def __eq__(self, other: Reverse[T]) -> bool:
|
||||||
|
return self.value == other.value
|
||||||
|
|
||||||
|
class MaxHeap[T]:
|
||||||
|
_heap: list[Reverse[T]]
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self._heap = []
|
||||||
|
|
||||||
|
def __len__(self) -> int:
|
||||||
|
return len(self._heap)
|
||||||
|
|
||||||
|
def __iter__(self) -> Iterator[T]:
|
||||||
|
yield from (item.value for item in self._heap)
|
||||||
|
|
||||||
|
# Push a value on the heap
|
||||||
|
def push(self, value: T) -> None:
|
||||||
|
heapq.heappush(self._heap, Reverse(value))
|
||||||
|
|
||||||
|
# Peek at the current maximum value
|
||||||
|
def peek(self) -> T:
|
||||||
|
return self._heap[0].value
|
||||||
|
|
||||||
|
# Pop and return the highest value
|
||||||
|
def pop(self) -> T:
|
||||||
|
return heapq.heappop(self._heap).value
|
||||||
|
|
||||||
|
# Pushes a value onto the heap, pops and returns the highest value
|
||||||
|
def pushpop(self, value: T) -> None:
|
||||||
|
return heapq.heappushpop(self._heap, Reverse(value)).value
|
||||||
|
```
|
||||||
|
|
||||||
|
#### The actual Implementation
|
||||||
|
|
||||||
|
Now that we have written the necessary building blocks, let's tackle the
|
||||||
|
Implementation of `closest` for our _k-d Tree_.
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Wrapper type for closest points, ordered by `distance`
|
||||||
|
@dataclasses.dataclass(order=True)
|
||||||
|
class ClosestPoint[T](NamedTuple):
|
||||||
|
point: Point = field(compare=False)
|
||||||
|
value: T = field(compare=False)
|
||||||
|
distance: float
|
||||||
|
|
||||||
|
class KdTree[T]:
|
||||||
|
def closest(self, point: Point, n: int = 1) -> list[ClosestPoint[T]]:
|
||||||
|
assert n > 0
|
||||||
|
# Create the output heap
|
||||||
|
res = MaxHeap()
|
||||||
|
# Recurse onto the root node
|
||||||
|
self._root.closest(point, res, n, self._aabb)
|
||||||
|
# Return the resulting list, from closest to farthest
|
||||||
|
return sorted(res)
|
||||||
|
|
||||||
|
class KdNode[T]:
|
||||||
|
def closest(
|
||||||
|
self,
|
||||||
|
point: Point,
|
||||||
|
out: MaxHeap[ClosestPoint[T]],
|
||||||
|
n: int,
|
||||||
|
bounds: AABB,
|
||||||
|
) -> None:
|
||||||
|
# Forward to the wrapped node
|
||||||
|
self.inner.closest(point, out, n, bounds)
|
||||||
|
|
||||||
|
class KdLeafNode[T]:
|
||||||
|
def closest(
|
||||||
|
self,
|
||||||
|
point: Point,
|
||||||
|
out: MaxHeap[ClosestPoint[T]],
|
||||||
|
n: int,
|
||||||
|
bounds: AABB,
|
||||||
|
) -> None:
|
||||||
|
# At the leaf, simply iterate over all points and add them to the heap
|
||||||
|
for p, val in self.points.items():
|
||||||
|
item = ClosestPoint(p, val, dist(p, point))
|
||||||
|
if len(out) < n:
|
||||||
|
# If the heap isn't full, just push
|
||||||
|
out.push(item)
|
||||||
|
elif out.peek().distance > item.distance:
|
||||||
|
# Otherwise, push and pop to keep the heap at `n` elements
|
||||||
|
out.pushpop(item)
|
||||||
|
|
||||||
|
class KdSplitNode[T]:
|
||||||
|
def closest(
|
||||||
|
self,
|
||||||
|
point: Point,
|
||||||
|
out: list[ClosestPoint[T]],
|
||||||
|
n: int,
|
||||||
|
bounds: AABB,
|
||||||
|
) -> None:
|
||||||
|
index = self._index(point)
|
||||||
|
children_bounds = bounds.split(self.axis, self.mid)
|
||||||
|
# Iterate over the child which contains the point, then its neighbour
|
||||||
|
for i in (index, 1 - index):
|
||||||
|
child, bounds = self.children[i], children_bounds[i]
|
||||||
|
# `min_dist` is 0 for the first child, and the minimum distance of
|
||||||
|
# all points contained in the second child
|
||||||
|
min_dist = bounds.dist_to_point(point)
|
||||||
|
# If the heap is at capacity and the child to inspect too far, stop
|
||||||
|
if len(out) == n and min_dist > out.peek().distance:
|
||||||
|
return
|
||||||
|
# Otherwise, recurse
|
||||||
|
child.closest(point, out, n, bounds)
|
||||||
|
```
|
||||||
|
|
Loading…
Reference in a new issue