Spaces:
Sleeping
Sleeping
New changes
Browse files- .gitignore +2 -0
- app.py +47 -23
.gitignore
CHANGED
@@ -1 +1,3 @@
|
|
1 |
.env
|
|
|
|
|
|
1 |
.env
|
2 |
+
app copy.py
|
3 |
+
new_gradio.py
|
app.py
CHANGED
@@ -505,6 +505,22 @@ def compute_ragbench_metrics(judge_response: dict, retrieved_sentence_keys: list
|
|
505 |
"Adherence": adherence
|
506 |
}
|
507 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
508 |
|
509 |
def evaluate_rag_pipeline(domain, q_indices):
|
510 |
import torch
|
@@ -597,39 +613,47 @@ def evaluate_rag_pipeline(domain, q_indices):
|
|
597 |
|
598 |
# Updated wrapper
|
599 |
def evaluate_rag_gradio(domain, q_indices_str):
|
600 |
-
# Capture logs
|
601 |
log_stream = io.StringIO()
|
602 |
sys.stdout = log_stream
|
603 |
|
604 |
try:
|
605 |
-
# Parse comma-separated indices
|
606 |
q_indices = [int(x.strip()) for x in q_indices_str.split(",") if x.strip().isdigit()]
|
607 |
results = evaluate_rag_pipeline(domain, q_indices)
|
608 |
-
|
609 |
logs = log_stream.getvalue()
|
610 |
return results, logs
|
611 |
-
|
612 |
except Exception as e:
|
613 |
traceback.print_exc()
|
614 |
return {"error": str(e)}, log_stream.getvalue()
|
615 |
-
|
616 |
finally:
|
617 |
-
sys.stdout = sys.__stdout__
|
618 |
-
|
619 |
-
# Gradio
|
620 |
-
|
621 |
-
|
622 |
-
|
623 |
-
|
624 |
-
|
625 |
-
|
626 |
-
|
627 |
-
|
628 |
-
|
629 |
-
|
630 |
-
|
631 |
-
|
632 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
633 |
|
634 |
-
# Launch
|
635 |
-
|
|
|
505 |
"Adherence": adherence
|
506 |
}
|
507 |
|
508 |
+
# --- Dataset dictionary ---
|
509 |
+
domain_datasets = {
|
510 |
+
"Legal": legal_dataset,
|
511 |
+
"Medical": med_dataset,
|
512 |
+
"GK": gk_dataset,
|
513 |
+
"CS": cs_dataset,
|
514 |
+
"Finance": fin_dataset
|
515 |
+
}
|
516 |
+
|
517 |
+
# --- Get questions for selected domain ---
|
518 |
+
def get_questions_for_domain(domain):
|
519 |
+
dataset = domain_datasets.get(domain, [])
|
520 |
+
if not dataset:
|
521 |
+
return "β οΈ No dataset found for the selected domain."
|
522 |
+
|
523 |
+
return "\n".join([f"{i}. {item['question']}" for i, item in enumerate(dataset)])
|
524 |
|
525 |
def evaluate_rag_pipeline(domain, q_indices):
|
526 |
import torch
|
|
|
613 |
|
614 |
# Updated wrapper
|
615 |
def evaluate_rag_gradio(domain, q_indices_str):
|
|
|
616 |
log_stream = io.StringIO()
|
617 |
sys.stdout = log_stream
|
618 |
|
619 |
try:
|
|
|
620 |
q_indices = [int(x.strip()) for x in q_indices_str.split(",") if x.strip().isdigit()]
|
621 |
results = evaluate_rag_pipeline(domain, q_indices)
|
|
|
622 |
logs = log_stream.getvalue()
|
623 |
return results, logs
|
|
|
624 |
except Exception as e:
|
625 |
traceback.print_exc()
|
626 |
return {"error": str(e)}, log_stream.getvalue()
|
|
|
627 |
finally:
|
628 |
+
sys.stdout = sys.__stdout__
|
629 |
+
|
630 |
+
# === Gradio UI using Blocks ===
|
631 |
+
with gr.Blocks(title="RAG Evaluation Dashboard") as demo:
|
632 |
+
gr.Markdown("## π RAG Evaluation Dashboard")
|
633 |
+
gr.Markdown("Evaluate your RAG pipeline and also browse the questions available for each domain.")
|
634 |
+
|
635 |
+
with gr.Row():
|
636 |
+
domain_input = gr.Dropdown(choices=list(domain_datasets.keys()), label="Select Domain")
|
637 |
+
q_index_input = gr.Textbox(label="Enter Query Indices (e.g., 89,121,245)", lines=1)
|
638 |
+
|
639 |
+
with gr.Row():
|
640 |
+
view_btn = gr.Button("π View Questions for Selected Domain")
|
641 |
+
questions_display = gr.Textbox(label="Domain Questions", lines=10, interactive=False)
|
642 |
+
|
643 |
+
with gr.Row():
|
644 |
+
run_btn = gr.Button("π Run Evaluation")
|
645 |
+
|
646 |
+
result_output = gr.JSON(label="Evaluation Metrics (RMSE & AUC-ROC)")
|
647 |
+
log_output = gr.Textbox(label="Execution Log", lines=10, interactive=True)
|
648 |
+
|
649 |
+
# Bindings
|
650 |
+
view_btn.click(fn=get_questions_for_domain, inputs=domain_input, outputs=questions_display)
|
651 |
+
|
652 |
+
run_btn.click(
|
653 |
+
fn=evaluate_rag_gradio,
|
654 |
+
inputs=[domain_input, q_index_input],
|
655 |
+
outputs=[result_output, log_output]
|
656 |
+
)
|
657 |
|
658 |
+
# === Launch ===
|
659 |
+
demo.launch(server_name="0.0.0.0", server_port=7860, debug=True)
|