diff --git a/2020/d14/ex2/ex2.py b/2020/d14/ex2/ex2.py new file mode 100755 index 0000000..55c254c --- /dev/null +++ b/2020/d14/ex2/ex2.py @@ -0,0 +1,56 @@ +#!/usr/bin/env python +import itertools +import re +import sys +from dataclasses import dataclass +from typing import Dict, Iterator, List + + +@dataclass +class Mask: + ones: int + zeros: int + + +Memory = Dict[int, int] + + +def gen_floats(xs: int) -> Iterator[int]: + x_positions = [i for i in range(36) if (1 << i) & xs] + + return ( + sum(1 << x_positions[i] for i, v in enumerate(val) if v == 1) + for val in itertools.product([0, 1], repeat=len(x_positions)) + ) + + +def solve(raw: List[str]) -> int: + mask = Mask( + 0, + (2 << 36) - 1, + ) + mem_pattern = re.compile("mem\\[([0-9]+)\\] = ([0-9]+)") + mask_pattern = re.compile("mask = ([01X]+)") + mem: Memory = {} + for instr in raw: + if (mem_match := mem_pattern.match(instr)) is not None: + addr, val = int(mem_match.group(1)), int(mem_match.group(2)) + addr |= mask.ones + xs = ~mask.ones & mask.zeros + addr &= ~xs # Put floating bits to 0 + for x in gen_floats(xs): + mem[addr | x] = val + elif (mask_match := mask_pattern.match(instr)) is not None: + ones = int(mask_match.group(1).replace("X", "0"), 2) + zeros = int(mask_match.group(1).replace("X", "1"), 2) + mask = Mask(ones, zeros) + return sum(mem.values()) + + +def main() -> None: + input = [line.strip() for line in sys.stdin.readlines()] + print(solve(input)) + + +if __name__ == "__main__": + main()