Spaces:
Running
Running
| import gradio as gr | |
| import torch | |
| from transformers import SwinForImageClassification, AutoFeatureExtractor | |
| import mediapipe as mp | |
| import cv2 | |
| from PIL import Image | |
| import numpy as np | |
| import os | |
| # Face shape descriptions | |
| face_shape_descriptions = { | |
| "Heart": "dengan dahi lebar dan dagu yang runcing.", | |
| "Oblong": "yang lebih panjang dari lebar dengan garis pipi lurus.", | |
| "Oval": "dengan proporsi seimbang dan dagu sedikit melengkung.", | |
| "Round": "dengan garis rahang melengkung dan pipi penuh.", | |
| "Square": "dengan rahang tegas dan dahi lebar." | |
| } | |
| # Frame images path | |
| glasses_images = { | |
| "Heart": "Glasses/RimlessFrame.jpg", | |
| "Oblong": "Glasses/RectangleFrame.jpg", | |
| "Oval": "Glasses/RoundFrame.jpg", | |
| "Round": "Glasses/SquareFrame.jpg", | |
| "Square": "Glasses/OvalFrame.jpg" | |
| } | |
| id2label = {0: 'Heart', 1: 'Oblong', 2: 'Oval', 3: 'Round', 4: 'Square'} | |
| label2id = {v: k for k, v in id2label.items()} | |
| # Load model | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| model_checkpoint = "microsoft/swin-tiny-patch4-window7-224" | |
| feature_extractor = AutoFeatureExtractor.from_pretrained(model_checkpoint) | |
| model = SwinForImageClassification.from_pretrained( | |
| model_checkpoint, | |
| label2id=label2id, | |
| id2label=id2label, | |
| ignore_mismatched_sizes=True | |
| ) | |
| # Load your trained weights | |
| state_dict = torch.load('LR-0001-adamW-32-64swin.pth', map_location=device) | |
| model.load_state_dict(state_dict, strict=False) | |
| model.to(device) | |
| model.eval() | |
| # Initialize Mediapipe | |
| mp_face_detection = mp.solutions.face_detection.FaceDetection(model_selection=1, min_detection_confidence=0.5) | |
| # --- New: Decision tree function | |
| def recommend_glasses_tree(face_shape): | |
| face_shape = face_shape.lower() | |
| if face_shape == "square": | |
| return ["Oval", "Round"] | |
| elif face_shape == "round": | |
| return ["Square", "Octagon", "Cat Eye"] | |
| elif face_shape == "oval": | |
| return ["Oval", "Pilot (Aviator)", "Cat Eye", "Round"] | |
| elif face_shape == "heart": | |
| return ["Pilot (Aviator)", "Cat Eye", "Round"] | |
| elif face_shape == "oblong": | |
| return ["Square", "Oval", "Pilot (Aviator)", "Cat Eye"] | |
| else: | |
| return [] | |
| # Preprocess function | |
| def preprocess_image(image): | |
| img = np.array(image) | |
| img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) | |
| results = mp_face_detection.process(cv2.cvtColor(img, cv2.COLOR_BGR2RGB)) | |
| if results.detections: | |
| detection = results.detections[0] | |
| bbox = detection.location_data.relative_bounding_box | |
| h, w, _ = img.shape | |
| x1 = int(bbox.xmin * w) | |
| y1 = int(bbox.ymin * h) | |
| x2 = int((bbox.xmin + bbox.width) * w) | |
| y2 = int((bbox.ymin + bbox.height) * h) | |
| img = img[y1:y2, x1:x2] | |
| else: | |
| raise ValueError("Wajah tidak terdeteksi.") | |
| img = cv2.resize(img, (224, 224)) | |
| img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) | |
| inputs = feature_extractor(images=img, return_tensors="pt") | |
| return inputs['pixel_values'].squeeze(0) | |
| # Prediction function | |
| def predict(image): | |
| try: | |
| inputs = preprocess_image(image).unsqueeze(0).to(device) | |
| with torch.no_grad(): | |
| outputs = model(inputs) | |
| probs = torch.nn.functional.softmax(outputs.logits, dim=1) | |
| pred_idx = torch.argmax(probs, dim=1).item() | |
| pred_label = id2label[pred_idx] | |
| pred_prob = probs[0][pred_idx].item() * 100 | |
| # --- Use decision tree for recommendations | |
| frame_recommendations = recommend_glasses_tree(pred_label) | |
| description = face_shape_descriptions[pred_label] | |
| frame_image_path = glasses_images.get(pred_label) | |
| # Build explanation text | |
| if frame_recommendations: | |
| recommended_frames = ', '.join(frame_recommendations) | |
| explanation = (f"Bentuk wajah kamu adalah {pred_label} ({pred_prob:.2f}%). " | |
| f"Kamu memiliki bentuk wajah {description} " | |
| f"Rekomendasi bentuk kacamata yang sesuai dengan wajah kamu adalah: {recommended_frames}.") | |
| else: | |
| explanation = (f"Bentuk wajah kamu adalah {pred_label} ({pred_prob:.2f}%). " | |
| f"Tidak ada rekomendasi frame untuk bentuk wajah ini.") | |
| # Load frame image if available | |
| if frame_image_path and os.path.exists(frame_image_path): | |
| frame_image = Image.open(frame_image_path) | |
| else: | |
| frame_image = None | |
| return pred_label, explanation, frame_image | |
| except Exception as e: | |
| return "Error", str(e), None | |
| # Gradio Interface | |
| iface = gr.Interface( | |
| fn=predict, | |
| inputs=gr.Image(type="pil"), | |
| outputs=[ | |
| gr.Textbox(label="Bentuk Wajah Terdeteksi"), | |
| gr.Textbox(label="Rekomendasi dan Penjelasan"), | |
| gr.Image(label="Gambar Frame Rekomendasi") | |
| ], | |
| title="Rekomendasi Kacamata Berdasarkan Bentuk Wajah", | |
| description="Upload foto wajahmu untuk mendapatkan rekomendasi bentuk kacamata yang sesuai!" | |
| ) | |
| if __name__ == "__main__": | |
| iface.launch() | |