posts: kd-tree: add insertion
This commit is contained in:
parent
e3a3930ff3
commit
a12642a9bd
|
@ -124,3 +124,80 @@ class KdTree[T]:
|
||||||
# Tree starts out empty
|
# Tree starts out empty
|
||||||
self._root = KdNode()
|
self._root = KdNode()
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### Inserting a point
|
||||||
|
|
||||||
|
To add a point to the tree, we simply recurse from node to node, similar to a
|
||||||
|
_BST_'s insertion algorithm. Once we've found the correct leaf node to insert
|
||||||
|
our point into, we simply do so.
|
||||||
|
|
||||||
|
If that leaf node goes over the maximum number of points it can store, we must
|
||||||
|
then split it along an axis, cycling between `X`, `Y`, and `Z` at each level of
|
||||||
|
the tree (i.e: splitting along the `X` axis on the first level, then `Y` on the
|
||||||
|
second, then `Z` after that, and then `X`, etc...).
|
||||||
|
|
||||||
|
```python
|
||||||
|
# How many points should be stored in a leaf node before being split
|
||||||
|
MAX_CAPACITY = 32
|
||||||
|
|
||||||
|
def median(values: Iterable[float]) -> float:
|
||||||
|
sorted_values = sorted(values)
|
||||||
|
mid_point = len(sorted_values) // 2
|
||||||
|
if len(sorted_values) % 2 == 1:
|
||||||
|
return sorted_values[mid_point]
|
||||||
|
a, b = sorted_values[mid_point], sorted_values[mid_point + 1]
|
||||||
|
return a + (b - a) / 2
|
||||||
|
|
||||||
|
def partition[T](
|
||||||
|
pred: Callable[[T], bool],
|
||||||
|
iterable: Iterable[T]
|
||||||
|
) -> tuple[list[T], list[T]]:
|
||||||
|
truths, falses = [], []
|
||||||
|
for v in iterable:
|
||||||
|
(truths if pred(v) else falses).append(v)
|
||||||
|
return truths, falses
|
||||||
|
|
||||||
|
def split_leaf[T](node: KdLeafNode[T], axis: Axis) -> KdSplitNode[T]:
|
||||||
|
# Find the median value for the given axis
|
||||||
|
mid = median(p[axis] for p in node.points)
|
||||||
|
# Split into left/right children according to the mid-point and axis
|
||||||
|
left, right = partition(lambda kv: kv[0][axis] <= mid, node.points.items())
|
||||||
|
return KdSplitNode(
|
||||||
|
split_axis,
|
||||||
|
mid,
|
||||||
|
(KdNode.from_items(left), KdNode.from_items(right)),
|
||||||
|
)
|
||||||
|
|
||||||
|
class KdTree[T]:
|
||||||
|
def insert(self, point: Point, val: T) -> bool:
|
||||||
|
# Forward to the root node, choose `X` as the first split axis
|
||||||
|
return self._root.insert(point, val, Axis.X)
|
||||||
|
|
||||||
|
class KdLeafNode[T]:
|
||||||
|
def insert(self, point: Point, val: T, split_axis: Axis) -> bool:
|
||||||
|
# Check whether we're overwriting a previous value
|
||||||
|
was_mapped = point in self.points
|
||||||
|
# Store the corresponding value
|
||||||
|
self.points[point] = val
|
||||||
|
# Return whether we've performed an overwrite
|
||||||
|
return was_mapped
|
||||||
|
|
||||||
|
class KdSplitNode[T]:
|
||||||
|
def insert(self, point: Point, val: T, split_axis: Axis) -> bool:
|
||||||
|
# Find the child which contains the point
|
||||||
|
child = self.children[self._index(point)]
|
||||||
|
# Recurse into it, choosing the next split axis
|
||||||
|
return child.insert(point, val, split_axis.next())
|
||||||
|
|
||||||
|
class KdNode[T]:
|
||||||
|
def insert(self, point: Point, val: T, split_axis: Axis) -> bool:
|
||||||
|
# Add the point to the wrapped node...
|
||||||
|
res = self.inner.insert(point, val, split_axis)
|
||||||
|
# ... And take care of splitting leaf nodes when necessary
|
||||||
|
if (
|
||||||
|
isinstance(self.inner, KdLeafNode)
|
||||||
|
and len(self.inner.points) > MAX_CAPACITY
|
||||||
|
):
|
||||||
|
self.inner = split_leaf(self.inner, split_axis)
|
||||||
|
return res
|
||||||
|
```
|
||||||
|
|
Loading…
Reference in a new issue