diff --git a/2023/d20/ex2/ex2.py b/2023/d20/ex2/ex2.py new file mode 100755 index 0000000..16d81b5 --- /dev/null +++ b/2023/d20/ex2/ex2.py @@ -0,0 +1,128 @@ +#!/usr/bin/env python + +import enum +import itertools +import math +import sys +from collections import defaultdict, deque +from typing import NamedTuple + + +class ModuleType(enum.StrEnum): + FLIP_FLOP = "%" + CONJUNCTION = "&" + BROADCAST = "broadcaster" + + +class Pulse(enum.IntEnum): + LOW = 0 + HIGH = 1 + + +class Rule(NamedTuple): + module_type: ModuleType + destinations: list[str] + + +Modules = dict[str, Rule] + + +def solve(input: list[str]) -> int: + def parse_rule(line: str) -> tuple[str, Rule]: + module, outputs = line.split(" -> ") + + name: str + module_type: ModuleType + + if module != "broadcaster": + name = module[1:] + module_type = ModuleType(module[0]) + else: + name = module + module_type = ModuleType(module) + + return name, Rule(module_type, outputs.split(", ")) + + def parse(input: list[str]) -> Modules: + return dict(map(parse_rule, input)) + + def compute_inputs(modules: Modules) -> dict[str, list[str]]: + inputs: dict[str, list[str]] = defaultdict(list) + + for src, rule in modules.items(): + for dst in rule.destinations: + inputs[dst].append(src) + + return inputs + + def count_buttons(modules: Modules) -> int: + def count_buttons_for( + wanted_src: str, + wanted_dst: str, + pulse_wanted: Pulse, + ) -> int: + inputs = compute_inputs(modules) + last_pulse: dict[str, Pulse] = defaultdict(lambda: Pulse.LOW) + + for i in itertools.count(start=1): + queue: deque[tuple[Pulse, str]] = deque([(Pulse.LOW, "broadcaster")]) + + while queue: + pulse, name = queue.popleft() + + mod = modules.get(name) + + # This is for unknown outputs + if mod is None: + continue + + new_pulse: Pulse + match mod.module_type: + case ModuleType.FLIP_FLOP: + if pulse == Pulse.HIGH: + continue + new_pulse = Pulse(1 - last_pulse[name]) + case ModuleType.CONJUNCTION: + high_inputs = all( + last_pulse[input] == Pulse.HIGH + for input in inputs[name] + ) + new_pulse = Pulse.LOW if high_inputs else Pulse.HIGH + case ModuleType.BROADCAST: + new_pulse = pulse + + last_pulse[name] = new_pulse + for dst in mod.destinations: + queue.append((new_pulse, dst)) + # We found the pulse we wanted, report the number of button presses + if ( + new_pulse == pulse_wanted + and name == wanted_src + and dst == wanted_dst + ): + return i + + assert False # Sanity check + + inputs = compute_inputs(modules) + # The input has a single conjunction leading to "rx" + # So we want to compute when all of *its* inputs are high at the same time + assert len(inputs["rx"]) == 1 # Sanity check + rx_input = inputs["rx"][0] + assert modules[rx_input].module_type == ModuleType.CONJUNCTION # Sanity check + # Shortcut: assume that the high pulse output is cyclic + return math.lcm( + *(count_buttons_for(mod, rx_input, Pulse.HIGH) for mod in inputs[rx_input]) + ) + + modules = parse(input) + return count_buttons(modules) + + +def main() -> None: + input = sys.stdin.read().splitlines() + print(solve(input)) + + +if __name__ == "__main__": + main()