2021: d16: ex2: add solution
This commit is contained in:
parent
b42fd7e459
commit
44d90b4153
158
2021/d16/ex2/ex2.py
Executable file
158
2021/d16/ex2/ex2.py
Executable file
|
@ -0,0 +1,158 @@
|
|||
#!/usr/bin/env python
|
||||
|
||||
import enum
|
||||
import functools
|
||||
import itertools
|
||||
import math
|
||||
import sys
|
||||
from dataclasses import dataclass
|
||||
from typing import Callable, Dict, Iterable, Iterator, List, Tuple, TypeVar
|
||||
|
||||
RawPacket = List[bool]
|
||||
|
||||
|
||||
class PacketType(enum.IntEnum):
|
||||
SUM = 0
|
||||
PRODUCT = 1
|
||||
MINIMUM = 2
|
||||
MAXIMUM = 3
|
||||
LITTERAL = 4
|
||||
GREATER = 5
|
||||
LESS = 6
|
||||
EQUAL = 7
|
||||
|
||||
|
||||
class PacketLengthType(enum.IntEnum):
|
||||
TOTAL_BITS = 0 # Next 15 bits are the number of bits in the sub-packets
|
||||
TOTAL_PACKETS = 1 # Next 11 bits are the number of sub-packets
|
||||
|
||||
|
||||
@dataclass
|
||||
class PacketLength:
|
||||
type: PacketLengthType
|
||||
length: int
|
||||
|
||||
|
||||
@dataclass
|
||||
class Packet:
|
||||
version: int
|
||||
type: PacketType
|
||||
|
||||
|
||||
@dataclass
|
||||
class OperatorPacket(Packet):
|
||||
packets: List[Packet]
|
||||
|
||||
|
||||
@dataclass
|
||||
class LitteralPacket(Packet):
|
||||
number: int
|
||||
|
||||
|
||||
def solve(input: List[str]) -> int:
|
||||
def to_raw(packet: str) -> RawPacket:
|
||||
def bits(n: int) -> Iterator[bool]:
|
||||
for i in range(3, -1, -1):
|
||||
yield bool(n & (1 << i))
|
||||
|
||||
nums = [int(c, 16) for c in input[0]]
|
||||
return list(itertools.chain.from_iterable(bits(n) for n in nums))
|
||||
|
||||
def bits_to_int(bits: Iterable[bool]) -> int:
|
||||
return functools.reduce(lambda a, b: (a << 1) + b, bits, 0)
|
||||
|
||||
def parse_packet(p: RawPacket) -> Tuple[int, Packet]:
|
||||
def packet_version(p: RawPacket) -> int:
|
||||
return bits_to_int(p[:3])
|
||||
|
||||
def packet_type(p: RawPacket) -> PacketType:
|
||||
return PacketType(bits_to_int(p[3:6]))
|
||||
|
||||
def parse_length(p: RawPacket) -> Tuple[int, PacketLength]:
|
||||
assert packet_type(p) != PacketType.LITTERAL # Sanity check
|
||||
type = PacketLengthType(bits_to_int(p[6:7]))
|
||||
if type == PacketLengthType.TOTAL_BITS:
|
||||
index = 7 + 15
|
||||
length = bits_to_int(p[7:index])
|
||||
else:
|
||||
index = 7 + 11
|
||||
length = bits_to_int(p[7:index])
|
||||
return index, PacketLength(type, length)
|
||||
|
||||
def parse_litteral(p: RawPacket) -> Tuple[int, LitteralPacket]:
|
||||
version, type = packet_version(p), packet_type(p)
|
||||
assert type == PacketType.LITTERAL # Sanity check
|
||||
index = 6
|
||||
bits: List[bool] = []
|
||||
while True:
|
||||
bits += p[index + 1 : index + 5]
|
||||
index += 5
|
||||
# Check if we were at the last one
|
||||
if p[index - 5] == 0:
|
||||
break
|
||||
return index, LitteralPacket(version, type, bits_to_int(bits))
|
||||
|
||||
def parse_operator(p: RawPacket) -> Tuple[int, OperatorPacket]:
|
||||
version, type = packet_version(p), packet_type(p)
|
||||
assert type != PacketType.LITTERAL # Sanity check
|
||||
|
||||
index, length = parse_length(p)
|
||||
packets: List[Packet] = []
|
||||
|
||||
if length.type == PacketLengthType.TOTAL_BITS:
|
||||
sub_index = 0
|
||||
while sub_index < length.length:
|
||||
parsed, packet = parse_packet(p[index:])
|
||||
sub_index += parsed
|
||||
index += parsed
|
||||
packets.append(packet)
|
||||
else:
|
||||
while len(packets) < length.length:
|
||||
parsed, packet = parse_packet(p[index:])
|
||||
index += parsed
|
||||
packets.append(packet)
|
||||
|
||||
return index, OperatorPacket(version, type, packets)
|
||||
|
||||
if packet_type(p) == PacketType.LITTERAL:
|
||||
return parse_litteral(p)
|
||||
return parse_operator(p)
|
||||
|
||||
def eval(p: Packet) -> int:
|
||||
if p.type == PacketType.LITTERAL:
|
||||
assert isinstance(p, LitteralPacket) # Sanity check
|
||||
return p.number
|
||||
assert isinstance(p, OperatorPacket) # Sanity check
|
||||
|
||||
packet_values = [eval(c) for c in p.packets]
|
||||
|
||||
ops: Dict[PacketType, Callable[[List[int]], int]] = {
|
||||
PacketType.SUM: sum,
|
||||
PacketType.PRODUCT: math.prod,
|
||||
PacketType.MINIMUM: min,
|
||||
PacketType.MAXIMUM: max,
|
||||
PacketType.SUM: sum,
|
||||
PacketType.GREATER: lambda values: values[0] > values[1],
|
||||
PacketType.LESS: lambda values: values[0] < values[1],
|
||||
PacketType.EQUAL: lambda values: values[0] == values[1],
|
||||
}
|
||||
|
||||
assert len(packet_values) >= 1 # Sanity check
|
||||
if p.type in (PacketType.GREATER, PacketType.LESS, PacketType.EQUAL):
|
||||
assert len(packet_values) == 2 # Sanity check
|
||||
|
||||
return ops[p.type](packet_values)
|
||||
|
||||
raw = to_raw(input[0])
|
||||
__, packet = parse_packet(raw)
|
||||
|
||||
return eval(packet)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
input = [line.strip() for line in sys.stdin.readlines()]
|
||||
print(solve(input))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
Loading…
Reference in a new issue