Spaces:
Sleeping
Sleeping
import asyncio | |
import logging | |
import gradio as gr | |
import pandas as pd | |
from data_access import get_questions, get_source_finders, get_run_ids, get_baseline_rankers, \ | |
get_unified_sources, get_source_text, calculate_cumulative_statistics_for_all_questions, get_metadata, \ | |
get_async_connection | |
logger = logging.getLogger(__name__) | |
ALL_QUESTIONS_STR = "All questions" | |
# Initialize data at the module level | |
questions = [] | |
source_finders = [] | |
questions_dict = {} | |
source_finders_dict = {} | |
question_options = [] | |
baseline_rankers_dict = {} | |
baseline_ranker_options = [] | |
run_ids = [] | |
available_run_id_dict = {} | |
finder_options = [] | |
previous_run_id = "initial_run" | |
run_id_options = [] | |
run_id_dropdown = None | |
# Last source runs for retrieving full baseline_reason on selection | |
last_source_runs = [] | |
# Maximum length for baseline_reason display | |
TRUNCATE_REASON_LEN = 50 | |
# Get all questions | |
# Initialize data in a single async function | |
async def initialize_data(): | |
global source_finders, source_finders_dict, finder_options, baseline_rankers_dict, source_finders_dict, baseline_ranker_options | |
async with get_async_connection() as conn: | |
source_finders = await get_source_finders(conn) | |
baseline_rankers = await get_baseline_rankers(conn) | |
# Convert to dictionaries for easier lookup | |
baseline_rankers_dict = {f["name"]: f["id"] for f in baseline_rankers} | |
source_finders_dict = {f["name"]: f["id"] for f in source_finders} | |
# Create formatted options for dropdowns | |
finder_options = [s["name"] for s in source_finders] | |
baseline_ranker_options = [b["name"] for b in baseline_rankers] | |
def update_run_ids(question_option, source_finder_name, baseline_ranker_name): | |
return asyncio.run(update_run_ids_async(question_option, source_finder_name, baseline_ranker_name)) | |
async def update_run_ids_async(question_option, source_finder_name, baseline_ranker_name): | |
global question_options, questions_dict, previous_run_id, available_run_id_dict, run_id_options | |
async with get_async_connection() as conn: | |
finder_id_int = source_finders_dict.get(source_finder_name) | |
available_run_id_dict = await get_run_ids(conn, finder_id_int) | |
run_id_options = list(available_run_id_dict.keys()) | |
return gr.Dropdown(choices=[]), None, None, gr.Dropdown(choices=run_id_options, | |
value=None), "Select Question to see results.csv", "", "" | |
def update_questions_list(source_finder_name, run_id, baseline_ranker_name): | |
return asyncio.run(update_questions_list_async(source_finder_name, run_id, baseline_ranker_name)) | |
async def update_questions_list_async(source_finder_name, run_id, baseline_ranker_name): | |
global available_run_id_dict | |
if source_finder_name and run_id and baseline_ranker_name: | |
async with get_async_connection() as conn: | |
run_id_int = available_run_id_dict.get(run_id) | |
baseline_ranker_id = baseline_rankers_dict.get(baseline_ranker_name) | |
questions = await get_updated_question_list(conn, baseline_ranker_id, run_id_int) | |
return gr.Dropdown(choices=questions, value=None), None, None, None, None, "" | |
else: | |
return None, None, None, None, None, "" | |
async def get_updated_question_list(conn, baseline_ranker_id, finder_id_int): | |
global questions_dict, questions | |
questions = await get_questions(conn, finder_id_int, baseline_ranker_id) | |
if questions: | |
questions_dict = {q["text"]: q["id"] for q in questions} | |
question_options = [ALL_QUESTIONS_STR] + [q['text'] for q in questions] | |
else: | |
question_options = [] | |
return question_options | |
def update_sources_list(question_option, source_finder_id, run_id: str, baseline_ranker_id: str, | |
evt: gr.EventData = None): | |
global previous_run_id | |
if evt: | |
logger.info(f"event: {evt.target.elem_id}") | |
if evt.target.elem_id == "run_id_dropdown" and (type(run_id) == list or run_id == previous_run_id): | |
return gr.skip(), gr.skip(), gr.skip(), gr.skip(), gr.skip() | |
if type(run_id) == str: | |
previous_run_id = run_id | |
return asyncio.run(update_sources_list_async(question_option, source_finder_id, run_id, baseline_ranker_id)) | |
# Main function to handle UI interactions | |
async def update_sources_list_async(question_option, source_finder_name, run_id, baseline_ranker_name: str): | |
global available_run_id_dict, previous_run_id, questions_dict | |
if not question_option: | |
return gr.skip(), gr.skip(), "No question selected", "", "" | |
if not source_finder_name or not run_id or not baseline_ranker_name: | |
return gr.skip(), gr.skip(), "Need to select source finder and baseline", "", "" | |
logger.info("processing update") | |
async with get_async_connection() as conn: | |
if type(baseline_ranker_name) == list: | |
baseline_ranker_name = baseline_ranker_name[0] | |
baseline_ranker_id_int = 1 if len(baseline_ranker_name) == 0 else baseline_rankers_dict.get( | |
baseline_ranker_name) | |
if len(source_finder_name): | |
finder_id_int = source_finders_dict.get(source_finder_name) | |
else: | |
finder_id_int = None | |
if question_option == ALL_QUESTIONS_STR: | |
if finder_id_int: | |
run_id_int = available_run_id_dict.get(run_id) | |
all_stats = await calculate_cumulative_statistics_for_all_questions(conn, list(questions_dict.values()), | |
run_id_int, | |
baseline_ranker_id_int) | |
else: | |
all_stats = None | |
return None, all_stats, "Select Run Id and source finder to see results.csv", "", "" | |
# Extract question ID from selection | |
question_id = questions_dict.get(question_option) | |
available_run_id_dict = await get_run_ids(conn, finder_id_int, question_id) | |
previous_run_id = run_id | |
run_id_int = available_run_id_dict.get(run_id) | |
source_runs = None | |
stats = None | |
# Get source runs data | |
if finder_id_int: | |
source_runs, stats = await get_unified_sources(conn, question_id, run_id_int, baseline_ranker_id_int) | |
global last_source_runs | |
last_source_runs = source_runs | |
df = pd.DataFrame(source_runs) | |
if not source_runs: | |
return None, None, "No results.csv found for the selected filters", "", "" | |
# Format table columns | |
columns_to_display = ['sugya_id', 'in_baseline', 'baseline_rank', 'in_source_run', 'source_run_rank', | |
'tractate', | |
'folio', 'reason'] | |
df_display = df[columns_to_display] if all(col in df.columns for col in columns_to_display) else df | |
# CSV for download | |
# csv_data = df.to_csv(index=False) | |
metadata = await get_metadata(conn, question_id, run_id_int) | |
result_message = f"Found {len(source_runs)} results.csv" | |
return df_display, stats, result_message, metadata, "" | |
# Add a new function to handle row selection | |
async def handle_row_selection_async(evt: gr.SelectData): | |
if evt is None or evt.value is None: | |
return "No source selected" | |
try: | |
# Get the ID from the selected row | |
tractate_chunk_id = evt.row_value[0] | |
# Get the source text | |
async with get_async_connection() as conn: | |
text = await get_source_text(conn, tractate_chunk_id) | |
return text | |
except Exception as e: | |
return f"Error retrieving source text: {str(e)}" | |
def handle_row_selection(evt: gr.SelectData): | |
return asyncio.run(handle_row_selection_async(evt)) | |
# Create Gradio app | |
# Ensure we clean up when done | |
async def main(): | |
global run_id_dropdown | |
await initialize_data() | |
with gr.Blocks(title="Source Runs Explorer", theme=gr.themes.Citrus()) as app: | |
gr.Markdown("# Source Runs Explorer") | |
with gr.Row(): | |
with gr.Column(scale=3): | |
with gr.Row(): | |
with gr.Column(scale=1): | |
source_finder_dropdown = gr.Dropdown( | |
choices=finder_options, | |
value=None, | |
label="Source Finder", | |
interactive=True, | |
elem_id="source_finder_dropdown" | |
) | |
with gr.Column(scale=1): | |
run_id_dropdown = gr.Dropdown( | |
choices=run_id_options, | |
value=None, | |
allow_custom_value=True, | |
label="source finder Run ID", | |
interactive=True, | |
elem_id="run_id_dropdown" | |
) | |
with gr.Column(scale=1): | |
baseline_rankers_dropdown = gr.Dropdown( | |
choices=baseline_ranker_options, | |
value=None, | |
label="Select Baseline Ranker", | |
interactive=True, | |
elem_id="baseline_rankers_dropdown" | |
) | |
with gr.Row(): | |
with gr.Column(scale=1): | |
# Main content area | |
question_dropdown = gr.Dropdown( | |
choices=[ALL_QUESTIONS_STR] + question_options, | |
label="Select Question (if list is empty this means there is no overlap between source run and baseline)", | |
value=None, | |
interactive=True, | |
elem_id="question_dropdown" | |
) | |
with gr.Column(scale=1): | |
# Sidebar area | |
gr.Markdown("""To Get started select the following: | |
* Source Finder | |
* Source Finder Run ID (corresponds to a run of the source finder for a group of questions) | |
* Baseline Ranker (corresponds to a run of the baseline ranker for a group of questions) | |
**Note: if there is no overlap between the baseline questions and the source finder questions, the question list will be empty.** | |
""") | |
with gr.Row(): | |
result_text = gr.Markdown("Select a question to view source runs") | |
with gr.Row(): | |
gr.Markdown("# Source Run Statistics") | |
with gr.Row(): | |
statistics_table = gr.DataFrame( | |
headers=["num_high_ranked_baseline_sources", | |
"num_high_ranked_found_sources", | |
"overlap_count", | |
"overlap_percentage", | |
"high_ranked_overlap_count", | |
"high_ranked_overlap_percentage" | |
], | |
interactive=False, | |
) | |
with gr.Row(): | |
metadata_text = gr.TextArea( | |
label="Metadata of Source Finder for Selected Question", | |
elem_id="metadata", | |
lines=2 | |
) | |
with gr.Row(): | |
gr.Markdown("# Sources Found") | |
with gr.Row(): | |
with gr.Column(scale=3): | |
results_table = gr.DataFrame( | |
headers=['id', 'tractate', 'folio', 'in_baseline', 'baseline_rank', 'in_source_run', | |
'source_run_rank', 'source_reason', 'baseline_reason'], | |
interactive=False | |
) | |
with gr.Column(scale=1): | |
source_text = gr.TextArea( | |
value="Text of the source will appear here", | |
lines=15, | |
label="Source Text", | |
interactive=False, | |
elem_id="source_text" | |
) | |
# download_button = gr.DownloadButton( | |
# label="Download Results as CSV", | |
# interactive=True, | |
# visible=True | |
# ) | |
# Set up event handlers | |
results_table.select( | |
handle_row_selection, | |
inputs=None, | |
outputs=source_text | |
) | |
baseline_rankers_dropdown.change( | |
update_questions_list, | |
inputs=[source_finder_dropdown, run_id_dropdown, baseline_rankers_dropdown], | |
outputs=[question_dropdown, result_text, metadata_text, results_table, statistics_table, source_text] | |
) | |
run_id_dropdown.change( | |
update_questions_list, | |
inputs=[source_finder_dropdown, run_id_dropdown, baseline_rankers_dropdown], | |
outputs=[question_dropdown, result_text, metadata_text, results_table, statistics_table, source_text] | |
) | |
question_dropdown.change( | |
update_sources_list, | |
inputs=[question_dropdown, source_finder_dropdown, run_id_dropdown, baseline_rankers_dropdown], | |
outputs=[results_table, statistics_table, result_text, metadata_text, source_text] | |
) | |
source_finder_dropdown.change( | |
update_run_ids, | |
inputs=[question_dropdown, source_finder_dropdown, baseline_rankers_dropdown], | |
# outputs=[run_id_dropdown, results_table, result_text, download_button] | |
outputs=[question_dropdown, results_table, statistics_table, run_id_dropdown, result_text, metadata_text, source_text] | |
) | |
app.queue() | |
app.launch() | |
if __name__ == "__main__": | |
logging.basicConfig(level=logging.INFO) | |
asyncio.run(main()) | |