update from https://github.com/ArneBinder/argumentation-structure-identification/pull/529
d868d2e
verified
import hydra | |
import pyrootutils | |
from omegaconf import DictConfig, OmegaConf, SCMode | |
root = pyrootutils.setup_root( | |
search_from=__file__, | |
indicator=[".project-root"], | |
pythonpath=True, | |
dotenv=True, | |
) | |
import json | |
import logging | |
import gradio as gr | |
import torch | |
import yaml | |
from src.demo.annotation_utils import load_argumentation_model | |
from src.demo.backend_utils import ( | |
download_processed_documents, | |
load_acl_anthology_venues, | |
process_text_from_arxiv, | |
process_uploaded_files, | |
process_uploaded_pdf_files, | |
render_annotated_document, | |
upload_processed_documents, | |
wrapped_add_annotated_pie_documents_from_dataset, | |
wrapped_process_text, | |
) | |
from src.demo.frontend_utils import ( | |
change_tab, | |
escape_regex, | |
get_cell_for_fixed_column_from_df, | |
open_accordion, | |
open_accordion_with_stats, | |
unescape_regex, | |
) | |
from src.demo.rendering_utils import AVAILABLE_RENDER_MODES, HIGHLIGHT_SPANS_JS | |
from src.demo.retriever_utils import ( | |
get_document_as_dict, | |
get_span_annotation, | |
load_retriever, | |
retrieve_all_relevant_spans, | |
retrieve_all_similar_spans, | |
retrieve_relevant_spans, | |
retrieve_similar_spans, | |
) | |
def load_yaml_config(path: str) -> str: | |
with open(path, "r") as file: | |
yaml_string = file.read() | |
config = yaml.safe_load(yaml_string) | |
return yaml.dump(config) | |
def resolve_config(cfg) -> dict: | |
return OmegaConf.to_container(cfg, resolve=True, structured_config_mode=SCMode.DICT) | |
def main(cfg: DictConfig) -> None: | |
# configure logging | |
logging.basicConfig() | |
# resolve everything in the config to prevent any issues with to json serialization etc. | |
cfg = resolve_config(cfg) | |
example_text = cfg["example_text"] | |
default_device = "cuda:0" if torch.cuda.is_available() else "cpu" | |
default_retriever_config_str = yaml.dump(cfg["retriever"]) | |
default_argumentation_model_config_str = yaml.dump(cfg["argumentation_model"]) | |
handle_parts_of_same = cfg["handle_parts_of_same"] | |
default_arxiv_id = cfg["default_arxiv_id"] | |
default_load_pie_dataset_kwargs_str = json.dumps( | |
cfg["default_load_pie_dataset_kwargs"], indent=2 | |
) | |
default_render_mode = cfg["default_render_mode"] | |
if default_render_mode not in AVAILABLE_RENDER_MODES: | |
raise ValueError( | |
f"Invalid default render mode '{default_render_mode}'. " | |
f"Choose one of {AVAILABLE_RENDER_MODES}." | |
) | |
default_render_kwargs = cfg["default_render_kwargs"] | |
# captions for better readability | |
default_split_regex = cfg["default_split_regex"] | |
# map from render mode to the corresponding caption | |
render_mode2caption = { | |
render_mode: cfg["render_mode_captions"].get(render_mode, render_mode) | |
for render_mode in AVAILABLE_RENDER_MODES | |
} | |
render_caption2mode = {v: k for k, v in render_mode2caption.items()} | |
default_min_similarity = cfg["default_min_similarity"] | |
default_top_k = cfg["default_top_k"] | |
default_min_score = cfg["default_min_score"] | |
layer_caption_mapping = cfg["layer_caption_mapping"] | |
relation_name_mapping = cfg["relation_name_mapping"] | |
indexed_documents_label = "Indexed Documents" | |
indexed_documents_caption2column = { | |
"documents": "TOTAL", | |
"ADUs": "num_adus", | |
"Relations": "num_relations", | |
} | |
gr.Info("Loading models ...") | |
argumentation_model = load_argumentation_model( | |
config_str=default_argumentation_model_config_str, | |
device=default_device, | |
) | |
retriever = load_retriever( | |
config_str=default_retriever_config_str, device=default_device, config_format="yaml" | |
) | |
if cfg.get("pdf_fulltext_extractor"): | |
gr.Info("Loading PDF fulltext extractor ...") | |
pdf_fulltext_extractor = hydra.utils.instantiate(cfg["pdf_fulltext_extractor"]) | |
else: | |
pdf_fulltext_extractor = None | |
with gr.Blocks() as demo: | |
# wrap the pipeline and the embedding model/tokenizer in a tuple to avoid that it gets called | |
# models_state = gr.State((argumentation_model, embedding_model)) | |
argumentation_model_state = gr.State((argumentation_model,)) | |
retriever_state = gr.State((retriever,)) | |
with gr.Row(): | |
with gr.Tabs() as left_tabs: | |
with gr.Tab("User Input", id="user_input") as user_input_tab: | |
doc_id = gr.Textbox( | |
label="Document ID", | |
value="user_input", | |
) | |
doc_text = gr.Textbox( | |
label="Text", | |
lines=20, | |
value=example_text, | |
) | |
with gr.Accordion("Model Configuration", open=False): | |
with gr.Accordion("argumentation structure", open=True): | |
argumentation_model_config_str = gr.Code( | |
language="yaml", | |
label="Argumentation Model Configuration", | |
value=default_argumentation_model_config_str, | |
lines=len(default_argumentation_model_config_str.split("\n")), | |
) | |
load_arg_model_btn = gr.Button("Load Argumentation Model") | |
with gr.Accordion("retriever", open=True): | |
retriever_config_str = gr.Code( | |
language="yaml", | |
label="Retriever Configuration", | |
value=default_retriever_config_str, | |
lines=len(default_retriever_config_str.split("\n")), | |
) | |
load_retriever_btn = gr.Button("Load Retriever") | |
device = gr.Textbox( | |
label="Device (e.g. 'cuda' or 'cpu')", | |
value=default_device, | |
) | |
load_arg_model_btn.click( | |
fn=lambda _argumentation_model_config_str, _device: ( | |
load_argumentation_model( | |
config_str=_argumentation_model_config_str, | |
device=_device, | |
), | |
), | |
inputs=[argumentation_model_config_str, device], | |
outputs=argumentation_model_state, | |
) | |
load_retriever_btn.click( | |
fn=lambda _retriever_config, _device, _previous_retriever: ( | |
load_retriever( | |
config_str=_retriever_config, | |
device=_device, | |
previous_retriever=_previous_retriever[0], | |
config_format="yaml", | |
), | |
), | |
inputs=[retriever_config_str, device, retriever_state], | |
outputs=retriever_state, | |
) | |
split_regex_escaped = gr.Textbox( | |
label="Regex to partition the text", | |
placeholder="Regular expression pattern to split the text into partitions", | |
value=escape_regex(default_split_regex), | |
) | |
predict_btn = gr.Button("Analyse") | |
with gr.Tab("Analysed Document", id="analysed_document") as analysed_document_tab: | |
selected_document_id = gr.Textbox( | |
label="Document ID", max_lines=1, interactive=False | |
) | |
rendered_output = gr.HTML(label="Rendered Output") | |
with gr.Accordion("Render Options", open=False): | |
render_as = gr.Dropdown( | |
label="Render with", | |
choices=list(render_mode2caption.values()), | |
value=render_mode2caption[default_render_mode], | |
) | |
render_kwargs = gr.Code( | |
language="json", | |
label="Render Arguments", | |
lines=len(json.dumps(default_render_kwargs, indent=2).split("\n")), | |
value=json.dumps(default_render_kwargs, indent=2), | |
) | |
render_btn = gr.Button("Re-render") | |
with gr.Accordion("See plain result ...", open=False): | |
get_document_json_btn = gr.Button("Fetch annotated document as JSON") | |
document_json = gr.JSON(label="Model Output") | |
with gr.Tabs() as right_tabs: | |
with gr.Tab("Retrieval", id="retrieval") as retrieval_tab: | |
with gr.Accordion( | |
indexed_documents_label, open=False | |
) as processed_documents_accordion: | |
processed_documents_df = gr.DataFrame( | |
headers=["id", "num_adus", "num_relations"], | |
interactive=False, | |
elem_classes="df-docstore", | |
) | |
gr.Markdown("Data Snapshot:") | |
with gr.Row(): | |
download_processed_documents_btn = gr.DownloadButton("Download") | |
upload_processed_documents_btn = gr.UploadButton( | |
"Upload", file_types=["file"] | |
) | |
# currently not used | |
# relation_types = set_relation_types( | |
# argumentation_model_state.value[0], default=["supports_reversed", "contradicts_reversed"] | |
# ) | |
# Dummy textbox to hold the hover adu id. On click on the rendered output, | |
# its content will be copied to selected_adu_id which will trigger the retrieval. | |
hover_adu_id = gr.Textbox( | |
label="ID (hover)", | |
elem_id="hover_adu_id", | |
interactive=False, | |
visible=False, | |
) | |
selected_adu_id = gr.Textbox( | |
label="ID (selected)", | |
elem_id="selected_adu_id", | |
interactive=False, | |
visible=False, | |
) | |
selected_adu_text = gr.Textbox(label="Selected ADU", interactive=False) | |
with gr.Accordion("Relevant ADUs from other documents", open=True): | |
relevant_adus_df = gr.DataFrame( | |
headers=[ | |
"relation", | |
"adu", | |
"reference_adu", | |
"doc_id", | |
"sim_score", | |
"rel_score", | |
], | |
interactive=False, | |
) | |
with gr.Accordion("Retrieval Configuration", open=False): | |
min_similarity = gr.Slider( | |
label="Minimum Similarity", | |
minimum=0.0, | |
maximum=1.0, | |
step=0.01, | |
value=default_min_similarity, | |
) | |
top_k = gr.Slider( | |
label="Top K", | |
minimum=2, | |
maximum=50, | |
step=1, | |
value=default_top_k, | |
) | |
min_score = gr.Slider( | |
label="Minimum Score", | |
minimum=0.0, | |
maximum=1.0, | |
step=0.01, | |
value=default_min_score, | |
) | |
retrieve_similar_adus_btn = gr.Button( | |
"Retrieve *similar* ADUs for *selected* ADU" | |
) | |
similar_adus_df = gr.DataFrame( | |
headers=["doc_id", "adu_id", "score", "text"], interactive=False | |
) | |
retrieve_all_similar_adus_btn = gr.Button( | |
"Retrieve *similar* ADUs for *all* ADUs in the document" | |
) | |
all_similar_adus_df = gr.DataFrame( | |
headers=["doc_id", "query_adu_id", "adu_id", "score", "text"], | |
interactive=False, | |
) | |
retrieve_all_relevant_adus_btn = gr.Button( | |
"Retrieve *relevant* ADUs for *all* ADUs in the document" | |
) | |
all_relevant_adus_df = gr.DataFrame( | |
headers=["doc_id", "adu_id", "score", "text", "query_span_id"], | |
interactive=False, | |
) | |
all_relevant_adus_query_doc_id = gr.Textbox(visible=False) | |
with gr.Tab("Import Documents", id="import_documents") as import_documents_tab: | |
upload_btn = gr.UploadButton( | |
"Batch Analyse Texts", | |
file_types=["text"], | |
file_count="multiple", | |
) | |
upload_pdf_btn = gr.UploadButton( | |
"Batch Analyse PDFs", | |
# file_types=["pdf"], | |
file_count="multiple", | |
visible=pdf_fulltext_extractor is not None, | |
) | |
enable_acl_venue_loading = ( | |
pdf_fulltext_extractor is not None | |
and cfg.get("acl_anthology_data_dir") is not None | |
) | |
acl_anthology_venues = gr.Textbox( | |
label="ACL Anthology Venues", | |
value="wiesp", | |
max_lines=1, | |
visible=enable_acl_venue_loading, | |
) | |
load_acl_anthology_venues_btn = gr.Button( | |
"Import from ACL Anthology", | |
variant="secondary", | |
visible=enable_acl_venue_loading, | |
) | |
with gr.Accordion("Import text from arXiv", open=False): | |
arxiv_id = gr.Textbox( | |
label="arXiv paper ID", | |
placeholder=f"e.g. {default_arxiv_id}", | |
max_lines=1, | |
) | |
load_arxiv_only_abstract = gr.Checkbox(label="abstract only", value=False) | |
load_arxiv_btn = gr.Button( | |
"Load & Analyse from arXiv", variant="secondary" | |
) | |
with gr.Accordion( | |
"Import argument structure annotated PIE dataset", open=False | |
): | |
load_pie_dataset_kwargs_str = gr.Code( | |
language="json", | |
label="Parameters for Loading the PIE Dataset", | |
value=default_load_pie_dataset_kwargs_str, | |
lines=len(default_load_pie_dataset_kwargs_str.split("\n")), | |
) | |
load_pie_dataset_btn = gr.Button("Load & Embed PIE Dataset") | |
render_event_kwargs = dict( | |
fn=lambda _rendered_output, _retriever, _document_id, _render_as, _render_kwargs, _all_relevant_adus_df, _all_relevant_adus_query_doc_id: ( | |
render_annotated_document( | |
retriever=_retriever[0], | |
document_id=_document_id, | |
render_with=render_caption2mode[_render_as], | |
render_kwargs_json=_render_kwargs, | |
highlight_span_ids=( | |
_all_relevant_adus_df["query_span_id"].tolist() | |
if _document_id == _all_relevant_adus_query_doc_id | |
else None | |
), | |
) | |
if _document_id.strip() != "" | |
else _rendered_output | |
), | |
inputs=[ | |
rendered_output, | |
retriever_state, | |
selected_document_id, | |
render_as, | |
render_kwargs, | |
all_relevant_adus_df, | |
all_relevant_adus_query_doc_id, | |
], | |
outputs=rendered_output, | |
) | |
show_overview_kwargs = dict( | |
fn=lambda _retriever: _retriever[0].docstore.overview( | |
layer_captions=layer_caption_mapping, use_predictions=True | |
), | |
inputs=[retriever_state], | |
outputs=[processed_documents_df], | |
) | |
show_stats_kwargs = dict( | |
fn=lambda _processed_documents_df: open_accordion_with_stats( | |
_processed_documents_df, | |
base_label=indexed_documents_label, | |
caption2column=indexed_documents_caption2column, | |
total_column="TOTAL", | |
), | |
inputs=[processed_documents_df], | |
outputs=[processed_documents_accordion], | |
) | |
predict_btn.click( | |
fn=lambda: change_tab(analysed_document_tab.id), inputs=[], outputs=[left_tabs] | |
).then( | |
fn=lambda _doc_text, _doc_id, _argumentation_model, _retriever, _split_regex_escaped: wrapped_process_text( | |
text=_doc_text, | |
doc_id=_doc_id, | |
argumentation_model=_argumentation_model[0], | |
retriever=_retriever[0], | |
split_regex_escaped=( | |
unescape_regex(_split_regex_escaped) if _split_regex_escaped else None | |
), | |
handle_parts_of_same=handle_parts_of_same, | |
), | |
inputs=[ | |
doc_text, | |
doc_id, | |
argumentation_model_state, | |
retriever_state, | |
split_regex_escaped, | |
], | |
outputs=[selected_document_id], | |
api_name="predict", | |
).success( | |
**show_overview_kwargs | |
).success( | |
**show_stats_kwargs | |
).success( | |
**render_event_kwargs | |
) | |
render_btn.click(**render_event_kwargs, api_name="render") | |
load_arxiv_btn.click( | |
fn=lambda: change_tab(analysed_document_tab.id), inputs=[], outputs=[left_tabs] | |
).then( | |
fn=lambda _arxiv_id, _load_arxiv_only_abstract, _argumentation_model, _retriever, _split_regex_escaped: process_text_from_arxiv( | |
arxiv_id=_arxiv_id.strip() or default_arxiv_id, | |
abstract_only=_load_arxiv_only_abstract, | |
argumentation_model=_argumentation_model[0], | |
retriever=_retriever[0], | |
split_regex_escaped=( | |
unescape_regex(_split_regex_escaped) if _split_regex_escaped else None | |
), | |
handle_parts_of_same=handle_parts_of_same, | |
), | |
inputs=[ | |
arxiv_id, | |
load_arxiv_only_abstract, | |
argumentation_model_state, | |
retriever_state, | |
split_regex_escaped, | |
], | |
outputs=[selected_document_id], | |
api_name="predict", | |
).success( | |
**show_overview_kwargs | |
).success( | |
**show_stats_kwargs | |
) | |
load_pie_dataset_btn.click( | |
fn=lambda: change_tab(retrieval_tab.id), inputs=[], outputs=[right_tabs] | |
).then(fn=open_accordion, inputs=[], outputs=[processed_documents_accordion]).then( | |
fn=lambda _retriever, _load_pie_dataset_kwargs_str: wrapped_add_annotated_pie_documents_from_dataset( | |
retriever=_retriever[0], | |
verbose=True, | |
layer_captions=layer_caption_mapping, | |
**json.loads(_load_pie_dataset_kwargs_str), | |
), | |
inputs=[retriever_state, load_pie_dataset_kwargs_str], | |
outputs=[processed_documents_df], | |
).success( | |
**show_stats_kwargs | |
) | |
selected_document_id.change( | |
fn=lambda: change_tab(analysed_document_tab.id), inputs=[], outputs=[left_tabs] | |
).then(**render_event_kwargs) | |
get_document_json_btn.click( | |
fn=lambda _retriever, _document_id: get_document_as_dict( | |
retriever=_retriever[0], doc_id=_document_id | |
), | |
inputs=[retriever_state, selected_document_id], | |
outputs=[document_json], | |
) | |
upload_btn.upload( | |
fn=lambda: change_tab(retrieval_tab.id), inputs=[], outputs=[right_tabs] | |
).then(fn=open_accordion, inputs=[], outputs=[processed_documents_accordion]).then( | |
fn=lambda _file_names, _argumentation_model, _retriever, _split_regex_escaped: process_uploaded_files( | |
file_names=_file_names, | |
argumentation_model=_argumentation_model[0], | |
retriever=_retriever[0], | |
split_regex_escaped=( | |
unescape_regex(_split_regex_escaped) if _split_regex_escaped else None | |
), | |
handle_parts_of_same=handle_parts_of_same, | |
layer_captions=layer_caption_mapping, | |
), | |
inputs=[ | |
upload_btn, | |
argumentation_model_state, | |
retriever_state, | |
split_regex_escaped, | |
], | |
outputs=[processed_documents_df], | |
).success( | |
**show_stats_kwargs | |
) | |
upload_pdf_btn.upload( | |
fn=lambda: change_tab(retrieval_tab.id), inputs=[], outputs=[right_tabs] | |
).then(fn=open_accordion, inputs=[], outputs=[processed_documents_accordion]).then( | |
fn=lambda _file_names, _argumentation_model, _retriever, _split_regex_escaped: process_uploaded_pdf_files( | |
file_names=_file_names, | |
argumentation_model=_argumentation_model[0], | |
retriever=_retriever[0], | |
split_regex_escaped=( | |
unescape_regex(_split_regex_escaped) if _split_regex_escaped else None | |
), | |
handle_parts_of_same=handle_parts_of_same, | |
layer_captions=layer_caption_mapping, | |
pdf_fulltext_extractor=pdf_fulltext_extractor, | |
), | |
inputs=[ | |
upload_pdf_btn, | |
argumentation_model_state, | |
retriever_state, | |
split_regex_escaped, | |
], | |
outputs=[processed_documents_df], | |
).success( | |
**show_stats_kwargs | |
) | |
load_acl_anthology_venues_btn.click( | |
fn=lambda: change_tab(retrieval_tab.id), inputs=[], outputs=[right_tabs] | |
).then(fn=open_accordion, inputs=[], outputs=[processed_documents_accordion]).then( | |
fn=lambda _acl_anthology_venues, _argumentation_model, _retriever, _split_regex_escaped: load_acl_anthology_venues( | |
pdf_fulltext_extractor=pdf_fulltext_extractor, | |
venues=[venue.strip() for venue in _acl_anthology_venues.split(",")], | |
argumentation_model=_argumentation_model[0], | |
retriever=_retriever[0], | |
split_regex_escaped=( | |
unescape_regex(_split_regex_escaped) if _split_regex_escaped else None | |
), | |
handle_parts_of_same=handle_parts_of_same, | |
layer_captions=layer_caption_mapping, | |
acl_anthology_data_dir=cfg.get("acl_anthology_data_dir"), | |
pdf_output_dir=cfg.get("acl_anthology_pdf_dir"), | |
), | |
inputs=[ | |
acl_anthology_venues, | |
argumentation_model_state, | |
retriever_state, | |
split_regex_escaped, | |
], | |
outputs=[processed_documents_df], | |
).success( | |
**show_stats_kwargs | |
) | |
processed_documents_df.select( | |
fn=get_cell_for_fixed_column_from_df, | |
inputs=[processed_documents_df, gr.State("doc_id")], | |
outputs=[selected_document_id], | |
) | |
download_processed_documents_btn.click( | |
fn=lambda _retriever: download_processed_documents( | |
_retriever[0], file_name="processed_documents" | |
), | |
inputs=[retriever_state], | |
outputs=[download_processed_documents_btn], | |
) | |
upload_processed_documents_btn.upload( | |
fn=lambda file_name, _retriever: upload_processed_documents( | |
file_name, retriever=_retriever[0], layer_captions=layer_caption_mapping | |
), | |
inputs=[upload_processed_documents_btn, retriever_state], | |
outputs=[processed_documents_df], | |
).success(**show_stats_kwargs) | |
retrieve_relevant_adus_event_kwargs = dict( | |
fn=lambda _retriever, _selected_adu_id, _min_similarity, _top_k, _min_score: retrieve_relevant_spans( | |
retriever=_retriever[0], | |
query_span_id=_selected_adu_id, | |
k=_top_k, | |
min_score=_min_score, | |
score_threshold=_min_similarity, | |
relation_label_mapping=relation_name_mapping, | |
# columns=relevant_adus.headers | |
), | |
inputs=[ | |
retriever_state, | |
selected_adu_id, | |
min_similarity, | |
top_k, | |
min_score, | |
], | |
outputs=[relevant_adus_df], | |
) | |
relevant_adus_df.select( | |
fn=get_cell_for_fixed_column_from_df, | |
inputs=[relevant_adus_df, gr.State("doc_id")], | |
outputs=[selected_document_id], | |
) | |
selected_adu_id.change( | |
fn=lambda _retriever, _selected_adu_id: get_span_annotation( | |
retriever=_retriever[0], span_id=_selected_adu_id | |
), | |
inputs=[retriever_state, selected_adu_id], | |
outputs=[selected_adu_text], | |
).success(**retrieve_relevant_adus_event_kwargs) | |
retrieve_similar_adus_btn.click( | |
fn=lambda _retriever, _selected_adu_id, _min_similarity, _tok_k, _min_score: retrieve_similar_spans( | |
retriever=_retriever[0], | |
query_span_id=_selected_adu_id, | |
k=_tok_k, | |
min_score=_min_score, | |
score_threshold=_min_similarity, | |
), | |
inputs=[ | |
retriever_state, | |
selected_adu_id, | |
min_similarity, | |
top_k, | |
min_score, | |
], | |
outputs=[similar_adus_df], | |
) | |
similar_adus_df.select( | |
fn=get_cell_for_fixed_column_from_df, | |
inputs=[similar_adus_df, gr.State("doc_id")], | |
outputs=[selected_document_id], | |
) | |
retrieve_all_similar_adus_btn.click( | |
fn=lambda _retriever, _document_id, _min_similarity, _tok_k, _min_score: retrieve_all_similar_spans( | |
retriever=_retriever[0], | |
query_doc_id=_document_id, | |
k=_tok_k, | |
min_score=_min_score, | |
score_threshold=_min_similarity, | |
query_span_id_column="query_span_id", | |
), | |
inputs=[ | |
retriever_state, | |
selected_document_id, | |
min_similarity, | |
top_k, | |
min_score, | |
], | |
outputs=[all_similar_adus_df], | |
) | |
retrieve_all_relevant_adus_btn.click( | |
fn=lambda _retriever, _document_id, _min_similarity, _tok_k, _min_score: ( | |
retrieve_all_relevant_spans( | |
retriever=_retriever[0], | |
query_doc_id=_document_id, | |
k=_tok_k, | |
min_score=_min_score, | |
score_threshold=_min_similarity, | |
query_span_id_column="query_span_id", | |
query_span_text_column="query_span_text", | |
), | |
_document_id, | |
), | |
inputs=[ | |
retriever_state, | |
selected_document_id, | |
min_similarity, | |
top_k, | |
min_score, | |
], | |
outputs=[all_relevant_adus_df, all_relevant_adus_query_doc_id], | |
) | |
all_relevant_adus_df.change(**render_event_kwargs) | |
# select query span id from the "retrieve all" result data frames | |
all_similar_adus_df.select( | |
fn=get_cell_for_fixed_column_from_df, | |
inputs=[all_similar_adus_df, gr.State("query_span_id")], | |
outputs=[selected_adu_id], | |
) | |
all_relevant_adus_df.select( | |
fn=get_cell_for_fixed_column_from_df, | |
inputs=[all_relevant_adus_df, gr.State("query_span_id")], | |
outputs=[selected_adu_id], | |
) | |
# argumentation_model_state.change( | |
# fn=lambda _argumentation_model: set_relation_types(_argumentation_model[0]), | |
# inputs=[argumentation_model_state], | |
# outputs=[relation_types], | |
# ) | |
rendered_output.change(fn=None, js=HIGHLIGHT_SPANS_JS, inputs=[], outputs=[]) | |
demo.launch() | |
if __name__ == "__main__": | |
main() | |