2021: d16: ex1: add solution
This commit is contained in:
parent
918aa91aba
commit
e4e9042c0f
138
2021/d16/ex1/ex1.py
Executable file
138
2021/d16/ex1/ex1.py
Executable file
|
@ -0,0 +1,138 @@
|
|||
#!/usr/bin/env python
|
||||
|
||||
import enum
|
||||
import functools
|
||||
import itertools
|
||||
import sys
|
||||
from dataclasses import dataclass
|
||||
from typing import Iterable, Iterator, List, Tuple
|
||||
|
||||
RawPacket = List[bool]
|
||||
|
||||
|
||||
class PacketType(enum.IntEnum):
|
||||
SUM = 0
|
||||
PRODUCT = 1
|
||||
MINIMUM = 2
|
||||
MAXIMUM = 3
|
||||
LITTERAL = 4
|
||||
GREATER = 5
|
||||
LESS = 6
|
||||
EQUAL = 7
|
||||
|
||||
|
||||
class PacketLengthType(enum.IntEnum):
|
||||
TOTAL_BITS = 0 # Next 15 bits are the number of bits in the sub-packets
|
||||
TOTAL_PACKETS = 1 # Next 11 bits are the number of sub-packets
|
||||
|
||||
|
||||
@dataclass
|
||||
class PacketLength:
|
||||
type: PacketLengthType
|
||||
length: int
|
||||
|
||||
|
||||
@dataclass
|
||||
class Packet:
|
||||
version: int
|
||||
type: PacketType
|
||||
|
||||
|
||||
@dataclass
|
||||
class OperatorPacket(Packet):
|
||||
packets: List[Packet]
|
||||
|
||||
|
||||
@dataclass
|
||||
class LitteralPacket(Packet):
|
||||
number: int
|
||||
|
||||
|
||||
def solve(input: List[str]) -> int:
|
||||
def to_raw(packet: str) -> RawPacket:
|
||||
def bits(n: int) -> Iterator[bool]:
|
||||
for i in range(3, -1, -1):
|
||||
yield bool(n & (1 << i))
|
||||
|
||||
nums = [int(c, 16) for c in input[0]]
|
||||
return list(itertools.chain.from_iterable(bits(n) for n in nums))
|
||||
|
||||
def bits_to_int(bits: Iterable[bool]) -> int:
|
||||
return functools.reduce(lambda a, b: (a << 1) + b, bits, 0)
|
||||
|
||||
def parse_packet(p: RawPacket) -> Tuple[int, Packet]:
|
||||
def packet_version(p: RawPacket) -> int:
|
||||
return bits_to_int(p[:3])
|
||||
|
||||
def packet_type(p: RawPacket) -> PacketType:
|
||||
return PacketType(bits_to_int(p[3:6]))
|
||||
|
||||
def parse_length(p: RawPacket) -> Tuple[int, PacketLength]:
|
||||
assert packet_type(p) != PacketType.LITTERAL # Sanity check
|
||||
type = PacketLengthType(bits_to_int(p[6:7]))
|
||||
if type == PacketLengthType.TOTAL_BITS:
|
||||
index = 7 + 15
|
||||
length = bits_to_int(p[7:index])
|
||||
else:
|
||||
index = 7 + 11
|
||||
length = bits_to_int(p[7:index])
|
||||
return index, PacketLength(type, length)
|
||||
|
||||
def parse_litteral(p: RawPacket) -> Tuple[int, LitteralPacket]:
|
||||
version, type = packet_version(p), packet_type(p)
|
||||
assert type == PacketType.LITTERAL # Sanity check
|
||||
index = 6
|
||||
bits: List[bool] = []
|
||||
while True:
|
||||
bits += p[index + 1 : index + 5]
|
||||
index += 5
|
||||
# Check if we were at the last one
|
||||
if p[index - 5] == 0:
|
||||
break
|
||||
return index, LitteralPacket(version, type, bits_to_int(bits))
|
||||
|
||||
def parse_operator(p: RawPacket) -> Tuple[int, OperatorPacket]:
|
||||
version, type = packet_version(p), packet_type(p)
|
||||
assert type != PacketType.LITTERAL # Sanity check
|
||||
|
||||
index, length = parse_length(p)
|
||||
packets: List[Packet] = []
|
||||
|
||||
if length.type == PacketLengthType.TOTAL_BITS:
|
||||
sub_index = 0
|
||||
while sub_index < length.length:
|
||||
parsed, packet = parse_packet(p[index:])
|
||||
sub_index += parsed
|
||||
index += parsed
|
||||
packets.append(packet)
|
||||
else:
|
||||
while len(packets) < length.length:
|
||||
parsed, packet = parse_packet(p[index:])
|
||||
index += parsed
|
||||
packets.append(packet)
|
||||
|
||||
return index, OperatorPacket(version, type, packets)
|
||||
|
||||
if packet_type(p) == PacketType.LITTERAL:
|
||||
return parse_litteral(p)
|
||||
return parse_operator(p)
|
||||
|
||||
def score(p: Packet) -> int:
|
||||
if p.type == PacketType.LITTERAL:
|
||||
return p.version
|
||||
assert isinstance(p, OperatorPacket) # Sanity check
|
||||
return p.version + sum(score(c) for c in p.packets)
|
||||
|
||||
raw = to_raw(input[0])
|
||||
__, packet = parse_packet(raw)
|
||||
|
||||
return score(packet)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
input = [line.strip() for line in sys.stdin.readlines()]
|
||||
print(solve(input))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
Loading…
Reference in a new issue