Spaces:
Sleeping
Sleeping
from fastapi import FastAPI, HTTPException | |
from fastapi.middleware.cors import CORSMiddleware | |
from contextlib import asynccontextmanager | |
from typing import List | |
import torch | |
import uvicorn | |
from models.schemas import EmbeddingRequest, EmbeddingResponse, ModelInfo | |
from utils.helpers import load_models, get_embeddings, cleanup_memory | |
# Global model cache | |
models_cache = {} | |
# Load jina-v3 at startup (most important model) | |
STARTUP_MODEL = "jina-v3" | |
async def lifespan(app: FastAPI): | |
"""Application lifespan handler for startup and shutdown""" | |
# Startup - load jina-v3 model | |
try: | |
global models_cache | |
print(f"Loading startup model: {STARTUP_MODEL}...") | |
models_cache = load_models([STARTUP_MODEL]) | |
print(f"Startup model loaded successfully: {list(models_cache.keys())}") | |
yield | |
except Exception as e: | |
print(f"Failed to load startup model: {str(e)}") | |
# Continue anyway - jina-v3 can be loaded on demand if startup fails | |
yield | |
finally: | |
# Shutdown - cleanup resources | |
cleanup_memory() | |
def ensure_model_loaded(model_name: str, max_length_limit: int): | |
"""Load a specific model on demand if not already loaded""" | |
global models_cache | |
if model_name not in models_cache: | |
try: | |
print(f"Loading model on demand: {model_name}...") | |
new_models = load_models([model_name]) | |
models_cache.update(new_models) | |
print(f"Model {model_name} loaded successfully!") | |
except Exception as e: | |
print(f"Failed to load model {model_name}: {str(e)}") | |
raise HTTPException(status_code=500, detail=f"Model {model_name} loading failed: {str(e)}") | |
def validate_request_for_model(request: EmbeddingRequest, model_name: str, max_length_limit: int): | |
"""Validate request parameters for specific model""" | |
if not request.texts: | |
raise HTTPException(status_code=400, detail="No texts provided") | |
if len(request.texts) > 50: | |
raise HTTPException(status_code=400, detail="Maximum 50 texts per request") | |
if request.max_length is not None and request.max_length > max_length_limit: | |
raise HTTPException(status_code=400, detail=f"Max length for {model_name} is {max_length_limit}") | |
app = FastAPI( | |
title="Multilingual & Legal Embedding API", | |
description="Multi-model embedding API with dedicated endpoints per model", | |
version="4.0.0", | |
lifespan=lifespan | |
) | |
# Add CORS middleware to allow cross-origin requests | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], # In production, specify actual domains | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
async def root(): | |
return { | |
"message": "Multilingual & Legal Embedding API - Endpoint Per Model", | |
"version": "4.0.0", | |
"status": "running", | |
"docs": "/docs", | |
"startup_model": STARTUP_MODEL, | |
"available_endpoints": { | |
"jina-v3": "/embed/jina-v3", | |
"roberta-ca": "/embed/roberta-ca", | |
"jina": "/embed/jina", | |
"robertalex": "/embed/robertalex", | |
"legal-bert": "/embed/legal-bert" | |
} | |
} | |
# Jina v3 - Multilingual (loads at startup) | |
async def embed_jina_v3(request: EmbeddingRequest): | |
"""Generate embeddings using Jina v3 model (multilingual)""" | |
try: | |
ensure_model_loaded("jina-v3", 8192) | |
validate_request_for_model(request, "jina-v3", 8192) | |
embeddings = get_embeddings( | |
request.texts, | |
"jina-v3", | |
models_cache, | |
request.normalize, | |
request.max_length | |
) | |
return EmbeddingResponse( | |
embeddings=embeddings, | |
model_used="jina-v3", | |
dimensions=len(embeddings[0]) if embeddings else 0, | |
num_texts=len(request.texts) | |
) | |
except ValueError as e: | |
raise HTTPException(status_code=400, detail=str(e)) | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=f"Internal error: {str(e)}") | |
# Catalan RoBERTa | |
async def embed_roberta_ca(request: EmbeddingRequest): | |
"""Generate embeddings using Catalan RoBERTa model""" | |
try: | |
ensure_model_loaded("roberta-ca", 512) | |
validate_request_for_model(request, "roberta-ca", 512) | |
embeddings = get_embeddings( | |
request.texts, | |
"roberta-ca", | |
models_cache, | |
request.normalize, | |
request.max_length | |
) | |
return EmbeddingResponse( | |
embeddings=embeddings, | |
model_used="roberta-ca", | |
dimensions=len(embeddings[0]) if embeddings else 0, | |
num_texts=len(request.texts) | |
) | |
except ValueError as e: | |
raise HTTPException(status_code=400, detail=str(e)) | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=f"Internal error: {str(e)}") | |
# Jina v2 - Spanish/English | |
async def embed_jina(request: EmbeddingRequest): | |
"""Generate embeddings using Jina v2 Spanish/English model""" | |
try: | |
ensure_model_loaded("jina", 8192) | |
validate_request_for_model(request, "jina", 8192) | |
embeddings = get_embeddings( | |
request.texts, | |
"jina", | |
models_cache, | |
request.normalize, | |
request.max_length | |
) | |
return EmbeddingResponse( | |
embeddings=embeddings, | |
model_used="jina", | |
dimensions=len(embeddings[0]) if embeddings else 0, | |
num_texts=len(request.texts) | |
) | |
except ValueError as e: | |
raise HTTPException(status_code=400, detail=str(e)) | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=f"Internal error: {str(e)}") | |
# RoBERTalex - Spanish Legal | |
async def embed_robertalex(request: EmbeddingRequest): | |
"""Generate embeddings using RoBERTalex Spanish legal model""" | |
try: | |
ensure_model_loaded("robertalex", 512) | |
validate_request_for_model(request, "robertalex", 512) | |
embeddings = get_embeddings( | |
request.texts, | |
"robertalex", | |
models_cache, | |
request.normalize, | |
request.max_length | |
) | |
return EmbeddingResponse( | |
embeddings=embeddings, | |
model_used="robertalex", | |
dimensions=len(embeddings[0]) if embeddings else 0, | |
num_texts=len(request.texts) | |
) | |
except ValueError as e: | |
raise HTTPException(status_code=400, detail=str(e)) | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=f"Internal error: {str(e)}") | |
# Legal BERT - English Legal | |
async def embed_legal_bert(request: EmbeddingRequest): | |
"""Generate embeddings using Legal BERT English model""" | |
try: | |
ensure_model_loaded("legal-bert", 512) | |
validate_request_for_model(request, "legal-bert", 512) | |
embeddings = get_embeddings( | |
request.texts, | |
"legal-bert", | |
models_cache, | |
request.normalize, | |
request.max_length | |
) | |
return EmbeddingResponse( | |
embeddings=embeddings, | |
model_used="legal-bert", | |
dimensions=len(embeddings[0]) if embeddings else 0, | |
num_texts=len(request.texts) | |
) | |
except ValueError as e: | |
raise HTTPException(status_code=400, detail=str(e)) | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=f"Internal error: {str(e)}") | |
async def list_models(): | |
"""List available models and their specifications""" | |
return [ | |
ModelInfo( | |
model_id="jina-v3", | |
name="jinaai/jina-embeddings-v3", | |
dimensions=1024, | |
max_sequence_length=8192, | |
languages=["Multilingual"], | |
model_type="multilingual", | |
description="Latest Jina v3 with superior multilingual performance - loaded at startup" | |
), | |
ModelInfo( | |
model_id="roberta-ca", | |
name="projecte-aina/roberta-large-ca-v2", | |
dimensions=1024, | |
max_sequence_length=512, | |
languages=["Catalan"], | |
model_type="general", | |
description="Catalan RoBERTa-large model trained on large corpus" | |
), | |
ModelInfo( | |
model_id="jina", | |
name="jinaai/jina-embeddings-v2-base-es", | |
dimensions=768, | |
max_sequence_length=8192, | |
languages=["Spanish", "English"], | |
model_type="bilingual", | |
description="Bilingual Spanish-English embeddings with long context support" | |
), | |
ModelInfo( | |
model_id="robertalex", | |
name="PlanTL-GOB-ES/RoBERTalex", | |
dimensions=768, | |
max_sequence_length=512, | |
languages=["Spanish"], | |
model_type="legal domain", | |
description="Spanish legal domain specialized embeddings" | |
), | |
ModelInfo( | |
model_id="legal-bert", | |
name="nlpaueb/legal-bert-base-uncased", | |
dimensions=768, | |
max_sequence_length=512, | |
languages=["English"], | |
model_type="legal domain", | |
description="English legal domain BERT model" | |
) | |
] | |
async def health_check(): | |
"""Health check endpoint""" | |
startup_loaded = STARTUP_MODEL in models_cache | |
return { | |
"status": "healthy" if startup_loaded else "partial", | |
"startup_model": STARTUP_MODEL, | |
"startup_model_loaded": startup_loaded, | |
"available_models": list(models_cache.keys()), | |
"models_count": len(models_cache), | |
"endpoints": { | |
"jina-v3": f"/embed/jina-v3 {'(ready)' if 'jina-v3' in models_cache else '(loads on demand)'}", | |
"roberta-ca": f"/embed/roberta-ca {'(ready)' if 'roberta-ca' in models_cache else '(loads on demand)'}", | |
"jina": f"/embed/jina {'(ready)' if 'jina' in models_cache else '(loads on demand)'}", | |
"robertalex": f"/embed/robertalex {'(ready)' if 'robertalex' in models_cache else '(loads on demand)'}", | |
"legal-bert": f"/embed/legal-bert {'(ready)' if 'legal-bert' in models_cache else '(loads on demand)'}" | |
} | |
} | |
if __name__ == "__main__": | |
# Set multi-threading for CPU | |
torch.set_num_threads(8) | |
torch.set_num_interop_threads(1) | |
uvicorn.run(app, host="0.0.0.0", port=7860) |