advent-of-code/2020/d14/ex2/ex2.py

57 lines
1.4 KiB
Python
Raw Normal View History

2020-12-14 09:53:44 +01:00
#!/usr/bin/env python
import itertools
import re
import sys
from dataclasses import dataclass
from typing import Dict, Iterator, List
@dataclass
class Mask:
ones: int
zeros: int
Memory = Dict[int, int]
def gen_floats(xs: int) -> Iterator[int]:
x_positions = [i for i in range(36) if (1 << i) & xs]
return (
sum(1 << x_positions[i] for i, v in enumerate(val) if v == 1)
for val in itertools.product([0, 1], repeat=len(x_positions))
)
def solve(raw: List[str]) -> int:
mask = Mask(
0,
(2 << 36) - 1,
)
mem_pattern = re.compile("mem\\[([0-9]+)\\] = ([0-9]+)")
mask_pattern = re.compile("mask = ([01X]+)")
mem: Memory = {}
for instr in raw:
if (mem_match := mem_pattern.match(instr)) is not None:
addr, val = int(mem_match.group(1)), int(mem_match.group(2))
addr |= mask.ones
xs = ~mask.ones & mask.zeros
addr &= ~xs # Put floating bits to 0
for x in gen_floats(xs):
mem[addr | x] = val
elif (mask_match := mask_pattern.match(instr)) is not None:
ones = int(mask_match.group(1).replace("X", "0"), 2)
zeros = int(mask_match.group(1).replace("X", "1"), 2)
mask = Mask(ones, zeros)
return sum(mem.values())
def main() -> None:
input = [line.strip() for line in sys.stdin.readlines()]
print(solve(input))
if __name__ == "__main__":
main()