jithin14's picture
Update app.py
eccfb10 verified
import gradio as gr
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
from model import BoundingBoxPredictor
import matplotlib.patches as patches
import os
import uuid
predictor = BoundingBoxPredictor()
current_dir = os.path.dirname(os.path.abspath(__file__))
cnn_path = os.path.join(current_dir, 'convolutional_nn.h5')
knn_path = os.path.join(current_dir, 'knn_model_tuned.pkl')
sample_images = [
os.path.join(current_dir, f"image{i}.jpg")
for i in range(4996, 5000)
]
try:
predictor.load_models(cnn_path, knn_path)
print("Models loaded successfully!")
except Exception as e:
print(f"Error loading models: {str(e)}")
raise
def predict_and_draw_bbox(input_image, model_choice):
if input_image is None:
return None, "Please upload an image."
try:
bbox = predictor.predict(input_image, model_type=model_choice.lower())
img_array = np.array(input_image)
fig, ax = plt.subplots()
ax.imshow(img_array, cmap='gray')
rect = patches.Rectangle(
(bbox['x'], bbox['y']),
bbox['width'],
bbox['height'],
linewidth=2,
edgecolor='r',
facecolor='none'
)
ax.add_patch(rect)
plt.axis('off')
output_filename = f'output_{uuid.uuid4().hex[:8]}.png'
output_path = os.path.join(current_dir, output_filename)
plt.savefig(output_path, bbox_inches='tight', pad_inches=0)
plt.close()
return (
output_path,
f"Model Used: {model_choice}\n\nBounding Box Coordinates:\nX: {bbox['x']:.2f}\nY: {bbox['y']:.2f}\nWidth: {bbox['width']:.2f}\nHeight: {bbox['height']:.2f}"
)
except Exception as e:
return None, f"Error processing image: {str(e)}"
def select_sample(evt: gr.SelectData):
selected_image = Image.open(sample_images[evt.index])
return selected_image
with gr.Blocks() as iface:
gr.Markdown("# Bounding Box Detector")
gr.Markdown("""Upload an image to detect the bounding box. Choose between CNN Model or KNN Model.""")
with gr.Row():
with gr.Column():
input_image = gr.Image(type="pil", label="Input Image")
model_choice = gr.Radio(
choices=["CNN", "KNN"],
label="Choose Model",
value="CNN"
)
submit_btn = gr.Button("Detect Bounding Box")
with gr.Column():
output_image = gr.Image(type="filepath", label="Detected Bounding Box")
output_text = gr.Textbox(label="Coordinates")
gr.Markdown("### Sample Images (click to use)")
gallery = gr.Gallery(
value=sample_images,
label="Sample Images",
show_label=False,
columns=4,
height="auto"
).select(select_sample, None, input_image)
submit_btn.click(
predict_and_draw_bbox,
inputs=[input_image, model_choice],
outputs=[output_image, output_text]
)
if __name__ == "__main__":
iface.launch()