|
from fastapi import FastAPI, HTTPException |
|
from typing import List |
|
import torch |
|
import uvicorn |
|
import gc |
|
import os |
|
|
|
from models.schemas import EmbeddingRequest, EmbeddingResponse, ModelInfo |
|
from utils.helpers import load_models, get_embeddings, cleanup_memory |
|
|
|
app = FastAPI( |
|
title="Spanish Embedding API", |
|
description="Dual Spanish embedding models API", |
|
version="1.0.0" |
|
) |
|
|
|
|
|
models_cache = {} |
|
|
|
@app.on_event("startup") |
|
async def startup_event(): |
|
"""Load models on startup""" |
|
global models_cache |
|
models_cache = load_models() |
|
print("Models loaded successfully!") |
|
|
|
@app.get("/") |
|
async def root(): |
|
return { |
|
"message": "Spanish Embedding API", |
|
"models": ["jina", "robertalex"], |
|
"status": "running", |
|
"docs": "/docs" |
|
} |
|
|
|
@app.post("/embed", response_model=EmbeddingResponse) |
|
async def create_embeddings(request: EmbeddingRequest): |
|
"""Generate embeddings for input texts""" |
|
try: |
|
if not request.texts: |
|
raise HTTPException(status_code=400, detail="No texts provided") |
|
|
|
if len(request.texts) > 50: |
|
raise HTTPException(status_code=400, detail="Maximum 50 texts per request") |
|
|
|
embeddings = get_embeddings( |
|
request.texts, |
|
request.model, |
|
models_cache, |
|
request.normalize, |
|
request.max_length |
|
) |
|
|
|
|
|
if len(request.texts) > 20: |
|
cleanup_memory() |
|
|
|
return EmbeddingResponse( |
|
embeddings=embeddings, |
|
model_used=request.model, |
|
dimensions=len(embeddings[0]) if embeddings else 0, |
|
num_texts=len(request.texts) |
|
) |
|
|
|
except ValueError as e: |
|
raise HTTPException(status_code=400, detail=str(e)) |
|
except Exception as e: |
|
raise HTTPException(status_code=500, detail=f"Internal error: {str(e)}") |
|
|
|
@app.get("/models", response_model=List[ModelInfo]) |
|
async def list_models(): |
|
"""List available models and their specifications""" |
|
return [ |
|
ModelInfo( |
|
model_id="jina", |
|
name="jinaai/jina-embeddings-v2-base-es", |
|
dimensions=768, |
|
max_sequence_length=8192, |
|
languages=["Spanish", "English"], |
|
model_type="bilingual", |
|
description="Bilingual Spanish-English embeddings with long context support" |
|
), |
|
ModelInfo( |
|
model_id="robertalex", |
|
name="PlanTL-GOB-ES/RoBERTalex", |
|
dimensions=768, |
|
max_sequence_length=512, |
|
languages=["Spanish"], |
|
model_type="legal domain", |
|
description="Spanish legal domain specialized embeddings" |
|
) |
|
] |
|
|
|
@app.get("/health") |
|
async def health_check(): |
|
"""Health check endpoint""" |
|
return { |
|
"status": "healthy", |
|
"models_loaded": len(models_cache) == 2, |
|
"available_models": list(models_cache.keys()) |
|
} |
|
|
|
if __name__ == "__main__": |
|
|
|
torch.set_num_threads(8) |
|
torch.set_num_interop_threads(1) |
|
|
|
uvicorn.run(app, host="0.0.0.0", port=7860) |