diff --git a/calculator/calculator/parse/__init__.py b/calculator/calculator/parse/__init__.py new file mode 100644 index 0000000..0213ec2 --- /dev/null +++ b/calculator/calculator/parse/__init__.py @@ -0,0 +1 @@ +from .postfix import parse_postfix diff --git a/calculator/calculator/parse/parsed_string.py b/calculator/calculator/parse/parsed_string.py new file mode 100644 index 0000000..0965970 --- /dev/null +++ b/calculator/calculator/parse/parsed_string.py @@ -0,0 +1,37 @@ +from typing import List, Union + +from calculator.ast import BinOp, Constant, Node, UnaryOp +from pydantic.dataclasses import dataclass + + +def begins_with_digit(d: str) -> bool: + return "0" <= d[0] <= "9" + + +@dataclass +class ParsedString: + input: str + + def _is_done(self) -> bool: + return len(self.input) == 0 + + def _get_token(self) -> Union[int, str]: + ans = "" + self.input = self.input.strip() # Remove whitespace + while begins_with_digit(self.input): + ans += self.input[0] + self.input = self.input[1:] + if len(self.input) == 0 or not begins_with_digit(self.input): + break + if len(ans) != 0: # Was a number, return the converted int + return int(ans) + # Was not a number, return the (presumed) symbol + ans = self.input[0] + self.input = self.input[1:] + return ans + + def tokenize(self) -> List[Union[int, str]]: + ans: List[str] = [] + while not self._is_done(): + ans.append(self._get_token()) + return ans diff --git a/calculator/calculator/parse/postfix.py b/calculator/calculator/parse/postfix.py new file mode 100644 index 0000000..0297069 --- /dev/null +++ b/calculator/calculator/parse/postfix.py @@ -0,0 +1,34 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, List, Union, cast + +from calculator.ast import BinOp, Constant, UnaryOp +from calculator.core import operations + +from .parsed_string import ParsedString + +if TYPE_CHECKING: + from calculator.ast import Node + + +def stack_to_tree(s: List[Union[int, str]]) -> Node: + top = s.pop() + if type(top) is int: + return Constant(top) + top = cast(str, top) + if top == "@": + rhs = stack_to_tree(s) + return UnaryOp(operations.negate, rhs) + rhs = stack_to_tree(s) + lhs = stack_to_tree(s) + return BinOp(operations.STR_TO_BIN[top], lhs, rhs) + + +def parse_postfix(input: str) -> Node: + """ + Parses the given string in postfix notation. + Negation is represented by the '@' sign. + """ + parsed = ParsedString(input).tokenize() + ans = stack_to_tree(parsed) + return ans diff --git a/calculator/calculator/parse/test_postfix.py b/calculator/calculator/parse/test_postfix.py new file mode 100644 index 0000000..a5784f7 --- /dev/null +++ b/calculator/calculator/parse/test_postfix.py @@ -0,0 +1,32 @@ +from calculator.ast import BinOp, Constant, UnaryOp +from calculator.core import operations + +from .postfix import parse_postfix + + +def test_parse_constant(): + assert parse_postfix("42") == Constant(42) + + +def test_parse_negated_constant(): + assert parse_postfix("42 @") == UnaryOp(operations.negate, Constant(42)) + + +def test_parse_doubly_negated_constant(): + assert parse_postfix("42@@") == UnaryOp( + operations.negate, UnaryOp(operations.negate, Constant(42)) + ) + + +def test_parse_binary_operation(): + assert parse_postfix("12 27 +") == BinOp( + operations.plus, Constant(12), Constant(27) + ) + + +def test_parse_complete_expression_tree(): + assert parse_postfix("12 27 + 42 51 - *") == BinOp( + operations.times, + BinOp(operations.plus, Constant(12), Constant(27)), + BinOp(operations.minus, Constant(42), Constant(51)), + )