Image Classification
Transformers
English
art
File size: 4,141 Bytes
0dab8bf
 
 
 
 
 
 
5b16cb6
0c1c962
5b16cb6
 
0c1c962
 
0dab8bf
14cad36
94ce1e2
14cad36
10b8909
 
 
 
 
 
 
 
 
94ce1e2
5b16cb6
 
 
 
 
 
0dab8bf
 
 
 
 
 
 
 
0c1c962
0dab8bf
 
 
 
 
 
 
0bbf37a
 
 
5b16cb6
 
0dab8bf
 
 
 
 
94ce1e2
0c1c962
94ce1e2
 
0c1c962
14cad36
6151861
10b8909
 
 
 
 
0c1c962
 
10b8909
0dab8bf
5b16cb6
 
14cad36
 
 
 
0dab8bf
 
5b16cb6
 
0dab8bf
 
5b16cb6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6151861
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from transformers import ViTForImageClassification
import os
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, precision_score, confusion_matrix, f1_score, average_precision_score
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import recall_score
from vit_model_traning import labeling, CustomDataset

# 驻讜谞拽爪讬讛 诇讛讞讝专转 HTML 砖诇 住专讟讜谉
def display_video(video_url):
    return f'''
    <div id="video-container" style="display: none;">
        <video width="640" height="480" controls autoplay>
            <source src="{video_url}" type="video/mp4">
            Your browser does not support the video tag.
        </video>
    </div>
    <script>
        document.getElementById('video-container').style.display = 'block';
    </script>
    '''

def shuffle_and_split_data(dataframe, test_size=0.2, random_state=59):
    shuffled_df = dataframe.sample(frac=1, random_state=random_state).reset_index(drop=True)
    train_df, val_df = train_test_split(shuffled_df, test_size=test_size, random_state=random_state)
    return train_df, val_df    

if __name__ == "__main__":
    # Check for GPU availability
    device = torch.device('cuda')

    # Load the pre-trained ViT model and move it to GPU
    model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224').to(device)

    model.classifier = nn.Linear(model.config.hidden_size, 2).to(device)

    # Define the image preprocessing pipeline
    preprocess = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor()
    ])

    # Load the test dataset
    test_real_folder = 'datasets/test_set/real/'
    test_fake_folder = 'datasets/test_set/fake/'
    
    test_set = labeling(test_real_folder, test_fake_folder)
    test_dataset = CustomDataset(test_set, transform=preprocess)
    test_loader = DataLoader(test_dataset, batch_size=32)

    # Load the trained model
    model.load_state_dict(torch.load('trained_model.pth'))

    # 拽讬砖讜专 诇住专讟讜谉
    video_url = '"C:\Users\litav\Downloads\0001-0120.mp4"'  # 讛讞诇讬驻讬 讻讗谉 注诐 讛-URL 砖诇 讛住专讟讜谉 砖诇讱
    video_html = display_video(video_url)

    # 讛专讗讛 讗转 讛住专讟讜谉 讻讗砖专 讛讻驻转讜专 谞诇讞抓
    print(video_html)  # 讝讛 讗诪讜专 诇讛爪讬讙 讗转 讛-HTML 讘讚砖讘讜专讚 砖诇讱

    # Evaluate the model
    model.eval()
    true_labels = []
    predicted_labels = []

    # 讻讗谉 转讜住讬祝 拽讜讚 JavaScript 诇讛驻注讬诇 讗转 讛住专讟讜谉 讘注转 诇讞讬爪讛 注诇 讻驻转讜专 讛-SUBMIT
    # 讚讜讙诪讛: <button onclick="playVideo()">Submit</button>

    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)

            # 讛专讗讛 讗转 讛住专讟讜谉 讘注转 讞讬讝讜讬
            print(video_html)  # 讛爪讙 讗转 讛-HTML 砖诇 讛住专讟讜谉

            outputs = model(images)
            logits = outputs.logits  # Extract logits from the output
            _, predicted = torch.max(logits, 1)
            true_labels.extend(labels.cpu().numpy())
            predicted_labels.extend(predicted.cpu().numpy())

    # Calculate evaluation metrics
    accuracy = accuracy_score(true_labels, predicted_labels)
    precision = precision_score(true_labels, predicted_labels)
    cm = confusion_matrix(true_labels, predicted_labels)
    f1 = f1_score(true_labels, predicted_labels)
    ap = average_precision_score(true_labels, predicted_labels)
    recall = recall_score(true_labels, predicted_labels)

    print(f"Test Accuracy: {accuracy:.2%}")
    print(f"Precision: {precision:.2%}")
    print(f"F1 Score: {f1:.2%}")
    print(f"Average Precision: {ap:.2%}")
    print(f"Recall: {recall:.2%}")

    # Plot the confusion matrix
    plt.figure(figsize=(8, 6))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', cbar=False)
    plt.xlabel('Predicted Labels')
    plt.ylabel('True Labels')
    plt.title('Confusion Matrix')
    plt.show()