File size: 4,220 Bytes
9374fca
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
# app.py

import torch
from transformers import SwinForImageClassification, AutoFeatureExtractor
import cv2
import mediapipe as mp
import matplotlib.pyplot as plt
from PIL import Image
import gradio as gr
import os
import numpy as np

# Initialize id2label and label2id
id2label = {0: 'Heart', 1: 'Oblong', 2: 'Oval', 3: 'Round', 4: 'Square'}
label2id = {v: k for k, v in id2label.items()}

# Initialize glasses recommendations
glasses_recommendations = {
    "Heart": "Frame Rimless",
    "Oblong": "Frame Persegi Panjang",
    "Oval": "Frame Bulat",
    "Round": "Frame Kotak",
    "Square": "Frame Oval"
}

# Glasses images should be in the repo (e.g., "glasses/Heart.jpg")
glasses_images = {
    "Heart": "glasses/RimlessFrame.jpg",
    "Oblong": "glasses/RectangleFrame.jpg",
    "Oval": "glasses/RoundFrame.jpg",
    "Round": "glasses/SquareFrame.jpg",
    "Square": "glasses/OvalFrame.jpg"
}

# Load model
model_checkpoint = "microsoft/swin-tiny-patch4-window7-224"
device = "cuda" if torch.cuda.is_available() else "cpu"

model = SwinForImageClassification.from_pretrained(
    model_checkpoint,
    label2id=label2id,
    id2label=id2label,
    ignore_mismatched_sizes=True
)

# Load your fine-tuned model weights (uploaded into Space!)
model.load_state_dict(torch.load('LR-0001-adamW-32-64swin.pth', map_location=device), strict=False)
model = model.to(device)
model.eval()

# Initialize feature extractor
feature_extractor = AutoFeatureExtractor.from_pretrained(model_checkpoint)

# Initialize Mediapipe Face Detection
mp_face_detection = mp.solutions.face_detection.FaceDetection(model_selection=1, min_detection_confidence=0.5)

# Preprocess image
def preprocess_image(image):
    image = np.array(image)
    image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
    results = mp_face_detection.process(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))

    if results.detections:
        detection = results.detections[0]
        bbox = detection.location_data.relative_bounding_box
        h, w, _ = image.shape
        x1 = int(bbox.xmin * w)
        y1 = int(bbox.ymin * h)
        x2 = int((bbox.xmin + bbox.width) * w)
        y2 = int((bbox.ymin + bbox.height) * h)

        face = image[y1:y2, x1:x2]
    else:
        raise ValueError("No face detected in the image.")

    face = cv2.resize(face, (224, 224))
    face = cv2.cvtColor(face, cv2.COLOR_BGR2RGB)
    pixel_values = feature_extractor(images=face, return_tensors="pt")['pixel_values']

    return pixel_values.squeeze(0)

# Prediction
def predict(image):
    try:
        image_tensor = preprocess_image(image)
        image_tensor = image_tensor.unsqueeze(0).to(device)

        with torch.no_grad():
            outputs = model(image_tensor)
            logits = outputs.logits
            probabilities = torch.nn.functional.softmax(logits, dim=1).squeeze(0)
            sorted_probs = sorted([(id2label[i], probabilities[i].item() * 100) for i in range(len(probabilities))], key=lambda x: x[1], reverse=True)

        predicted_label, predicted_prob = sorted_probs[0]
        all_probs = {label: (f"{prob:.2f}%", glasses_recommendations[label]) for label, prob in sorted_probs}

        # Prepare result text
        result_text = f"Bentuk Wajah: {predicted_label} ({predicted_prob:.2f}%)\n\n"
        result_text += "Probabilitas Setiap Kelas:\n"
        for label, (prob, recommendation) in all_probs.items():
            result_text += f"{label}: {prob} - Rekomendasi Kacamata: {recommendation}\n"

        # Prepare glasses image
        glasses_image_path = glasses_images.get(predicted_label, None)
        glasses_img = None
        if glasses_image_path and os.path.exists(glasses_image_path):
            glasses_img = Image.open(glasses_image_path)

        return result_text, glasses_img

    except Exception as e:
        return f"Error: {str(e)}", None

# Gradio Interface
demo = gr.Interface(
    fn=predict,
    inputs=gr.Image(type="pil"),
    outputs=[gr.Textbox(label="Hasil Prediksi"), gr.Image(label="Rekomendasi Kacamata")],
    title="Deteksi Bentuk Wajah & Rekomendasi Kacamata",
    description="Upload gambar wajahmu untuk mendapatkan bentuk wajah dan rekomendasi kacamata!"
)

if __name__ == "__main__":
    demo.launch()