Spaces:
Sleeping
Sleeping
Commit
·
a417ea3
1
Parent(s):
771c860
Extract constants and variables
Browse files- README.md +3 -3
- abstract_syntax_trees.py +5 -0
- arithmetic.py +33 -21
- demonstration.ipynb +109 -39
- examples.py +41 -4
- synthesizer.py +73 -9
README.md
CHANGED
|
@@ -2,7 +2,7 @@
|
|
| 2 |
|
| 3 |
🚨🚨PLEASE DO NOT GRADE YET🚨🚨
|
| 4 |
|
| 5 |
-
Completed for [CS252R: Program Synthesis](https://synthesis.metareflection.club/) at the Harvard John A. Paulson School of Engineering and Applied Sciences, taught Fall 2023 by Prof. Nada Amin.
|
| 6 |
|
| 7 |
## 🛠️ Background
|
| 8 |
|
|
@@ -37,7 +37,7 @@ optional arguments:
|
|
| 37 |
--domain {arithmetic,string}
|
| 38 |
Domain of synthesis (either "arithmetic" or "string").
|
| 39 |
--examples {addition,subtraction,multiplication,division}
|
| 40 |
-
Examples to synthesize program from. Must be a valid key in the "
|
| 41 |
--max_weight MAX_WEIGHT
|
| 42 |
Maximum weight of programs to consider before terminating search.
|
| 43 |
```
|
|
@@ -47,7 +47,7 @@ For example, to synthesize programs in the arithmetic domain from the addition i
|
|
| 47 |
python3 synthesizer.py --domain arithmetic --examples addition
|
| 48 |
```
|
| 49 |
|
| 50 |
-
To add additional input-output examples, modify `examples.py`. Add a new key to the dictionary `
|
| 51 |
|
| 52 |
## 🔮 Virtual Environment
|
| 53 |
|
|
|
|
| 2 |
|
| 3 |
🚨🚨PLEASE DO NOT GRADE YET🚨🚨
|
| 4 |
|
| 5 |
+
Completed for [CS252R: Program Synthesis](https://synthesis.metareflection.club/) at the Harvard John A. Paulson School of Engineering and Applied Sciences, taught in Fall 2023 by Prof. Nada Amin.
|
| 6 |
|
| 7 |
## 🛠️ Background
|
| 8 |
|
|
|
|
| 37 |
--domain {arithmetic,string}
|
| 38 |
Domain of synthesis (either "arithmetic" or "string").
|
| 39 |
--examples {addition,subtraction,multiplication,division}
|
| 40 |
+
Examples to synthesize program from. Must be a valid key in the "example_set" dictionary.
|
| 41 |
--max_weight MAX_WEIGHT
|
| 42 |
Maximum weight of programs to consider before terminating search.
|
| 43 |
```
|
|
|
|
| 47 |
python3 synthesizer.py --domain arithmetic --examples addition
|
| 48 |
```
|
| 49 |
|
| 50 |
+
To add additional input-output examples, modify `examples.py`. Add a new key to the dictionary `example_set` and set the value to be a list of tuples.
|
| 51 |
|
| 52 |
## 🔮 Virtual Environment
|
| 53 |
|
abstract_syntax_trees.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
'''
|
| 2 |
+
ABSTRACT SYNTAX TREES
|
| 3 |
+
This file contains Python classes that define the abstract syntax tree (AST) for program synthesis.
|
| 4 |
+
'''
|
| 5 |
+
|
arithmetic.py
CHANGED
|
@@ -7,22 +7,36 @@ This file contains Python classes that define the arithmetic operators for progr
|
|
| 7 |
CLASS DEFINITIONS
|
| 8 |
'''
|
| 9 |
|
| 10 |
-
class
|
| 11 |
'''
|
| 12 |
-
Class to represent an
|
|
|
|
| 13 |
'''
|
| 14 |
-
def __init__(self,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
self.value = value
|
| 16 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
|
| 18 |
class Add:
|
| 19 |
'''
|
| 20 |
Operator to add two numerical values.
|
| 21 |
'''
|
| 22 |
def __init__(self):
|
| 23 |
-
self.arity = 2
|
| 24 |
-
self.
|
| 25 |
-
self.return_type = int
|
|
|
|
| 26 |
|
| 27 |
def __call__(self, x, y):
|
| 28 |
return x + y
|
|
@@ -35,9 +49,10 @@ class Subtract:
|
|
| 35 |
Operator to subtract two numerical values.
|
| 36 |
'''
|
| 37 |
def __init__(self):
|
| 38 |
-
self.arity = 2
|
| 39 |
-
self.
|
| 40 |
-
self.return_type = int
|
|
|
|
| 41 |
|
| 42 |
def __call__(self, x, y):
|
| 43 |
return x - y
|
|
@@ -50,9 +65,10 @@ class Multiply:
|
|
| 50 |
Operator to multiply two numerical values.
|
| 51 |
'''
|
| 52 |
def __init__(self):
|
| 53 |
-
self.arity = 2
|
| 54 |
-
self.
|
| 55 |
-
self.return_type = int
|
|
|
|
| 56 |
|
| 57 |
def __call__(self, x, y):
|
| 58 |
return x * y
|
|
@@ -65,9 +81,10 @@ class Divide:
|
|
| 65 |
Operator to divide two numerical values.
|
| 66 |
'''
|
| 67 |
def __init__(self):
|
| 68 |
-
self.arity = 2
|
| 69 |
-
self.
|
| 70 |
-
self.return_type = int
|
|
|
|
| 71 |
|
| 72 |
def __call__(self, x, y):
|
| 73 |
try: # check for division by zero error
|
|
@@ -79,11 +96,6 @@ class Divide:
|
|
| 79 |
return f"{x} / {y}"
|
| 80 |
|
| 81 |
|
| 82 |
-
'''
|
| 83 |
-
FUNCTION DEFINITIONS
|
| 84 |
-
'''
|
| 85 |
-
|
| 86 |
-
|
| 87 |
'''
|
| 88 |
GLOBAL CONSTANTS
|
| 89 |
'''
|
|
|
|
| 7 |
CLASS DEFINITIONS
|
| 8 |
'''
|
| 9 |
|
| 10 |
+
class IntegerVariable:
|
| 11 |
'''
|
| 12 |
+
Class to represent an integer variable. Note that position is the position of the variable in the input.
|
| 13 |
+
For example, if the input is [4, 5, 6] and the variable is the third element (i.e., 6), then position = 2.
|
| 14 |
'''
|
| 15 |
+
def __init__(self, position):
|
| 16 |
+
self.value = None # value of the variable, initially None
|
| 17 |
+
self.position = position # position of the variable in the arguments to program
|
| 18 |
+
self.type = int # type of the variable
|
| 19 |
+
|
| 20 |
+
def assign(self, value):
|
| 21 |
self.value = value
|
| 22 |
+
|
| 23 |
+
class IntegerConstant:
|
| 24 |
+
'''
|
| 25 |
+
Class to represent an integer constant.
|
| 26 |
+
'''
|
| 27 |
+
def __init__(self, value):
|
| 28 |
+
self.value = value # value of the constant
|
| 29 |
+
self.type = int # type of the constant
|
| 30 |
|
| 31 |
class Add:
|
| 32 |
'''
|
| 33 |
Operator to add two numerical values.
|
| 34 |
'''
|
| 35 |
def __init__(self):
|
| 36 |
+
self.arity = 2 # number of arguments
|
| 37 |
+
self.arg_types = [int, int] # argument types
|
| 38 |
+
self.return_type = int # return type
|
| 39 |
+
self.weight = 1 # weight
|
| 40 |
|
| 41 |
def __call__(self, x, y):
|
| 42 |
return x + y
|
|
|
|
| 49 |
Operator to subtract two numerical values.
|
| 50 |
'''
|
| 51 |
def __init__(self):
|
| 52 |
+
self.arity = 2 # number of arguments
|
| 53 |
+
self.arg_types = [int, int] # argument types
|
| 54 |
+
self.return_type = int # return type
|
| 55 |
+
self.weight = 1 # weight
|
| 56 |
|
| 57 |
def __call__(self, x, y):
|
| 58 |
return x - y
|
|
|
|
| 65 |
Operator to multiply two numerical values.
|
| 66 |
'''
|
| 67 |
def __init__(self):
|
| 68 |
+
self.arity = 2 # number of arguments
|
| 69 |
+
self.arg_types = [int, int] # argument types
|
| 70 |
+
self.return_type = int # return type
|
| 71 |
+
self.weight = 1 # weight
|
| 72 |
|
| 73 |
def __call__(self, x, y):
|
| 74 |
return x * y
|
|
|
|
| 81 |
Operator to divide two numerical values.
|
| 82 |
'''
|
| 83 |
def __init__(self):
|
| 84 |
+
self.arity = 2 # number of arguments
|
| 85 |
+
self.arg_types = [int, int] # argument types
|
| 86 |
+
self.return_type = int # return type
|
| 87 |
+
self.weight = 1 # weight
|
| 88 |
|
| 89 |
def __call__(self, x, y):
|
| 90 |
try: # check for division by zero error
|
|
|
|
| 96 |
return f"{x} / {y}"
|
| 97 |
|
| 98 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 99 |
'''
|
| 100 |
GLOBAL CONSTANTS
|
| 101 |
'''
|
demonstration.ipynb
CHANGED
|
@@ -15,7 +15,7 @@
|
|
| 15 |
},
|
| 16 |
{
|
| 17 |
"cell_type": "code",
|
| 18 |
-
"execution_count":
|
| 19 |
"metadata": {},
|
| 20 |
"outputs": [],
|
| 21 |
"source": [
|
|
@@ -27,7 +27,7 @@
|
|
| 27 |
"\n",
|
| 28 |
"# import arithmetic module\n",
|
| 29 |
"# from arithmetic import *\n",
|
| 30 |
-
"from examples import
|
| 31 |
"import config"
|
| 32 |
]
|
| 33 |
},
|
|
@@ -40,16 +40,23 @@
|
|
| 40 |
},
|
| 41 |
{
|
| 42 |
"cell_type": "code",
|
| 43 |
-
"execution_count":
|
| 44 |
"metadata": {},
|
| 45 |
"outputs": [],
|
| 46 |
"source": [
|
| 47 |
"domain = \"arithmetic\"\n",
|
| 48 |
"examples_key = \"addition\"\n",
|
| 49 |
-
"examples =
|
| 50 |
"max_weight = 3"
|
| 51 |
]
|
| 52 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 53 |
{
|
| 54 |
"cell_type": "markdown",
|
| 55 |
"metadata": {},
|
|
@@ -59,35 +66,40 @@
|
|
| 59 |
},
|
| 60 |
{
|
| 61 |
"cell_type": "code",
|
| 62 |
-
"execution_count":
|
| 63 |
"metadata": {},
|
| 64 |
"outputs": [],
|
| 65 |
"source": [
|
| 66 |
-
"
|
| 67 |
-
"
|
| 68 |
-
"
|
| 69 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 70 |
"\n",
|
| 71 |
-
"
|
| 72 |
-
"
|
| 73 |
-
"''' \n",
|
| 74 |
"\n",
|
| 75 |
-
"class
|
| 76 |
" '''\n",
|
| 77 |
-
" Class to represent an
|
| 78 |
" '''\n",
|
| 79 |
" def __init__(self, value):\n",
|
| 80 |
-
" self.value = value\n",
|
| 81 |
-
" self.type = int\n",
|
| 82 |
"\n",
|
| 83 |
"class Add:\n",
|
| 84 |
" '''\n",
|
| 85 |
" Operator to add two numerical values.\n",
|
| 86 |
" '''\n",
|
| 87 |
" def __init__(self):\n",
|
| 88 |
-
" self.arity = 2
|
| 89 |
-
" self.
|
| 90 |
-
" self.return_type = int
|
|
|
|
| 91 |
"\n",
|
| 92 |
" def __call__(self, x, y):\n",
|
| 93 |
" return x + y\n",
|
|
@@ -100,9 +112,10 @@
|
|
| 100 |
" Operator to subtract two numerical values.\n",
|
| 101 |
" '''\n",
|
| 102 |
" def __init__(self):\n",
|
| 103 |
-
" self.arity = 2
|
| 104 |
-
" self.
|
| 105 |
-
" self.return_type = int
|
|
|
|
| 106 |
"\n",
|
| 107 |
" def __call__(self, x, y):\n",
|
| 108 |
" return x - y\n",
|
|
@@ -115,9 +128,10 @@
|
|
| 115 |
" Operator to multiply two numerical values.\n",
|
| 116 |
" '''\n",
|
| 117 |
" def __init__(self):\n",
|
| 118 |
-
" self.arity = 2
|
| 119 |
-
" self.
|
| 120 |
-
" self.return_type = int
|
|
|
|
| 121 |
"\n",
|
| 122 |
" def __call__(self, x, y):\n",
|
| 123 |
" return x * y\n",
|
|
@@ -130,9 +144,10 @@
|
|
| 130 |
" Operator to divide two numerical values.\n",
|
| 131 |
" '''\n",
|
| 132 |
" def __init__(self):\n",
|
| 133 |
-
" self.arity = 2
|
| 134 |
-
" self.
|
| 135 |
-
" self.return_type = int
|
|
|
|
| 136 |
"\n",
|
| 137 |
" def __call__(self, x, y):\n",
|
| 138 |
" try: # check for division by zero error\n",
|
|
@@ -145,11 +160,6 @@
|
|
| 145 |
"\n",
|
| 146 |
"\n",
|
| 147 |
"'''\n",
|
| 148 |
-
"FUNCTION DEFINITIONS\n",
|
| 149 |
-
"''' \n",
|
| 150 |
-
"\n",
|
| 151 |
-
"\n",
|
| 152 |
-
"'''\n",
|
| 153 |
"GLOBAL CONSTANTS\n",
|
| 154 |
"''' \n",
|
| 155 |
"\n",
|
|
@@ -161,7 +171,70 @@
|
|
| 161 |
"cell_type": "markdown",
|
| 162 |
"metadata": {},
|
| 163 |
"source": [
|
| 164 |
-
"I define
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 165 |
]
|
| 166 |
},
|
| 167 |
{
|
|
@@ -198,16 +271,13 @@
|
|
| 198 |
"metadata": {},
|
| 199 |
"outputs": [],
|
| 200 |
"source": [
|
| 201 |
-
"# initialize program bank\n",
|
| 202 |
-
"program_bank = []\n",
|
| 203 |
-
"\n",
|
| 204 |
"# iterate over each level\n",
|
| 205 |
-
"for i in range(
|
| 206 |
"\n",
|
| 207 |
" # define level program bank\n",
|
| 208 |
" level_program_bank = []\n",
|
| 209 |
"\n",
|
| 210 |
-
" for op in arithmetic_operators
|
| 211 |
"\n",
|
| 212 |
" break"
|
| 213 |
]
|
|
|
|
| 15 |
},
|
| 16 |
{
|
| 17 |
"cell_type": "code",
|
| 18 |
+
"execution_count": 1,
|
| 19 |
"metadata": {},
|
| 20 |
"outputs": [],
|
| 21 |
"source": [
|
|
|
|
| 27 |
"\n",
|
| 28 |
"# import arithmetic module\n",
|
| 29 |
"# from arithmetic import *\n",
|
| 30 |
+
"from examples import example_set, check_examples\n",
|
| 31 |
"import config"
|
| 32 |
]
|
| 33 |
},
|
|
|
|
| 40 |
},
|
| 41 |
{
|
| 42 |
"cell_type": "code",
|
| 43 |
+
"execution_count": 2,
|
| 44 |
"metadata": {},
|
| 45 |
"outputs": [],
|
| 46 |
"source": [
|
| 47 |
"domain = \"arithmetic\"\n",
|
| 48 |
"examples_key = \"addition\"\n",
|
| 49 |
+
"examples = example_set[examples_key]\n",
|
| 50 |
"max_weight = 3"
|
| 51 |
]
|
| 52 |
},
|
| 53 |
+
{
|
| 54 |
+
"cell_type": "markdown",
|
| 55 |
+
"metadata": {},
|
| 56 |
+
"source": [
|
| 57 |
+
"First, I define a function to check that, across all input-output pairs, all inputs are of the same length and that argument types are consistent across inputs."
|
| 58 |
+
]
|
| 59 |
+
},
|
| 60 |
{
|
| 61 |
"cell_type": "markdown",
|
| 62 |
"metadata": {},
|
|
|
|
| 66 |
},
|
| 67 |
{
|
| 68 |
"cell_type": "code",
|
| 69 |
+
"execution_count": 8,
|
| 70 |
"metadata": {},
|
| 71 |
"outputs": [],
|
| 72 |
"source": [
|
| 73 |
+
"class IntegerVariable:\n",
|
| 74 |
+
" '''\n",
|
| 75 |
+
" Class to represent an integer variable. Note that position is the position of the variable in the input.\n",
|
| 76 |
+
" For example, if the input is [4, 5, 6] and the variable is the third element (i.e., 6), then position = 2.\n",
|
| 77 |
+
" '''\n",
|
| 78 |
+
" def __init__(self, position):\n",
|
| 79 |
+
" self.value = None # value of the variable, initially None\n",
|
| 80 |
+
" self.position = position # position of the variable in the arguments to program\n",
|
| 81 |
+
" self.type = int # type of the variable\n",
|
| 82 |
"\n",
|
| 83 |
+
" def assign(self, value):\n",
|
| 84 |
+
" self.value = value\n",
|
|
|
|
| 85 |
"\n",
|
| 86 |
+
"class IntegerConstant:\n",
|
| 87 |
" '''\n",
|
| 88 |
+
" Class to represent an integer constant.\n",
|
| 89 |
" '''\n",
|
| 90 |
" def __init__(self, value):\n",
|
| 91 |
+
" self.value = value # value of the constant\n",
|
| 92 |
+
" self.type = int # type of the constant\n",
|
| 93 |
"\n",
|
| 94 |
"class Add:\n",
|
| 95 |
" '''\n",
|
| 96 |
" Operator to add two numerical values.\n",
|
| 97 |
" '''\n",
|
| 98 |
" def __init__(self):\n",
|
| 99 |
+
" self.arity = 2 # number of arguments\n",
|
| 100 |
+
" self.arg_types = [int, int] # argument types\n",
|
| 101 |
+
" self.return_type = int # return type\n",
|
| 102 |
+
" self.weight = 1 # weight\n",
|
| 103 |
"\n",
|
| 104 |
" def __call__(self, x, y):\n",
|
| 105 |
" return x + y\n",
|
|
|
|
| 112 |
" Operator to subtract two numerical values.\n",
|
| 113 |
" '''\n",
|
| 114 |
" def __init__(self):\n",
|
| 115 |
+
" self.arity = 2 # number of arguments\n",
|
| 116 |
+
" self.arg_types = [int, int] # argument types\n",
|
| 117 |
+
" self.return_type = int # return type\n",
|
| 118 |
+
" self.weight = 1 # weight\n",
|
| 119 |
"\n",
|
| 120 |
" def __call__(self, x, y):\n",
|
| 121 |
" return x - y\n",
|
|
|
|
| 128 |
" Operator to multiply two numerical values.\n",
|
| 129 |
" '''\n",
|
| 130 |
" def __init__(self):\n",
|
| 131 |
+
" self.arity = 2 # number of arguments\n",
|
| 132 |
+
" self.arg_types = [int, int] # argument types\n",
|
| 133 |
+
" self.return_type = int # return type\n",
|
| 134 |
+
" self.weight = 1 # weight\n",
|
| 135 |
"\n",
|
| 136 |
" def __call__(self, x, y):\n",
|
| 137 |
" return x * y\n",
|
|
|
|
| 144 |
" Operator to divide two numerical values.\n",
|
| 145 |
" '''\n",
|
| 146 |
" def __init__(self):\n",
|
| 147 |
+
" self.arity = 2 # number of arguments\n",
|
| 148 |
+
" self.arg_types = [int, int] # argument types\n",
|
| 149 |
+
" self.return_type = int # return type\n",
|
| 150 |
+
" self.weight = 1 # weight\n",
|
| 151 |
"\n",
|
| 152 |
" def __call__(self, x, y):\n",
|
| 153 |
" try: # check for division by zero error\n",
|
|
|
|
| 160 |
"\n",
|
| 161 |
"\n",
|
| 162 |
"'''\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 163 |
"GLOBAL CONSTANTS\n",
|
| 164 |
"''' \n",
|
| 165 |
"\n",
|
|
|
|
| 171 |
"cell_type": "markdown",
|
| 172 |
"metadata": {},
|
| 173 |
"source": [
|
| 174 |
+
"I define a function to extract constants from examples."
|
| 175 |
+
]
|
| 176 |
+
},
|
| 177 |
+
{
|
| 178 |
+
"cell_type": "code",
|
| 179 |
+
"execution_count": 9,
|
| 180 |
+
"metadata": {},
|
| 181 |
+
"outputs": [],
|
| 182 |
+
"source": [
|
| 183 |
+
"def extract_constants(examples):\n",
|
| 184 |
+
" '''\n",
|
| 185 |
+
" Extracts the constants from the input-output examples. Also constructs variables as needed\n",
|
| 186 |
+
" based on the input-output examples, and adds them to the list of constants.\n",
|
| 187 |
+
" '''\n",
|
| 188 |
+
"\n",
|
| 189 |
+
" # check validity of provided examples\n",
|
| 190 |
+
" # if valid, extract arity and argument types\n",
|
| 191 |
+
" arity, arg_types = check_examples(examples)\n",
|
| 192 |
+
"\n",
|
| 193 |
+
" # initialize list of constants\n",
|
| 194 |
+
" constants = []\n",
|
| 195 |
+
"\n",
|
| 196 |
+
" # get unique set of inputs\n",
|
| 197 |
+
" inputs = [input for example in examples for input in example[0]]\n",
|
| 198 |
+
" inputs = set(inputs)\n",
|
| 199 |
+
"\n",
|
| 200 |
+
" # add 1 to the set of inputs\n",
|
| 201 |
+
" inputs.add(1)\n",
|
| 202 |
+
"\n",
|
| 203 |
+
" # extract constants in input\n",
|
| 204 |
+
" for input in inputs:\n",
|
| 205 |
+
"\n",
|
| 206 |
+
" if type(input) == int:\n",
|
| 207 |
+
" constants.append(IntegerConstant(input))\n",
|
| 208 |
+
" elif type(input) == str:\n",
|
| 209 |
+
" # constants.append(StringConstant(input))\n",
|
| 210 |
+
" pass\n",
|
| 211 |
+
" else:\n",
|
| 212 |
+
" raise Exception(\"Input of unknown type.\")\n",
|
| 213 |
+
" \n",
|
| 214 |
+
" # initialize list of variables\n",
|
| 215 |
+
" variables = []\n",
|
| 216 |
+
"\n",
|
| 217 |
+
" # extract variables in input\n",
|
| 218 |
+
" for position, arg in enumerate(arg_types):\n",
|
| 219 |
+
" if arg == int:\n",
|
| 220 |
+
" variables.append(IntegerVariable(position))\n",
|
| 221 |
+
" elif arg == str:\n",
|
| 222 |
+
" # variables.append(StringVariable(position))\n",
|
| 223 |
+
" pass\n",
|
| 224 |
+
" else:\n",
|
| 225 |
+
" raise Exception(\"Input of unknown type.\")\n",
|
| 226 |
+
"\n",
|
| 227 |
+
" return constants + variables"
|
| 228 |
+
]
|
| 229 |
+
},
|
| 230 |
+
{
|
| 231 |
+
"cell_type": "code",
|
| 232 |
+
"execution_count": 16,
|
| 233 |
+
"metadata": {},
|
| 234 |
+
"outputs": [],
|
| 235 |
+
"source": [
|
| 236 |
+
"# initialize program bank\n",
|
| 237 |
+
"program_bank = extract_constants(examples)"
|
| 238 |
]
|
| 239 |
},
|
| 240 |
{
|
|
|
|
| 271 |
"metadata": {},
|
| 272 |
"outputs": [],
|
| 273 |
"source": [
|
|
|
|
|
|
|
|
|
|
| 274 |
"# iterate over each level\n",
|
| 275 |
+
"for i in range(2, max_weight):\n",
|
| 276 |
"\n",
|
| 277 |
" # define level program bank\n",
|
| 278 |
" level_program_bank = []\n",
|
| 279 |
"\n",
|
| 280 |
+
" for op in arithmetic_operators:\n",
|
| 281 |
"\n",
|
| 282 |
" break"
|
| 283 |
]
|
examples.py
CHANGED
|
@@ -1,13 +1,18 @@
|
|
| 1 |
'''
|
| 2 |
EXAMPLES
|
| 3 |
This file contains input-output examples for both arithmetic and string domain-specific languages (DSLs).
|
| 4 |
-
To add a new example, add a new key to the dictionary '
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
'''
|
| 6 |
|
| 7 |
# define examples
|
| 8 |
-
|
| 9 |
# arithmetic examples
|
| 10 |
-
'addition': [([7, 2], 9), ([
|
| 11 |
'subtraction': [([9, 2], 7), ([6, 1], 5), ([7, 3], 4), ([8, 4], 4), ([10, 2], 8)],
|
| 12 |
'multiplication': [([2, 3], 6), ([4, 5], 20), ([7, 8], 56), ([9, 2], 18), ([3, 4], 12)],
|
| 13 |
'division': [([6, 2], 3), ([8, 4], 2), ([9, 3], 3), ([10, 5], 2), ([12, 6], 2)]
|
|
@@ -15,4 +20,36 @@ examples = {
|
|
| 15 |
# string examples
|
| 16 |
|
| 17 |
# custom user examples
|
| 18 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
'''
|
| 2 |
EXAMPLES
|
| 3 |
This file contains input-output examples for both arithmetic and string domain-specific languages (DSLs).
|
| 4 |
+
To add a new example, add a new key to the dictionary 'example_set' and set the value to be a list of tuples.
|
| 5 |
+
|
| 6 |
+
Note that we synthesize programs with a consistent arity. Therefore, in each set of input-output examples, all
|
| 7 |
+
input examples must be of the same length. Further, argument types must remain consistent across examples. We
|
| 8 |
+
test for these conditions in the `check_examples` function below, which is called by the `extract_constants`
|
| 9 |
+
function in the synthesizer.
|
| 10 |
'''
|
| 11 |
|
| 12 |
# define examples
|
| 13 |
+
example_set = {
|
| 14 |
# arithmetic examples
|
| 15 |
+
'addition': [([7, 2], 9), ([8, 1], 9), ([3, 9], 12), ([5, 8], 13)], # ([4, 6], 10),
|
| 16 |
'subtraction': [([9, 2], 7), ([6, 1], 5), ([7, 3], 4), ([8, 4], 4), ([10, 2], 8)],
|
| 17 |
'multiplication': [([2, 3], 6), ([4, 5], 20), ([7, 8], 56), ([9, 2], 18), ([3, 4], 12)],
|
| 18 |
'division': [([6, 2], 3), ([8, 4], 2), ([9, 3], 3), ([10, 5], 2), ([12, 6], 2)]
|
|
|
|
| 20 |
# string examples
|
| 21 |
|
| 22 |
# custom user examples
|
| 23 |
+
}
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
# CHECK EXAMPLE VALIDITY
|
| 27 |
+
def check_examples(examples):
|
| 28 |
+
'''
|
| 29 |
+
Checks that all input examples are of same length and that argument types are consistent across examples.
|
| 30 |
+
If valid, returns arity and argument types of function to be generated.
|
| 31 |
+
|
| 32 |
+
Args:
|
| 33 |
+
examples (list): list of tuples, where each tuple is of the form (input, output)
|
| 34 |
+
|
| 35 |
+
Returns:
|
| 36 |
+
input_lengths[0] (int): arity of function
|
| 37 |
+
arg_types[0] (list): argument types of function
|
| 38 |
+
'''
|
| 39 |
+
|
| 40 |
+
# get input examples
|
| 41 |
+
inputs = [example[0] for example in examples]
|
| 42 |
+
|
| 43 |
+
# check all inputs are of same length
|
| 44 |
+
input_lengths = [len(input) for input in inputs]
|
| 45 |
+
if len(set(input_lengths)) != 1:
|
| 46 |
+
raise ValueError("All input examples must be of same length.")
|
| 47 |
+
|
| 48 |
+
# check that types of arguments are same
|
| 49 |
+
arg_types = [[type(arg) for arg in input] for input in inputs]
|
| 50 |
+
consistent_types = all([arg_types[0] == arg_type for arg_type in arg_types])
|
| 51 |
+
if not consistent_types:
|
| 52 |
+
raise ValueError("Argument types must be consistent across inputs.")
|
| 53 |
+
|
| 54 |
+
# return arity and argument types
|
| 55 |
+
return input_lengths[0], arg_types[0]
|
synthesizer.py
CHANGED
|
@@ -12,11 +12,13 @@ import numpy as np
|
|
| 12 |
import argparse
|
| 13 |
|
| 14 |
# import examples
|
| 15 |
-
from
|
|
|
|
| 16 |
import config
|
| 17 |
|
| 18 |
|
| 19 |
-
|
|
|
|
| 20 |
'''
|
| 21 |
Parse command line arguments.
|
| 22 |
'''
|
|
@@ -32,8 +34,8 @@ def parse_args(examples):
|
|
| 32 |
help='Domain of synthesis (either "arithmetic" or "string").')
|
| 33 |
|
| 34 |
parser.add_argument('--examples', dest='examples_key', type=str, required=True, # default="addition",
|
| 35 |
-
choices=
|
| 36 |
-
help='Examples to synthesize program from. Must be a valid key in the "
|
| 37 |
|
| 38 |
parser.add_argument('--max_weight', type=int, required=False, default=3,
|
| 39 |
help='Maximum weight of programs to consider before terminating search.')
|
|
@@ -42,13 +44,75 @@ def parse_args(examples):
|
|
| 42 |
return args
|
| 43 |
|
| 44 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 45 |
if __name__ == '__main__':
|
| 46 |
|
| 47 |
# parse command line arguments
|
| 48 |
-
args = parse_args(
|
| 49 |
-
print(args
|
| 50 |
-
print(args.examples_key)
|
| 51 |
-
print(args.max_weight)
|
| 52 |
|
| 53 |
# run bottom-up enumerative synthesis
|
| 54 |
-
|
|
|
|
| 12 |
import argparse
|
| 13 |
|
| 14 |
# import examples
|
| 15 |
+
from arithmetic import *
|
| 16 |
+
from examples import example_set, check_examples
|
| 17 |
import config
|
| 18 |
|
| 19 |
|
| 20 |
+
# PARSE ARGUMENTS
|
| 21 |
+
def parse_args():
|
| 22 |
'''
|
| 23 |
Parse command line arguments.
|
| 24 |
'''
|
|
|
|
| 34 |
help='Domain of synthesis (either "arithmetic" or "string").')
|
| 35 |
|
| 36 |
parser.add_argument('--examples', dest='examples_key', type=str, required=True, # default="addition",
|
| 37 |
+
choices=example_set.keys(),
|
| 38 |
+
help='Examples to synthesize program from. Must be a valid key in the "example_set" dictionary.')
|
| 39 |
|
| 40 |
parser.add_argument('--max_weight', type=int, required=False, default=3,
|
| 41 |
help='Maximum weight of programs to consider before terminating search.')
|
|
|
|
| 44 |
return args
|
| 45 |
|
| 46 |
|
| 47 |
+
# EXTRACT CONSTANTS AND VARIABLES
|
| 48 |
+
def extract_constants(examples):
|
| 49 |
+
'''
|
| 50 |
+
Extracts the constants from the input-output examples. Also constructs variables as needed
|
| 51 |
+
based on the input-output examples, and adds them to the list of constants.
|
| 52 |
+
'''
|
| 53 |
+
|
| 54 |
+
# check validity of provided examples
|
| 55 |
+
# if valid, extract arity and argument types
|
| 56 |
+
arity, arg_types = check_examples(examples)
|
| 57 |
+
|
| 58 |
+
# initialize list of constants
|
| 59 |
+
constants = []
|
| 60 |
+
|
| 61 |
+
# get unique set of inputs
|
| 62 |
+
inputs = [input for example in examples for input in example[0]]
|
| 63 |
+
inputs = set(inputs)
|
| 64 |
+
|
| 65 |
+
# add 1 to the set of inputs
|
| 66 |
+
inputs.add(1)
|
| 67 |
+
|
| 68 |
+
# extract constants in input
|
| 69 |
+
for input in inputs:
|
| 70 |
+
|
| 71 |
+
if type(input) == int:
|
| 72 |
+
constants.append(IntegerConstant(input))
|
| 73 |
+
elif type(input) == str:
|
| 74 |
+
# constants.append(StringConstant(input))
|
| 75 |
+
pass
|
| 76 |
+
else:
|
| 77 |
+
raise Exception("Input of unknown type.")
|
| 78 |
+
|
| 79 |
+
# initialize list of variables
|
| 80 |
+
variables = []
|
| 81 |
+
|
| 82 |
+
# extract variables in input
|
| 83 |
+
for position, arg in enumerate(arg_types):
|
| 84 |
+
if arg == int:
|
| 85 |
+
variables.append(IntegerVariable(position))
|
| 86 |
+
elif arg == str:
|
| 87 |
+
# variables.append(StringVariable(position))
|
| 88 |
+
pass
|
| 89 |
+
else:
|
| 90 |
+
raise Exception("Input of unknown type.")
|
| 91 |
+
|
| 92 |
+
return constants + variables
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
# RUN SYNTHESIZER
|
| 96 |
+
def run_synthesizer(args):
|
| 97 |
+
'''
|
| 98 |
+
Run bottom-up enumerative synthesis.
|
| 99 |
+
'''
|
| 100 |
+
|
| 101 |
+
# retrieve selected input-output examples
|
| 102 |
+
examples = example_set[args.examples_key]
|
| 103 |
+
|
| 104 |
+
# extract constants from examples
|
| 105 |
+
program_bank = extract_constants(examples)
|
| 106 |
+
print(examples)
|
| 107 |
+
|
| 108 |
+
pass
|
| 109 |
+
|
| 110 |
+
|
| 111 |
if __name__ == '__main__':
|
| 112 |
|
| 113 |
# parse command line arguments
|
| 114 |
+
args = parse_args()
|
| 115 |
+
# print(args)
|
|
|
|
|
|
|
| 116 |
|
| 117 |
# run bottom-up enumerative synthesis
|
| 118 |
+
run_synthesizer(args)
|