advent-of-code/2024/d17/ex1/ex1.py

108 lines
2.9 KiB
Python
Raw Normal View History

2024-12-17 19:10:51 +01:00
#!/usr/bin/env python
import dataclasses
import enum
import sys
@dataclasses.dataclass
class Registers:
reg_a: int
reg_b: int
reg_c: int
class Instruction(enum.IntEnum):
ADV = 0
BXL = 1
BST = 2
JNZ = 3
BXC = 4
OUT = 5
BDV = 6
CDV = 7
@dataclasses.dataclass
class Computer:
registers: Registers
program: list[int]
ip: int = 0
def _resolve_combo_operand(self, operand: int) -> int:
assert operand != 7 # Sanity check
if 0 <= operand <= 3:
return operand
if operand == 4:
return self.registers.reg_a
if operand == 5:
return self.registers.reg_b
if operand == 6:
return self.registers.reg_c
assert False # Sanity check
# Returns False if the computer is halted
# `output` is an out parameter
def step(self, output: list[int]) -> bool:
# NOTE: also accounting for operand in overflow check here
if (self.ip + 1) >= len(self.program):
return False
instr, literal_operand = (
Instruction(self.program[self.ip]),
self.program[self.ip + 1],
)
combo_operand = self._resolve_combo_operand(literal_operand)
ip_delta = 2
match instr:
case Instruction.ADV:
self.registers.reg_a //= 2**combo_operand
case Instruction.BXL:
self.registers.reg_b ^= literal_operand
case Instruction.BST:
self.registers.reg_b = combo_operand % 8
case Instruction.JNZ:
if self.registers.reg_a != 0:
self.ip = literal_operand
ip_delta = 0
case Instruction.BXC:
self.registers.reg_b ^= self.registers.reg_c
case Instruction.OUT:
output.append(combo_operand % 8)
case Instruction.BDV:
self.registers.reg_b = self.registers.reg_a // 2**combo_operand
case Instruction.CDV:
self.registers.reg_c = self.registers.reg_a // 2**combo_operand
self.ip += ip_delta
return True
def solve(input: str) -> str:
def parse_registers(input: list[str]) -> Registers:
def parse_register(input: str) -> int:
return int(input.split(": ")[1])
return Registers(*map(parse_register, input))
def parse(input: str) -> Computer:
registers, program_str = input.split("\n\n")
program = list(map(int, program_str.removeprefix("Program: ").split(",")))
return Computer(parse_registers(registers.splitlines()), program)
computer = parse(input)
output: list[int] = []
while computer.step(output):
pass
return ",".join(str(n) for n in output)
def main() -> None:
input = sys.stdin.read()
print(solve(input))
if __name__ == "__main__":
main()