Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import numpy as np | |
| from PIL import Image, ImageDraw | |
| import torch | |
| from torchvision import transforms | |
| from transformers import AutoModelForImageClassification, AutoFeatureExtractor | |
| # Define all available models | |
| MODEL_LIST = { | |
| 'beit': "microsoft/beit-base-patch16-224-pt22k-ft22k", | |
| 'vit': "google/vit-base-patch16-224", | |
| 'convnext': "facebook/convnext-tiny-224", | |
| } | |
| # Global variables | |
| current_model = None | |
| current_preprocessor = None | |
| device = "cuda" if torch.cuda.is_available() else "cpu" # Dynamically set device | |
| # Load model and preprocessor | |
| def load_model_and_preprocessor(model_name): | |
| """Load model and preprocessor for a given model name.""" | |
| global current_model, current_preprocessor | |
| print(f"Loading model and preprocessor for: {model_name} on {device}") | |
| current_model = AutoModelForImageClassification.from_pretrained(MODEL_LIST[model_name]).to(device).eval() | |
| current_preprocessor = AutoFeatureExtractor.from_pretrained(MODEL_LIST[model_name]) | |
| return f"Model {model_name} loaded successfully on {device}." | |
| # Predict function | |
| def predict(image, model, preprocessor): | |
| """Make a prediction on the given image patch using the loaded model.""" | |
| if model is None or preprocessor is None: | |
| raise ValueError("Model and preprocessor are not loaded.") | |
| inputs = preprocessor(images=image, return_tensors="pt").to(device) | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| predicted_class = torch.argmax(outputs.logits, dim=1).item() | |
| return model.config.id2label[predicted_class] | |
| # Function to draw a rectangle on the image | |
| def draw_rectangle(image, x, y, size=224): | |
| """Draw a rectangle on the image.""" | |
| image_pil = image.copy() # Create a copy to avoid modifying the original image | |
| draw = ImageDraw.Draw(image_pil) | |
| x1, y1 = x, y | |
| x2, y2 = x + size, y + size | |
| draw.rectangle([x1, y1, x2, y2], outline="red", width=5) | |
| return image_pil | |
| # Function to crop the image | |
| def crop_image(image, x, y, size=224): | |
| """Crop a region from the image.""" | |
| image_np = np.array(image) | |
| h, w, _ = image_np.shape | |
| x = min(max(x, 0), w - size) | |
| y = min(max(y, 0), h - size) | |
| cropped = image_np[y:y+size, x:x+size] | |
| return Image.fromarray(cropped) | |
| # Gradio Interface | |
| with gr.Blocks() as demo: | |
| gr.Markdown("## Test Public Models for Coral Classification") | |
| with gr.Row(): | |
| with gr.Column(): | |
| model_selector = gr.Dropdown(choices=list(MODEL_LIST.keys()), value='beit', label="Select Model") | |
| image_input = gr.Image(type="pil", label="Upload Image", interactive=True) | |
| x_slider = gr.Slider(minimum=0, maximum=1000, step=1, value=0, label="X Coordinate") | |
| y_slider = gr.Slider(minimum=0, maximum=1000, step=1, value=0, label="Y Coordinate") | |
| with gr.Column(): | |
| interactive_image = gr.Image(label="Interactive Image with Selection") | |
| cropped_image = gr.Image(label="Cropped Patch") | |
| label_output = gr.Textbox(label="Predicted Label") | |
| # Update the model and preprocessor | |
| def update_model(model_name): | |
| return load_model_and_preprocessor(model_name) | |
| # Update the rectangle and crop the patch | |
| def update_selection(image, x, y): | |
| overlay_image = draw_rectangle(image, x, y) | |
| cropped = crop_image(image, x, y) | |
| return overlay_image, cropped | |
| # Predict the label from the cropped patch | |
| def predict_from_cropped(cropped): | |
| print(f"Type of cropped_image before prediction: {type(cropped)}") | |
| return predict(cropped, current_model, current_preprocessor) | |
| # Buttons and interactions | |
| crop_button = gr.Button("Crop") | |
| crop_button.click(fn=update_selection, inputs=[image_input, x_slider, y_slider], outputs=[interactive_image, cropped_image]) | |
| predict_button = gr.Button("Predict") | |
| predict_button.click(fn=predict_from_cropped, inputs=cropped_image, outputs=label_output) | |
| model_selector.change(fn=update_model, inputs=model_selector, outputs=None) | |
| # Update sliders dynamically based on uploaded image size | |
| def update_sliders(image): | |
| if image is not None: | |
| width, height = image.size | |
| return gr.update(maximum=width - 224), gr.update(maximum=height - 224) | |
| return gr.update(), gr.update() | |
| image_input.change(fn=update_sliders, inputs=image_input, outputs=[x_slider, y_slider]) | |
| # Initialize model on app start | |
| demo.load(fn=lambda: load_model_and_preprocessor('beit'), inputs=None, outputs=None) | |
| demo.launch(server_name="0.0.0.0", server_port=7860) | |