advent-of-code/2018/d16/ex2/ex2.py

134 lines
4 KiB
Python
Executable file

#!/usr/bin/env python
import copy
import enum
import sys
from typing import NamedTuple
class OpCode(enum.StrEnum):
ADDR = "addr"
ADDI = "addi"
MULR = "mulr"
MULI = "muli"
BANR = "banr"
BANI = "bani"
BORR = "borr"
BORI = "bori"
SETR = "setr"
SETI = "seti"
GTIR = "gtir"
GTRI = "gtri"
GTRR = "gtrr"
EQIR = "eqir"
EQRI = "eqri"
EQRR = "eqrr"
def apply(self, registers: list[int], a: int, b: int, c: int) -> list[int]:
registers = copy.deepcopy(registers)
if self == OpCode.ADDR:
registers[c] = registers[a] + registers[b]
if self == OpCode.ADDI:
registers[c] = registers[a] + b
if self == OpCode.MULR:
registers[c] = registers[a] * registers[b]
if self == OpCode.MULI:
registers[c] = registers[a] * b
if self == OpCode.BANR:
registers[c] = registers[a] & registers[b]
if self == OpCode.BANI:
registers[c] = registers[a] & b
if self == OpCode.BORR:
registers[c] = registers[a] | registers[b]
if self == OpCode.BORI:
registers[c] = registers[a] | b
if self == OpCode.SETR:
registers[c] = registers[a]
if self == OpCode.SETI:
registers[c] = a
if self == OpCode.GTIR:
registers[c] = a > registers[b]
if self == OpCode.GTRI:
registers[c] = registers[a] > b
if self == OpCode.GTRR:
registers[c] = registers[a] > registers[b]
if self == OpCode.EQIR:
registers[c] = a == registers[b]
if self == OpCode.EQRI:
registers[c] = registers[a] == b
if self == OpCode.EQRR:
registers[c] = registers[a] == registers[b]
return registers
Instruction = list[int]
class Example(NamedTuple):
before: list[int]
data: Instruction
after: list[int]
def solve(input: str) -> int:
def parse_example(input: list[str]) -> Example:
before = input[0].removeprefix("Before: [").removesuffix("]")
data = input[1]
after = input[2].removeprefix("After: [").removesuffix("]")
return Example(
[int(n) for n in before.split(", ")],
[int(n) for n in data.split()],
[int(n) for n in after.split(", ")],
)
def parse_examples(input: str) -> list[Example]:
return [parse_example(example.splitlines()) for example in input.split("\n\n")]
def parse_data(input: list[str]) -> list[Instruction]:
return [[int(n) for n in line.split()] for line in input]
def parse(input: str) -> tuple[list[Example], list[Instruction]]:
examples, data = input.split("\n\n\n\n")
return parse_examples(examples), parse_data(data.splitlines())
def find_opcodes(examples: list[Example]) -> dict[int, OpCode]:
candidates: dict[int, set[OpCode]] = {n: set(OpCode) for n in range(16)}
for example in examples:
opcode, a, b, c = example.data
candidates[opcode] &= {
op
for op in candidates[opcode]
if op.apply(example.before, a, b, c) == example.after
}
while not all(len(ops) == 1 for ops in candidates.values()):
singles = {
n: next(iter(ops)) for n, ops in candidates.items() if len(ops) == 1
}
for n in candidates:
if n in singles:
continue
candidates[n] -= set(singles.values())
return {n: ops.pop() for n, ops in candidates.items()}
def run_program(data: list[Instruction], opcodes: dict[int, OpCode]) -> list[int]:
registers = [0] * 4
for opcode, a, b, c in data:
registers = opcodes[opcode].apply(registers, a, b, c)
return registers
examples, data = parse(input)
opcodes = find_opcodes(examples)
registers = run_program(data, opcodes)
return registers[0]
def main() -> None:
input = sys.stdin.read()
print(solve(input))
if __name__ == "__main__":
main()