VNSLR / app.py
fossbk's picture
sử dụng các tab trong gr.Blocks() thay vì sử dụng gr.TabbedInterface
366b65f verified
raw
history blame
2.85 kB
import gradio as gr
from transformers import pipeline
from PIL import Image
import torch
import tempfile
import os
from moviepy.editor import VideoFileClip
# Kiểm tra thiết bị sử dụng GPU hay CPU
device = "cuda" if torch.cuda.is_available() else "cpu"
# Tải mô hình phân loại ảnh từ Hugging Face (sử dụng mô hình ảnh cho video)
image_classifier = pipeline("image-classification", model="google/vit-base-patch16-224-in21k", device=0 if device == "cuda" else -1)
# Hàm phân loại ảnh
def classify_image(image, model_name):
if model_name == "ViT":
classifier = image_classifier
else:
classifier = image_classifier # Chỉnh sửa ở đây nếu muốn hỗ trợ thêm các mô hình khác
# Phân loại ảnh
result = classifier(image)
return result[0]['label'], result[0]['score']
# Hàm phân loại video (trích xuất frame đầu tiên của video)
def classify_video(video, model_name):
# Trích xuất frame đầu tiên của video
video_clip = VideoFileClip(video.name)
frame = video_clip.get_frame(0) # Lấy frame đầu tiên
image = Image.fromarray(frame)
# Phân loại frame đầu tiên của video
if model_name == "ViT":
classifier = image_classifier
else:
classifier = image_classifier # Chỉnh sửa ở đây nếu muốn hỗ trợ thêm các mô hình khác
result = classifier(image)
return result[0]['label'], result[0]['score']
# Giao diện Gradio với các tab
with gr.Blocks() as demo:
with gr.Tab("Image Classification"):
gr.Markdown("### Upload an image for classification")
with gr.Row():
model_choice_image = gr.Dropdown(choices=["ViT", "ResNet"], label="Choose a Model", value="ViT")
image_input = gr.Image(type="pil", label="Upload Image")
image_output_label = gr.Textbox(label="Prediction")
image_output_score = gr.Textbox(label="Confidence Score")
classify_image_button = gr.Button("Classify Image")
classify_image_button.click(classify_image, inputs=[image_input, model_choice_image], outputs=[image_output_label, image_output_score])
with gr.Tab("Video Classification"):
gr.Markdown("### Upload a video for classification")
with gr.Row():
model_choice_video = gr.Dropdown(choices=["ViT", "ResNet"], label="Choose a Model", value="ViT")
video_input = gr.Video(label="Upload Video")
video_output_label = gr.Textbox(label="Prediction")
video_output_score = gr.Textbox(label="Confidence Score")
classify_video_button = gr.Button("Classify Video")
classify_video_button.click(classify_video, inputs=[video_input, model_choice_video], outputs=[video_output_label, video_output_score])
demo.launch()