eval_results / app.py
davidr70's picture
add source reason
22c5ba1
import asyncio
import logging
import gradio as gr
import pandas as pd
from data_access import get_questions, get_source_finders, get_run_ids, get_baseline_rankers, \
get_unified_sources, get_source_text, calculate_cumulative_statistics_for_all_questions, get_metadata, \
get_async_connection
logger = logging.getLogger(__name__)
ALL_QUESTIONS_STR = "All questions"
# Initialize data at the module level
questions = []
source_finders = []
questions_dict = {}
source_finders_dict = {}
question_options = []
baseline_rankers_dict = {}
baseline_ranker_options = []
run_ids = []
available_run_id_dict = {}
finder_options = []
previous_run_id = "initial_run"
run_id_options = []
run_id_dropdown = None
# Last source runs for retrieving full baseline_reason on selection
last_source_runs = []
# Maximum length for baseline_reason display
TRUNCATE_REASON_LEN = 50
# Get all questions
# Initialize data in a single async function
async def initialize_data():
global source_finders, source_finders_dict, finder_options, baseline_rankers_dict, source_finders_dict, baseline_ranker_options
async with get_async_connection() as conn:
source_finders = await get_source_finders(conn)
baseline_rankers = await get_baseline_rankers(conn)
# Convert to dictionaries for easier lookup
baseline_rankers_dict = {f["name"]: f["id"] for f in baseline_rankers}
source_finders_dict = {f["name"]: f["id"] for f in source_finders}
# Create formatted options for dropdowns
finder_options = [s["name"] for s in source_finders]
baseline_ranker_options = [b["name"] for b in baseline_rankers]
def update_run_ids(question_option, source_finder_name, baseline_ranker_name):
return asyncio.run(update_run_ids_async(question_option, source_finder_name, baseline_ranker_name))
async def update_run_ids_async(question_option, source_finder_name, baseline_ranker_name):
global question_options, questions_dict, previous_run_id, available_run_id_dict, run_id_options
async with get_async_connection() as conn:
finder_id_int = source_finders_dict.get(source_finder_name)
available_run_id_dict = await get_run_ids(conn, finder_id_int)
run_id_options = list(available_run_id_dict.keys())
return gr.Dropdown(choices=[]), None, None, gr.Dropdown(choices=run_id_options,
value=None), "Select Question to see results.csv", "", ""
def update_questions_list(source_finder_name, run_id, baseline_ranker_name):
return asyncio.run(update_questions_list_async(source_finder_name, run_id, baseline_ranker_name))
async def update_questions_list_async(source_finder_name, run_id, baseline_ranker_name):
global available_run_id_dict
if source_finder_name and run_id and baseline_ranker_name:
async with get_async_connection() as conn:
run_id_int = available_run_id_dict.get(run_id)
baseline_ranker_id = baseline_rankers_dict.get(baseline_ranker_name)
questions = await get_updated_question_list(conn, baseline_ranker_id, run_id_int)
return gr.Dropdown(choices=questions, value=None), None, None, None, None, ""
else:
return None, None, None, None, None, ""
async def get_updated_question_list(conn, baseline_ranker_id, finder_id_int):
global questions_dict, questions
questions = await get_questions(conn, finder_id_int, baseline_ranker_id)
if questions:
questions_dict = {q["text"]: q["id"] for q in questions}
question_options = [ALL_QUESTIONS_STR] + [q['text'] for q in questions]
else:
question_options = []
return question_options
def update_sources_list(question_option, source_finder_id, run_id: str, baseline_ranker_id: str,
evt: gr.EventData = None):
global previous_run_id
if evt:
logger.info(f"event: {evt.target.elem_id}")
if evt.target.elem_id == "run_id_dropdown" and (type(run_id) == list or run_id == previous_run_id):
return gr.skip(), gr.skip(), gr.skip(), gr.skip(), gr.skip()
if type(run_id) == str:
previous_run_id = run_id
return asyncio.run(update_sources_list_async(question_option, source_finder_id, run_id, baseline_ranker_id))
# Main function to handle UI interactions
async def update_sources_list_async(question_option, source_finder_name, run_id, baseline_ranker_name: str):
global available_run_id_dict, previous_run_id, questions_dict
if not question_option:
return gr.skip(), gr.skip(), "No question selected", "", ""
if not source_finder_name or not run_id or not baseline_ranker_name:
return gr.skip(), gr.skip(), "Need to select source finder and baseline", "", ""
logger.info("processing update")
async with get_async_connection() as conn:
if type(baseline_ranker_name) == list:
baseline_ranker_name = baseline_ranker_name[0]
baseline_ranker_id_int = 1 if len(baseline_ranker_name) == 0 else baseline_rankers_dict.get(
baseline_ranker_name)
if len(source_finder_name):
finder_id_int = source_finders_dict.get(source_finder_name)
else:
finder_id_int = None
if question_option == ALL_QUESTIONS_STR:
if finder_id_int:
run_id_int = available_run_id_dict.get(run_id)
all_stats = await calculate_cumulative_statistics_for_all_questions(conn, list(questions_dict.values()),
run_id_int,
baseline_ranker_id_int)
else:
all_stats = None
return None, all_stats, "Select Run Id and source finder to see results.csv", "", ""
# Extract question ID from selection
question_id = questions_dict.get(question_option)
available_run_id_dict = await get_run_ids(conn, finder_id_int, question_id)
previous_run_id = run_id
run_id_int = available_run_id_dict.get(run_id)
source_runs = None
stats = None
# Get source runs data
if finder_id_int:
source_runs, stats = await get_unified_sources(conn, question_id, run_id_int, baseline_ranker_id_int)
global last_source_runs
last_source_runs = source_runs
df = pd.DataFrame(source_runs)
if not source_runs:
return None, None, "No results.csv found for the selected filters", "", ""
# Format table columns
columns_to_display = ['sugya_id', 'in_baseline', 'baseline_rank', 'in_source_run', 'source_run_rank',
'tractate',
'folio', 'reason']
df_display = df[columns_to_display] if all(col in df.columns for col in columns_to_display) else df
# CSV for download
# csv_data = df.to_csv(index=False)
metadata = await get_metadata(conn, question_id, run_id_int)
result_message = f"Found {len(source_runs)} results.csv"
return df_display, stats, result_message, metadata, ""
# Add a new function to handle row selection
async def handle_row_selection_async(evt: gr.SelectData):
if evt is None or evt.value is None:
return "No source selected"
try:
# Get the ID from the selected row
tractate_chunk_id = evt.row_value[0]
# Get the source text
async with get_async_connection() as conn:
text = await get_source_text(conn, tractate_chunk_id)
return text
except Exception as e:
return f"Error retrieving source text: {str(e)}"
def handle_row_selection(evt: gr.SelectData):
return asyncio.run(handle_row_selection_async(evt))
# Create Gradio app
# Ensure we clean up when done
async def main():
global run_id_dropdown
await initialize_data()
with gr.Blocks(title="Source Runs Explorer", theme=gr.themes.Citrus()) as app:
gr.Markdown("# Source Runs Explorer")
with gr.Row():
with gr.Column(scale=3):
with gr.Row():
with gr.Column(scale=1):
source_finder_dropdown = gr.Dropdown(
choices=finder_options,
value=None,
label="Source Finder",
interactive=True,
elem_id="source_finder_dropdown"
)
with gr.Column(scale=1):
run_id_dropdown = gr.Dropdown(
choices=run_id_options,
value=None,
allow_custom_value=True,
label="source finder Run ID",
interactive=True,
elem_id="run_id_dropdown"
)
with gr.Column(scale=1):
baseline_rankers_dropdown = gr.Dropdown(
choices=baseline_ranker_options,
value=None,
label="Select Baseline Ranker",
interactive=True,
elem_id="baseline_rankers_dropdown"
)
with gr.Row():
with gr.Column(scale=1):
# Main content area
question_dropdown = gr.Dropdown(
choices=[ALL_QUESTIONS_STR] + question_options,
label="Select Question (if list is empty this means there is no overlap between source run and baseline)",
value=None,
interactive=True,
elem_id="question_dropdown"
)
with gr.Column(scale=1):
# Sidebar area
gr.Markdown("""To Get started select the following:
* Source Finder
* Source Finder Run ID (corresponds to a run of the source finder for a group of questions)
* Baseline Ranker (corresponds to a run of the baseline ranker for a group of questions)
**Note: if there is no overlap between the baseline questions and the source finder questions, the question list will be empty.**
""")
with gr.Row():
result_text = gr.Markdown("Select a question to view source runs")
with gr.Row():
gr.Markdown("# Source Run Statistics")
with gr.Row():
statistics_table = gr.DataFrame(
headers=["num_high_ranked_baseline_sources",
"num_high_ranked_found_sources",
"overlap_count",
"overlap_percentage",
"high_ranked_overlap_count",
"high_ranked_overlap_percentage"
],
interactive=False,
)
with gr.Row():
metadata_text = gr.TextArea(
label="Metadata of Source Finder for Selected Question",
elem_id="metadata",
lines=2
)
with gr.Row():
gr.Markdown("# Sources Found")
with gr.Row():
with gr.Column(scale=3):
results_table = gr.DataFrame(
headers=['id', 'tractate', 'folio', 'in_baseline', 'baseline_rank', 'in_source_run',
'source_run_rank', 'source_reason', 'baseline_reason'],
interactive=False
)
with gr.Column(scale=1):
source_text = gr.TextArea(
value="Text of the source will appear here",
lines=15,
label="Source Text",
interactive=False,
elem_id="source_text"
)
# download_button = gr.DownloadButton(
# label="Download Results as CSV",
# interactive=True,
# visible=True
# )
# Set up event handlers
results_table.select(
handle_row_selection,
inputs=None,
outputs=source_text
)
baseline_rankers_dropdown.change(
update_questions_list,
inputs=[source_finder_dropdown, run_id_dropdown, baseline_rankers_dropdown],
outputs=[question_dropdown, result_text, metadata_text, results_table, statistics_table, source_text]
)
run_id_dropdown.change(
update_questions_list,
inputs=[source_finder_dropdown, run_id_dropdown, baseline_rankers_dropdown],
outputs=[question_dropdown, result_text, metadata_text, results_table, statistics_table, source_text]
)
question_dropdown.change(
update_sources_list,
inputs=[question_dropdown, source_finder_dropdown, run_id_dropdown, baseline_rankers_dropdown],
outputs=[results_table, statistics_table, result_text, metadata_text, source_text]
)
source_finder_dropdown.change(
update_run_ids,
inputs=[question_dropdown, source_finder_dropdown, baseline_rankers_dropdown],
# outputs=[run_id_dropdown, results_table, result_text, download_button]
outputs=[question_dropdown, results_table, statistics_table, run_id_dropdown, result_text, metadata_text, source_text]
)
app.queue()
app.launch()
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)
asyncio.run(main())