#!/usr/bin/env python import enum import functools import itertools import sys from dataclasses import dataclass from typing import Iterable, Iterator, List, Tuple, TypeVar RawPacket = List[bool] class PacketType(enum.IntEnum): SUM = 0 PRODUCT = 1 MINIMUM = 2 MAXIMUM = 3 LITTERAL = 4 GREATER = 5 LESS = 6 EQUAL = 7 class PacketLengthType(enum.IntEnum): TOTAL_BITS = 0 # Next 15 bits are the number of bits in the sub-packets TOTAL_PACKETS = 1 # Next 11 bits are the number of sub-packets @dataclass class PacketLength: type: PacketLengthType length: int @dataclass class Packet: version: int type: PacketType @dataclass class OperatorPacket(Packet): packets: List[Packet] @dataclass class LitteralPacket(Packet): number: int def solve(input: List[str]) -> int: def to_raw(packet: str) -> RawPacket: def bits(n: int) -> Iterator[bool]: for i in range(3, -1, -1): yield bool(n & (1 << i)) nums = [int(c, 16) for c in input[0]] return list(itertools.chain.from_iterable(bits(n) for n in nums)) def bits_to_int(bits: Iterable[bool]) -> int: return functools.reduce(lambda a, b: (a << 1) + b, bits, 0) def parse_packet(p: RawPacket) -> Tuple[int, Packet]: def packet_version(p: RawPacket) -> int: return bits_to_int(p[:3]) def packet_type(p: RawPacket) -> PacketType: return PacketType(bits_to_int(p[3:6])) def parse_length(p: RawPacket) -> Tuple[int, PacketLength]: assert packet_type(p) != PacketType.LITTERAL # Sanity check type = PacketLengthType(bits_to_int(p[6:7])) if type == PacketLengthType.TOTAL_BITS: index = 7 + 15 length = bits_to_int(p[7:index]) else: index = 7 + 11 length = bits_to_int(p[7:index]) return index, PacketLength(type, length) def parse_litteral(p: RawPacket) -> Tuple[int, LitteralPacket]: version, type = packet_version(p), packet_type(p) assert type == PacketType.LITTERAL # Sanity check index = 6 bits: List[bool] = [] while True: bits += p[index + 1 : index + 5] index += 5 # Check if we were at the last one if p[index - 5] == 0: break return index, LitteralPacket(version, type, bits_to_int(bits)) def parse_operator(p: RawPacket) -> Tuple[int, OperatorPacket]: version, type = packet_version(p), packet_type(p) assert type != PacketType.LITTERAL # Sanity check index, length = parse_length(p) packets: List[Packet] = [] if length.type == PacketLengthType.TOTAL_BITS: sub_index = 0 while sub_index < length.length: parsed, packet = parse_packet(p[index:]) sub_index += parsed index += parsed packets.append(packet) else: while len(packets) < length.length: parsed, packet = parse_packet(p[index:]) index += parsed packets.append(packet) return index, OperatorPacket(version, type, packets) if packet_type(p) == PacketType.LITTERAL: return parse_litteral(p) return parse_operator(p) def score(p: Packet) -> int: if p.type == PacketType.LITTERAL: return p.version assert isinstance(p, OperatorPacket) # Sanity check return p.version + sum(score(c) for c in p.packets) raw = to_raw(input[0]) __, packet = parse_packet(raw) return score(packet) def main() -> None: input = [line.strip() for line in sys.stdin.readlines()] print(solve(input)) if __name__ == "__main__": main()