posts: kd-tree: add insertion
This commit is contained in:
parent
b821a20ba1
commit
95658484f3
|
@ -124,3 +124,80 @@ class KdTree[T]:
|
|||
# Tree starts out empty
|
||||
self._root = KdNode()
|
||||
```
|
||||
|
||||
### Inserting a point
|
||||
|
||||
To add a point to the tree, we simply recurse from node to node, similar to a
|
||||
_BST_'s insertion algorithm. Once we've found the correct leaf node to insert
|
||||
our point into, we simply do so.
|
||||
|
||||
If that leaf node goes over the maximum number of points it can store, we must
|
||||
then split it along an axis, cycling between `X`, `Y`, and `Z` at each level of
|
||||
the tree (i.e: splitting along the `X` axis on the first level, then `Y` on the
|
||||
second, then `Z` after that, and then `X`, etc...).
|
||||
|
||||
```python
|
||||
# How many points should be stored in a leaf node before being split
|
||||
MAX_CAPACITY = 32
|
||||
|
||||
def median(values: Iterable[float]) -> float:
|
||||
sorted_values = sorted(values)
|
||||
mid_point = len(sorted_values) // 2
|
||||
if len(sorted_values) % 2 == 1:
|
||||
return sorted_values[mid_point]
|
||||
a, b = sorted_values[mid_point], sorted_values[mid_point + 1]
|
||||
return a + (b - a) / 2
|
||||
|
||||
def partition[T](
|
||||
pred: Callable[[T], bool],
|
||||
iterable: Iterable[T]
|
||||
) -> tuple[list[T], list[T]]:
|
||||
truths, falses = [], []
|
||||
for v in iterable:
|
||||
(truths if pred(v) else falses).append(v)
|
||||
return truths, falses
|
||||
|
||||
def split_leaf[T](node: KdLeafNode[T], axis: Axis) -> KdSplitNode[T]:
|
||||
# Find the median value for the given axis
|
||||
mid = median(p[axis] for p in node.points)
|
||||
# Split into left/right children according to the mid-point and axis
|
||||
left, right = partition(lambda kv: kv[0][axis] <= mid, node.points.items())
|
||||
return KdSplitNode(
|
||||
split_axis,
|
||||
mid,
|
||||
(KdLeafNode.from_items(left), KdLeafNode.from_items(right)),
|
||||
)
|
||||
|
||||
class KdTree[T]:
|
||||
def insert(self, point: Point, val: T) -> bool:
|
||||
# Forward to the root node, choose `X` as the first split axis
|
||||
return self._root.insert(point, val, Axis.X)
|
||||
|
||||
class KdLeafNode[T]:
|
||||
def insert(self, point: Point, val: T, split_axis: Axis) -> bool:
|
||||
# Check whether we're overwriting a previous value
|
||||
was_mapped = point in self.points
|
||||
# Store the corresponding value
|
||||
self.points[point] = val
|
||||
# Return whether we've performed an overwrite
|
||||
return was_mapped
|
||||
|
||||
class KdSplitNode[T]:
|
||||
def insert(self, point: Point, val: T, split_axis: Axis) -> bool:
|
||||
# Find the child which contains the point
|
||||
child = self.children[self._index(point)]
|
||||
# Recurse into it, choosing the next split axis
|
||||
return child.insert(point, val, split_axis.next())
|
||||
|
||||
class KdNode[T]:
|
||||
def insert(self, point: Point, val: T, split_axis: Axis) -> bool:
|
||||
# Add the point to the wrapped node...
|
||||
res = self.inner.insert(point, val, split_axis)
|
||||
# ... And take care of splitting leaf nodes when necessary
|
||||
if (
|
||||
isinstance(self.inner, KdLeafNode)
|
||||
and len(self.inner.points) > MAX_CAPACITY
|
||||
):
|
||||
self.inner = split_leaf(self.inner, split_axis)
|
||||
return res
|
||||
```
|
||||
|
|
Loading…
Reference in a new issue