Faceshape / app.py
ruminasval's picture
Update app.py
0328ba0 verified
raw
history blame
5.04 kB
import gradio as gr
import torch
from transformers import SwinForImageClassification, AutoFeatureExtractor
import mediapipe as mp
import cv2
from PIL import Image
import numpy as np
import os
# Face shape descriptions
face_shape_descriptions = {
"Heart": "dengan dahi lebar dan dagu yang runcing.",
"Oblong": "yang lebih panjang dari lebar dengan garis pipi lurus.",
"Oval": "dengan proporsi seimbang dan dagu sedikit melengkung.",
"Round": "dengan garis rahang melengkung dan pipi penuh.",
"Square": "dengan rahang tegas dan dahi lebar."
}
# Frame images path
glasses_images = {
"Heart": "Glasses/RimlessFrame.jpg",
"Oblong": "Glasses/RectangleFrame.jpg",
"Oval": "Glasses/RoundFrame.jpg",
"Round": "Glasses/SquareFrame.jpg",
"Square": "Glasses/OvalFrame.jpg"
}
id2label = {0: 'Heart', 1: 'Oblong', 2: 'Oval', 3: 'Round', 4: 'Square'}
label2id = {v: k for k, v in id2label.items()}
# Load model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_checkpoint = "microsoft/swin-tiny-patch4-window7-224"
feature_extractor = AutoFeatureExtractor.from_pretrained(model_checkpoint)
model = SwinForImageClassification.from_pretrained(
model_checkpoint,
label2id=label2id,
id2label=id2label,
ignore_mismatched_sizes=True
)
# Load your trained weights
state_dict = torch.load('LR-0001-adamW-32-64swin.pth', map_location=device)
model.load_state_dict(state_dict, strict=False)
model.to(device)
model.eval()
# Initialize Mediapipe
mp_face_detection = mp.solutions.face_detection.FaceDetection(model_selection=1, min_detection_confidence=0.5)
# --- New: Decision tree function
def recommend_glasses_tree(face_shape):
face_shape = face_shape.lower()
if face_shape == "square":
return ["Oval", "Round"]
elif face_shape == "round":
return ["Square", "Octagon", "Cat Eye"]
elif face_shape == "oval":
return ["Oval", "Pilot (Aviator)", "Cat Eye", "Round"]
elif face_shape == "heart":
return ["Pilot (Aviator)", "Cat Eye", "Round"]
elif face_shape == "oblong":
return ["Square", "Oval", "Pilot (Aviator)", "Cat Eye"]
else:
return []
# Preprocess function
def preprocess_image(image):
img = np.array(image)
img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
results = mp_face_detection.process(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
if results.detections:
detection = results.detections[0]
bbox = detection.location_data.relative_bounding_box
h, w, _ = img.shape
x1 = int(bbox.xmin * w)
y1 = int(bbox.ymin * h)
x2 = int((bbox.xmin + bbox.width) * w)
y2 = int((bbox.ymin + bbox.height) * h)
img = img[y1:y2, x1:x2]
else:
raise ValueError("Wajah tidak terdeteksi.")
img = cv2.resize(img, (224, 224))
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
inputs = feature_extractor(images=img, return_tensors="pt")
return inputs['pixel_values'].squeeze(0)
# Prediction function
def predict(image):
try:
inputs = preprocess_image(image).unsqueeze(0).to(device)
with torch.no_grad():
outputs = model(inputs)
probs = torch.nn.functional.softmax(outputs.logits, dim=1)
pred_idx = torch.argmax(probs, dim=1).item()
pred_label = id2label[pred_idx]
pred_prob = probs[0][pred_idx].item() * 100
# --- Use decision tree for recommendations
frame_recommendations = recommend_glasses_tree(pred_label)
description = face_shape_descriptions[pred_label]
frame_image_path = glasses_images.get(pred_label)
# Build explanation text
if frame_recommendations:
recommended_frames = ', '.join(frame_recommendations)
explanation = (f"Bentuk wajah kamu adalah {pred_label} ({pred_prob:.2f}%). "
f"Kamu memiliki bentuk wajah {description} "
f"Rekomendasi bentuk kacamata yang sesuai dengan wajah kamu adalah: {recommended_frames}.")
else:
explanation = (f"Bentuk wajah kamu adalah {pred_label} ({pred_prob:.2f}%). "
f"Tidak ada rekomendasi frame untuk bentuk wajah ini.")
# Load frame image if available
if frame_image_path and os.path.exists(frame_image_path):
frame_image = Image.open(frame_image_path)
else:
frame_image = None
return pred_label, explanation, frame_image
except Exception as e:
return "Error", str(e), None
# Gradio Interface
iface = gr.Interface(
fn=predict,
inputs=gr.Image(type="pil"),
outputs=[
gr.Textbox(label="Bentuk Wajah Terdeteksi"),
gr.Textbox(label="Rekomendasi dan Penjelasan"),
gr.Image(label="Gambar Frame Rekomendasi")
],
title="Rekomendasi Kacamata Berdasarkan Bentuk Wajah",
description="Upload foto wajahmu untuk mendapatkan rekomendasi bentuk kacamata yang sesuai!"
)
if __name__ == "__main__":
iface.launch()