advent-of-code/2021/d08/ex2/ex2.py

79 lines
1.9 KiB
Python
Raw Normal View History

2021-12-08 10:57:37 +01:00
#!/usr/bin/env python
import itertools
import sys
from dataclasses import dataclass
from typing import Dict, List, Set
@dataclass
class Entry:
signals: List[Set[str]]
outputs: List[Set[str]]
def solve(input: List[str]) -> int:
def parse_entry(input: str) -> Entry:
signals, outputs = input.split(" | ")
return Entry(
[set(s) for s in signals.split()], [set(o) for o in outputs.split()]
)
def deduce_signals(entry: Entry) -> Dict[int, str]:
_1, _7, _4, *signals_to_deduce, _8 = sorted(entry.signals, key=len)
signals = {
1: _1,
4: _4,
7: _7,
8: _8,
}
for sig in signals_to_deduce:
match = len(sig), len(sig & _4), len(sig & _1)
if match == (6, 3, 2):
signals[0] = sig
elif match == (5, 2, 1):
signals[2] = sig
elif match == (5, 3, 2):
signals[3] = sig
elif match == (5, 3, 1):
signals[5] = sig
elif match == (6, 3, 1):
signals[6] = sig
elif match == (6, 4, 2):
signals[9] = sig
else:
assert False # Sanity check
assert len(signals) == 10 # Sanity check
return signals
def deduce_entry(entry: Entry) -> int:
decoded_signals = deduce_signals(entry)
res = 0
for output in entry.outputs:
assert output in decoded_signals.values() # Sanity check
for n, signal in decoded_signals.items():
if output != signal:
continue
res = res * 10 + n
return res
entries = [parse_entry(line) for line in input]
return sum(map(deduce_entry, entries))
def main() -> None:
input = [line.strip() for line in sys.stdin.readlines()]
print(solve(input))
if __name__ == "__main__":
main()