Faceshape / app.py
ruminasval's picture
Create app.py
9374fca verified
raw
history blame
4.22 kB
# app.py
import torch
from transformers import SwinForImageClassification, AutoFeatureExtractor
import cv2
import mediapipe as mp
import matplotlib.pyplot as plt
from PIL import Image
import gradio as gr
import os
import numpy as np
# Initialize id2label and label2id
id2label = {0: 'Heart', 1: 'Oblong', 2: 'Oval', 3: 'Round', 4: 'Square'}
label2id = {v: k for k, v in id2label.items()}
# Initialize glasses recommendations
glasses_recommendations = {
"Heart": "Frame Rimless",
"Oblong": "Frame Persegi Panjang",
"Oval": "Frame Bulat",
"Round": "Frame Kotak",
"Square": "Frame Oval"
}
# Glasses images should be in the repo (e.g., "glasses/Heart.jpg")
glasses_images = {
"Heart": "glasses/RimlessFrame.jpg",
"Oblong": "glasses/RectangleFrame.jpg",
"Oval": "glasses/RoundFrame.jpg",
"Round": "glasses/SquareFrame.jpg",
"Square": "glasses/OvalFrame.jpg"
}
# Load model
model_checkpoint = "microsoft/swin-tiny-patch4-window7-224"
device = "cuda" if torch.cuda.is_available() else "cpu"
model = SwinForImageClassification.from_pretrained(
model_checkpoint,
label2id=label2id,
id2label=id2label,
ignore_mismatched_sizes=True
)
# Load your fine-tuned model weights (uploaded into Space!)
model.load_state_dict(torch.load('LR-0001-adamW-32-64swin.pth', map_location=device), strict=False)
model = model.to(device)
model.eval()
# Initialize feature extractor
feature_extractor = AutoFeatureExtractor.from_pretrained(model_checkpoint)
# Initialize Mediapipe Face Detection
mp_face_detection = mp.solutions.face_detection.FaceDetection(model_selection=1, min_detection_confidence=0.5)
# Preprocess image
def preprocess_image(image):
image = np.array(image)
image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
results = mp_face_detection.process(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
if results.detections:
detection = results.detections[0]
bbox = detection.location_data.relative_bounding_box
h, w, _ = image.shape
x1 = int(bbox.xmin * w)
y1 = int(bbox.ymin * h)
x2 = int((bbox.xmin + bbox.width) * w)
y2 = int((bbox.ymin + bbox.height) * h)
face = image[y1:y2, x1:x2]
else:
raise ValueError("No face detected in the image.")
face = cv2.resize(face, (224, 224))
face = cv2.cvtColor(face, cv2.COLOR_BGR2RGB)
pixel_values = feature_extractor(images=face, return_tensors="pt")['pixel_values']
return pixel_values.squeeze(0)
# Prediction
def predict(image):
try:
image_tensor = preprocess_image(image)
image_tensor = image_tensor.unsqueeze(0).to(device)
with torch.no_grad():
outputs = model(image_tensor)
logits = outputs.logits
probabilities = torch.nn.functional.softmax(logits, dim=1).squeeze(0)
sorted_probs = sorted([(id2label[i], probabilities[i].item() * 100) for i in range(len(probabilities))], key=lambda x: x[1], reverse=True)
predicted_label, predicted_prob = sorted_probs[0]
all_probs = {label: (f"{prob:.2f}%", glasses_recommendations[label]) for label, prob in sorted_probs}
# Prepare result text
result_text = f"Bentuk Wajah: {predicted_label} ({predicted_prob:.2f}%)\n\n"
result_text += "Probabilitas Setiap Kelas:\n"
for label, (prob, recommendation) in all_probs.items():
result_text += f"{label}: {prob} - Rekomendasi Kacamata: {recommendation}\n"
# Prepare glasses image
glasses_image_path = glasses_images.get(predicted_label, None)
glasses_img = None
if glasses_image_path and os.path.exists(glasses_image_path):
glasses_img = Image.open(glasses_image_path)
return result_text, glasses_img
except Exception as e:
return f"Error: {str(e)}", None
# Gradio Interface
demo = gr.Interface(
fn=predict,
inputs=gr.Image(type="pil"),
outputs=[gr.Textbox(label="Hasil Prediksi"), gr.Image(label="Rekomendasi Kacamata")],
title="Deteksi Bentuk Wajah & Rekomendasi Kacamata",
description="Upload gambar wajahmu untuk mendapatkan bentuk wajah dan rekomendasi kacamata!"
)
if __name__ == "__main__":
demo.launch()