2021: d24: ex2: add solution

This commit is contained in:
Bruno BELANYI 2021-12-25 00:32:07 +01:00
parent c40b08063a
commit cb8cb691c4

112
2021/d24/ex2/ex2.py Executable file
View file

@ -0,0 +1,112 @@
#!/usr/bin/env python
import enum
import functools
import sys
from typing import List, Literal, NamedTuple, Optional, Union, cast
import z3
Register = Literal["w", "x", "y", "z"]
class InstructionType(enum.Enum):
INPUT = "inp"
ADD = "add"
MULTIPLY = "mul"
DIVIDE = "div"
MODULO = "mod"
EQUAL = "eql"
class Instruction(NamedTuple):
type: InstructionType
destination: Register
source: Optional[Union[Register, int]]
def solve(input: List[str]) -> int:
def parse() -> List[Instruction]:
def parse_instruction(line: str) -> Instruction:
type, *registers = line.split()
assert registers[0] in ("w", "x", "y", "z") # Sanity check
destination = cast(Register, registers[0])
source: Optional[Union[Register, int]] = None
if type != "inp":
if registers[1] in ("w", "x", "y", "z"):
source = cast(Register, registers[1])
else:
source = int(registers[1])
return Instruction(InstructionType(type), destination, source)
return [parse_instruction(line) for line in input]
def run_z3(instructions: List[Instruction]) -> int:
solver = z3.Optimize()
BITS = 64
digits = [z3.BitVec(f"input_{i}", BITS) for i in range(14)]
next_input = iter(digits).__next__
for d in digits:
solver.add(z3.And(1 <= d, d <= 9))
zero, one = z3.BitVecVal(0, BITS), z3.BitVecVal(1, BITS)
registers = {r: zero for r in ("w", "x", "y", "z")}
for i, instr in enumerate(instructions):
if instr.type == InstructionType.INPUT:
registers[instr.destination] = next_input()
continue
assert instr.source is not None # Sanity check
value = registers[instr.destination]
if isinstance(instr.source, int):
source = z3.BitVecVal(instr.source, BITS)
else:
source = registers[instr.source]
res = z3.BitVec(f"result_{i}", BITS)
if instr.type == InstructionType.ADD:
solver.add(res == (value + source))
elif instr.type == InstructionType.MULTIPLY:
solver.add(res == (value * source))
elif instr.type == InstructionType.DIVIDE:
solver.add(source > zero) # Sanity check
solver.add(res == (value / source))
elif instr.type == InstructionType.MODULO:
solver.add(value >= zero) # Sanity check
solver.add(source > zero) # Sanity check
solver.add(res == (value % source))
elif instr.type == InstructionType.EQUAL:
solver.add(res == z3.If(value == source, one, zero))
else:
assert False # Sanity check
registers[instr.destination] = res
solver.add(registers["z"] == zero)
digits_to_num = lambda digits: functools.reduce(lambda a, b: a * 10 + b, digits)
solver.minimize(digits_to_num(digits))
assert solver.check() == z3.sat # Sanity check
model = solver.model()
return digits_to_num(model[d].as_long() for d in digits)
return run_z3(parse())
def main() -> None:
input = [line.rstrip("\n") for line in sys.stdin.readlines()]
print(solve(input))
if __name__ == "__main__":
main()