| import gradio as gr | |
| import os | |
| import tensorflow as tf | |
| import numpy as np | |
| import requests | |
| from langchain_groq import ChatGroq | |
| from langchain.agents import initialize_agent | |
| from langchain.prompts import PromptTemplate | |
| from langchain.chains import LLMChain | |
| from langchain.tools import StructuredTool | |
| from tensorflow.keras.preprocessing import image | |
| model_path = "unet_model.h5" | |
| if not os.path.exists(model_path): | |
| hf_url = "https://huggingface.co/rishirajbal/UNET_plus_plus_Brain_segmentation/resolve/main/unet_model.h5" | |
| r = requests.get(hf_url) | |
| with open(model_path, "wb") as f: | |
| f.write(r.content) | |
| model = tf.keras.models.load_model(model_path, compile=False) | |
| def classify_image(image_input): | |
| img = tf.image.resize(image_input, (256, 256)) | |
| img = img / 255.0 | |
| img = np.expand_dims(img, axis=0) | |
| prediction = model.predict(img)[0] | |
| mask = (prediction > 0.5).astype(np.uint8) * 255 | |
| return mask | |
| def rishigpt_handler(image_input, groq_api_key): | |
| os.environ["GROQ_API_KEY"] = groq_api_key | |
| mask = classify_image(image_input) | |
| def classify_image_tool(img_path): | |
| return "Brain tumor mask generated." | |
| tool = StructuredTool.from_function( | |
| classify_image_tool, | |
| name="segment_brain", | |
| description="Segment brain MRI for tumor detection." | |
| ) | |
| llm = ChatGroq( | |
| model="meta-llama/llama-4-scout-17b-16e-instruct", | |
| temperature=0.3 | |
| ) | |
| agent = initialize_agent( | |
| tools=[tool], | |
| llm=llm, | |
| agent="zero-shot-react-description", | |
| verbose=True | |
| ) | |
| user_query = "I uploaded a brain MRI. What does the segmentation say?" | |
| classification = agent.run(user_query) | |
| prompt = PromptTemplate( | |
| input_variables=["result"], | |
| template="You are a medical imaging expert. Based on the result: {result}, explain what this means for diagnosis." | |
| ) | |
| llm_chain = LLMChain( | |
| llm=llm, | |
| prompt=prompt | |
| ) | |
| description = llm_chain.run({"result": classification}) | |
| return mask, description | |
| inputs = [ | |
| gr.Image(type="numpy", label="Upload Brain MRI Slice"), | |
| gr.Textbox(type="password", label="Groq API Key") | |
| ] | |
| outputs = [ | |
| gr.Image(type="numpy", label="Tumor Segmentation Mask"), | |
| gr.Textbox(label="Medical Explanation") | |
| ] | |
| gr.Interface( | |
| fn=rishigpt_handler, | |
| inputs=inputs, | |
| outputs=outputs, | |
| title="RishiGPT Medical Brain Segmentation", | |
| description="UNet++ Brain Tumor Segmentation" | |
| ).launch() | |