from fastapi import FastAPI, File, UploadFile, Form, HTTPException from fastapi.middleware.cors import CORSMiddleware from transformers import AutoImageProcessor, AutoModel from PIL import Image import torch import uuid import io app = FastAPI() # Habilita CORS si lo necesitas app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Carga del modelo processor = AutoImageProcessor.from_pretrained("facebook/dinov2-small") model = AutoModel.from_pretrained("facebook/dinov2-small") model.eval() # Memoria temporal para almacenar imágenes (podrías usar base de datos si prefieres) temp_images = {} event_ids = {} # Paso 1: Subida de imagen + event_id @app.post("/upload") async def upload_image(file: UploadFile = File(...), event_id: str = Form(...)): try: content = await file.read() image_id = str(uuid.uuid4()) temp_images[image_id] = content event_ids[image_id] = event_id return {"image_id": image_id, "event_id": event_id} except Exception as e: return {"error": str(e)} # Paso 2: Obtener embedding por image_id @app.post("/embedding") async def get_embedding(image_id: str = Form(...)): if image_id not in temp_images: raise HTTPException(status_code=404, detail="image_id not found") event_id = event_ids[image_id] image_bytes = temp_images[image_id] try: image = Image.open(io.BytesIO(image_bytes)).convert("RGB") inputs = processor(images=image, return_tensors="pt") with torch.no_grad(): outputs = model(**inputs) # Promedio de todos los tokens (puedes cambiar por CLS si quieres) embedding = outputs.last_hidden_state.mean(dim=1).squeeze().tolist() return { "event_id": event_id, "embedding": embedding } except Exception as e: return {"error": str(e)}