| 
							 | 
						import re | 
					
					
						
						| 
							 | 
						import os | 
					
					
						
						| 
							 | 
						from io import BytesIO | 
					
					
						
						| 
							 | 
						from dotenv import load_dotenv | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						import gradio as gr | 
					
					
						
						| 
							 | 
						import pandas as pd | 
					
					
						
						| 
							 | 
						from pandas import DataFrame as PandasDataFrame | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						from llm import MessageChatCompletion | 
					
					
						
						| 
							 | 
						from customization import css, js | 
					
					
						
						| 
							 | 
						from examples import example_1, example_2, example_3, example_4 | 
					
					
						
						| 
							 | 
						from prompt_template import system_message_template, user_message_template | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						load_dotenv() | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						API_KEY = os.getenv("API_KEY") | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						df = pd.read_csv('subsectors.csv') | 
					
					
						
						| 
							 | 
						logs_columns = ['Abstract', 'Model', 'Results'] | 
					
					
						
						| 
							 | 
						logs_df = PandasDataFrame(columns=logs_columns) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						def download_logs(): | 
					
					
						
						| 
							 | 
						    global logs_df | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    output = BytesIO() | 
					
					
						
						| 
							 | 
						    logs_df.to_csv(output, index=False) | 
					
					
						
						| 
							 | 
						    output.seek(0)   | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    return output, "classification_logs.csv" | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						def build_context(row): | 
					
					
						
						| 
							 | 
						    subsector_name = row['Subsector'] | 
					
					
						
						| 
							 | 
						    context = f"Subsector name: {subsector_name}. " | 
					
					
						
						| 
							 | 
						    context += f"{subsector_name} Definition: {row['Definition']}. " | 
					
					
						
						| 
							 | 
						    context += f"{subsector_name} keywords: {row['Keywords']}. " | 
					
					
						
						| 
							 | 
						    context += f"{subsector_name} Does include: {row['Does include']}. " | 
					
					
						
						| 
							 | 
						    context += f"{subsector_name} Does not include: {row['Does not include']}.\n" | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    return context | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						def click_button(model, api_key, abstract): | 
					
					
						
						| 
							 | 
						    labels = df['Subsector'].tolist() | 
					
					
						
						| 
							 | 
						    prompt_context = [build_context(row) for _, row in df.iterrows()] | 
					
					
						
						| 
							 | 
						    language_model = MessageChatCompletion(model=model, api_key=api_key) | 
					
					
						
						| 
							 | 
						    system_message = system_message_template.format(prompt_context=prompt_context) | 
					
					
						
						| 
							 | 
						    user_message = user_message_template.format(labels=labels, abstract=abstract) | 
					
					
						
						| 
							 | 
						    language_model.new_system_message(content=system_message) | 
					
					
						
						| 
							 | 
						    language_model.new_user_message(content=user_message) | 
					
					
						
						| 
							 | 
						    language_model.send_message() | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    response_reasoning = language_model.get_last_message() | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    dict_pattern = r'\{.*?\}' | 
					
					
						
						| 
							 | 
						    match = re.search(dict_pattern, response_reasoning, re.DOTALL) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    if match and language_model.error is False: | 
					
					
						
						| 
							 | 
						        match_score_dict = eval(match.group(0)) | 
					
					
						
						| 
							 | 
						    else: | 
					
					
						
						| 
							 | 
						        match_score_dict = {} | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    new_log_entry = pd.DataFrame({'Abstract': [abstract], 'Model': [model], 'Results': [str(match_score_dict)]}) | 
					
					
						
						| 
							 | 
						    global logs_df | 
					
					
						
						| 
							 | 
						    logs_df = pd.concat([logs_df, new_log_entry], ignore_index=True) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    return match_score_dict, response_reasoning, logs_df | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						def on_select(evt: gr.SelectData):   | 
					
					
						
						| 
							 | 
						    selected = df.iloc[[evt.index[0]]].iloc[0] | 
					
					
						
						| 
							 | 
						    name, definition, keywords, does_include, does_not_include = selected['Subsector'], selected['Definition'], selected['Keywords'], selected['Does include'], selected['Does not include'] | 
					
					
						
						| 
							 | 
						    name_accordion = gr.Accordion(label=name) | 
					
					
						
						| 
							 | 
						    return name_accordion, definition, keywords, does_include, does_not_include | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						with gr.Blocks(css=css, js=js) as demo: | 
					
					
						
						| 
							 | 
						    state_lotto = gr.State() | 
					
					
						
						| 
							 | 
						    selected_x_labels = gr.State() | 
					
					
						
						| 
							 | 
						    with gr.Tab("Patent Discovery"): | 
					
					
						
						| 
							 | 
						        with gr.Row(): | 
					
					
						
						| 
							 | 
						            with gr.Column(scale=5): | 
					
					
						
						| 
							 | 
						                dropdown_model = gr.Dropdown( | 
					
					
						
						| 
							 | 
						                    label="Model", | 
					
					
						
						| 
							 | 
						                    choices=["gpt-4", "gpt-4-turbo-preview", "gpt-3.5-turbo", "gpt-3.5-turbo-0125"], | 
					
					
						
						| 
							 | 
						                    value="gpt-3.5-turbo-0125", | 
					
					
						
						| 
							 | 
						                    multiselect=False, | 
					
					
						
						| 
							 | 
						                    interactive=True | 
					
					
						
						| 
							 | 
						                ) | 
					
					
						
						| 
							 | 
						            with gr.Column(scale=5): | 
					
					
						
						| 
							 | 
						                api_key = gr.Textbox( | 
					
					
						
						| 
							 | 
						                    label="API Key", | 
					
					
						
						| 
							 | 
						                    interactive=True, | 
					
					
						
						| 
							 | 
						                    lines=1, | 
					
					
						
						| 
							 | 
						                    max_lines=1, | 
					
					
						
						| 
							 | 
						                    type="password", | 
					
					
						
						| 
							 | 
						                    value=API_KEY | 
					
					
						
						| 
							 | 
						                ) | 
					
					
						
						| 
							 | 
						        with gr.Row(equal_height=True): | 
					
					
						
						| 
							 | 
						            abstract_description = gr.Textbox( | 
					
					
						
						| 
							 | 
						                label="Abstract description", | 
					
					
						
						| 
							 | 
						                lines=5, | 
					
					
						
						| 
							 | 
						                max_lines=10000, | 
					
					
						
						| 
							 | 
						                interactive=True, | 
					
					
						
						| 
							 | 
						                placeholder="Input a patent abstract" | 
					
					
						
						| 
							 | 
						            ) | 
					
					
						
						| 
							 | 
						        with gr.Row(): | 
					
					
						
						| 
							 | 
						            with gr.Accordion(label="Example Abstracts", open=False): | 
					
					
						
						| 
							 | 
						                gr.Examples( | 
					
					
						
						| 
							 | 
						                    examples=[example_1, example_2, example_3, example_4], | 
					
					
						
						| 
							 | 
						                    inputs=abstract_description, | 
					
					
						
						| 
							 | 
						                    fn=click_button, | 
					
					
						
						| 
							 | 
						                    label="", | 
					
					
						
						| 
							 | 
						                     | 
					
					
						
						| 
							 | 
						                ) | 
					
					
						
						| 
							 | 
						        with gr.Row(): | 
					
					
						
						| 
							 | 
						            btn_get_result = gr.Button("Classify") | 
					
					
						
						| 
							 | 
						        with gr.Row(elem_classes=['all_results']): | 
					
					
						
						| 
							 | 
						            with gr.Column(scale=4): | 
					
					
						
						| 
							 | 
						                label_result = gr.Label(num_top_classes=None) | 
					
					
						
						| 
							 | 
						            with gr.Column(scale=6): | 
					
					
						
						| 
							 | 
						                reasoning = gr.Markdown(label="Reasoning", elem_classes=['reasoning_results']) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    with gr.Tab("Subsector definitions"): | 
					
					
						
						| 
							 | 
						        with gr.Row(): | 
					
					
						
						| 
							 | 
						            with gr.Column(scale=4): | 
					
					
						
						| 
							 | 
						                df_subsectors = gr.DataFrame(df[['Subsector']], interactive=False, height=800) | 
					
					
						
						| 
							 | 
						            with gr.Column(scale=6): | 
					
					
						
						| 
							 | 
						                with gr.Accordion(label='Artificial Intelligence, Big Data and Analytics') as subsector_name: | 
					
					
						
						| 
							 | 
						                    s1_definition = gr.Textbox(label="Definition", lines=5, max_lines=100, value="Virtual reality (VR) is an artificial, computer-generated simulation or recreation of a real life environment or situation. Augmented reality (AR) is a technology that layers computer-generated enhancements atop an existing reality in order to make it more meaningful through the ability to interact with it. ") | 
					
					
						
						| 
							 | 
						                    s1_keywords = gr.Textbox(label="Keywords", lines=5, max_lines=100, | 
					
					
						
						| 
							 | 
						                                             value="Mixed Reality, 360 video, frame rate, metaverse, virtual world, cross reality, Artificial intelligence, computer vision") | 
					
					
						
						| 
							 | 
						                    does_include = gr.Textbox(label="Does include", lines=4) | 
					
					
						
						| 
							 | 
						                    does_not_include = gr.Textbox(label="Does not include", lines=3) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    with gr.Tab("Logs"): | 
					
					
						
						| 
							 | 
						        output_dataframe = gr.Dataframe( | 
					
					
						
						| 
							 | 
						            value=logs_df, | 
					
					
						
						| 
							 | 
						            type="pandas", | 
					
					
						
						| 
							 | 
						            height=500, | 
					
					
						
						| 
							 | 
						            headers=['Abstract', 'Model', 'Results'], | 
					
					
						
						| 
							 | 
						            interactive=False, | 
					
					
						
						| 
							 | 
						            column_widths=["45%", "10%", "45%"], | 
					
					
						
						| 
							 | 
						        ) | 
					
					
						
						| 
							 | 
						        btn_export = gr.Button( | 
					
					
						
						| 
							 | 
						            value="Export to CSV", | 
					
					
						
						| 
							 | 
						            size="sm", | 
					
					
						
						| 
							 | 
						        ) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    btn_get_result.click( | 
					
					
						
						| 
							 | 
						        fn=click_button, | 
					
					
						
						| 
							 | 
						        inputs=[dropdown_model, api_key, abstract_description], | 
					
					
						
						| 
							 | 
						        outputs=[label_result, reasoning, output_dataframe]) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    btn_export.click( | 
					
					
						
						| 
							 | 
						        fn=download_logs, | 
					
					
						
						| 
							 | 
						        outputs=[gr.outputs.File(label="Download CSV")]   | 
					
					
						
						| 
							 | 
						    ) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    df_subsectors.select( | 
					
					
						
						| 
							 | 
						        fn=on_select, | 
					
					
						
						| 
							 | 
						        outputs=[subsector_name, s1_definition, s1_keywords, does_include, does_not_include] | 
					
					
						
						| 
							 | 
						    ) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						if __name__ == "__main__": | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    demo.launch() | 
					
					
						
						| 
							 | 
						
 |