129 lines
3.8 KiB
Python
Executable file
129 lines
3.8 KiB
Python
Executable file
#!/usr/bin/env python
|
|
|
|
import enum
|
|
import itertools
|
|
import math
|
|
import sys
|
|
from collections import defaultdict, deque
|
|
from typing import NamedTuple
|
|
|
|
|
|
class ModuleType(enum.StrEnum):
|
|
FLIP_FLOP = "%"
|
|
CONJUNCTION = "&"
|
|
BROADCAST = "broadcaster"
|
|
|
|
|
|
class Pulse(enum.IntEnum):
|
|
LOW = 0
|
|
HIGH = 1
|
|
|
|
|
|
class Rule(NamedTuple):
|
|
module_type: ModuleType
|
|
destinations: list[str]
|
|
|
|
|
|
Modules = dict[str, Rule]
|
|
|
|
|
|
def solve(input: list[str]) -> int:
|
|
def parse_rule(line: str) -> tuple[str, Rule]:
|
|
module, outputs = line.split(" -> ")
|
|
|
|
name: str
|
|
module_type: ModuleType
|
|
|
|
if module != "broadcaster":
|
|
name = module[1:]
|
|
module_type = ModuleType(module[0])
|
|
else:
|
|
name = module
|
|
module_type = ModuleType(module)
|
|
|
|
return name, Rule(module_type, outputs.split(", "))
|
|
|
|
def parse(input: list[str]) -> Modules:
|
|
return dict(map(parse_rule, input))
|
|
|
|
def compute_inputs(modules: Modules) -> dict[str, list[str]]:
|
|
inputs: dict[str, list[str]] = defaultdict(list)
|
|
|
|
for src, rule in modules.items():
|
|
for dst in rule.destinations:
|
|
inputs[dst].append(src)
|
|
|
|
return inputs
|
|
|
|
def count_buttons(modules: Modules) -> int:
|
|
def count_buttons_for(
|
|
wanted_src: str,
|
|
wanted_dst: str,
|
|
pulse_wanted: Pulse,
|
|
) -> int:
|
|
inputs = compute_inputs(modules)
|
|
last_pulse: dict[str, Pulse] = defaultdict(lambda: Pulse.LOW)
|
|
|
|
for i in itertools.count(start=1):
|
|
queue: deque[tuple[Pulse, str]] = deque([(Pulse.LOW, "broadcaster")])
|
|
|
|
while queue:
|
|
pulse, name = queue.popleft()
|
|
|
|
mod = modules.get(name)
|
|
|
|
# This is for unknown outputs
|
|
if mod is None:
|
|
continue
|
|
|
|
new_pulse: Pulse
|
|
match mod.module_type:
|
|
case ModuleType.FLIP_FLOP:
|
|
if pulse == Pulse.HIGH:
|
|
continue
|
|
new_pulse = Pulse(1 - last_pulse[name])
|
|
case ModuleType.CONJUNCTION:
|
|
high_inputs = all(
|
|
last_pulse[input] == Pulse.HIGH
|
|
for input in inputs[name]
|
|
)
|
|
new_pulse = Pulse.LOW if high_inputs else Pulse.HIGH
|
|
case ModuleType.BROADCAST:
|
|
new_pulse = pulse
|
|
|
|
last_pulse[name] = new_pulse
|
|
for dst in mod.destinations:
|
|
queue.append((new_pulse, dst))
|
|
# We found the pulse we wanted, report the number of button presses
|
|
if (
|
|
new_pulse == pulse_wanted
|
|
and name == wanted_src
|
|
and dst == wanted_dst
|
|
):
|
|
return i
|
|
|
|
assert False # Sanity check
|
|
|
|
inputs = compute_inputs(modules)
|
|
# The input has a single conjunction leading to "rx"
|
|
# So we want to compute when all of *its* inputs are high at the same time
|
|
assert len(inputs["rx"]) == 1 # Sanity check
|
|
rx_input = inputs["rx"][0]
|
|
assert modules[rx_input].module_type == ModuleType.CONJUNCTION # Sanity check
|
|
# Shortcut: assume that the high pulse output is cyclic
|
|
return math.lcm(
|
|
*(count_buttons_for(mod, rx_input, Pulse.HIGH) for mod in inputs[rx_input])
|
|
)
|
|
|
|
modules = parse(input)
|
|
return count_buttons(modules)
|
|
|
|
|
|
def main() -> None:
|
|
input = sys.stdin.read().splitlines()
|
|
print(solve(input))
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|