face-shape / app.py
ruminasval's picture
Create app.py
26fd668 verified
raw
history blame
2.86 kB
import streamlit as st
import torch
import cv2
import mediapipe as mp
from transformers import SwinForImageClassification, AutoFeatureExtractor
from PIL import Image
import numpy as np
# Initialize face detection
mp_face_detection = mp.solutions.face_detection.FaceDetection(
model_selection=1, min_detection_confidence=0.5)
# Initialize model and labels
@st.cache_resource
def load_model():
id2label = {0: 'Heart', 1: 'Oblong', 2: 'Oval', 3: 'Round', 4: 'Square'}
label2id = {v: k for k, v in id2label.items()}
model = SwinForImageClassification.from_pretrained(
"microsoft/swin-tiny-patch4-window7-224",
label2id=label2id,
id2label=id2label,
ignore_mismatched_sizes=True
)
model.load_state_dict(torch.load('swin.pth', map_location='cpu'))
model.eval()
return model, AutoFeatureExtractor.from_pretrained("microsoft/swin-tiny-patch4-window7-224")
model, feature_extractor = load_model()
glasses_recommendations = {
"Heart": "Rimless (tanpa bingkai bawah)",
"Oblong": "Kotak",
"Oval": "Berbagai bentuk bingkai",
"Round": "Kotak",
"Square": "Oval atau bundar"
}
def preprocess_image(image):
results = mp_face_detection.process(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
if results.detections:
detection = results.detections[0]
bbox = detection.location_data.relative_bounding_box
h, w, _ = image.shape
x1 = int(bbox.xmin * w)
y1 = int(bbox.ymin * h)
x2 = int((bbox.xmin + bbox.width) * w)
y2 = int((bbox.ymin + bbox.height) * h)
image = image[y1:y2, x1:x2]
else:
raise ValueError("No face detected")
image = cv2.resize(image, (224, 224))
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
return feature_extractor(images=image, return_tensors="pt")['pixel_values']
def predict(image):
image_tensor = preprocess_image(image)
with torch.no_grad():
outputs = model(image_tensor)
return id2label[torch.argmax(outputs.logits).item()]
# Streamlit UI
st.title("Face Shape & Glasses Recommender")
st.write("Upload a face photo for shape analysis and glasses recommendations")
uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
if uploaded_file is not None:
image = Image.open(uploaded_file).convert('RGB')
img_array = np.array(image)
st.image(image, caption='Uploaded Image', use_column_width=True)
try:
with st.spinner('Analyzing...'):
prediction = predict(cv2.cvtColor(img_array, cv2.COLOR_RGB2BGR))
recommendation = glasses_recommendations[prediction]
st.success(f"Predicted Face Shape: **{prediction}**")
st.info(f"Recommended Glasses Frame: **{recommendation}**")
except Exception as e:
st.error(f"Error: {str(e)}")