diff --git a/2023/d03/ex2/ex2.py b/2023/d03/ex2/ex2.py new file mode 100755 index 0000000..6cf3d80 --- /dev/null +++ b/2023/d03/ex2/ex2.py @@ -0,0 +1,100 @@ +#!/usr/bin/env python + +import sys +from collections import defaultdict +from collections.abc import Iterator +from typing import NamedTuple + + +class Point(NamedTuple): + x: int + y: int + + +class SchematicNumber(NamedTuple): + value: int + start: Point + + def neighbours(self) -> Iterator["Point"]: + # How long is the number + length = len(str(self.value)) + + # Every point to the left + for dx in range(-1, 1 + 1): + yield Point(self.start.x + dx, self.start.y - 1) + + # Every point below/above + for dy in range(length): + for dx in (-1, 1): + yield Point(self.start.x + dx, self.start.y + dy) + + # Every point to the right + for dx in range(-1, 1 + 1): + yield Point(self.start.x + dx, self.start.y + length) + + +SymbolsMap = dict[Point, str] + + +def parse_line(x: int, line: str) -> tuple[set[SchematicNumber], SymbolsMap]: + numbers: set[SchematicNumber] = set() + symbols: SymbolsMap = {} + + y = 0 + while y != len(line): + char = line[y] + pos = Point(x, y) + if char.isdigit(): + dy = 0 + while (y + dy) < len(line) and (line[y + dy]).isdigit(): + dy += 1 + numbers.add(SchematicNumber(value=int(line[y : y + dy]), start=pos)) + y += dy + continue + elif char == ".": + pass + else: + symbols[pos] = char + y += 1 + + return numbers, symbols + + +def parse(input: list[str]) -> tuple[set[SchematicNumber], SymbolsMap]: + numbers: set[SchematicNumber] = set() + symbols: SymbolsMap = {} + + for x, line in enumerate(input): + new_numbers, new_symbols = parse_line(x, line) + numbers |= new_numbers + symbols |= new_symbols + + return numbers, symbols + + +def solve(input: list[str]) -> int: + numbers, symbols = parse(input) + gear_adjacency: dict[Point, list[SchematicNumber]] = defaultdict(list) + + for n in numbers: + for p in n.neighbours(): + if not symbols.get(p) == "*": + continue + gear_adjacency[p].append(n) + + gear_power = [ + nums[0].value * nums[1].value + for nums in gear_adjacency.values() + if len(nums) == 2 + ] + + return sum(gear_power) + + +def main() -> None: + input = sys.stdin.read().splitlines() + print(solve(input)) + + +if __name__ == "__main__": + main()