|
|
import gradio as gr |
|
|
import time |
|
|
from src.pipeline import generate_report |
|
|
from src.tools_loader import get_tools |
|
|
|
|
|
|
|
|
_ = get_tools() |
|
|
|
|
|
def process_inputs(target_variable: str, image_path: str): |
|
|
"""Gradio callback to generate SHAP explanation report.""" |
|
|
if not image_path: |
|
|
return "**Please upload a SHAP summary plot image to begin.**" |
|
|
if not target_variable.strip(): |
|
|
return "**Please enter a target variable (e.g., life expectancy).**" |
|
|
|
|
|
start = time.time() |
|
|
report = generate_report(target_variable.strip(), image_path) |
|
|
elapsed = time.time() - start |
|
|
|
|
|
return f"""### SHAP Explanation Report for **{target_variable.strip()}** |
|
|
|
|
|
{report} |
|
|
|
|
|
--- |
|
|
*Generated in {elapsed:.1f} seconds* |
|
|
""" |
|
|
|
|
|
|
|
|
with gr.Blocks( |
|
|
theme=gr.themes.Soft(), |
|
|
title="SHAP Summary Plot Explainer", |
|
|
css=""" |
|
|
.input-section { max-width: 600px; margin: 0 auto; } |
|
|
.report-output { margin-top: 30px; } |
|
|
""" |
|
|
) as demo: |
|
|
|
|
|
|
|
|
gr.Markdown("# SHAP Summary Plot Explainer\n\nUpload a SHAP plot and specify your prediction target to get a detailed explanation.") |
|
|
|
|
|
with gr.Column(elem_classes=["input-section"]): |
|
|
target_input = gr.Textbox( |
|
|
label="Target Variable", |
|
|
placeholder="e.g., life expectancy, credit score, disease risk..." |
|
|
) |
|
|
shap_image = gr.Image( |
|
|
type="filepath", |
|
|
label="Upload SHAP Summary Plot Image", |
|
|
height=350 |
|
|
) |
|
|
generate_button = gr.Button("Generate Explanation", variant="primary") |
|
|
|
|
|
with gr.Column(elem_classes=["report-output"]): |
|
|
report_output = gr.Markdown("**Awaiting input...**") |
|
|
|
|
|
|
|
|
generate_button.click( |
|
|
fn=process_inputs, |
|
|
inputs=[target_input, shap_image], |
|
|
outputs=report_output, |
|
|
show_progress="full" |
|
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch() |