From 95658484f39c62dd03488918968780292c921f8d Mon Sep 17 00:00:00 2001 From: Bruno BELANYI Date: Sat, 10 Aug 2024 16:47:11 +0100 Subject: [PATCH] posts: kd-tree: add insertion --- content/posts/2024-08-10-kd-tree/index.md | 77 +++++++++++++++++++++++ 1 file changed, 77 insertions(+) diff --git a/content/posts/2024-08-10-kd-tree/index.md b/content/posts/2024-08-10-kd-tree/index.md index 7719fdd..eb35d28 100644 --- a/content/posts/2024-08-10-kd-tree/index.md +++ b/content/posts/2024-08-10-kd-tree/index.md @@ -124,3 +124,80 @@ class KdTree[T]: # Tree starts out empty self._root = KdNode() ``` + +### Inserting a point + +To add a point to the tree, we simply recurse from node to node, similar to a +_BST_'s insertion algorithm. Once we've found the correct leaf node to insert +our point into, we simply do so. + +If that leaf node goes over the maximum number of points it can store, we must +then split it along an axis, cycling between `X`, `Y`, and `Z` at each level of +the tree (i.e: splitting along the `X` axis on the first level, then `Y` on the +second, then `Z` after that, and then `X`, etc...). + +```python +# How many points should be stored in a leaf node before being split +MAX_CAPACITY = 32 + +def median(values: Iterable[float]) -> float: + sorted_values = sorted(values) + mid_point = len(sorted_values) // 2 + if len(sorted_values) % 2 == 1: + return sorted_values[mid_point] + a, b = sorted_values[mid_point], sorted_values[mid_point + 1] + return a + (b - a) / 2 + +def partition[T]( + pred: Callable[[T], bool], + iterable: Iterable[T] +) -> tuple[list[T], list[T]]: + truths, falses = [], [] + for v in iterable: + (truths if pred(v) else falses).append(v) + return truths, falses + +def split_leaf[T](node: KdLeafNode[T], axis: Axis) -> KdSplitNode[T]: + # Find the median value for the given axis + mid = median(p[axis] for p in node.points) + # Split into left/right children according to the mid-point and axis + left, right = partition(lambda kv: kv[0][axis] <= mid, node.points.items()) + return KdSplitNode( + split_axis, + mid, + (KdLeafNode.from_items(left), KdLeafNode.from_items(right)), + ) + +class KdTree[T]: + def insert(self, point: Point, val: T) -> bool: + # Forward to the root node, choose `X` as the first split axis + return self._root.insert(point, val, Axis.X) + +class KdLeafNode[T]: + def insert(self, point: Point, val: T, split_axis: Axis) -> bool: + # Check whether we're overwriting a previous value + was_mapped = point in self.points + # Store the corresponding value + self.points[point] = val + # Return whether we've performed an overwrite + return was_mapped + +class KdSplitNode[T]: + def insert(self, point: Point, val: T, split_axis: Axis) -> bool: + # Find the child which contains the point + child = self.children[self._index(point)] + # Recurse into it, choosing the next split axis + return child.insert(point, val, split_axis.next()) + +class KdNode[T]: + def insert(self, point: Point, val: T, split_axis: Axis) -> bool: + # Add the point to the wrapped node... + res = self.inner.insert(point, val, split_axis) + # ... And take care of splitting leaf nodes when necessary + if ( + isinstance(self.inner, KdLeafNode) + and len(self.inner.points) > MAX_CAPACITY + ): + self.inner = split_leaf(self.inner, split_axis) + return res +```