2021: d24: ex2: add solution
This commit is contained in:
parent
c40b08063a
commit
cb8cb691c4
112
2021/d24/ex2/ex2.py
Executable file
112
2021/d24/ex2/ex2.py
Executable file
|
@ -0,0 +1,112 @@
|
|||
#!/usr/bin/env python
|
||||
|
||||
import enum
|
||||
import functools
|
||||
import sys
|
||||
from typing import List, Literal, NamedTuple, Optional, Union, cast
|
||||
|
||||
import z3
|
||||
|
||||
Register = Literal["w", "x", "y", "z"]
|
||||
|
||||
|
||||
class InstructionType(enum.Enum):
|
||||
INPUT = "inp"
|
||||
ADD = "add"
|
||||
MULTIPLY = "mul"
|
||||
DIVIDE = "div"
|
||||
MODULO = "mod"
|
||||
EQUAL = "eql"
|
||||
|
||||
|
||||
class Instruction(NamedTuple):
|
||||
type: InstructionType
|
||||
destination: Register
|
||||
source: Optional[Union[Register, int]]
|
||||
|
||||
|
||||
def solve(input: List[str]) -> int:
|
||||
def parse() -> List[Instruction]:
|
||||
def parse_instruction(line: str) -> Instruction:
|
||||
type, *registers = line.split()
|
||||
|
||||
assert registers[0] in ("w", "x", "y", "z") # Sanity check
|
||||
destination = cast(Register, registers[0])
|
||||
|
||||
source: Optional[Union[Register, int]] = None
|
||||
if type != "inp":
|
||||
if registers[1] in ("w", "x", "y", "z"):
|
||||
source = cast(Register, registers[1])
|
||||
else:
|
||||
source = int(registers[1])
|
||||
|
||||
return Instruction(InstructionType(type), destination, source)
|
||||
|
||||
return [parse_instruction(line) for line in input]
|
||||
|
||||
def run_z3(instructions: List[Instruction]) -> int:
|
||||
solver = z3.Optimize()
|
||||
|
||||
BITS = 64
|
||||
|
||||
digits = [z3.BitVec(f"input_{i}", BITS) for i in range(14)]
|
||||
next_input = iter(digits).__next__
|
||||
|
||||
for d in digits:
|
||||
solver.add(z3.And(1 <= d, d <= 9))
|
||||
|
||||
zero, one = z3.BitVecVal(0, BITS), z3.BitVecVal(1, BITS)
|
||||
|
||||
registers = {r: zero for r in ("w", "x", "y", "z")}
|
||||
|
||||
for i, instr in enumerate(instructions):
|
||||
if instr.type == InstructionType.INPUT:
|
||||
registers[instr.destination] = next_input()
|
||||
continue
|
||||
|
||||
assert instr.source is not None # Sanity check
|
||||
|
||||
value = registers[instr.destination]
|
||||
if isinstance(instr.source, int):
|
||||
source = z3.BitVecVal(instr.source, BITS)
|
||||
else:
|
||||
source = registers[instr.source]
|
||||
|
||||
res = z3.BitVec(f"result_{i}", BITS)
|
||||
if instr.type == InstructionType.ADD:
|
||||
solver.add(res == (value + source))
|
||||
elif instr.type == InstructionType.MULTIPLY:
|
||||
solver.add(res == (value * source))
|
||||
elif instr.type == InstructionType.DIVIDE:
|
||||
solver.add(source > zero) # Sanity check
|
||||
solver.add(res == (value / source))
|
||||
elif instr.type == InstructionType.MODULO:
|
||||
solver.add(value >= zero) # Sanity check
|
||||
solver.add(source > zero) # Sanity check
|
||||
solver.add(res == (value % source))
|
||||
elif instr.type == InstructionType.EQUAL:
|
||||
solver.add(res == z3.If(value == source, one, zero))
|
||||
else:
|
||||
assert False # Sanity check
|
||||
registers[instr.destination] = res
|
||||
|
||||
solver.add(registers["z"] == zero)
|
||||
|
||||
digits_to_num = lambda digits: functools.reduce(lambda a, b: a * 10 + b, digits)
|
||||
|
||||
solver.minimize(digits_to_num(digits))
|
||||
assert solver.check() == z3.sat # Sanity check
|
||||
|
||||
model = solver.model()
|
||||
return digits_to_num(model[d].as_long() for d in digits)
|
||||
|
||||
return run_z3(parse())
|
||||
|
||||
|
||||
def main() -> None:
|
||||
input = [line.rstrip("\n") for line in sys.stdin.readlines()]
|
||||
print(solve(input))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
Loading…
Reference in a new issue