|
import gradio as gr |
|
from transformers import pipeline |
|
from PIL import Image |
|
import torch |
|
import tempfile |
|
import os |
|
from moviepy.editor import VideoFileClip |
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
image_classifier = pipeline("image-classification", model="google/vit-base-patch16-224-in21k", device=0 if device == "cuda" else -1) |
|
|
|
|
|
def classify_image(image, model_name): |
|
if model_name == "ViT": |
|
classifier = image_classifier |
|
else: |
|
classifier = image_classifier |
|
|
|
|
|
result = classifier(image) |
|
return result[0]['label'], result[0]['score'] |
|
|
|
|
|
def classify_video(video, model_name): |
|
|
|
video_clip = VideoFileClip(video.name) |
|
frame = video_clip.get_frame(0) |
|
image = Image.fromarray(frame) |
|
|
|
|
|
if model_name == "ViT": |
|
classifier = image_classifier |
|
else: |
|
classifier = image_classifier |
|
|
|
result = classifier(image) |
|
return result[0]['label'], result[0]['score'] |
|
|
|
|
|
with gr.Blocks() as demo: |
|
with gr.Tab("Image Classification"): |
|
gr.Markdown("### Upload an image for classification") |
|
with gr.Row(): |
|
model_choice_image = gr.Dropdown(choices=["ViT", "ResNet"], label="Choose a Model", value="ViT") |
|
image_input = gr.Image(type="pil", label="Upload Image") |
|
image_output_label = gr.Textbox(label="Prediction") |
|
image_output_score = gr.Textbox(label="Confidence Score") |
|
|
|
classify_image_button = gr.Button("Classify Image") |
|
|
|
classify_image_button.click(classify_image, inputs=[image_input, model_choice_image], outputs=[image_output_label, image_output_score]) |
|
|
|
with gr.Tab("Video Classification"): |
|
gr.Markdown("### Upload a video for classification") |
|
with gr.Row(): |
|
model_choice_video = gr.Dropdown(choices=["ViT", "ResNet"], label="Choose a Model", value="ViT") |
|
video_input = gr.Video(label="Upload Video") |
|
video_output_label = gr.Textbox(label="Prediction") |
|
video_output_score = gr.Textbox(label="Confidence Score") |
|
|
|
classify_video_button = gr.Button("Classify Video") |
|
|
|
classify_video_button.click(classify_video, inputs=[video_input, model_choice_video], outputs=[video_output_label, video_output_score]) |
|
|
|
demo.launch() |
|
|