Midnightar commited on
Commit
929b85d
·
verified ·
1 Parent(s): 52a2f69

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +30 -17
app.py CHANGED
@@ -1,45 +1,58 @@
1
  import os
2
- os.environ["HF_HOME"] = "/tmp/hf_cache"
3
- os.environ["TRANSFORMERS_CACHE"] = "/tmp/hf_cache"
4
-
5
- import io
6
  import torch
7
  from fastapi import FastAPI, File, UploadFile
8
  from fastapi.responses import JSONResponse, HTMLResponse
9
  from transformers import AutoImageProcessor, AutoModelForImageClassification
10
  from PIL import Image
11
 
12
- # Load model and processor
 
 
 
13
  processor = AutoImageProcessor.from_pretrained("prithivMLmods/Realistic-Gender-Classification")
14
  model = AutoModelForImageClassification.from_pretrained("prithivMLmods/Realistic-Gender-Classification")
15
 
16
- # FastAPI app
17
  app = FastAPI()
18
 
19
  @app.get("/", response_class=HTMLResponse)
20
  async def home():
21
- return '''
22
  <html>
23
  <body>
24
- <h2>Upload an Image for Gender Detection</h2>
25
  <form action="/predict" enctype="multipart/form-data" method="post">
26
  <input name="file" type="file" accept="image/*">
27
  <input type="submit" value="Upload">
28
  </form>
29
  </body>
30
  </html>
31
- '''
32
 
33
  @app.post("/predict")
34
  async def predict(file: UploadFile = File(...)):
35
- image = Image.open(io.BytesIO(await file.read())).convert("RGB")
36
- inputs = processor(images=image, return_tensors="pt")
 
 
 
 
 
 
 
 
 
 
 
 
37
 
38
- with torch.no_grad():
39
- logits = model(**inputs).logits
40
- probs = torch.nn.functional.softmax(logits, dim=-1).cpu().numpy()[0]
 
 
41
 
42
- labels = model.config.id2label
43
- result = {labels[i]: float(probs[i]) for i in range(len(labels))}
44
 
45
- return JSONResponse(content=result)
 
 
1
  import os
 
 
 
 
2
  import torch
3
  from fastapi import FastAPI, File, UploadFile
4
  from fastapi.responses import JSONResponse, HTMLResponse
5
  from transformers import AutoImageProcessor, AutoModelForImageClassification
6
  from PIL import Image
7
 
8
+ # Set Hugging Face cache to avoid permission issues
9
+ os.environ["HF_HOME"] = "/tmp/hf_cache"
10
+
11
+ # Load processor + model
12
  processor = AutoImageProcessor.from_pretrained("prithivMLmods/Realistic-Gender-Classification")
13
  model = AutoModelForImageClassification.from_pretrained("prithivMLmods/Realistic-Gender-Classification")
14
 
15
+ # Create FastAPI app
16
  app = FastAPI()
17
 
18
  @app.get("/", response_class=HTMLResponse)
19
  async def home():
20
+ return """
21
  <html>
22
  <body>
23
+ <h2>Upload Image for Gender Detection</h2>
24
  <form action="/predict" enctype="multipart/form-data" method="post">
25
  <input name="file" type="file" accept="image/*">
26
  <input type="submit" value="Upload">
27
  </form>
28
  </body>
29
  </html>
30
+ """
31
 
32
  @app.post("/predict")
33
  async def predict(file: UploadFile = File(...)):
34
+ try:
35
+ # Load image
36
+ image = Image.open(file.file).convert("RGB")
37
+
38
+ # Preprocess
39
+ inputs = processor(images=image, return_tensors="pt")
40
+
41
+ # Predict
42
+ with torch.no_grad():
43
+ outputs = model(**inputs)
44
+ probs = torch.nn.functional.softmax(outputs.logits, dim=-1)[0].cpu().numpy()
45
+
46
+ # Get labels (ensure consistent order)
47
+ labels = list(model.config.id2label.values())
48
 
49
+ # Fix keys: return "male" and "female" only
50
+ result = {
51
+ "female": float(probs[labels.index("female portrait")]),
52
+ "male": float(probs[labels.index("male portrait")])
53
+ }
54
 
55
+ return JSONResponse(content=result)
 
56
 
57
+ except Exception as e:
58
+ return JSONResponse(content={"error": str(e)}, status_code=500)