162 lines
4.8 KiB
Python
Executable file
162 lines
4.8 KiB
Python
Executable file
#!/usr/bin/env python
|
|
|
|
import enum
|
|
import functools
|
|
import itertools
|
|
import math
|
|
import sys
|
|
from dataclasses import dataclass
|
|
from typing import Iterable, Iterator, List, Tuple, TypeVar
|
|
|
|
RawPacket = List[bool]
|
|
|
|
|
|
class PacketType(enum.IntEnum):
|
|
SUM = 0
|
|
PRODUCT = 1
|
|
MINIMUM = 2
|
|
MAXIMUM = 3
|
|
LITTERAL = 4
|
|
GREATER = 5
|
|
LESS = 6
|
|
EQUAL = 7
|
|
|
|
|
|
class PacketLengthType(enum.IntEnum):
|
|
TOTAL_BITS = 0 # Next 15 bits are the number of bits in the sub-packets
|
|
TOTAL_PACKETS = 1 # Next 11 bits are the number of sub-packets
|
|
|
|
|
|
@dataclass
|
|
class PacketLength:
|
|
type: PacketLengthType
|
|
length: int
|
|
|
|
|
|
@dataclass
|
|
class Packet:
|
|
version: int
|
|
type: PacketType
|
|
|
|
|
|
@dataclass
|
|
class OperatorPacket(Packet):
|
|
packets: List[Packet]
|
|
|
|
|
|
@dataclass
|
|
class LitteralPacket(Packet):
|
|
number: int
|
|
|
|
|
|
def solve(input: List[str]) -> int:
|
|
def to_raw(packet: str) -> RawPacket:
|
|
def bits(n: int) -> Iterator[bool]:
|
|
for i in range(3, -1, -1):
|
|
yield bool(n & (1 << i))
|
|
|
|
nums = [int(c, 16) for c in input[0]]
|
|
return list(itertools.chain.from_iterable(bits(n) for n in nums))
|
|
|
|
def bits_to_int(bits: Iterable[bool]) -> int:
|
|
return functools.reduce(lambda a, b: (a << 1) + b, bits, 0)
|
|
|
|
def parse_packet(p: RawPacket) -> Tuple[int, Packet]:
|
|
def packet_version(p: RawPacket) -> int:
|
|
return bits_to_int(p[:3])
|
|
|
|
def packet_type(p: RawPacket) -> PacketType:
|
|
return PacketType(bits_to_int(p[3:6]))
|
|
|
|
def parse_length(p: RawPacket) -> Tuple[int, PacketLength]:
|
|
assert packet_type(p) != PacketType.LITTERAL # Sanity check
|
|
type = PacketLengthType(bits_to_int(p[6:7]))
|
|
if type == PacketLengthType.TOTAL_BITS:
|
|
index = 7 + 15
|
|
length = bits_to_int(p[7:index])
|
|
else:
|
|
index = 7 + 11
|
|
length = bits_to_int(p[7:index])
|
|
return index, PacketLength(type, length)
|
|
|
|
def parse_litteral(p: RawPacket) -> Tuple[int, LitteralPacket]:
|
|
version, type = packet_version(p), packet_type(p)
|
|
assert type == PacketType.LITTERAL # Sanity check
|
|
index = 6
|
|
bits: List[bool] = []
|
|
while True:
|
|
bits += p[index + 1 : index + 5]
|
|
index += 5
|
|
# Check if we were at the last one
|
|
if p[index - 5] == 0:
|
|
break
|
|
return index, LitteralPacket(version, type, bits_to_int(bits))
|
|
|
|
def parse_operator(p: RawPacket) -> Tuple[int, OperatorPacket]:
|
|
version, type = packet_version(p), packet_type(p)
|
|
assert type != PacketType.LITTERAL # Sanity check
|
|
|
|
index, length = parse_length(p)
|
|
packets: List[Packet] = []
|
|
|
|
if length.type == PacketLengthType.TOTAL_BITS:
|
|
sub_index = 0
|
|
while sub_index < length.length:
|
|
parsed, packet = parse_packet(p[index:])
|
|
sub_index += parsed
|
|
index += parsed
|
|
packets.append(packet)
|
|
else:
|
|
while len(packets) < length.length:
|
|
parsed, packet = parse_packet(p[index:])
|
|
index += parsed
|
|
packets.append(packet)
|
|
|
|
return index, OperatorPacket(version, type, packets)
|
|
|
|
if packet_type(p) == PacketType.LITTERAL:
|
|
return parse_litteral(p)
|
|
return parse_operator(p)
|
|
|
|
def eval(p: Packet) -> int:
|
|
if p.type == PacketType.LITTERAL:
|
|
assert isinstance(p, LitteralPacket) # Sanity check
|
|
return p.number
|
|
assert isinstance(p, OperatorPacket) # Sanity check
|
|
|
|
packet_values = [eval(c) for c in p.packets]
|
|
assert len(packet_values) >= 1 # Sanity check
|
|
|
|
if p.type == PacketType.SUM:
|
|
return sum(packet_values)
|
|
elif p.type == PacketType.PRODUCT:
|
|
return math.prod(packet_values)
|
|
elif p.type == PacketType.MINIMUM:
|
|
return min(packet_values)
|
|
elif p.type == PacketType.MAXIMUM:
|
|
return max(packet_values)
|
|
elif p.type == PacketType.GREATER:
|
|
assert len(packet_values) == 2 # Sanity check
|
|
return int(packet_values[0] > packet_values[1])
|
|
elif p.type == PacketType.LESS:
|
|
assert len(packet_values) == 2 # Sanity check
|
|
return int(packet_values[0] < packet_values[1])
|
|
elif p.type == PacketType.EQUAL:
|
|
assert len(packet_values) == 2 # Sanity check
|
|
return int(packet_values[0] == packet_values[1])
|
|
assert False # Sanity check
|
|
|
|
raw = to_raw(input[0])
|
|
__, packet = parse_packet(raw)
|
|
|
|
return eval(packet)
|
|
|
|
|
|
def main() -> None:
|
|
input = [line.strip() for line in sys.stdin.readlines()]
|
|
print(solve(input))
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|