blasisd's picture
Changed endpoint variable logic to base url variable logic.
99bb9dd
import os
from typing import Any
import gradio as gr
import requests
from dotenv import load_dotenv
load_dotenv()
API_ENDPOINT = f"{os.getenv('MUSHEFF_API_BASE_URL')}/classify"
def classify_image(img_path: str) -> tuple[dict[str, Any]]:
"""Classifies a mushroom image, returning species, edibility, and confidence score.
In case of error returns the error.
:param img_path: the path to the mushroom image
:type img_path: str
:return: confidence score, species, edibility or error message (as output or error components)
:rtype: tuple[dict[str, Any]]
"""
try:
with open(img_path, "rb") as fp:
files = {"image_file": (fp.name, fp)}
# Calling the service
response = requests.post(API_ENDPOINT, files=files)
# Handle HTTP errors
if response.status_code != 200:
error_msg = response.json().get(
"detail", f"API Error {response.status_code}"
)
return show_error(f"Backend Error: {error_msg}")
# Process successful response
data = response.json()
return show_results(data)
except requests.exceptions.Timeout:
return show_error("Request timed out. Please try again.")
except requests.exceptions.ConnectionError:
return show_error("Cannot connect to the server. Please check your connection.")
except TypeError:
return show_error("Invalid input. Please select a valid file and try again.")
except Exception as e:
return show_error(f"System error: {e}")
def show_results(data: dict) -> tuple[dict[str, Any]]:
"""Show the results data retrieved from the service
in the respective output components
:param data: the data retrieved from the classify service
:type data: dict
:return: the output components
:rtype: tuple[dict[str, Any]]
"""
# Create visual output components
confidence = data["confidence"]
class_name = data["mushroom_type"]
edibility = data["toxicity_profile"]
return (
# Confidence output
gr.update(
value=generate_confidence_html(confidence),
),
# Species display
gr.update(
value=generate_class_html(class_name),
),
# Edibility alert
gr.update(
value=generate_edibility_html(edibility, confidence, class_name),
),
gr.update(value=""), # Error output
)
def show_error(message: str) -> tuple[dict[str, Any]]:
"""Update the error output component with the error message
Nullify all the other output components
:param message: the error message
:type message: str
:return: the error message component, along with the other nullified output components
:rtype: tuple[dict[str, Any]]
"""
return (
gr.update(value=""), # Confidence output
gr.update(value=""), # Class output
gr.update(value=""), # Edibility output
gr.update(value=generate_error_html(message)), # Error alert
)
def generate_confidence_html(confidence):
return f"""<div class="confidence-display" style="font-size:1.1rem">
<div class="confidence-header">
<span class="confidence-icon">πŸ“Š</span>
<span class="confidence-title">Classification Confidence</span>
</div>
<div class="confidence-visual">
<div class="confidence-bar-bg">
<div class="confidence-bar-fill" style="width: {confidence*100:.2f}%"></div>
</div>
<div class="confidence-value">{confidence*100:.2f}%</div>
</div>
</div>
"""
def generate_class_html(class_name):
return f"""
<div style='
font-size: 1.5rem;
text-align: center;
padding: 20px;
background: #E3F2FD;
border-radius: 8px;
'>
πŸ„ <br><strong>{class_name.replace("_", " ")}</strong>
</div>
"""
def generate_edibility_html(edibility, confidence, class_name):
if edibility == "edible":
return (
f"<div class='edible-alert'>"
f"βœ… SAFE TO EAT (with verification)<hr style='margin:10px 0;'>"
f"<div style='font-size:1.1rem'>"
f"Always confirm with mycologist before consumption"
f"</div></div>"
)
else:
return (
f"<div class='poison-alert'>"
f"☠️ <strong>POISONOUS!</strong> DO NOT CONSUME<hr style='margin:10px 0;'>"
f"<div style='font-size:1.1rem;color:var(--poison-color)'>"
f"Misidentification risk: {100 - confidence*100:.2f}%<br>"
f"<em>Immediately contact poison control if ingested</em>"
f"</div></div>"
)
def generate_error_html(message):
return f"""
<div class='error-banner'>
<span class='error-icon'>❌</span>
<strong>CLASSIFICATION FAILED</strong><br>
{message}<br>
<em>Please try again or contact support</em>
</div>
"""
def toggle_row_visibility(*comp_vals):
"""Update row visibility based on non empty value(s) of component(s) on row"""
return gr.Row(visible=any(comp_vals))
def handle_image_change(new_image):
"""This function will be called when image changes (cleared or new one selected)"""
# New image selected or deleted (new_image or None), in any case reset output components
return (
new_image,
gr.Row(visible=False),
gr.Row(visible=False),
gr.update(visible=False),
)
def hr_line_update():
return gr.update(visible=True)
# Custom HTML components
safety_html = """
<div class="safety-banner">
<div class="warning-icon">⚠️</div>
<strong>CRITICAL SAFETY NOTICE:</strong> FungiSage Vision provides probabilistic guidance only - not guarantees.
Mushroom misidentification can be fatal. Always consult certified mycologists before consumption.
</div>
"""
slogan_html = """
<div class="slogan-container">
<div class="mushroom-icon">πŸ„</div>
<div class="slogan-text">
<span class="brand-tagline">Massif Mushroom Intelligence.</span>
<div class="brand-action">
<span class="brand-name">FungiSage Vision</span> Guides,
<span class="verify-emphasis">You Verify.</span>
</div>
</div>
</div>
"""
# CSS file path
css_path = os.path.join(os.path.dirname(__file__), "static/styles/custom.css")
with gr.Blocks(
theme=gr.themes.Glass(),
css_paths=css_path,
) as demo:
gr.HTML(safety_html)
gr.HTML(slogan_html)
# Input section
with gr.Group(elem_id="inputsContainer"):
gr.Markdown("### Upload Mushroom Image", elem_id="uploadHeader")
image_input = gr.Image(type="filepath", label="", height=300)
classify_btn = gr.Button(
"Classify Mushroom", elem_id="classifyBtn", variant="primary"
)
# Output section
with gr.Group(elem_id="outputsContainer"):
with gr.Row(visible=False, scale=1, equal_height=True) as results_group:
with gr.Column(scale=1, elem_id="confCol"):
confidence_output = gr.HTML(
label="Confidence Level", elem_classes="output-card"
)
with gr.Column(scale=2, elem_id="speciesCol"):
class_output = gr.HTML(
label="Identified Species", elem_classes="output-card"
)
with gr.Column(scale=1, elem_id="edibilityCol"):
edibility_output = gr.HTML(
label="Safety Assessment", elem_classes="output-card"
)
with gr.Row(visible=False) as errors_group:
error_output = gr.HTML(elem_classes="error-card")
# Used only for auto scrolling when results or errors occur
bottom_line = gr.HTML("<hr>", visible=False, elem_id="bottomLine")
# Classification function called on button click
classify_btn.click(
fn=classify_image,
inputs=image_input,
outputs=[
confidence_output,
class_output,
edibility_output,
error_output,
],
).then(
fn=toggle_row_visibility,
inputs=[confidence_output, class_output, edibility_output],
outputs=results_group, # If any of the result outputs (got results from service), then show results row
).then(
fn=toggle_row_visibility,
inputs=[error_output],
outputs=errors_group, # If error occurred, then show errors row
).then(
fn=hr_line_update,
inputs=[],
outputs=[bottom_line],
scroll_to_output=True, # Useful to scroll there after results or errors occur
)
# Handle image changes (upload or clear/hide rows with results and errors)
image_input.change(
fn=handle_image_change,
inputs=image_input,
outputs=[image_input, results_group, errors_group, bottom_line],
)
if __name__ == "__main__":
demo.launch()