advent-of-code/2024/d20/ex1/ex1.py

108 lines
2.6 KiB
Python
Executable file

#!/usr/bin/env python
import itertools
import sys
from collections.abc import Iterator
from typing import NamedTuple
class Point(NamedTuple):
x: int
y: int
def neighbours(self) -> Iterator["Point"]:
for dx, dy in (
(-1, 0),
(1, 0),
(0, -1),
(0, 1),
):
yield Point(self.x + dx, self.y + dy)
class ParsedMap(NamedTuple):
start: Point
end: Point
tracks: set[Point]
MIN_SAVE = 100
def solve(input: str) -> int:
def parse(input: list[str]) -> ParsedMap:
start: Point | None = None
end: Point | None = None
tracks: set[Point] = set()
for x, line in enumerate(input):
for y, c in enumerate(line):
if c == "#":
continue
p = Point(x, y)
if c == "S":
start = p
elif c == "E":
end = p
tracks.add(p)
assert start is not None and end is not None # Sanity check
return ParsedMap(start, end, tracks)
def flood_distance(start: Point, points: set[Point]) -> dict[Point, int]:
res = {start: 0}
queue = {start}
while queue:
p = queue.pop()
dist = res[p]
for n in p.neighbours():
if n in res:
continue
if n not in points:
continue
res[n] = dist + 1
queue.add(n)
return res
def dist(a: Point, b: Point) -> int:
return abs(a.x - b.x) + abs(a.y - b.y)
def disk(p: Point, radius: int) -> Iterator[Point]:
for dx, dy in itertools.product(range(-radius, radius + 1), repeat=2):
n = Point(p.x + dx, p.y + dy)
if dist(p, n) > radius:
continue
yield n
def find_cheats(start: Point, end: Point, tracks: set[Point]) -> int:
start_dist = flood_distance(start, tracks)
end_dist = flood_distance(end, tracks)
assert start_dist[end] == end_dist[start]
fastest = start_dist[end]
res = 0
for a in tracks:
for b in disk(a, 2):
if b not in tracks:
continue
time = start_dist[a] + dist(a, b) + end_dist[b]
if (fastest - time) < MIN_SAVE:
continue
res += 1
return res
start, end, tracks = parse(input.splitlines())
return find_cheats(start, end, tracks)
def main() -> None:
input = sys.stdin.read()
print(solve(input))
if __name__ == "__main__":
main()