#!/usr/bin/env python import re import sys from dataclasses import dataclass from typing import Dict, List @dataclass class Mask: ones: int zeros: int Memory = Dict[int, int] def solve(raw: List[str]) -> int: mask = Mask( 0, (2 << 36) - 1, ) mem_pattern = re.compile("mem\\[([0-9]+)\\] = ([0-9]+)") mask_pattern = re.compile("mask = ([01X]+)") mem: Memory = {} for instr in raw: if (mem_match := mem_pattern.match(instr)) is not None: addr, val = int(mem_match.group(1)), int(mem_match.group(2)) val |= mask.ones val &= mask.zeros mem[addr] = val elif (mask_match := mask_pattern.match(instr)) is not None: ones = int(mask_match.group(1).replace("X", "0"), 2) zeros = int(mask_match.group(1).replace("X", "1"), 2) mask = Mask(ones, zeros) return sum(mem.values()) def main() -> None: input = [line.strip() for line in sys.stdin.readlines()] print(solve(input)) if __name__ == "__main__": main()