From cb5b497a30d70c6e9cefddb0fed268d5d88f3687 Mon Sep 17 00:00:00 2001 From: Bruno BELANYI Date: Fri, 9 May 2025 18:40:14 +0100 Subject: [PATCH] 2017: d07: ex2: add solution --- 2017/d07/ex2/ex2.py | 73 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 73 insertions(+) create mode 100755 2017/d07/ex2/ex2.py diff --git a/2017/d07/ex2/ex2.py b/2017/d07/ex2/ex2.py new file mode 100755 index 0000000..c3e2db9 --- /dev/null +++ b/2017/d07/ex2/ex2.py @@ -0,0 +1,73 @@ +#!/usr/bin/env python + +import sys +from collections import Counter +from typing import NamedTuple + + +class ProgramYell(NamedTuple): + weight: int + children: set[str] + + +Tower = dict[str, ProgramYell] + + +def solve(input: str) -> int: + def parse_line(line: str) -> tuple[str, ProgramYell]: + name, rest = line.split(" ", 1) + weight, rest = rest.split(")") + children = set(rest.removeprefix(" -> ").split(", ")) if rest else set() + return name, ProgramYell(int(weight[1:]), children) + + def parse(input: str) -> Tower: + return {name: yell for name, yell in map(parse_line, input.splitlines())} + + def find_base(tower: Tower) -> str: + candidates = set(tower.keys()) + for yell in tower.values(): + candidates -= yell.children + assert len(candidates) == 1 + return candidates.pop() + + def compute_weights(tower: Tower, base: str) -> dict[str, int]: + def helper(base: str) -> int: + total = ( + sum(helper(child) for child in tower[base].children) + + tower[base].weight + ) + res[base] = total + return total + + res: dict[str, int] = {} + helper(base) + return res + + def fix_balance(tower: Tower) -> int: + base = find_base(tower) + weights = compute_weights(tower, base) + + balanced_weight = -1 # Will be updated to the correct value afterwards + while True: + children = tower[base].children + subtowers = Counter(weights[child] for child in children) + if len(subtowers) == 1: + assert balanced_weight != -1 # Sanity check + return balanced_weight - sum(weights[child] for child in children) + assert len(subtowers) == 2 # Sanity check + balanced_weight, unbalanced_weight = (v for v, _ in subtowers.most_common()) + base = next( + child for child in children if weights[child] == unbalanced_weight + ) + + tower = parse(input) + return fix_balance(tower) + + +def main() -> None: + input = sys.stdin.read() + print(solve(input)) + + +if __name__ == "__main__": + main()