File size: 1,243 Bytes
212f1ad
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
from fastapi import FastAPI, File, UploadFile
from transformers import ViTImageProcessor, ViTForImageClassification
from PIL import Image
import torch
import io

MODEL_NAME = "ahmed-masoud/sign_language_translator"


try:
    processor = ViTImageProcessor.from_pretrained(MODEL_NAME)
    
    model = ViTForImageClassification.from_pretrained(MODEL_NAME)
    
    print(f"Modelo '{MODEL_NAME}' cargado")
    
except Exception as e:
    print(f"Error al cargar el modelo {e}")
    model = None
    processor = None

app = FastAPI(title="API de ASL con modelo de HF")


@app.post("/predict/")
async def translate_sign(file: UploadFile = File(...)):
    if not model or not processor:
        return {"error": "Modelo no disponible."}

    image_bytes = await file.read()
    image = Image.open(io.BytesIO(image_bytes))

    inputs = processor(images=image, return_tensors="pt")

    with torch.no_grad():
        outputs = model(**inputs)
        logits = outputs.logits

    predicted_class_idx = logits.argmax(-1).item()
    
    predicted_label = model.config.id2label[predicted_class_idx]

    return {"prediction": predicted_label}


@app.get("/")
def read_root():
    return {"message": "API ok. Usa el endpoint /predict/ para predecir."}