Spaces:
Running
Running
| # app.py | |
| import torch | |
| from transformers import SwinForImageClassification, AutoFeatureExtractor | |
| import cv2 | |
| import mediapipe as mp | |
| import matplotlib.pyplot as plt | |
| from PIL import Image | |
| import gradio as gr | |
| import os | |
| import numpy as np | |
| # Initialize id2label and label2id | |
| id2label = {0: 'Heart', 1: 'Oblong', 2: 'Oval', 3: 'Round', 4: 'Square'} | |
| label2id = {v: k for k, v in id2label.items()} | |
| # Initialize glasses recommendations | |
| glasses_recommendations = { | |
| "Heart": "Frame Rimless", | |
| "Oblong": "Frame Persegi Panjang", | |
| "Oval": "Frame Bulat", | |
| "Round": "Frame Kotak", | |
| "Square": "Frame Oval" | |
| } | |
| # Glasses images should be in the repo (e.g., "glasses/Heart.jpg") | |
| glasses_images = { | |
| "Heart": "glasses/RimlessFrame.jpg", | |
| "Oblong": "glasses/RectangleFrame.jpg", | |
| "Oval": "glasses/RoundFrame.jpg", | |
| "Round": "glasses/SquareFrame.jpg", | |
| "Square": "glasses/OvalFrame.jpg" | |
| } | |
| # Load model | |
| model_checkpoint = "microsoft/swin-tiny-patch4-window7-224" | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| model = SwinForImageClassification.from_pretrained( | |
| model_checkpoint, | |
| label2id=label2id, | |
| id2label=id2label, | |
| ignore_mismatched_sizes=True | |
| ) | |
| # Load your fine-tuned model weights (uploaded into Space!) | |
| model.load_state_dict(torch.load('LR-0001-adamW-32-64swin.pth', map_location=device), strict=False) | |
| model = model.to(device) | |
| model.eval() | |
| # Initialize feature extractor | |
| feature_extractor = AutoFeatureExtractor.from_pretrained(model_checkpoint) | |
| # Initialize Mediapipe Face Detection | |
| mp_face_detection = mp.solutions.face_detection.FaceDetection(model_selection=1, min_detection_confidence=0.5) | |
| # Preprocess image | |
| def preprocess_image(image): | |
| image = np.array(image) | |
| image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) | |
| results = mp_face_detection.process(cv2.cvtColor(image, cv2.COLOR_BGR2RGB)) | |
| if results.detections: | |
| detection = results.detections[0] | |
| bbox = detection.location_data.relative_bounding_box | |
| h, w, _ = image.shape | |
| x1 = int(bbox.xmin * w) | |
| y1 = int(bbox.ymin * h) | |
| x2 = int((bbox.xmin + bbox.width) * w) | |
| y2 = int((bbox.ymin + bbox.height) * h) | |
| face = image[y1:y2, x1:x2] | |
| else: | |
| raise ValueError("No face detected in the image.") | |
| face = cv2.resize(face, (224, 224)) | |
| face = cv2.cvtColor(face, cv2.COLOR_BGR2RGB) | |
| pixel_values = feature_extractor(images=face, return_tensors="pt")['pixel_values'] | |
| return pixel_values.squeeze(0) | |
| # Prediction | |
| def predict(image): | |
| try: | |
| image_tensor = preprocess_image(image) | |
| image_tensor = image_tensor.unsqueeze(0).to(device) | |
| with torch.no_grad(): | |
| outputs = model(image_tensor) | |
| logits = outputs.logits | |
| probabilities = torch.nn.functional.softmax(logits, dim=1).squeeze(0) | |
| sorted_probs = sorted([(id2label[i], probabilities[i].item() * 100) for i in range(len(probabilities))], key=lambda x: x[1], reverse=True) | |
| predicted_label, predicted_prob = sorted_probs[0] | |
| all_probs = {label: (f"{prob:.2f}%", glasses_recommendations[label]) for label, prob in sorted_probs} | |
| # Prepare result text | |
| result_text = f"Bentuk Wajah: {predicted_label} ({predicted_prob:.2f}%)\n\n" | |
| result_text += "Probabilitas Setiap Kelas:\n" | |
| for label, (prob, recommendation) in all_probs.items(): | |
| result_text += f"{label}: {prob} - Rekomendasi Kacamata: {recommendation}\n" | |
| # Prepare glasses image | |
| glasses_image_path = glasses_images.get(predicted_label, None) | |
| glasses_img = None | |
| if glasses_image_path and os.path.exists(glasses_image_path): | |
| glasses_img = Image.open(glasses_image_path) | |
| return result_text, glasses_img | |
| except Exception as e: | |
| return f"Error: {str(e)}", None | |
| # Gradio Interface | |
| demo = gr.Interface( | |
| fn=predict, | |
| inputs=gr.Image(type="pil"), | |
| outputs=[gr.Textbox(label="Hasil Prediksi"), gr.Image(label="Rekomendasi Kacamata")], | |
| title="Deteksi Bentuk Wajah & Rekomendasi Kacamata", | |
| description="Upload gambar wajahmu untuk mendapatkan bentuk wajah dan rekomendasi kacamata!" | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |