advent-of-code/2018/d15/ex1/ex1.py

193 lines
5.4 KiB
Python
Raw Normal View History

2024-12-30 22:22:52 +01:00
#!/usr/bin/env python
import dataclasses
import enum
import sys
from typing import Iterator, NamedTuple
class Point(NamedTuple):
x: int
y: int
# Returned in reading order
def neighbours(self) -> Iterator["Point"]:
for dx, dy in (
(-1, 0),
(0, -1),
(0, 1),
(1, 0),
):
yield Point(self.x + dx, self.y + dy)
class Unit(enum.StrEnum):
ELF = "E"
GOBLIN = "G"
def ennemy(self) -> "Unit":
if self == Unit.ELF:
return Unit.GOBLIN
if self == Unit.GOBLIN:
return Unit.ELF
assert False # Sanity check
@dataclasses.dataclass
class UnitData:
hp: int = 200
power: int = 3
def solve(input: str) -> int:
def parse(input: list[str]) -> tuple[set[Point], dict[Unit, set[Point]]]:
walls: set[Point] = set()
units: dict[Unit, set[Point]] = {u: set() for u in Unit}
for x, line in enumerate(input):
for y, c in enumerate(line):
p = Point(x, y)
if c in Unit:
units[Unit(c)].add(p)
if c == "#":
walls.add(p)
return walls, units
def double_bfs(
unit_type: Unit,
unit_pos: Point,
walls: set[Point],
units: dict[Unit, set[Point]],
) -> Point | None:
def bfs(
start: Point,
targets: set[Point],
blockers: set[Point],
) -> Point | None:
frontier = [start]
seen: set[Point] = set()
while frontier:
new_frontier: set[Point] = set()
for p in frontier:
if p in targets:
return p
seen.add(p)
for n in p.neighbours():
if n in seen:
continue
if n in blockers:
continue
new_frontier.add(n)
frontier = sorted(new_frontier)
return None
blockers = walls | units[unit_type]
ennemies = units[unit_type.ennemy()]
# First BFS from start to square next to an ennemy
targets_in_range = {
n for ennemy in ennemies for n in ennemy.neighbours() if n not in blockers
}
if (target := bfs(unit_pos, targets_in_range, blockers)) is None:
return None
# Then back from chosen target to one of the movement squares
movement_squares = {n for n in unit_pos.neighbours() if n not in blockers}
return bfs(target, movement_squares, blockers)
def do_move(
unit_type: Unit,
unit_pos: Point,
walls: set[Point],
units: dict[Unit, set[Point]],
unit_data: dict[Point, UnitData],
) -> Point:
# If already next to an ennemy, do not move
if any(n in units[unit_type.ennemy()] for n in unit_pos.neighbours()):
return unit_pos
new_pos = double_bfs(unit_type, unit_pos, walls, units)
# Nowhere to move to, no-op
if new_pos is None:
return unit_pos
assert new_pos != unit_pos # Sanity check
assert unit_pos in units[unit_type] # Sanity check
assert new_pos not in units[unit_type] # Sanity check
# Make the movement in-place
units[unit_type] ^= {unit_pos, new_pos}
unit_data[new_pos] = unit_data.pop(unit_pos)
return new_pos
def do_attack(
unit_type: Unit,
unit_pos: Point,
units: dict[Unit, set[Point]],
unit_data: dict[Point, UnitData],
) -> None:
# Look for an attack target
target = min(
(n for n in unit_pos.neighbours() if n in units[unit_type.ennemy()]),
key=lambda p: unit_data[p].hp,
default=None,
)
# If not in range, no-op
if target is None:
return
assert target not in units[unit_type] # Sanity check
assert target in units[unit_type.ennemy()] # Sanity check
assert unit_data[target].hp > 0 # Sanity check
# Make the attack in-place
unit_data[target].hp -= unit_data[unit_pos].power
# And if we killed it, remove it from `units`
if unit_data[target].hp <= 0:
units[unit_type.ennemy()].remove(target)
unit_data.pop(target)
def turn(
walls: set[Point],
units: dict[Unit, set[Point]],
unit_data: dict[Point, UnitData],
) -> bool:
turn_order = sorted((p, u) for u, points in units.items() for p in points)
for p, u in turn_order:
# Don't do anything if the unit is dead
if p not in units[u]:
continue
# If no ennemies left, finish the turn early and indicate that we're done
if not units[u.ennemy()]:
return False
# Movements and attacks are made in-place
p = do_move(u, p, walls, units, unit_data)
do_attack(u, p, units, unit_data)
return True
walls, units = parse(input.splitlines())
unit_data = {p: UnitData() for points in units.values() for p in points}
turns = 0
while turn(walls, units, unit_data):
turns += 1
return turns * sum(data.hp for data in unit_data.values())
def main() -> None:
input = sys.stdin.read()
print(solve(input))
if __name__ == "__main__":
main()