diff --git a/calculator/calculator/print/__init__.py b/calculator/calculator/print/__init__.py new file mode 100644 index 0000000..b4bd652 --- /dev/null +++ b/calculator/calculator/print/__init__.py @@ -0,0 +1 @@ +from .printer import Printer diff --git a/calculator/calculator/print/printer.py b/calculator/calculator/print/printer.py new file mode 100644 index 0000000..8d26202 --- /dev/null +++ b/calculator/calculator/print/printer.py @@ -0,0 +1,57 @@ +from __future__ import annotations + +import dataclasses +from typing import TYPE_CHECKING + +from calculator.ast.visit import Visitor +from calculator.core.operations import ( + identity, + int_div, + minus, + negate, + plus, + pow, + times, +) +from pydantic.dataclasses import dataclass + +if TYPE_CHECKING: + from calculator.ast import BinOp, Constant, Node, UnaryOp + +OP_TO_STR = { + plus: "+", + minus: "-", + times: "*", + int_div: "/", + pow: "^", + identity: "+", + negate: "-", +} + + +@dataclass +class Printer(Visitor): + """ + Print a Tree + """ + + indent: int = dataclasses.field(default=0, init=False) + + def print(self, n: Node) -> None: + n.accept(self) + + def visit_constant(self, c: Constant) -> None: + print(" " * self.indent + str(c.value)) + + def visit_binop(self, b: BinOp) -> None: + print(" " * self.indent + OP_TO_STR[b.op]) + self.indent += 2 + self.print(b.lhs) + self.print(b.rhs) + self.indent -= 2 + + def visit_unaryop(self, b: UnaryOp) -> None: + print(" " * self.indent + OP_TO_STR[b.op]) + self.indent += 2 + self.print(b.rhs) + self.indent -= 2 diff --git a/calculator/calculator/print/test_printer.py b/calculator/calculator/print/test_printer.py new file mode 100644 index 0000000..22f7670 --- /dev/null +++ b/calculator/calculator/print/test_printer.py @@ -0,0 +1,28 @@ +import pytest +from calculator.parse import parse_infix + +from .printer import Printer + + +def print_test_helper(input: str, expected: str, capsys): + Printer().print(parse_infix(input)) + out, __ = capsys.readouterr() + assert out == expected + + +def test_printer_constant(capsys): + print_test_helper("42", "42\n", capsys) + + +def test_printer_negated_constant(capsys): + print_test_helper("-42", "-\n 42\n", capsys) + + +def test_printer_binaryop(capsys): + print_test_helper("12 + 27", "+\n 12\n 27\n", capsys) + + +def test_print_complex_expression(capsys): + print_test_helper( + "12 + 27 * 42 ^51", "+\n 12\n *\n 27\n ^\n 42\n 51\n", capsys + )