Spaces:
Sleeping
Sleeping
Update main.py
Browse files
main.py
CHANGED
|
@@ -36,19 +36,32 @@ def read_root():
|
|
| 36 |
def health_check():
|
| 37 |
return {"status": "ok"}
|
| 38 |
|
| 39 |
-
|
| 40 |
-
|
| 41 |
@app.post("/extract", tags=["Extract Embeddings"])
|
| 42 |
async def extract_embeddings(file: UploadFile = File(...)):
|
| 43 |
# Load the image
|
| 44 |
contents = await file.read()
|
| 45 |
image = Image.open(io.BytesIO(contents)).convert('RGB')
|
| 46 |
|
| 47 |
-
#
|
| 48 |
-
|
| 49 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 50 |
# Extract the face embeddings
|
| 51 |
-
|
|
|
|
| 52 |
|
| 53 |
return JSONResponse(content={"embeddings": embeddings})
|
| 54 |
|
|
|
|
| 36 |
def health_check():
|
| 37 |
return {"status": "ok"}
|
| 38 |
|
|
|
|
|
|
|
| 39 |
@app.post("/extract", tags=["Extract Embeddings"])
|
| 40 |
async def extract_embeddings(file: UploadFile = File(...)):
|
| 41 |
# Load the image
|
| 42 |
contents = await file.read()
|
| 43 |
image = Image.open(io.BytesIO(contents)).convert('RGB')
|
| 44 |
|
| 45 |
+
# Detect faces
|
| 46 |
+
faces = mtcnn(image)
|
| 47 |
+
|
| 48 |
+
# Check if any faces were detected
|
| 49 |
+
if faces is None:
|
| 50 |
+
return JSONResponse(content={"error": "No faces detected in the image"})
|
| 51 |
+
|
| 52 |
+
# If faces is a list, take the first face. If it's a tensor, it's already the first (and only) face
|
| 53 |
+
if isinstance(faces, list):
|
| 54 |
+
face = faces[0]
|
| 55 |
+
else:
|
| 56 |
+
face = faces
|
| 57 |
+
|
| 58 |
+
# Ensure the face tensor is 4D (batch_size, channels, height, width)
|
| 59 |
+
if face.dim() == 3:
|
| 60 |
+
face = face.unsqueeze(0)
|
| 61 |
+
|
| 62 |
# Extract the face embeddings
|
| 63 |
+
with torch.no_grad():
|
| 64 |
+
embeddings = resnet(face).cpu().numpy().tolist()
|
| 65 |
|
| 66 |
return JSONResponse(content={"embeddings": embeddings})
|
| 67 |
|