VNSLR / app.py
fossbk's picture
Update app.py
c50ae26 verified
raw
history blame
2.95 kB
import gradio as gr
from transformers import pipeline
from PIL import Image
import torch
import tempfile
import os
#from moviepy.editor import VideoFileClip
# Kiểm tra thiết bị sử dụng GPU hay CPU
device = "cuda" if torch.cuda.is_available() else "cpu"
# Tải mô hình phân loại ảnh từ Hugging Face (sử dụng mô hình ảnh cho video)
image_classifier = pipeline("image-classification", model="google/vit-base-patch16-224-in21k", device=0 if device == "cuda" else -1)
# Hàm phân loại ảnh
def classify_image(image, model_name):
if model_name == "ViT":
classifier = image_classifier
else:
classifier = image_classifier # Chỉnh sửa ở đây nếu muốn hỗ trợ thêm các mô hình khác
# Phân loại ảnh
result = classifier(image)
return result[0]['label'], result[0]['score']
# Hàm phân loại video (trích xuất frame đầu tiên của video)
def classify_video(video, model_name):
# Trích xuất frame đầu tiên của video
video_clip = VideoFileClip(video.name)
frame = video_clip.get_frame(0) # Lấy frame đầu tiên
image = Image.fromarray(frame)
# Phân loại frame đầu tiên của video
if model_name == "ViT":
classifier = image_classifier
else:
classifier = image_classifier # Chỉnh sửa ở đây nếu muốn hỗ trợ thêm các mô hình khác
result = classifier(image)
return result[0]['label'], result[0]['score']
# Giao diện Gradio
with gr.Blocks() as demo:
with gr.TabbedInterface() as tabs:
with gr.TabItem("Image Classification"):
gr.Markdown("### Upload an image for classification")
with gr.Row():
model_choice_image = gr.Dropdown(choices=["ViT", "ResNet"], label="Choose a Model", value="ViT")
image_input = gr.Image(type="pil", label="Upload Image")
image_output_label = gr.Textbox(label="Prediction")
image_output_score = gr.Textbox(label="Confidence Score")
classify_image_button = gr.Button("Classify Image")
classify_image_button.click(classify_image, inputs=[image_input, model_choice_image], outputs=[image_output_label, image_output_score])
with gr.TabItem("Video Classification"):
gr.Markdown("### Upload a video for classification")
with gr.Row():
model_choice_video = gr.Dropdown(choices=["ViT", "ResNet"], label="Choose a Model", value="ViT")
video_input = gr.Video(label="Upload Video")
video_output_label = gr.Textbox(label="Prediction")
video_output_score = gr.Textbox(label="Confidence Score")
classify_video_button = gr.Button("Classify Video")
classify_video_button.click(classify_video, inputs=[video_input, model_choice_video], outputs=[video_output_label, video_output_score])
demo.launch()