Spaces:
Sleeping
Sleeping
Commit
·
3872a55
1
Parent(s):
c9abdca
Add variable assignment and checking
Browse files- README.md +4 -0
- abstract_syntax_tree.py +21 -9
- arithmetic.py +35 -11
- demonstration.ipynb +2 -25
README.md
CHANGED
|
@@ -49,6 +49,10 @@ python3 synthesizer.py --domain arithmetic --examples addition
|
|
| 49 |
|
| 50 |
To add additional input-output examples, modify `examples.py`. Add a new key to the dictionary `example_set` and set the value to be a list of tuples.
|
| 51 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 52 |
## 🔮 Virtual Environment
|
| 53 |
|
| 54 |
To create a virtual environment, run:
|
|
|
|
| 49 |
|
| 50 |
To add additional input-output examples, modify `examples.py`. Add a new key to the dictionary `example_set` and set the value to be a list of tuples.
|
| 51 |
|
| 52 |
+
## 🔎 Abstract Syntax Tree
|
| 53 |
+
|
| 54 |
+
The most important data structure in this implementation is the abstract syntax tree (AST). The AST is a tree representation of a program, where each node is either a primitive or a compound expression. The AST is represented by the `OperatorNode` class in `abstract_syntax_tree.py`. My AST implementation includes functions to recursively evaluate the operator and its operands, and also to generate a string representation of the program.
|
| 55 |
+
|
| 56 |
## 🔮 Virtual Environment
|
| 57 |
|
| 58 |
To create a virtual environment, run:
|
abstract_syntax_tree.py
CHANGED
|
@@ -11,28 +11,40 @@ class OperatorNode:
|
|
| 11 |
operator (object): operator object (e.g., Add, Subtract, etc.)
|
| 12 |
children (list): list of children nodes (operands)
|
| 13 |
|
| 14 |
-
Example
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
|
|
|
|
|
|
| 19 |
'''
|
|
|
|
| 20 |
def __init__(self, operator, children):
|
| 21 |
self.operator = operator # Operator object (e.g., Add, Subtract, etc.)
|
| 22 |
self.children = children # list of children nodes (operands)
|
| 23 |
|
| 24 |
-
def evaluate(self):
|
| 25 |
-
|
|
|
|
| 26 |
if len(self.children) != self.operator.arity:
|
| 27 |
raise ValueError("Invalid number of operands for operator")
|
|
|
|
| 28 |
# recursively evaluate the operator and its operands
|
| 29 |
-
operands = [child.evaluate() for child in self.children]
|
| 30 |
-
return self.operator.evaluate(*operands)
|
| 31 |
|
| 32 |
def str(self):
|
| 33 |
-
|
|
|
|
| 34 |
if len(self.children) != self.operator.arity:
|
| 35 |
raise ValueError("Invalid number of operands for operator")
|
|
|
|
| 36 |
# recursively generate a string representation of the AST
|
| 37 |
operand_strings = [child.str() for child in self.children]
|
| 38 |
return self.operator.str(*operand_strings)
|
|
|
|
| 11 |
operator (object): operator object (e.g., Add, Subtract, etc.)
|
| 12 |
children (list): list of children nodes (operands)
|
| 13 |
|
| 14 |
+
Example:
|
| 15 |
+
add_node = OperatorNode(Add(), [IntegerConstant(7), IntegerConstant(5)])
|
| 16 |
+
subtract_node = OperatorNode(Subtract(), [IntegerConstant(3), IntegerConstant(1)])
|
| 17 |
+
multiply_node = OperatorNode(Multiply(), [add_node, subtract_node])
|
| 18 |
+
multiply_node.evaluate() # returns 24
|
| 19 |
+
multiply_node.str() # returns "((7 + 5) * (3 - 1))"
|
| 20 |
|
| 21 |
+
For variable computation, the input arguments are passed to the evaluate() method.
|
| 22 |
+
For example, if instead:
|
| 23 |
+
|
| 24 |
+
add_node = OperatorNode(Add(), [IntegerVariable(0), IntegerConstant(5)])
|
| 25 |
+
multiply_node.evaluate([7]) # returns 24
|
| 26 |
'''
|
| 27 |
+
|
| 28 |
def __init__(self, operator, children):
|
| 29 |
self.operator = operator # Operator object (e.g., Add, Subtract, etc.)
|
| 30 |
self.children = children # list of children nodes (operands)
|
| 31 |
|
| 32 |
+
def evaluate(self, input = None):
|
| 33 |
+
|
| 34 |
+
# check arity of operator in AST
|
| 35 |
if len(self.children) != self.operator.arity:
|
| 36 |
raise ValueError("Invalid number of operands for operator")
|
| 37 |
+
|
| 38 |
# recursively evaluate the operator and its operands
|
| 39 |
+
operands = [child.evaluate(input) for child in self.children]
|
| 40 |
+
return self.operator.evaluate(*operands, input)
|
| 41 |
|
| 42 |
def str(self):
|
| 43 |
+
|
| 44 |
+
# check arity of operator in AST
|
| 45 |
if len(self.children) != self.operator.arity:
|
| 46 |
raise ValueError("Invalid number of operands for operator")
|
| 47 |
+
|
| 48 |
# recursively generate a string representation of the AST
|
| 49 |
operand_strings = [child.str() for child in self.children]
|
| 50 |
return self.operator.str(*operand_strings)
|
arithmetic.py
CHANGED
|
@@ -13,15 +13,39 @@ class IntegerVariable:
|
|
| 13 |
For example, if the input is [4, 5, 6] and the variable is the third element (i.e., 6), then position = 2.
|
| 14 |
'''
|
| 15 |
def __init__(self, position):
|
| 16 |
-
self.value = None # value of the variable, initially None
|
| 17 |
-
self.position = position # position of the variable in the arguments to program
|
| 18 |
self.type = int # type of the variable
|
| 19 |
|
| 20 |
-
def assign(self, value):
|
| 21 |
-
|
| 22 |
|
| 23 |
-
def evaluate(self):
|
| 24 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
|
| 26 |
def str(self):
|
| 27 |
return f"x{self.position}"
|
|
@@ -34,7 +58,7 @@ class IntegerConstant:
|
|
| 34 |
self.value = value # value of the constant
|
| 35 |
self.type = int # type of the constant
|
| 36 |
|
| 37 |
-
def evaluate(self):
|
| 38 |
return self.value
|
| 39 |
|
| 40 |
def str(self):
|
|
@@ -50,7 +74,7 @@ class Add:
|
|
| 50 |
self.return_type = int # return type
|
| 51 |
self.weight = 1 # weight
|
| 52 |
|
| 53 |
-
def evaluate(self, x, y):
|
| 54 |
return x + y
|
| 55 |
|
| 56 |
def str(self, x, y):
|
|
@@ -66,7 +90,7 @@ class Subtract:
|
|
| 66 |
self.return_type = int # return type
|
| 67 |
self.weight = 1 # weight
|
| 68 |
|
| 69 |
-
def evaluate(self, x, y):
|
| 70 |
return x - y
|
| 71 |
|
| 72 |
def str(self, x, y):
|
|
@@ -82,7 +106,7 @@ class Multiply:
|
|
| 82 |
self.return_type = int # return type
|
| 83 |
self.weight = 1 # weight
|
| 84 |
|
| 85 |
-
def evaluate(self, x, y):
|
| 86 |
return x * y
|
| 87 |
|
| 88 |
def str(self, x, y):
|
|
@@ -98,7 +122,7 @@ class Divide:
|
|
| 98 |
self.return_type = int # return type
|
| 99 |
self.weight = 1 # weight
|
| 100 |
|
| 101 |
-
def evaluate(self, x, y):
|
| 102 |
try: # check for division by zero error
|
| 103 |
return x / y
|
| 104 |
except ZeroDivisionError:
|
|
|
|
| 13 |
For example, if the input is [4, 5, 6] and the variable is the third element (i.e., 6), then position = 2.
|
| 14 |
'''
|
| 15 |
def __init__(self, position):
|
| 16 |
+
# self.value = None # value of the variable, initially None
|
| 17 |
+
self.position = position # zero-indexed position of the variable in the arguments to program
|
| 18 |
self.type = int # type of the variable
|
| 19 |
|
| 20 |
+
# def assign(self, value):
|
| 21 |
+
# self.value = value
|
| 22 |
|
| 23 |
+
# def evaluate(self, input = None):
|
| 24 |
+
# # check that variable has been assigned a value
|
| 25 |
+
# if self.value is None:
|
| 26 |
+
# raise ValueError(f"Variable {self.position} has not been assigned a value.")
|
| 27 |
+
|
| 28 |
+
# return self.value
|
| 29 |
+
|
| 30 |
+
def evaluate(self, input = None):
|
| 31 |
+
|
| 32 |
+
# check that input is not None
|
| 33 |
+
if input is None:
|
| 34 |
+
raise ValueError("Input is None.")
|
| 35 |
+
|
| 36 |
+
# check that input is a list
|
| 37 |
+
if type(input) != list:
|
| 38 |
+
raise ValueError("Input is not a list.")
|
| 39 |
+
|
| 40 |
+
# check that input is not empty
|
| 41 |
+
if len(input) == 0:
|
| 42 |
+
raise ValueError("Input is empty.")
|
| 43 |
+
|
| 44 |
+
# check that position is valid
|
| 45 |
+
if self.position >= len(input):
|
| 46 |
+
raise ValueError(f"Position {self.position} is out of range for input of length {len(input)}.")
|
| 47 |
+
|
| 48 |
+
return input[self.position]
|
| 49 |
|
| 50 |
def str(self):
|
| 51 |
return f"x{self.position}"
|
|
|
|
| 58 |
self.value = value # value of the constant
|
| 59 |
self.type = int # type of the constant
|
| 60 |
|
| 61 |
+
def evaluate(self, input = None):
|
| 62 |
return self.value
|
| 63 |
|
| 64 |
def str(self):
|
|
|
|
| 74 |
self.return_type = int # return type
|
| 75 |
self.weight = 1 # weight
|
| 76 |
|
| 77 |
+
def evaluate(self, x, y, input = None):
|
| 78 |
return x + y
|
| 79 |
|
| 80 |
def str(self, x, y):
|
|
|
|
| 90 |
self.return_type = int # return type
|
| 91 |
self.weight = 1 # weight
|
| 92 |
|
| 93 |
+
def evaluate(self, x, y, input = None):
|
| 94 |
return x - y
|
| 95 |
|
| 96 |
def str(self, x, y):
|
|
|
|
| 106 |
self.return_type = int # return type
|
| 107 |
self.weight = 1 # weight
|
| 108 |
|
| 109 |
+
def evaluate(self, x, y, input = None):
|
| 110 |
return x * y
|
| 111 |
|
| 112 |
def str(self, x, y):
|
|
|
|
| 122 |
self.return_type = int # return type
|
| 123 |
self.weight = 1 # weight
|
| 124 |
|
| 125 |
+
def evaluate(self, x, y, input = None):
|
| 126 |
try: # check for division by zero error
|
| 127 |
return x / y
|
| 128 |
except ZeroDivisionError:
|
demonstration.ipynb
CHANGED
|
@@ -26,35 +26,12 @@
|
|
| 26 |
"import argparse\n",
|
| 27 |
"\n",
|
| 28 |
"# import arithmetic module\n",
|
| 29 |
-
"
|
| 30 |
-
"
|
| 31 |
"from examples import example_set, check_examples\n",
|
| 32 |
"import config"
|
| 33 |
]
|
| 34 |
},
|
| 35 |
-
{
|
| 36 |
-
"cell_type": "code",
|
| 37 |
-
"execution_count": 14,
|
| 38 |
-
"metadata": {},
|
| 39 |
-
"outputs": [
|
| 40 |
-
{
|
| 41 |
-
"data": {
|
| 42 |
-
"text/plain": [
|
| 43 |
-
"24"
|
| 44 |
-
]
|
| 45 |
-
},
|
| 46 |
-
"execution_count": 14,
|
| 47 |
-
"metadata": {},
|
| 48 |
-
"output_type": "execute_result"
|
| 49 |
-
}
|
| 50 |
-
],
|
| 51 |
-
"source": [
|
| 52 |
-
"add_node = OperatorNode(Add(), [IntegerConstant(7), IntegerConstant(5)])\n",
|
| 53 |
-
"subtract_node = OperatorNode(Subtract(), [IntegerConstant(3), IntegerConstant(1)])\n",
|
| 54 |
-
"multiply_node = OperatorNode(Multiply(), [add_node, subtract_node])\n",
|
| 55 |
-
"multiply_node.evaluate()"
|
| 56 |
-
]
|
| 57 |
-
},
|
| 58 |
{
|
| 59 |
"cell_type": "markdown",
|
| 60 |
"metadata": {},
|
|
|
|
| 26 |
"import argparse\n",
|
| 27 |
"\n",
|
| 28 |
"# import arithmetic module\n",
|
| 29 |
+
"from arithmetic import *\n",
|
| 30 |
+
"from abstract_syntax_tree import OperatorNode\n",
|
| 31 |
"from examples import example_set, check_examples\n",
|
| 32 |
"import config"
|
| 33 |
]
|
| 34 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 35 |
{
|
| 36 |
"cell_type": "markdown",
|
| 37 |
"metadata": {},
|