Spaces:
Sleeping
Sleeping
| # -*- coding: utf-8 -*- | |
| # @Date : 8/23/2024 20:00 PM | |
| # @Author : didi | |
| # @Desc : Entrance of AFlow. | |
| import argparse | |
| from typing import Dict, List | |
| from metagpt.configs.models_config import ModelsConfig | |
| from metagpt.ext.aflow.data.download_data import download | |
| from metagpt.ext.aflow.scripts.optimizer import Optimizer | |
| class ExperimentConfig: | |
| def __init__(self, dataset: str, question_type: str, operators: List[str]): | |
| self.dataset = dataset | |
| self.question_type = question_type | |
| self.operators = operators | |
| EXPERIMENT_CONFIGS: Dict[str, ExperimentConfig] = { | |
| "DROP": ExperimentConfig( | |
| dataset="DROP", | |
| question_type="qa", | |
| operators=["Custom", "AnswerGenerate", "ScEnsemble"], | |
| ), | |
| "HotpotQA": ExperimentConfig( | |
| dataset="HotpotQA", | |
| question_type="qa", | |
| operators=["Custom", "AnswerGenerate", "ScEnsemble"], | |
| ), | |
| "MATH": ExperimentConfig( | |
| dataset="MATH", | |
| question_type="math", | |
| operators=["Custom", "ScEnsemble", "Programmer"], | |
| ), | |
| "GSM8K": ExperimentConfig( | |
| dataset="GSM8K", | |
| question_type="math", | |
| operators=["Custom", "ScEnsemble", "Programmer"], | |
| ), | |
| "MBPP": ExperimentConfig( | |
| dataset="MBPP", | |
| question_type="code", | |
| operators=["Custom", "CustomCodeGenerate", "ScEnsemble", "Test"], | |
| ), | |
| "HumanEval": ExperimentConfig( | |
| dataset="HumanEval", | |
| question_type="code", | |
| operators=["Custom", "CustomCodeGenerate", "ScEnsemble", "Test"], | |
| ), | |
| } | |
| def parse_args(): | |
| parser = argparse.ArgumentParser(description="AFlow Optimizer") | |
| parser.add_argument( | |
| "--dataset", | |
| type=str, | |
| choices=list(EXPERIMENT_CONFIGS.keys()), | |
| required=True, | |
| help="Dataset type", | |
| ) | |
| parser.add_argument("--sample", type=int, default=4, help="Sample count") | |
| parser.add_argument( | |
| "--optimized_path", | |
| type=str, | |
| default="metagpt/ext/aflow/scripts/optimized", | |
| help="Optimized result save path", | |
| ) | |
| parser.add_argument("--initial_round", type=int, default=1, help="Initial round") | |
| parser.add_argument("--max_rounds", type=int, default=20, help="Max iteration rounds") | |
| parser.add_argument("--check_convergence", type=bool, default=True, help="Whether to enable early stop") | |
| parser.add_argument("--validation_rounds", type=int, default=5, help="Validation rounds") | |
| parser.add_argument( | |
| "--if_first_optimize", | |
| type=lambda x: x.lower() == "true", | |
| default=True, | |
| help="Whether to download dataset for the first time", | |
| ) | |
| parser.add_argument( | |
| "--opt_model_name", | |
| type=str, | |
| default="claude-3-5-sonnet-20240620", | |
| help="Specifies the name of the model used for optimization tasks.", | |
| ) | |
| parser.add_argument( | |
| "--exec_model_name", | |
| type=str, | |
| default="gpt-4o-mini", | |
| help="Specifies the name of the model used for execution tasks.", | |
| ) | |
| return parser.parse_args() | |
| if __name__ == "__main__": | |
| args = parse_args() | |
| config = EXPERIMENT_CONFIGS[args.dataset] | |
| models_config = ModelsConfig.default() | |
| opt_llm_config = models_config.get(args.opt_model_name) | |
| if opt_llm_config is None: | |
| raise ValueError( | |
| f"The optimization model '{args.opt_model_name}' was not found in the 'models' section of the configuration file. " | |
| "Please add it to the configuration file or specify a valid model using the --opt_model_name flag. " | |
| ) | |
| exec_llm_config = models_config.get(args.exec_model_name) | |
| if exec_llm_config is None: | |
| raise ValueError( | |
| f"The execution model '{args.exec_model_name}' was not found in the 'models' section of the configuration file. " | |
| "Please add it to the configuration file or specify a valid model using the --exec_model_name flag. " | |
| ) | |
| download(["datasets", "initial_rounds"], if_first_download=args.if_first_optimize) | |
| optimizer = Optimizer( | |
| dataset=config.dataset, | |
| question_type=config.question_type, | |
| opt_llm_config=opt_llm_config, | |
| exec_llm_config=exec_llm_config, | |
| check_convergence=args.check_convergence, | |
| operators=config.operators, | |
| optimized_path=args.optimized_path, | |
| sample=args.sample, | |
| initial_round=args.initial_round, | |
| max_rounds=args.max_rounds, | |
| validation_rounds=args.validation_rounds, | |
| ) | |
| # Optimize workflow via setting the optimizer's mode to 'Graph' | |
| optimizer.optimize("Graph") | |
| # Test workflow via setting the optimizer's mode to 'Test' | |
| # optimizer.optimize("Test") | |