File size: 2,846 Bytes
285ea09
b94db6f
 
d3d18f9
8dad76f
 
366b65f
285ea09
b94db6f
285ea09
 
8dad76f
b94db6f
d3d18f9
b94db6f
 
 
 
 
 
 
 
 
 
 
8dad76f
b94db6f
8dad76f
 
 
 
 
 
b94db6f
8dad76f
b94db6f
8dad76f
b94db6f
8dad76f
b94db6f
 
366b65f
b94db6f
366b65f
 
 
 
 
 
 
285ea09
366b65f
285ea09
366b65f
285ea09
366b65f
 
 
 
 
 
 
e42e16e
366b65f
e42e16e
366b65f
e42e16e
285ea09
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import gradio as gr
from transformers import pipeline
from PIL import Image
import torch
import tempfile
import os
from moviepy.editor import VideoFileClip

# Kiểm tra thiết bị sử dụng GPU hay CPU
device = "cuda" if torch.cuda.is_available() else "cpu"

# Tải mô hình phân loại ảnh từ Hugging Face (sử dụng mô hình ảnh cho video)
image_classifier = pipeline("image-classification", model="google/vit-base-patch16-224-in21k", device=0 if device == "cuda" else -1)

# Hàm phân loại ảnh
def classify_image(image, model_name):
    if model_name == "ViT":
        classifier = image_classifier
    else:
        classifier = image_classifier  # Chỉnh sửa ở đây nếu muốn hỗ trợ thêm các mô hình khác

    # Phân loại ảnh
    result = classifier(image)
    return result[0]['label'], result[0]['score']

# Hàm phân loại video (trích xuất frame đầu tiên của video)
def classify_video(video, model_name):
    # Trích xuất frame đầu tiên của video
    video_clip = VideoFileClip(video.name)
    frame = video_clip.get_frame(0)  # Lấy frame đầu tiên
    image = Image.fromarray(frame)

    # Phân loại frame đầu tiên của video
    if model_name == "ViT":
        classifier = image_classifier
    else:
        classifier = image_classifier  # Chỉnh sửa ở đây nếu muốn hỗ trợ thêm các mô hình khác

    result = classifier(image)
    return result[0]['label'], result[0]['score']

# Giao diện Gradio với các tab
with gr.Blocks() as demo:
    with gr.Tab("Image Classification"):
        gr.Markdown("### Upload an image for classification")
        with gr.Row():
            model_choice_image = gr.Dropdown(choices=["ViT", "ResNet"], label="Choose a Model", value="ViT")
            image_input = gr.Image(type="pil", label="Upload Image")
            image_output_label = gr.Textbox(label="Prediction")
            image_output_score = gr.Textbox(label="Confidence Score")

        classify_image_button = gr.Button("Classify Image")

        classify_image_button.click(classify_image, inputs=[image_input, model_choice_image], outputs=[image_output_label, image_output_score])

    with gr.Tab("Video Classification"):
        gr.Markdown("### Upload a video for classification")
        with gr.Row():
            model_choice_video = gr.Dropdown(choices=["ViT", "ResNet"], label="Choose a Model", value="ViT")
            video_input = gr.Video(label="Upload Video")
            video_output_label = gr.Textbox(label="Prediction")
            video_output_score = gr.Textbox(label="Confidence Score")

        classify_video_button = gr.Button("Classify Video")

        classify_video_button.click(classify_video, inputs=[video_input, model_choice_video], outputs=[video_output_label, video_output_score])

    demo.launch()