advent-of-code/2024/d21/ex1/ex1.py

109 lines
2.7 KiB
Python
Raw Normal View History

2024-12-21 18:49:56 +01:00
#!/usr/bin/env python
import enum
import functools
import itertools
import sys
from typing import Literal, NamedTuple
class Point(NamedTuple):
x: int
y: int
class Direction(enum.StrEnum):
UP = "^"
DOWN = "v"
LEFT = "<"
RIGHT = ">"
Instruction = Direction | Literal["A"]
PAD = {
"7": Point(-3, -2),
"8": Point(-3, -1),
"9": Point(-3, 0),
"4": Point(-2, -2),
"5": Point(-2, -1),
"6": Point(-2, 0),
"1": Point(-1, -2),
"2": Point(-1, -1),
"3": Point(-1, 0),
"0": Point(0, -1),
"A": Point(0, 0),
}
ARROWS = {
"^": Point(0, -1),
"A": Point(0, 0),
"<": Point(1, -2),
"v": Point(1, -1),
">": Point(1, 0),
}
# Needs to be hash-able
InstructionList = tuple[Instruction, ...]
def solve(input: str) -> int:
def button_paths(
start_button: str,
end_button: str,
buttons: dict[str, Point],
) -> set[InstructionList]:
start, end = buttons[start_button], buttons[end_button]
sequences: set[InstructionList] = set()
dx, dy = (end.x - start.x), (end.y - start.y)
a_button: list[Instruction] = ["A"] # Work around MyPy limitation
move_x = [Direction.UP if dx < 0 else Direction.DOWN] * abs(dx)
move_y = [Direction.LEFT if dy < 0 else Direction.RIGHT] * abs(dy)
# Avoid moving over the gap
if Point(end.x, start.y) in buttons.values():
sequences.add(tuple(a_button + move_x + move_y + a_button))
if Point(start.x, end.y) in buttons.values():
sequences.add(tuple(a_button + move_y + move_x + a_button))
return sequences
@functools.cache
def path_cost(path: InstructionList, depth: int) -> int:
# Have we reached the actual keypad robot
if depth == 0:
# We start on 'A' so don't count it
return len(path) - 1
# Otherwise, intermediate robot must use arrow pad
cost = sum(
min(path_cost(path, depth - 1) for path in button_paths(start, end, ARROWS))
for start, end in itertools.pairwise(path)
)
return cost
def code_cost(code: str, depth: int) -> int:
# We start on 'A'
code = "A" + code
cost = sum(
min(path_cost(path, depth) for path in button_paths(start, end, PAD))
for start, end in itertools.pairwise(code)
)
return cost
def complexity(code: str, seq_len: int) -> int:
return int(code.replace("A", "")) * seq_len
codes = input.splitlines()
return sum(complexity(code, code_cost(code, 2)) for code in codes)
def main() -> None:
input = sys.stdin.read()
print(solve(input))
if __name__ == "__main__":
main()