advent-of-code/2021/d18/ex2/ex2.py

212 lines
7 KiB
Python
Raw Normal View History

2021-12-18 16:05:53 +01:00
#!/usr/bin/env python
import functools
import itertools
import sys
from copy import deepcopy
from dataclasses import dataclass
from typing import Iterator, List, Optional, Tuple
@dataclass
class Tree:
parent: Optional["Pair"]
@dataclass
class Pair(Tree):
left: Tree
right: Tree
@dataclass
class Num(Tree):
val: int
# True means left, False means right
Path = List[bool]
def solve(input: List[str]) -> int:
def make_pair(left: Tree, right: Tree, parent: Optional[Pair] = None) -> Pair:
pair = Pair(parent=parent, left=left, right=right)
pair.left.parent = pair
pair.right.parent = pair
return pair
def make_num(val: int, parent: Optional[Pair] = None) -> Num:
return Num(parent=parent, val=val)
# FIXME: remove this
def debug(tree: Tree) -> str:
if isinstance(tree, Pair):
return f"[{debug(tree.left)},{debug(tree.right)}]"
assert isinstance(tree, Num)
return str(tree.val)
def parse() -> List[Tree]:
def parse_snailfish_number(line: str) -> Tree:
def parse_index(input: str, index: int = 0) -> Tuple[int, Tree]:
if input[index] == "[":
left_index, left = parse_index(input, index + 1)
assert input[left_index] == "," # Sanity check
right_index, right = parse_index(input, left_index + 1)
assert input[right_index] == "]" # Sanity check
return right_index + 1, make_pair(left, right)
res = 0
for i in itertools.count(index):
if i < len(input) and input[i] in set(str(i) for i in range(10)):
res = res * 10 + int(input[i])
else:
return i, make_num(res)
assert False # Sanity check
__, res = parse_index(line)
return res
return [parse_snailfish_number(line) for line in input]
def explosion_path(number: Tree) -> Optional[Path]:
def dfs(number: Pair, path: Path = []) -> Optional[Path]:
if (
len(path) >= 4
and isinstance(number.left, Num)
and isinstance(number.right, Num)
):
return path
if isinstance(number.left, Pair):
left_path = dfs(number.left, path + [True])
if left_path is not None:
return left_path
if isinstance(number.right, Pair):
right_path = dfs(number.right, path + [False])
if right_path is not None:
return right_path
return None
assert isinstance(number, Pair) # Sanity check
return dfs(number)
def split_path(number: Tree) -> Optional[Path]:
def dfs(number: Pair, path: Path = []) -> Optional[Path]:
if isinstance(number.left, Num):
if number.left.val >= 10:
return path + [True]
else:
assert isinstance(number.left, Pair) # Sanity check
if (left_path := dfs(number.left, path + [True])) is not None:
return left_path
if isinstance(number.right, Num):
if number.right.val >= 10:
return path + [False]
else:
assert isinstance(number.right, Pair) # Sanity check
if (right_path := dfs(number.right, path + [False])) is not None:
return right_path
return None
assert isinstance(number, Pair) # Sanity check
return dfs(number)
def explode(number: Tree, path: Path) -> Tree:
def walk(number: Tree, reverse: bool) -> Iterator[Tree]:
if isinstance(number, Num):
yield number
else:
assert isinstance(number, Pair) # Sanity check
first, second = (
(number.right, number.left)
if reverse
else (number.left, number.right)
)
yield from walk(first, reverse)
yield number
yield from walk(second, reverse)
def next_num(number: Tree, reverse: bool) -> Optional[Num]:
if number.parent is None:
return None
sibling = number.parent.left if reverse else number.parent.right
if number is sibling:
return next_num(number.parent, reverse)
for node in walk(sibling, reverse=reverse):
if isinstance(node, Num):
return node
return None
assert isinstance(number, Pair) # Sanity check
if len(path) == 0:
p, n = next_num(number, reverse=True), next_num(number, reverse=False)
if p is not None:
assert isinstance(number.left, Num) # Safety check
p.val += number.left.val
if n is not None:
assert isinstance(number.right, Num) # Safety check
n.val += number.right.val
return make_num(0)
parent, left, right = number.parent, number.left, number.right
if path[0]:
left = explode(number.left, path[1:])
else:
right = explode(number.right, path[1:])
return make_pair(parent=parent, left=left, right=right)
def split(number: Tree, path: Path) -> Tree:
def split_int(num: int, parent: Optional[Pair]) -> Tree:
assert num >= 0 # Sanity check
left = num // 2
right = num - left
res = make_pair(left=make_num(left), right=make_num(right), parent=parent)
return res
if len(path) == 0:
assert isinstance(number, Num) # Sanity check
return split_int(number.val, number.parent)
assert isinstance(number, Pair) # Sanity check
parent, left, right = number.parent, number.left, number.right
if path[0]:
left = split(number.left, path[1:])
else:
right = split(number.right, path[1:])
return make_pair(parent=parent, left=left, right=right)
def reduce(number: Tree) -> Tree:
path = explosion_path(number)
if path is not None:
return reduce(explode(number, path))
path = split_path(number)
if path is not None:
return reduce(split(number, path))
return number
def add(left: Tree, right: Tree) -> Tree:
return reduce(make_pair(left=left, right=right))
def magnitude(number: Tree) -> int:
if isinstance(number, Num):
return number.val
assert isinstance(number, Pair) # Safety check
return 3 * magnitude(number.left) + 2 * magnitude(number.right)
numbers = parse()
return max(
magnitude(add(deepcopy(l), deepcopy(r)))
for l, r in itertools.permutations(numbers, 2)
)
def main() -> None:
input = [line.strip() for line in sys.stdin.readlines()]
print(solve(input))
if __name__ == "__main__":
main()