dino_n8n / app.py
yonadab's picture
Update app.py
533df08 verified
from fastapi import FastAPI, File, UploadFile, Form, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from transformers import AutoImageProcessor, AutoModel
from PIL import Image
import torch
import uuid
import io
app = FastAPI()
# Habilita CORS si lo necesitas
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Carga del modelo
processor = AutoImageProcessor.from_pretrained("facebook/dinov2-small")
model = AutoModel.from_pretrained("facebook/dinov2-small")
model.eval()
# Memoria temporal para almacenar imágenes (podrías usar base de datos si prefieres)
temp_images = {}
event_ids = {}
# Paso 1: Subida de imagen + event_id
@app.post("/upload")
async def upload_image(file: UploadFile = File(...), event_id: str = Form(...)):
try:
content = await file.read()
image_id = str(uuid.uuid4())
temp_images[image_id] = content
event_ids[image_id] = event_id
return {"image_id": image_id, "event_id": event_id}
except Exception as e:
return {"error": str(e)}
# Paso 2: Obtener embedding por image_id
@app.post("/embedding")
async def get_embedding(image_id: str = Form(...)):
if image_id not in temp_images:
raise HTTPException(status_code=404, detail="image_id not found")
event_id = event_ids[image_id]
image_bytes = temp_images[image_id]
try:
image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
inputs = processor(images=image, return_tensors="pt")
with torch.no_grad():
outputs = model(**inputs)
# Promedio de todos los tokens (puedes cambiar por CLS si quieres)
embedding = outputs.last_hidden_state.mean(dim=1).squeeze().tolist()
return {
"event_id": event_id,
"embedding": embedding
}
except Exception as e:
return {"error": str(e)}