Spaces:
Running
Running
| import logging | |
| import pathlib | |
| import gradio as gr | |
| import pandas as pd | |
| from gt4sd.algorithms.conditional_generation.regression_transformer import ( | |
| RegressionTransformer, | |
| ) | |
| from gt4sd.algorithms.registry import ApplicationsRegistry | |
| from utils import ( | |
| draw_grid_generate, | |
| draw_grid_predict, | |
| get_application, | |
| get_inference_dict, | |
| get_rt_name, | |
| ) | |
| logger = logging.getLogger(__name__) | |
| logger.addHandler(logging.NullHandler()) | |
| def regression_transformer( | |
| algorithm: str, | |
| task: str, | |
| target: str, | |
| number_of_samples: int, | |
| search: str, | |
| temperature: float, | |
| tolerance: int, | |
| wrapper: bool, | |
| fraction_to_mask: float, | |
| property_goal: str, | |
| tokens_to_mask: str, | |
| substructures_to_mask: str, | |
| substructures_to_keep: str, | |
| ): | |
| if task == "Predict" and wrapper: | |
| logger.warning( | |
| f"For prediction, no sampling_wrapper will be used, ignoring: fraction_to_mask: {fraction_to_mask}, " | |
| f"tokens_to_mask: {tokens_to_mask}, substructures_to_mask={substructures_to_mask}, " | |
| f"substructures_to_keep: {substructures_to_keep}." | |
| ) | |
| sampling_wrapper = {} | |
| elif not wrapper: | |
| sampling_wrapper = {} | |
| else: | |
| substructures_to_mask = ( | |
| [] | |
| if substructures_to_mask == "" | |
| else substructures_to_mask.replace(" ", "").split(",") | |
| ) | |
| substructures_to_keep = ( | |
| [] | |
| if substructures_to_keep == "" | |
| else substructures_to_keep.replace(" ", "").split(",") | |
| ) | |
| tokens_to_mask = [] if tokens_to_mask == "" else tokens_to_mask.split(",") | |
| property_goals = {} | |
| if property_goal == "": | |
| raise ValueError( | |
| "For conditional generation you have to specify `property_goal`." | |
| ) | |
| for line in property_goal.split(","): | |
| property_goals[line.split(":")[0].strip()] = float(line.split(":")[1]) | |
| sampling_wrapper = { | |
| "substructures_to_keep": substructures_to_keep, | |
| "substructures_to_mask": substructures_to_mask, | |
| "text_filtering": False, | |
| "fraction_to_mask": fraction_to_mask, | |
| "property_goal": property_goals, | |
| } | |
| algorithm_application = get_application(algorithm.split(":")[0]) | |
| algorithm_version = algorithm.split(" ")[-1].lower() | |
| config = algorithm_application( | |
| algorithm_version=algorithm_version, | |
| search=search.lower(), | |
| temperature=temperature, | |
| tolerance=tolerance, | |
| sampling_wrapper=sampling_wrapper, | |
| ) | |
| model = RegressionTransformer(configuration=config, target=target) | |
| samples = list(model.sample(number_of_samples)) | |
| if task == "Predict": | |
| return draw_grid_predict(samples[0], target, domain=algorithm.split(":")[0]) | |
| else: | |
| return draw_grid_generate(samples, domain=algorithm.split(":")[0]) | |
| if __name__ == "__main__": | |
| # Preparation (retrieve all available algorithms) | |
| all_algos = ApplicationsRegistry.list_available() | |
| rt_algos = list( | |
| filter(lambda x: "RegressionTransformer" in x["algorithm_name"], all_algos) | |
| ) | |
| rt_names = list(map(get_rt_name, rt_algos)) | |
| properties = {} | |
| for algo in rt_algos: | |
| application = get_application( | |
| algo["algorithm_application"].split("Transformer")[-1] | |
| ) | |
| data = get_inference_dict( | |
| application=application, algorithm_version=algo["algorithm_version"] | |
| ) | |
| properties[get_rt_name(algo)] = data | |
| properties | |
| # Load metadata | |
| metadata_root = pathlib.Path(__file__).parent.joinpath("model_cards") | |
| examples = pd.read_csv( | |
| metadata_root.joinpath("regression_transformer_examples.csv"), header=None | |
| ).fillna("") | |
| with open(metadata_root.joinpath("regression_transformer_article.md"), "r") as f: | |
| article = f.read() | |
| with open( | |
| metadata_root.joinpath("regression_transformer_description.md"), "r" | |
| ) as f: | |
| description = f.read() | |
| demo = gr.Interface( | |
| fn=regression_transformer, | |
| title="Regression Transformer", | |
| inputs=[ | |
| gr.Dropdown(rt_names, label="Algorithm version", value="Molecules: Qed"), | |
| gr.Radio(choices=["Predict", "Generate"], label="Task", value="Generate"), | |
| gr.Textbox( | |
| label="Input", placeholder="CC(C#C)N(C)C(=O)NC1=CC=C(Cl)C=C1", lines=1 | |
| ), | |
| gr.Slider( | |
| minimum=1, maximum=50, value=10, label="Number of samples", step=1 | |
| ), | |
| gr.Radio(choices=["Sample", "Greedy"], label="Search", value="Sample"), | |
| gr.Slider(minimum=0.5, maximum=2, value=1, label="Decoding temperature"), | |
| gr.Slider(minimum=5, maximum=100, value=30, label="Tolerance", step=1), | |
| gr.Radio(choices=[True, False], label="Sampling Wrapper", value=True), | |
| gr.Slider(minimum=0, maximum=1, value=0.5, label="Fraction to mask"), | |
| gr.Textbox(label="Property goal", placeholder="<qed>:0.75", lines=1), | |
| gr.Textbox(label="Tokens to mask", placeholder="N, C", lines=1), | |
| gr.Textbox( | |
| label="Substructures to mask", placeholder="C(=O), C#C", lines=1 | |
| ), | |
| gr.Textbox( | |
| label="Substructures to keep", placeholder="C1=CC=C(Cl)C=C1", lines=1 | |
| ), | |
| ], | |
| outputs=gr.HTML(label="Output"), | |
| article=article, | |
| description=description, | |
| examples=examples.values.tolist(), | |
| ) | |
| demo.launch(debug=True, show_error=True) | |