posts: kd-tree-revisited: add implementation
This commit is contained in:
parent
f35bad4c89
commit
2c1cd7df1c
|
@ -38,3 +38,75 @@ us a minimal bound on the distance of the points stored on the other side.
|
||||||
This can be easily computed from the `axis` and `mid` values which are stored in
|
This can be easily computed from the `axis` and `mid` values which are stored in
|
||||||
the inner nodes: to project the node on the plane we simply replace its
|
the inner nodes: to project the node on the plane we simply replace its
|
||||||
coordinate for this axis by `mid`.
|
coordinate for this axis by `mid`.
|
||||||
|
|
||||||
|
## Simplified search
|
||||||
|
|
||||||
|
With that out of the way, let's now see how `closest` can be implemented without
|
||||||
|
needing to track the tree's `AABB` at the root:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Wrapper type for closest points, ordered by `distance`
|
||||||
|
@dataclasses.dataclass(order=True)
|
||||||
|
class ClosestPoint[T](NamedTuple):
|
||||||
|
point: Point = field(compare=False)
|
||||||
|
value: T = field(compare=False)
|
||||||
|
distance: float
|
||||||
|
|
||||||
|
class KdTree[T]:
|
||||||
|
def closest(self, point: Point, n: int = 1) -> list[ClosestPoint[T]]:
|
||||||
|
assert n > 0
|
||||||
|
res = MaxHeap()
|
||||||
|
# Instead of passing an `AABB`, we give an initial projection point,
|
||||||
|
# the query origin itself (since we haven't visited any split node yet)
|
||||||
|
self._root.closest(point, res, n, point)
|
||||||
|
return sorted(res)
|
||||||
|
|
||||||
|
class KdNode[T]:
|
||||||
|
def closest(
|
||||||
|
self,
|
||||||
|
point: Point,
|
||||||
|
out: MaxHeap[ClosestPoint[T]],
|
||||||
|
n: int,
|
||||||
|
projection: Point,
|
||||||
|
) -> None:
|
||||||
|
# Same implementation
|
||||||
|
self.inner.closest(point, out, n, bounds)
|
||||||
|
|
||||||
|
class KdLeafNode[T]:
|
||||||
|
def closest(
|
||||||
|
self,
|
||||||
|
point: Point,
|
||||||
|
out: MaxHeap[ClosestPoint[T]],
|
||||||
|
n: int,
|
||||||
|
projection: Point,
|
||||||
|
) -> None:
|
||||||
|
# Same implementation
|
||||||
|
for p, val in self.points.items():
|
||||||
|
item = ClosestPoint(p, val, dist(p, point))
|
||||||
|
if len(out) < n:
|
||||||
|
out.push(item)
|
||||||
|
elif out.peek().distance > item.distance:
|
||||||
|
out.pushpop(item)
|
||||||
|
|
||||||
|
class KdSplitNode[T]:
|
||||||
|
def closest(
|
||||||
|
self,
|
||||||
|
point: Point,
|
||||||
|
out: list[ClosestPoint[T]],
|
||||||
|
n: int,
|
||||||
|
projection: Point,
|
||||||
|
) -> None:
|
||||||
|
index = self._index(point)
|
||||||
|
self.children[index].closest(point, out, n, projection)
|
||||||
|
# Project onto the splitting plane, for a minimum distance to its points
|
||||||
|
projection = projection.replace(self.axis, self.mid)
|
||||||
|
# If we're at capacity and can't possibly find any closer points, exit
|
||||||
|
if len(out) == n and dist(point, projection) > out.peek().distance:
|
||||||
|
return
|
||||||
|
# Otherwise recurse on the other side to check for nearer neighbours
|
||||||
|
self.children[1 - index].closest(point, out, n, projection)
|
||||||
|
```
|
||||||
|
|
||||||
|
As you can see, the main difference is in `KdSplitNode`'s implementation, where
|
||||||
|
we can quickly compute the minimum distance between the search's origin and all
|
||||||
|
potential points in that subspace.
|
||||||
|
|
Loading…
Reference in a new issue