face-shape / app.py
ruminasval's picture
Update app.py
efebb94 verified
raw
history blame
3.19 kB
import streamlit as st
import torch
import cv2
import mediapipe as mp
from transformers import SwinForImageClassification, AutoFeatureExtractor
from PIL import Image
import numpy as np
# Initialize face detection
mp_face_detection = mp.solutions.face_detection.FaceDetection(
model_selection=1, min_detection_confidence=0.5)
# Initialize model and labels
@st.cache_resource
def load_model():
id2label = {0: 'Heart', 1: 'Oblong', 2: 'Oval', 3: 'Round', 4: 'Square'}
label2id = {v: k for k, v in id2label.items()}
model = SwinForImageClassification.from_pretrained(
"microsoft/swin-tiny-patch4-window7-224",
label2id=label2id,
id2label=id2label,
ignore_mismatched_sizes=True
)
model.load_state_dict(torch.load('swin.pth', map_location='cpu'))
model.eval()
return model, AutoFeatureExtractor.from_pretrained("microsoft/swin-tiny-patch4-window7-224")
model, feature_extractor = load_model()
glasses_recommendations = {
"Heart": "Rimless (tanpa bingkai bawah)",
"Oblong": "Kotak",
"Oval": "Berbagai bentuk bingkai",
"Round": "Kotak",
"Square": "Oval atau bundar"
}
def preprocess_image(image):
results = mp_face_detection.process(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
if results.detections:
detection = results.detections[0]
bbox = detection.location_data.relative_bounding_box
h, w, _ = image.shape
x1 = int(bbox.xmin * w)
y1 = int(bbox.ymin * h)
x2 = int((bbox.xmin + bbox.width) * w)
y2 = int((bbox.ymin + bbox.height) * h)
image = image[y1:y2, x1:x2]
else:
raise ValueError("No face detected")
image = cv2.resize(image, (224, 224))
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
# Use feature extractor with proper batching
inputs = feature_extractor(
images=[image], # Pass as a list of images
return_tensors="pt",
padding=True # Enable padding as a safety measure
)
return inputs['pixel_values']
def predict_face_shape(image):
image_tensor = preprocess_image(image).to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))
with torch.no_grad():
outputs = model(image_tensor)
predicted_class_idx = torch.argmax(outputs.logits, dim=1).item()
return id2label[predicted_class_idx]
# Streamlit UI
st.title("Face Shape & Glasses Recommender")
st.write("Upload a face photo for shape analysis and glasses recommendations")
uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
if uploaded_file is not None:
image = Image.open(uploaded_file).convert('RGB')
img_array = np.array(image)
st.image(image, caption='Uploaded Image', use_column_width=True)
try:
with st.spinner('Analyzing...'):
prediction = predict(cv2.cvtColor(img_array, cv2.COLOR_RGB2BGR))
recommendation = glasses_recommendations[prediction]
st.success(f"Predicted Face Shape: **{prediction}**")
st.info(f"Recommended Glasses Frame: **{recommendation}**")
except Exception as e:
st.error(f"Error: {str(e)}")