ruminasval commited on
Commit
efebb94
·
verified ·
1 Parent(s): 110be59

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -4
app.py CHANGED
@@ -54,13 +54,23 @@ def preprocess_image(image):
54
 
55
  image = cv2.resize(image, (224, 224))
56
  image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
57
- return feature_extractor(images=image, return_tensors="pt")['pixel_values']
 
 
 
 
 
 
 
58
 
59
- def predict(image):
60
- image_tensor = preprocess_image(image)
 
61
  with torch.no_grad():
62
  outputs = model(image_tensor)
63
- return id2label[torch.argmax(outputs.logits).item()]
 
 
64
 
65
  # Streamlit UI
66
  st.title("Face Shape & Glasses Recommender")
 
54
 
55
  image = cv2.resize(image, (224, 224))
56
  image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
57
+
58
+ # Use feature extractor with proper batching
59
+ inputs = feature_extractor(
60
+ images=[image], # Pass as a list of images
61
+ return_tensors="pt",
62
+ padding=True # Enable padding as a safety measure
63
+ )
64
+ return inputs['pixel_values']
65
 
66
+ def predict_face_shape(image):
67
+ image_tensor = preprocess_image(image).to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))
68
+
69
  with torch.no_grad():
70
  outputs = model(image_tensor)
71
+ predicted_class_idx = torch.argmax(outputs.logits, dim=1).item()
72
+
73
+ return id2label[predicted_class_idx]
74
 
75
  # Streamlit UI
76
  st.title("Face Shape & Glasses Recommender")