eval_results / data_access.py
davidr70's picture
add source reason
22c5ba1
import asyncio
import os
from contextlib import asynccontextmanager
from typing import Optional
import asyncpg
import psycopg2
from cachetools import TTLCache, cached
from dotenv import load_dotenv
import pandas as pd
# Global connection pool
load_dotenv()
@asynccontextmanager
async def get_async_connection(schema="talmudexplore", auto_commit=True):
"""
Get a connection for the current request.
Args:
schema: Database schema to use
auto_commit: If True (default), each statement auto-commits.
If False, requires explicit commit.
"""
conn = None
tx = None
try:
# Create a single connection without relying on a shared pool
conn = await asyncpg.connect(
database=os.getenv("pg_dbname"),
user=os.getenv("pg_user"),
password=os.getenv("pg_password"),
host=os.getenv("pg_host"),
port=os.getenv("pg_port")
)
await conn.execute(f'SET search_path TO {schema}')
if not auto_commit:
# Start a transaction that requires explicit commit
tx = conn.transaction()
await tx.start()
yield conn
if not auto_commit and tx:
await tx.commit()
finally:
if conn:
await conn.close()
async def get_questions(conn: asyncpg.Connection, source_finder_run_id: int, baseline_source_finder_run_id: int):
questions = await conn.fetch("""
select distinct q.id, question_text from talmudexplore.questions q
join (select question_id from talmudexplore.source_finder_run_question_metadata where source_finder_run_id = $1) sfrqm1
on sfrqm1.question_id = q.id
join (select question_id from talmudexplore.source_finder_run_question_metadata where source_finder_run_id = $2) sfrqm2
on sfrqm2.question_id = q.id;
""", source_finder_run_id, baseline_source_finder_run_id)
return [{"id": q["id"], "text": q["question_text"]} for q in questions]
@cached(cache=TTLCache(ttl=1800, maxsize=1024))
async def get_metadata(conn: asyncpg.Connection, question_id: int, source_finder_id_run_id: int):
metadata = await conn.fetchrow('''
SELECT metadata
FROM source_finder_run_question_metadata sfrqm
WHERE sfrqm.question_id = $1 and sfrqm.source_finder_run_id = $2;
''', question_id, source_finder_id_run_id)
if metadata is None:
return ""
return metadata.get('metadata')
# Get distinct source finders
async def get_source_finders(conn: asyncpg.Connection):
finders = await conn.fetch("""
SELECT distinct sf.id, sf.source_finder_type as name from talmudexplore.source_finder_runs sfr
join talmudexplore.source_finders sf on sf.id = sfr.source_finder_id
WHERE EXISTS (
SELECT 1
FROM talmudexplore.source_run_results srr
WHERE srr.source_finder_run_id = sfr.id
)
ORDER BY sf.id
"""
)
return [{"id": f["id"], "name": f["name"]} for f in finders]
# Get distinct run IDs for a question
@cached(cache=TTLCache(ttl=1800, maxsize=1024))
async def get_run_ids(conn: asyncpg.Connection, source_finder_id: int, question_id: int = None):
query = """
select distinct sfr.description, srs.source_finder_run_id as run_id
from source_run_results srs
join source_finder_runs sfr on srs.source_finder_run_id = sfr.id
join source_finders sf on sfr.source_finder_id = sf.id
where sfr.source_finder_id = $1
"""
if question_id is not None:
query += " and srs.question_id = $2"
params = (source_finder_id, question_id)
else:
params = (source_finder_id,)
query += " order by run_id DESC;"
run_ids = await conn.fetch(query, *params)
return {r["description"]:r["run_id"] for r in run_ids}
async def get_baseline_rankers(conn: asyncpg.Connection):
query = """
SELECT sfr.id, sf.source_finder_type, sfr.description from source_finder_runs sfr
join source_finders sf on sf.id = sfr.source_finder_id
WHERE EXISTS (
SELECT 1
FROM source_run_results srr
WHERE srr.source_finder_run_id = sfr.id
)
ORDER BY sf.id DESC
"""
rankers = await conn.fetch(query)
return [{"id": r["id"], "name": f"{r['source_finder_type']} : {r['description']}"} for r in rankers]
async def calculate_baseline_vs_source_stats_for_question(conn: asyncpg.Connection, baseline_sources , source_runs_sources):
# for a given question_id and source_finder_id and run_id calculate the baseline vs source stats
# e.g. overlap, high ranked overlap, etc.
actual_sources_set = {s["id"] for s in source_runs_sources}
baseline_sources_set = {s["id"] for s in baseline_sources}
# Calculate overlap
overlap = actual_sources_set.intersection(baseline_sources_set)
# only_in_1 = actual_sources_set - baseline_sources_set
# only_in_2 = baseline_sources_set - actual_sources_set
# Calculate high-ranked overlap (rank >= 4)
actual_high_ranked = {s["id"] for s in source_runs_sources if int(s["source_rank"]) >= 4}
baseline_high_ranked = {s["id"] for s in baseline_sources if int(s["baseline_rank"]) >= 4}
high_ranked_overlap = actual_high_ranked.intersection(baseline_high_ranked)
results = {
"total_baseline_sources": len(baseline_sources),
"total_found_sources": len(source_runs_sources),
"overlap_count": len(overlap),
"overlap_percentage": round(len(overlap) * 100 / max(len(actual_sources_set), len(baseline_sources_set)),
2) if max(len(actual_sources_set), len(baseline_sources_set)) > 0 else 0,
"num_high_ranked_baseline_sources": len(baseline_high_ranked),
"num_high_ranked_found_sources": len(actual_high_ranked),
"high_ranked_overlap_count": len(high_ranked_overlap),
"high_ranked_overlap_percentage": round(len(high_ranked_overlap) * 100 / max(len(actual_high_ranked), len(baseline_high_ranked)), 2) if max(len(actual_high_ranked), len(baseline_high_ranked)) > 0 else 0
}
#convert results.csv to dataframe
results_df = pd.DataFrame([results])
return results_df
async def calculate_cumulative_statistics_for_all_questions(conn: asyncpg.Connection, question_ids, source_finder_run_id: int, ranker_id: int):
"""
Calculate cumulative statistics across all questions for a specific source finder, run, and ranker.
Args:
conn (asyncpg.Connection): Database connection
question_ids (list): List of question IDs to analyze
source_finder_run_id (int): ID of the source finder and run as appears in source runs
ranker_id (int): ID of the baseline ranker
Returns:
pd.DataFrame: DataFrame containing aggregated statistics
"""
# Initialize aggregates
total_baseline_sources = 0
total_found_sources = 0
total_overlap = 0
total_high_ranked_baseline = 0
total_high_ranked_found = 0
total_high_ranked_overlap = 0
# Process each question
valid_questions = 0
for question_id in question_ids:
try:
# Get unified sources for this question
sources, stats = await get_unified_sources(conn, question_id, source_finder_run_id, ranker_id)
if sources and len(sources) > 0:
valid_questions += 1
stats_dict = stats.iloc[0].to_dict()
# Add to running totals
total_baseline_sources += stats_dict.get('total_baseline_sources', 0)
total_found_sources += stats_dict.get('total_found_sources', 0)
total_overlap += stats_dict.get('overlap_count', 0)
total_high_ranked_baseline += stats_dict.get('num_high_ranked_baseline_sources', 0)
total_high_ranked_found += stats_dict.get('num_high_ranked_found_sources', 0)
total_high_ranked_overlap += stats_dict.get('high_ranked_overlap_count', 0)
except Exception as e:
# Skip questions with errors
continue
# Calculate overall percentages
overlap_percentage = round(total_overlap * 100 / max(total_baseline_sources, total_found_sources), 2) \
if max(total_baseline_sources, total_found_sources) > 0 else 0
high_ranked_overlap_percentage = round(
total_high_ranked_overlap * 100 / max(total_high_ranked_baseline, total_high_ranked_found), 2) \
if max(total_high_ranked_baseline, total_high_ranked_found) > 0 else 0
# Compile results.csv
cumulative_stats = {
"total_questions_analyzed": valid_questions,
"total_baseline_sources": total_baseline_sources,
"total_found_sources": total_found_sources,
"total_overlap_count": total_overlap,
"overall_overlap_percentage": overlap_percentage,
"total_high_ranked_baseline_sources": total_high_ranked_baseline,
"total_high_ranked_found_sources": total_high_ranked_found,
"total_high_ranked_overlap_count": total_high_ranked_overlap,
"overall_high_ranked_overlap_percentage": high_ranked_overlap_percentage,
"avg_baseline_sources_per_question": round(total_baseline_sources / valid_questions,
2) if valid_questions > 0 else 0,
"avg_found_sources_per_question": round(total_found_sources / valid_questions,
2) if valid_questions > 0 else 0
}
return pd.DataFrame([cumulative_stats])
async def get_unified_sources(conn: asyncpg.Connection, question_id: int, source_finder_run_id: int, ranker_id: int):
"""
Create unified view of sources from both baseline_sources and source_runs
with indicators of where each source appears and their respective ranks.
"""
query_runs = """
SELECT tb.tractate_chunk_id as id,
sr.rank as source_rank,
sr.tractate,
sr.folio,
sr.reason as source_reason
FROM source_run_results sr
join talmud_bavli tb on sr.sugya_id = tb.xml_id
WHERE sr.question_id = $1
AND sr.source_finder_run_id = $2
"""
source_runs = await conn.fetch(query_runs, question_id, source_finder_run_id)
# Get sources from baseline_sources
baseline_query = query_runs.replace("source_rank", "baseline_rank").replace("source_reason", "baseline_reason")
baseline_sources = await conn.fetch(baseline_query, question_id, ranker_id)
stats_df = await calculate_baseline_vs_source_stats_for_question(conn, baseline_sources, source_runs)
# Convert to dictionaries for easier lookup
source_runs_dict = {s["id"]: dict(s) for s in source_runs}
baseline_dict = {s["id"]: dict(s) for s in baseline_sources}
# Get all unique sugya_ids
all_sugya_ids = set(source_runs_dict.keys()) | set(baseline_dict.keys())
# Build unified results.csv
unified_results = []
for sugya_id in all_sugya_ids:
in_source_run = sugya_id in source_runs_dict
in_baseline = sugya_id in baseline_dict
if in_baseline:
info = baseline_dict[sugya_id]
else:
info = source_runs_dict[sugya_id]
result = {
"id": sugya_id,
"tractate": info.get("tractate"),
"folio": info.get("folio"),
"in_baseline": "Yes" if in_baseline else "No",
"baseline_rank": baseline_dict.get(sugya_id, {}).get("baseline_rank", "N/A"),
"in_source_run": "Yes" if in_source_run else "No",
"source_run_rank": source_runs_dict.get(sugya_id, {}).get("source_rank", "N/A"),
"source_reason": source_runs_dict.get(sugya_id, {}).get("source_reason", "N/A"),
"baseline_reason": baseline_dict.get(sugya_id, {}).get("baseline_reason", "N/A"),
}
unified_results.append(result)
return unified_results, stats_df
@cached(cache=TTLCache(ttl=1800, maxsize=1024))
async def get_source_text(conn: asyncpg.Connection, tractate_chunk_id: int):
"""
Retrieves the text content for a given tractate chunk ID.
"""
query = """
SELECT tb.text as text
FROM talmud_bavli tb
WHERE tb.tractate_chunk_id = $1
"""
result = await conn.fetchrow(query, tractate_chunk_id)
return result["text"] if result else "Source text not found"
def get_pg_sync_connection(schema="talmudexplore"):
conn = psycopg2.connect(dbname=os.getenv("pg_dbname"),
user=os.getenv("pg_user"),
password=os.getenv("pg_password"),
host=os.getenv("pg_host"),
port=os.getenv("pg_port"),
options=f"-c search_path={schema}")
return conn