Midnightar commited on
Commit
8f136e1
·
verified ·
1 Parent(s): bf22ef4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -42
app.py CHANGED
@@ -1,68 +1,49 @@
1
  import os
 
 
 
 
2
  import torch
3
  from fastapi import FastAPI, File, UploadFile
4
  from fastapi.responses import JSONResponse, HTMLResponse
5
  from transformers import AutoImageProcessor, AutoModelForImageClassification
6
  from PIL import Image
7
 
8
- # Force cache to /tmp/hf_cache before anything else
9
- os.environ["TRANSFORMERS_CACHE"] = "/tmp/hf_cache"
10
- os.environ["HF_HOME"] = "/tmp/hf_cache"
11
- os.environ["HF_DATASETS_CACHE"] = "/tmp/hf_cache"
12
-
13
- # Create cache directory if missing
14
- os.makedirs("/tmp/hf_cache", exist_ok=True)
15
-
16
- # Load processor + model
17
- processor = AutoImageProcessor.from_pretrained(
18
- "prithivMLmods/Realistic-Gender-Classification", cache_dir="/tmp/hf_cache"
19
- )
20
- model = AutoModelForImageClassification.from_pretrained(
21
- "prithivMLmods/Realistic-Gender-Classification", cache_dir="/tmp/hf_cache"
22
- )
23
 
24
- # Create FastAPI app
25
  app = FastAPI()
26
 
27
  @app.get("/", response_class=HTMLResponse)
28
  async def home():
29
- return """
30
  <html>
31
  <body>
32
- <h2>Upload Image for Gender Detection</h2>
33
  <form action="/predict" enctype="multipart/form-data" method="post">
34
  <input name="file" type="file" accept="image/*">
35
  <input type="submit" value="Upload">
36
  </form>
37
  </body>
38
  </html>
39
- """
40
 
41
  @app.post("/predict")
42
  async def predict(file: UploadFile = File(...)):
43
- try:
44
- # Load image
45
- image = Image.open(file.file).convert("RGB")
46
-
47
- # Preprocess
48
- inputs = processor(images=image, return_tensors="pt")
49
-
50
- # Predict
51
- with torch.no_grad():
52
- outputs = model(**inputs)
53
- probs = torch.nn.functional.softmax(outputs.logits, dim=-1)[0].cpu().numpy()
54
-
55
- # Get labels
56
- labels = list(model.config.id2label.values())
57
 
58
- # Clean result for FlutterFlow
59
- result = {
60
- "female": float(probs[labels.index("female portrait")]),
61
- "male": float(probs[labels.index("male portrait")])
62
- }
63
 
64
- return JSONResponse(content=result)
 
65
 
66
- except Exception as e:
67
- return JSONResponse(content={"error": str(e)}, status_code=500)
68
-
 
 
 
1
  import os
2
+ os.environ["HF_HOME"] = "/tmp/hf_cache"
3
+ os.environ["TRANSFORMERS_CACHE"] = "/tmp/hf_cache"
4
+
5
+ import io
6
  import torch
7
  from fastapi import FastAPI, File, UploadFile
8
  from fastapi.responses import JSONResponse, HTMLResponse
9
  from transformers import AutoImageProcessor, AutoModelForImageClassification
10
  from PIL import Image
11
 
12
+ # Load model and processor
13
+ processor = AutoImageProcessor.from_pretrained("prithivMLmods/Realistic-Gender-Classification")
14
+ model = AutoModelForImageClassification.from_pretrained("prithivMLmods/Realistic-Gender-Classification")
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
+ # FastAPI app
17
  app = FastAPI()
18
 
19
  @app.get("/", response_class=HTMLResponse)
20
  async def home():
21
+ return '''
22
  <html>
23
  <body>
24
+ <h2>Upload an Image for Gender Detection</h2>
25
  <form action="/predict" enctype="multipart/form-data" method="post">
26
  <input name="file" type="file" accept="image/*">
27
  <input type="submit" value="Upload">
28
  </form>
29
  </body>
30
  </html>
31
+ '''
32
 
33
  @app.post("/predict")
34
  async def predict(file: UploadFile = File(...)):
35
+ image = Image.open(io.BytesIO(await file.read())).convert("RGB")
36
+ inputs = processor(images=image, return_tensors="pt")
 
 
 
 
 
 
 
 
 
 
 
 
37
 
38
+ with torch.no_grad():
39
+ logits = model(**inputs).logits
40
+ probs = torch.nn.functional.softmax(logits, dim=-1).cpu().numpy()[0]
 
 
41
 
42
+ labels = model.config.id2label
43
+ result = {labels[i]: float(probs[i]) for i in range(len(labels))}
44
 
45
+ result = {
46
+ "female": float(probs[0]),
47
+ "male": float(probs[1])
48
+ }
49
+ return JSONResponse(content=result)