diff --git a/2023/d22/ex2/ex2.py b/2023/d22/ex2/ex2.py new file mode 100755 index 0000000..c9ce5d4 --- /dev/null +++ b/2023/d22/ex2/ex2.py @@ -0,0 +1,157 @@ +#!/usr/bin/env python + +import dataclasses +import sys +from collections import defaultdict +from collections.abc import Iterator +from typing import NamedTuple + + +def sign(x: int) -> int: + if x == 0: + return 0 + return 1 if x > 0 else -1 + + +class Point(NamedTuple): + x: int + y: int + z: int + + def fall(self, delta: int = 0) -> "Point": + assert delta <= self.z # Sanity check + return self._replace(z=self.z - delta) + + +@dataclasses.dataclass +class Brick: + top_left: Point + bot_right: Point + + def __post_init__(self) -> None: + assert self.top_left.z >= self.bot_right.z # Sanity check + + def orientation(self) -> Point: + return Point( + sign(self.bot_right.x - self.top_left.x), + sign(self.bot_right.y - self.top_left.y), + sign(self.bot_right.z - self.top_left.z), + ) + + def blocks(self) -> Iterator[Point]: + p = self.top_left + dx, dy, dz = self.orientation() + while p != self.bot_right: + yield p + p = Point(p.x + dx, p.y + dy, p.z + dz) + yield self.bot_right + + def fall(self, delta: int = 0) -> "Brick": + assert delta >= 0 # Sanity check + return Brick(self.top_left.fall(delta), self.bot_right.fall(delta)) + + +class TowerMap(NamedTuple): + supports: dict[int, set[int]] + supported_by: dict[int, set[int]] + num_bricks: int + + @classmethod + def compute_support(cls, tower: dict[Point, int]) -> "TowerMap": + supports: dict[int, set[int]] = defaultdict(set) + supported_by: dict[int, set[int]] = defaultdict(set) + + for p, i in tower.items(): + under = p.fall(1) + support = tower.get(under) + # No supporting brick + if support is None: + continue + # Don't count the brick as supporting itself + if support == i: + continue + supports[support].add(i) + supported_by[i].add(support) + + return cls( + supports=dict(supports), + supported_by=dict(supported_by), + num_bricks=max(supports.keys() | supported_by.keys()) + 1, + ) + + def roots(self) -> set[int]: + return {p for p in range(self.num_bricks) if p not in self.supported_by} + + # From bottom to top of tower + def topo_sort(self) -> list[int]: + res: list[int] = [] + nodes = self.roots() + seen: set[int] = set() + + while nodes: + node = nodes.pop() + res.append(node) + seen.add(node) + for child in self.supports.get(node, set()): + if len(self.supported_by[child] - seen) == 0: + nodes.add(child) + + assert set(res) == set(range(self.num_bricks)) # Sanity check + # NOTE: from construction, the topo_sort is just list(range(self.num_bricks)) + # But I'd rather do the actual algorithm for completeness + return res + + +def solve(input: list[str]) -> int: + def parse_brick(line: str) -> Brick: + a, b = (Point._make(map(int, p.split(","))) for p in line.split("~")) + if a < b: + a, b = b, a + return Brick(a, b) + + # Returns which point in space belongs to which brick index + def drop(snapshots: list[Brick]) -> dict[Point, int]: + # Re-order by lowest height + snapshots = sorted(snapshots, key=lambda b: b.bot_right.z) + # By default the ground is at 0, index with Point(p.x, p.y, 0) + heights: dict[Point, int] = defaultdict(int) + res: dict[Point, int] = {} + + for i, brick in enumerate(snapshots): + z = max(heights[p.fall(p.z)] for p in brick.blocks()) + 1 + assert brick.bot_right.z >= z # Sanity check + delta = brick.bot_right.z - z # Drop it to the top of the pile + brick = brick.fall(delta) + # Record the height of the brick for every block composing it + for p in brick.blocks(): + res[p] = i + heights[p.fall(p.z)] = brick.top_left.z + + return res + + def disintegrate(tower_map: TowerMap, brick: int) -> int: + fallen = {brick} + + for b in tower_map.topo_sort(): + parents = tower_map.supported_by.get(b, set()) + # Bricks on the floor shouldn't fall + if len(parents) == 0: + continue + if all(parent in fallen for parent in parents): + fallen.add(b) + + return len(fallen) - 1 # Don't count the disintegrated brick + + snapshots = [parse_brick(line) for line in input] + tower = drop(snapshots) + tower_map = TowerMap.compute_support(tower) + return sum(disintegrate(tower_map, i) for i in range(tower_map.num_bricks)) + + +def main() -> None: + input = sys.stdin.read().splitlines() + print(solve(input)) + + +if __name__ == "__main__": + main()