Spaces:
Sleeping
Sleeping
| ''' | |
| BOTTOM UP ENUMERATIVE SYNTHESIS | |
| Ayush Noori | |
| CS252R, Fall 2020 | |
| Example of usage: | |
| python synthesis.py --domain arithmetic --examples addition | |
| ''' | |
| # load libraries | |
| import numpy as np | |
| import argparse | |
| import itertools | |
| import time | |
| # import examples | |
| from arithmetic import * | |
| from strings import * | |
| from abstract_syntax_tree import * | |
| from examples import example_set, check_examples | |
| import config | |
| # PARSE ARGUMENTS | |
| def parse_args(): | |
| ''' | |
| Parse command line arguments. | |
| ''' | |
| parser = argparse.ArgumentParser(description="Bottom-up enumerative synthesis in Python.") | |
| # define valid choices for the 'domain' argument | |
| valid_domain_choices = ["arithmetic", "strings"] | |
| # add examples | |
| parser.add_argument('--domain', type=str, required=True, # default="arithmetic", | |
| choices=valid_domain_choices, | |
| help='Domain of synthesis (either "arithmetic" or "string").') | |
| parser.add_argument('--examples', dest='examples_key', type=str, required=True, # default="addition", | |
| choices=example_set.keys(), | |
| help='Examples to synthesize program from. Must be a valid key in the "example_set" dictionary.') | |
| parser.add_argument('--max-weight', type=int, required=False, default=3, | |
| help='Maximum weight of programs to consider before terminating search.') | |
| args = parser.parse_args() | |
| return args | |
| # EXTRACT CONSTANTS AND VARIABLES | |
| def extract_constants(examples): | |
| ''' | |
| Extracts the constants from the input-output examples. Also constructs variables as needed | |
| based on the input-output examples, and adds them to the list of constants. | |
| ''' | |
| # check validity of provided examples | |
| # if valid, extract arity and argument types | |
| arity, arg_types = check_examples(examples) | |
| # initialize list of constants | |
| constants = [] | |
| # get unique set of inputs | |
| inputs = [input for example in examples for input in example[0]] | |
| inputs = set(inputs) | |
| # add 1 to the set of inputs | |
| inputs.add(1) | |
| # extract constants in input | |
| for input in inputs: | |
| if type(input) == int: | |
| constants.append(IntegerConstant(input)) | |
| elif type(input) == str: | |
| constants.append(StringConstant(input)) | |
| pass | |
| else: | |
| raise Exception("Input of unknown type.") | |
| # initialize list of variables | |
| variables = [] | |
| # extract variables in input | |
| for position, arg in enumerate(arg_types): | |
| if arg == int: | |
| variables.append(IntegerVariable(position)) | |
| elif arg == str: | |
| variables.append(StringVariable(position)) | |
| else: | |
| raise Exception("Input of unknown type.") | |
| return constants + variables | |
| # CHECK OBSERVATIONAL EQUIVALENCE | |
| def observationally_equivalent(program_a, program_b, examples): | |
| """ | |
| Returns True if Program A and Program B are observationally equivalent, False otherwise. | |
| """ | |
| inputs = [example[0] for example in examples] | |
| a_output = [program_a.evaluate(input) for input in inputs] | |
| b_output = [program_b.evaluate(input) for input in inputs] | |
| return a_output == b_output | |
| # CHECK CORRECTNESS | |
| def check_program(program, examples): | |
| ''' | |
| Check whether the program satisfies the input-output examples. | |
| ''' | |
| inputs = [example[0] for example in examples] | |
| outputs = [example[1] for example in examples] | |
| program_output = [program.evaluate(input) for input in inputs] | |
| return program_output == outputs | |
| # RUN SYNTHESIZER | |
| def run_synthesizer(args): | |
| ''' | |
| Run bottom-up enumerative synthesis. | |
| ''' | |
| # retrieve selected input-output examples | |
| examples = example_set[args.examples_key] | |
| # extract constants from examples | |
| program_bank = extract_constants(examples) | |
| program_bank_str = [p.str() for p in program_bank] | |
| print("\nSynthesis Log:") | |
| print(f"- Extracted {len(program_bank)} constants from examples.") | |
| # define operators | |
| if args.domain == "arithmetic": | |
| operators = arithmetic_operators | |
| elif args.domain == "strings": | |
| operators = string_operators | |
| else: | |
| raise Exception('Domain not recognized. Must be either "arithmetic" or "string".') | |
| # iterate over each level | |
| for weight in range(2, args.max_weight): | |
| # print message | |
| print(f"- Searching level {weight} with {len(program_bank)} primitives.") | |
| # iterate over each operator | |
| for op in operators: | |
| # get all possible combinations of primitives in program bank | |
| combinations = itertools.combinations(program_bank, op.arity) | |
| # iterate over each combination | |
| for combination in combinations: | |
| # get type signature | |
| type_signature = [p.type for p in combination] | |
| # check if type signature matches operator | |
| if type_signature != op.arg_types: | |
| continue | |
| # check that sum of weights of arguments <= w | |
| if sum([p.weight for p in combination]) > weight: | |
| continue | |
| # create new program | |
| program = OperatorNode(op, combination) | |
| # check if program is in program bank using string representation | |
| if program.str() in program_bank_str: | |
| continue | |
| # check if program is observationally equivalent to any program in program bank | |
| if any([observationally_equivalent(program, p, examples) for p in program_bank]): | |
| continue | |
| # add program to program bank | |
| program_bank.append(program) | |
| program_bank_str.append(program.str()) | |
| # check if program passes all examples | |
| if check_program(program, examples): | |
| return(program) | |
| # return None if no program is found | |
| return None | |
| if __name__ == '__main__': | |
| # parse command line arguments | |
| args = parse_args() | |
| # print(args) | |
| # run bottom-up enumerative synthesis | |
| start_time = time.time() | |
| program = run_synthesizer(args) | |
| end_time = time.time() | |
| elapsed_time = round(end_time - start_time, 4) | |
| # check if program was found | |
| print("\nSynthesis Results:") | |
| if program is None: | |
| print(f"- Max weight of {args.max_weight} reached, no program found in {elapsed_time}s.") | |
| else: | |
| print(f"- Program found in {elapsed_time}s.") | |
| print(f"- Program: {program.str()}") | |
| print(f"- Program weight: {program.weight}") | |
| print(f"- Program return type: {program.type.__name__}") | |