posts: kd-tree: add insertion

This commit is contained in:
Bruno BELANYI 2024-08-10 16:47:11 +01:00
parent b821a20ba1
commit 95658484f3

View file

@ -124,3 +124,80 @@ class KdTree[T]:
# Tree starts out empty # Tree starts out empty
self._root = KdNode() self._root = KdNode()
``` ```
### Inserting a point
To add a point to the tree, we simply recurse from node to node, similar to a
_BST_'s insertion algorithm. Once we've found the correct leaf node to insert
our point into, we simply do so.
If that leaf node goes over the maximum number of points it can store, we must
then split it along an axis, cycling between `X`, `Y`, and `Z` at each level of
the tree (i.e: splitting along the `X` axis on the first level, then `Y` on the
second, then `Z` after that, and then `X`, etc...).
```python
# How many points should be stored in a leaf node before being split
MAX_CAPACITY = 32
def median(values: Iterable[float]) -> float:
sorted_values = sorted(values)
mid_point = len(sorted_values) // 2
if len(sorted_values) % 2 == 1:
return sorted_values[mid_point]
a, b = sorted_values[mid_point], sorted_values[mid_point + 1]
return a + (b - a) / 2
def partition[T](
pred: Callable[[T], bool],
iterable: Iterable[T]
) -> tuple[list[T], list[T]]:
truths, falses = [], []
for v in iterable:
(truths if pred(v) else falses).append(v)
return truths, falses
def split_leaf[T](node: KdLeafNode[T], axis: Axis) -> KdSplitNode[T]:
# Find the median value for the given axis
mid = median(p[axis] for p in node.points)
# Split into left/right children according to the mid-point and axis
left, right = partition(lambda kv: kv[0][axis] <= mid, node.points.items())
return KdSplitNode(
split_axis,
mid,
(KdLeafNode.from_items(left), KdLeafNode.from_items(right)),
)
class KdTree[T]:
def insert(self, point: Point, val: T) -> bool:
# Forward to the root node, choose `X` as the first split axis
return self._root.insert(point, val, Axis.X)
class KdLeafNode[T]:
def insert(self, point: Point, val: T, split_axis: Axis) -> bool:
# Check whether we're overwriting a previous value
was_mapped = point in self.points
# Store the corresponding value
self.points[point] = val
# Return whether we've performed an overwrite
return was_mapped
class KdSplitNode[T]:
def insert(self, point: Point, val: T, split_axis: Axis) -> bool:
# Find the child which contains the point
child = self.children[self._index(point)]
# Recurse into it, choosing the next split axis
return child.insert(point, val, split_axis.next())
class KdNode[T]:
def insert(self, point: Point, val: T, split_axis: Axis) -> bool:
# Add the point to the wrapped node...
res = self.inner.insert(point, val, split_axis)
# ... And take care of splitting leaf nodes when necessary
if (
isinstance(self.inner, KdLeafNode)
and len(self.inner.points) > MAX_CAPACITY
):
self.inner = split_leaf(self.inner, split_axis)
return res
```