57 lines
1.4 KiB
Python
57 lines
1.4 KiB
Python
|
#!/usr/bin/env python
|
||
|
import itertools
|
||
|
import re
|
||
|
import sys
|
||
|
from dataclasses import dataclass
|
||
|
from typing import Dict, Iterator, List
|
||
|
|
||
|
|
||
|
@dataclass
|
||
|
class Mask:
|
||
|
ones: int
|
||
|
zeros: int
|
||
|
|
||
|
|
||
|
Memory = Dict[int, int]
|
||
|
|
||
|
|
||
|
def gen_floats(xs: int) -> Iterator[int]:
|
||
|
x_positions = [i for i in range(36) if (1 << i) & xs]
|
||
|
|
||
|
return (
|
||
|
sum(1 << x_positions[i] for i, v in enumerate(val) if v == 1)
|
||
|
for val in itertools.product([0, 1], repeat=len(x_positions))
|
||
|
)
|
||
|
|
||
|
|
||
|
def solve(raw: List[str]) -> int:
|
||
|
mask = Mask(
|
||
|
0,
|
||
|
(2 << 36) - 1,
|
||
|
)
|
||
|
mem_pattern = re.compile("mem\\[([0-9]+)\\] = ([0-9]+)")
|
||
|
mask_pattern = re.compile("mask = ([01X]+)")
|
||
|
mem: Memory = {}
|
||
|
for instr in raw:
|
||
|
if (mem_match := mem_pattern.match(instr)) is not None:
|
||
|
addr, val = int(mem_match.group(1)), int(mem_match.group(2))
|
||
|
addr |= mask.ones
|
||
|
xs = ~mask.ones & mask.zeros
|
||
|
addr &= ~xs # Put floating bits to 0
|
||
|
for x in gen_floats(xs):
|
||
|
mem[addr | x] = val
|
||
|
elif (mask_match := mask_pattern.match(instr)) is not None:
|
||
|
ones = int(mask_match.group(1).replace("X", "0"), 2)
|
||
|
zeros = int(mask_match.group(1).replace("X", "1"), 2)
|
||
|
mask = Mask(ones, zeros)
|
||
|
return sum(mem.values())
|
||
|
|
||
|
|
||
|
def main() -> None:
|
||
|
input = [line.strip() for line in sys.stdin.readlines()]
|
||
|
print(solve(input))
|
||
|
|
||
|
|
||
|
if __name__ == "__main__":
|
||
|
main()
|