File size: 2,846 Bytes
285ea09 b94db6f d3d18f9 8dad76f 366b65f 285ea09 b94db6f 285ea09 8dad76f b94db6f d3d18f9 b94db6f 8dad76f b94db6f 8dad76f b94db6f 8dad76f b94db6f 8dad76f b94db6f 8dad76f b94db6f 366b65f b94db6f 366b65f 285ea09 366b65f 285ea09 366b65f 285ea09 366b65f e42e16e 366b65f e42e16e 366b65f e42e16e 285ea09 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 |
import gradio as gr
from transformers import pipeline
from PIL import Image
import torch
import tempfile
import os
from moviepy.editor import VideoFileClip
# Kiểm tra thiết bị sử dụng GPU hay CPU
device = "cuda" if torch.cuda.is_available() else "cpu"
# Tải mô hình phân loại ảnh từ Hugging Face (sử dụng mô hình ảnh cho video)
image_classifier = pipeline("image-classification", model="google/vit-base-patch16-224-in21k", device=0 if device == "cuda" else -1)
# Hàm phân loại ảnh
def classify_image(image, model_name):
if model_name == "ViT":
classifier = image_classifier
else:
classifier = image_classifier # Chỉnh sửa ở đây nếu muốn hỗ trợ thêm các mô hình khác
# Phân loại ảnh
result = classifier(image)
return result[0]['label'], result[0]['score']
# Hàm phân loại video (trích xuất frame đầu tiên của video)
def classify_video(video, model_name):
# Trích xuất frame đầu tiên của video
video_clip = VideoFileClip(video.name)
frame = video_clip.get_frame(0) # Lấy frame đầu tiên
image = Image.fromarray(frame)
# Phân loại frame đầu tiên của video
if model_name == "ViT":
classifier = image_classifier
else:
classifier = image_classifier # Chỉnh sửa ở đây nếu muốn hỗ trợ thêm các mô hình khác
result = classifier(image)
return result[0]['label'], result[0]['score']
# Giao diện Gradio với các tab
with gr.Blocks() as demo:
with gr.Tab("Image Classification"):
gr.Markdown("### Upload an image for classification")
with gr.Row():
model_choice_image = gr.Dropdown(choices=["ViT", "ResNet"], label="Choose a Model", value="ViT")
image_input = gr.Image(type="pil", label="Upload Image")
image_output_label = gr.Textbox(label="Prediction")
image_output_score = gr.Textbox(label="Confidence Score")
classify_image_button = gr.Button("Classify Image")
classify_image_button.click(classify_image, inputs=[image_input, model_choice_image], outputs=[image_output_label, image_output_score])
with gr.Tab("Video Classification"):
gr.Markdown("### Upload a video for classification")
with gr.Row():
model_choice_video = gr.Dropdown(choices=["ViT", "ResNet"], label="Choose a Model", value="ViT")
video_input = gr.Video(label="Upload Video")
video_output_label = gr.Textbox(label="Prediction")
video_output_score = gr.Textbox(label="Confidence Score")
classify_video_button = gr.Button("Classify Video")
classify_video_button.click(classify_video, inputs=[video_input, model_choice_video], outputs=[video_output_label, video_output_score])
demo.launch()
|