import gradio as gr from src.pipeline import generate_report # ------------------------------------------------------------------ # 1. Pre-load models on Space start-up # ------------------------------------------------------------------ print("Pre-loading models for fast inference …") try: from src.tools_loader import get_tools # downloads BiomedCLIP + SPECTER-2 _ = get_tools() print("Models pre-loaded successfully!") except Exception as e: print(f"Model pre-loading failed: {e}") # ------------------------------------------------------------------ # 2. Inference wrapper # ------------------------------------------------------------------ def process_upload(image_path: str): """Run the multi-agent pipeline on an uploaded chest X-ray.""" if image_path is None: return "Please upload a chest X-ray image." try: report = generate_report(image_path) return report except Exception as e: return f"Error processing image: {e}" # ------------------------------------------------------------------ # 3. Gradio UI # ------------------------------------------------------------------ with gr.Blocks(title="Multi-Agent Radiology Assistant") as demo: gr.Markdown( """ # Multi-Agent Radiology Assistant Upload a chest X-ray and receive an AI-generated report produced by a multi-agent pipeline. """ ) # --- Upload widget + button ------------------------------------------------ with gr.Column(): input_image = gr.Image( type="filepath", label="Upload Chest X-ray", height=400 ) process_btn = gr.Button("Generate Report", variant="primary") # --- Report output --------------------------------------------------------- output_report = gr.Markdown(label="Radiology Report", show_label=True) # --- Wire everything together --------------------------------------------- process_btn.click( fn=process_upload, inputs=input_image, outputs=output_report ) gr.Markdown("### Need an example? \nUse any frontal CXR PNG file and click **Generate Report**.") # ------------------------------------------------------------------ if __name__ == "__main__": demo.launch()