Spaces:
Sleeping
Sleeping
import streamlit as st | |
import torch | |
import cv2 | |
import mediapipe as mp | |
from transformers import SwinForImageClassification, AutoFeatureExtractor | |
from PIL import Image | |
import numpy as np | |
# Initialize face detection | |
mp_face_detection = mp.solutions.face_detection.FaceDetection( | |
model_selection=1, min_detection_confidence=0.5) | |
# Initialize model and labels | |
def load_model(): | |
id2label = {0: 'Heart', 1: 'Oblong', 2: 'Oval', 3: 'Round', 4: 'Square'} # Moved inside load_model | |
label2id = {v: k for k, v in id2label.items()} | |
model = SwinForImageClassification.from_pretrained( | |
"microsoft/swin-tiny-patch4-window7-224", | |
label2id=label2id, | |
id2label=id2label, | |
ignore_mismatched_sizes=True | |
) | |
model.load_state_dict(torch.load('swin.pth', map_location='cpu')) | |
model.eval() | |
return model, AutoFeatureExtractor.from_pretrained("microsoft/swin-tiny-patch4-window7-224"), id2label # Return id2label | |
# Load model components and labels | |
model, feature_extractor, id2label = load_model() # Receive id2label here | |
glasses_recommendations = { | |
"Heart": "Rimless (tanpa bingkai bawah)", | |
"Oblong": "Kotak", | |
"Oval": "Berbagai bentuk bingkai", | |
"Round": "Kotak", | |
"Square": "Oval atau bundar" | |
} | |
def preprocess_image(image): | |
results = mp_face_detection.process(cv2.cvtColor(image, cv2.COLOR_BGR2RGB)) | |
if results.detections: | |
detection = results.detections[0] | |
bbox = detection.location_data.relative_bounding_box | |
h, w, _ = image.shape | |
x1 = max(0, int(bbox.xmin * w)) | |
y1 = max(0, int(bbox.ymin * h)) | |
x2 = min(w, int((bbox.xmin + bbox.width) * w)) | |
y2 = min(h, int((bbox.ymin + bbox.height) * h)) | |
# Add validation check | |
if (x2 <= x1) or (y2 <= y1) or (x2 - x1 < 10) or (y2 - y1 < 10): | |
raise ValueError("Invalid face crop dimensions") | |
image = image[y1:y2, x1:x2] | |
else: | |
raise ValueError("No face detected") | |
image = cv2.resize(image, (224, 224)) | |
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) | |
# Convert to PIL and extract | |
image_pil = Image.fromarray(image) | |
inputs = feature_extractor(images=image_pil, return_tensors="pt") | |
return inputs['pixel_values'] | |
def predict_face_shape(image): | |
# Force CPU usage on Hugging Face Spaces | |
device = torch.device("cpu") | |
image_tensor = preprocess_image(image).to(device) | |
with torch.no_grad(): | |
outputs = model(image_tensor) | |
predicted_class_idx = torch.argmax(outputs.logits, dim=1).item() | |
return id2label[predicted_class_idx] | |
# Streamlit UI | |
st.title("Face Shape & Glasses Recommender") | |
st.write("Upload a face photo for shape analysis and glasses recommendations") | |
uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"]) | |
if uploaded_file is not None: | |
image = Image.open(uploaded_file).convert('RGB') | |
img_array = np.array(image) | |
st.image(image, caption='Uploaded Image', use_column_width=True) | |
try: | |
with st.spinner('Analyzing...'): | |
# Convert PIL image to OpenCV format correctly | |
prediction = predict_face_shape(cv2.cvtColor(img_array, cv2.COLOR_RGB2BGR)) | |
recommendation = glasses_recommendations[prediction] | |
st.success(f"Predicted Face Shape: **{prediction}**") | |
st.info(f"Recommended Glasses Frame: **{recommendation}**") | |
except Exception as e: | |
st.error(f"Error: {str(e)}") |