Spaces:
Running
on
Zero
Running
on
Zero
""" | |
FastAPI Endpoint for GASM-LLM Integration | |
This module provides a FastAPI endpoint that can be used with OpenAI's CustomGPT | |
to access GASM-enhanced language processing capabilities. | |
""" | |
from fastapi import FastAPI, HTTPException, BackgroundTasks, Depends | |
from fastapi.middleware.cors import CORSMiddleware | |
from fastapi.responses import JSONResponse | |
from pydantic import BaseModel, Field | |
from typing import Dict, List, Optional, Any, Union | |
import torch | |
import logging | |
import asyncio | |
from datetime import datetime | |
import json | |
import os | |
from contextlib import asynccontextmanager | |
from gasm_llm_layer import GASMEnhancedLLM, GASMTokenEmbedding | |
from gasm.utils import check_se3_invariance | |
from gasm.core import GASM | |
# Configure logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
# Global model instance | |
model_instance = None | |
async def lifespan(app: FastAPI): | |
""" | |
Lifespan manager for FastAPI app | |
""" | |
global model_instance | |
# Startup | |
logger.info("Loading GASM-LLM model...") | |
try: | |
model_instance = GASMEnhancedLLM( | |
base_model_name="distilbert-base-uncased", | |
gasm_hidden_dim=256, | |
gasm_output_dim=128, | |
enable_geometry=True | |
) | |
logger.info("Model loaded successfully") | |
except Exception as e: | |
logger.error(f"Failed to load model: {e}") | |
model_instance = None | |
yield | |
# Shutdown | |
logger.info("Shutting down...") | |
model_instance = None | |
# Create FastAPI app | |
app = FastAPI( | |
title="GASM-LLM API", | |
description="API for GASM-enhanced Large Language Model processing", | |
version="1.0.0", | |
lifespan=lifespan | |
) | |
# Add CORS middleware | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
# Pydantic models for request/response | |
class TextProcessingRequest(BaseModel): | |
"""Request model for text processing""" | |
text: str = Field(..., description="Text to process", min_length=1, max_length=10000) | |
enable_geometry: bool = Field(True, description="Enable geometric processing") | |
return_embeddings: bool = Field(False, description="Return raw embeddings") | |
return_geometry: bool = Field(False, description="Return geometric information") | |
max_length: int = Field(512, description="Maximum sequence length", ge=1, le=2048) | |
model_config: Optional[Dict[str, Any]] = Field(None, description="Model configuration overrides") | |
class GeometricAnalysisRequest(BaseModel): | |
"""Request model for geometric analysis""" | |
text: str = Field(..., description="Text to analyze geometrically") | |
analysis_type: str = Field("full", description="Type of analysis: 'full', 'curvature', 'invariance'") | |
num_invariance_tests: int = Field(10, description="Number of invariance tests", ge=1, le=100) | |
tolerance: float = Field(1e-3, description="Tolerance for invariance tests", ge=1e-6, le=1e-1) | |
class ComparisonRequest(BaseModel): | |
"""Request model for comparing geometric vs standard processing""" | |
text: str = Field(..., description="Text to compare") | |
metrics: List[str] = Field(["embedding_norm", "attention_patterns", "geometric_consistency"], | |
description="Metrics to compare") | |
class BatchProcessingRequest(BaseModel): | |
"""Request model for batch processing""" | |
texts: List[str] = Field(..., description="List of texts to process", min_items=1, max_items=100) | |
enable_geometry: bool = Field(True, description="Enable geometric processing") | |
return_summary: bool = Field(True, description="Return summary statistics") | |
class TextProcessingResponse(BaseModel): | |
"""Response model for text processing""" | |
success: bool | |
timestamp: str | |
processing_time: float | |
text_length: int | |
model_info: Dict[str, Any] | |
embedding_stats: Dict[str, float] | |
geometric_stats: Optional[Dict[str, Any]] = None | |
embeddings: Optional[List[List[float]]] = None | |
geometric_info: Optional[Dict[str, Any]] = None | |
error: Optional[str] = None | |
class GeometricAnalysisResponse(BaseModel): | |
"""Response model for geometric analysis""" | |
success: bool | |
timestamp: str | |
analysis_type: str | |
curvature_analysis: Optional[Dict[str, Any]] = None | |
invariance_results: Optional[Dict[str, Any]] = None | |
geometric_properties: Optional[Dict[str, Any]] = None | |
error: Optional[str] = None | |
class ComparisonResponse(BaseModel): | |
"""Response model for comparison""" | |
success: bool | |
timestamp: str | |
geometric_results: Dict[str, Any] | |
standard_results: Dict[str, Any] | |
comparison_metrics: Dict[str, Any] | |
error: Optional[str] = None | |
class BatchProcessingResponse(BaseModel): | |
"""Response model for batch processing""" | |
success: bool | |
timestamp: str | |
num_texts: int | |
processing_times: List[float] | |
batch_summary: Dict[str, Any] | |
individual_results: Optional[List[Dict[str, Any]]] = None | |
error: Optional[str] = None | |
class HealthResponse(BaseModel): | |
"""Response model for health check""" | |
status: str | |
model_loaded: bool | |
device: str | |
memory_usage: Dict[str, Any] | |
uptime: str | |
def get_model(): | |
""" | |
Dependency to get the model instance | |
""" | |
global model_instance | |
if model_instance is None: | |
raise HTTPException(status_code=503, detail="Model not loaded") | |
return model_instance | |
async def root(): | |
""" | |
Root endpoint | |
""" | |
return { | |
"message": "GASM-LLM API", | |
"version": "1.0.0", | |
"description": "API for GASM-enhanced Large Language Model processing", | |
"endpoints": { | |
"process": "POST /process - Process text with geometric enhancement", | |
"analyze": "POST /analyze - Perform geometric analysis", | |
"compare": "POST /compare - Compare geometric vs standard processing", | |
"batch": "POST /batch - Process multiple texts", | |
"health": "GET /health - Health check", | |
"info": "GET /info - Model information" | |
} | |
} | |
async def health_check(): | |
""" | |
Health check endpoint | |
""" | |
global model_instance | |
# Check memory usage | |
memory_info = {} | |
if torch.cuda.is_available(): | |
memory_info["gpu_memory"] = { | |
"allocated": torch.cuda.memory_allocated(), | |
"reserved": torch.cuda.memory_reserved(), | |
"max_allocated": torch.cuda.max_memory_allocated() | |
} | |
# Check system memory (simplified) | |
import psutil | |
memory_info["system_memory"] = { | |
"used": psutil.virtual_memory().used, | |
"total": psutil.virtual_memory().total, | |
"percent": psutil.virtual_memory().percent | |
} | |
return HealthResponse( | |
status="healthy" if model_instance is not None else "unhealthy", | |
model_loaded=model_instance is not None, | |
device=str(torch.device("cuda" if torch.cuda.is_available() else "cpu")), | |
memory_usage=memory_info, | |
uptime=datetime.now().isoformat() | |
) | |
async def model_info(model: GASMEnhancedLLM = Depends(get_model)): | |
""" | |
Get model information | |
""" | |
return { | |
"model_name": model.base_model_name, | |
"geometry_enabled": model.enable_geometry, | |
"device": str(next(model.parameters()).device), | |
"total_parameters": sum(p.numel() for p in model.parameters()), | |
"trainable_parameters": sum(p.numel() for p in model.parameters() if p.requires_grad), | |
"model_size_mb": sum(p.numel() * p.element_size() for p in model.parameters()) / (1024 * 1024), | |
"gasm_config": { | |
"hidden_dim": getattr(model.gasm_embedding.gasm, 'hidden_dim', None) if hasattr(model, 'gasm_embedding') else None, | |
"output_dim": getattr(model.gasm_embedding.gasm, 'output_dim', None) if hasattr(model, 'gasm_embedding') else None, | |
"max_iterations": getattr(model.gasm_embedding.gasm, 'max_iterations', None) if hasattr(model, 'gasm_embedding') else None, | |
} | |
} | |
async def process_text( | |
request: TextProcessingRequest, | |
model: GASMEnhancedLLM = Depends(get_model) | |
): | |
""" | |
Process text with GASM-enhanced LLM | |
""" | |
start_time = datetime.now() | |
try: | |
# Configure model | |
model.enable_geometry = request.enable_geometry | |
# Process text | |
outputs = model.encode_text( | |
request.text, | |
return_geometry=request.return_geometry | |
) | |
# Calculate processing time | |
processing_time = (datetime.now() - start_time).total_seconds() | |
# Extract embeddings | |
embeddings = outputs['last_hidden_state'] | |
embedding_stats = { | |
"shape": list(embeddings.shape), | |
"mean": float(embeddings.mean()), | |
"std": float(embeddings.std()), | |
"min": float(embeddings.min()), | |
"max": float(embeddings.max()), | |
"norm": float(torch.norm(embeddings)) | |
} | |
# Prepare response | |
response = TextProcessingResponse( | |
success=True, | |
timestamp=start_time.isoformat(), | |
processing_time=processing_time, | |
text_length=len(request.text), | |
model_info={ | |
"model_name": model.base_model_name, | |
"geometry_enabled": request.enable_geometry, | |
"device": str(next(model.parameters()).device) | |
}, | |
embedding_stats=embedding_stats | |
) | |
# Add embeddings if requested | |
if request.return_embeddings: | |
response.embeddings = embeddings.detach().cpu().numpy().tolist() | |
# Add geometric information if available | |
if request.return_geometry and 'geometric_info' in outputs: | |
geometric_info = outputs['geometric_info'] | |
if geometric_info: | |
response.geometric_info = { | |
"num_sequences": len(geometric_info), | |
"has_curvature": any('output' in info for info in geometric_info), | |
"has_constraints": any('constraints' in info for info in geometric_info), | |
"has_relations": any('relations' in info for info in geometric_info) | |
} | |
return response | |
except Exception as e: | |
logger.error(f"Error processing text: {e}") | |
return TextProcessingResponse( | |
success=False, | |
timestamp=start_time.isoformat(), | |
processing_time=(datetime.now() - start_time).total_seconds(), | |
text_length=len(request.text), | |
model_info={}, | |
embedding_stats={}, | |
error=str(e) | |
) | |
async def analyze_geometry( | |
request: GeometricAnalysisRequest, | |
model: GASMEnhancedLLM = Depends(get_model) | |
): | |
""" | |
Perform geometric analysis of text | |
""" | |
start_time = datetime.now() | |
try: | |
# Enable geometry for analysis | |
model.enable_geometry = True | |
# Process text with geometric information | |
outputs = model.encode_text(request.text, return_geometry=True) | |
response = GeometricAnalysisResponse( | |
success=True, | |
timestamp=start_time.isoformat(), | |
analysis_type=request.analysis_type | |
) | |
# Perform requested analysis | |
if request.analysis_type in ["full", "curvature"]: | |
# Curvature analysis | |
geometric_info = outputs.get('geometric_info', []) | |
if geometric_info: | |
curvature_stats = [] | |
for info in geometric_info: | |
if 'output' in info: | |
geo_output = info['output'] | |
curvature_norm = torch.norm(geo_output, dim=1) | |
curvature_stats.append({ | |
"mean": float(curvature_norm.mean()), | |
"std": float(curvature_norm.std()), | |
"min": float(curvature_norm.min()), | |
"max": float(curvature_norm.max()) | |
}) | |
response.curvature_analysis = { | |
"per_sequence": curvature_stats, | |
"global_stats": { | |
"num_sequences": len(curvature_stats), | |
"avg_mean_curvature": sum(s["mean"] for s in curvature_stats) / len(curvature_stats) if curvature_stats else 0 | |
} | |
} | |
if request.analysis_type in ["full", "invariance"]: | |
# SE(3) invariance analysis | |
try: | |
# Create simple test data for invariance check | |
test_points = torch.randn(10, 3) | |
test_features = torch.randn(10, model.base_model.config.hidden_size) | |
test_relations = torch.randn(10, 10, 16) | |
# Test with simplified model for invariance | |
gasm_model = GASM( | |
feature_dim=model.base_model.config.hidden_size, | |
hidden_dim=256, | |
output_dim=3 | |
) | |
is_invariant = check_se3_invariance( | |
gasm_model, | |
test_points, | |
test_features, | |
test_relations, | |
num_tests=request.num_invariance_tests, | |
tolerance=request.tolerance | |
) | |
response.invariance_results = { | |
"is_invariant": is_invariant, | |
"num_tests": request.num_invariance_tests, | |
"tolerance": request.tolerance, | |
"test_type": "SE(3) invariance" | |
} | |
except Exception as e: | |
response.invariance_results = { | |
"is_invariant": None, | |
"error": str(e) | |
} | |
return response | |
except Exception as e: | |
logger.error(f"Error in geometric analysis: {e}") | |
return GeometricAnalysisResponse( | |
success=False, | |
timestamp=start_time.isoformat(), | |
analysis_type=request.analysis_type, | |
error=str(e) | |
) | |
async def compare_processing( | |
request: ComparisonRequest, | |
model: GASMEnhancedLLM = Depends(get_model) | |
): | |
""" | |
Compare geometric vs standard processing | |
""" | |
start_time = datetime.now() | |
try: | |
# Process with geometry | |
model.enable_geometry = True | |
geometric_outputs = model.encode_text(request.text, return_geometry=True) | |
# Process without geometry | |
model.enable_geometry = False | |
standard_outputs = model.encode_text(request.text, return_geometry=False) | |
# Extract results | |
geometric_embeddings = geometric_outputs['last_hidden_state'] | |
standard_embeddings = standard_outputs['last_hidden_state'] | |
# Calculate comparison metrics | |
comparison_metrics = {} | |
if "embedding_norm" in request.metrics: | |
comparison_metrics["embedding_norm"] = { | |
"geometric": float(torch.norm(geometric_embeddings)), | |
"standard": float(torch.norm(standard_embeddings)), | |
"ratio": float(torch.norm(geometric_embeddings) / torch.norm(standard_embeddings)) | |
} | |
if "attention_patterns" in request.metrics: | |
# Simplified attention pattern comparison | |
geo_attention = torch.softmax(geometric_embeddings @ geometric_embeddings.transpose(-2, -1), dim=-1) | |
std_attention = torch.softmax(standard_embeddings @ standard_embeddings.transpose(-2, -1), dim=-1) | |
comparison_metrics["attention_patterns"] = { | |
"geometric_entropy": float(torch.sum(-geo_attention * torch.log(geo_attention + 1e-9))), | |
"standard_entropy": float(torch.sum(-std_attention * torch.log(std_attention + 1e-9))), | |
"pattern_difference": float(torch.norm(geo_attention - std_attention)) | |
} | |
if "geometric_consistency" in request.metrics: | |
comparison_metrics["geometric_consistency"] = { | |
"has_geometric_info": 'geometric_info' in geometric_outputs, | |
"embedding_difference": float(torch.norm(geometric_embeddings - standard_embeddings)), | |
"relative_change": float(torch.norm(geometric_embeddings - standard_embeddings) / torch.norm(standard_embeddings)) | |
} | |
return ComparisonResponse( | |
success=True, | |
timestamp=start_time.isoformat(), | |
geometric_results={ | |
"embedding_stats": { | |
"shape": list(geometric_embeddings.shape), | |
"mean": float(geometric_embeddings.mean()), | |
"std": float(geometric_embeddings.std()), | |
"norm": float(torch.norm(geometric_embeddings)) | |
} | |
}, | |
standard_results={ | |
"embedding_stats": { | |
"shape": list(standard_embeddings.shape), | |
"mean": float(standard_embeddings.mean()), | |
"std": float(standard_embeddings.std()), | |
"norm": float(torch.norm(standard_embeddings)) | |
} | |
}, | |
comparison_metrics=comparison_metrics | |
) | |
except Exception as e: | |
logger.error(f"Error in comparison: {e}") | |
return ComparisonResponse( | |
success=False, | |
timestamp=start_time.isoformat(), | |
geometric_results={}, | |
standard_results={}, | |
comparison_metrics={}, | |
error=str(e) | |
) | |
async def batch_process( | |
request: BatchProcessingRequest, | |
model: GASMEnhancedLLM = Depends(get_model) | |
): | |
""" | |
Process multiple texts in batch | |
""" | |
start_time = datetime.now() | |
try: | |
model.enable_geometry = request.enable_geometry | |
processing_times = [] | |
individual_results = [] | |
for i, text in enumerate(request.texts): | |
text_start = datetime.now() | |
outputs = model.encode_text(text, return_geometry=False) | |
embeddings = outputs['last_hidden_state'] | |
processing_time = (datetime.now() - text_start).total_seconds() | |
processing_times.append(processing_time) | |
if not request.return_summary: | |
individual_results.append({ | |
"text_index": i, | |
"text_length": len(text), | |
"processing_time": processing_time, | |
"embedding_norm": float(torch.norm(embeddings)) | |
}) | |
# Calculate batch summary | |
batch_summary = { | |
"total_texts": len(request.texts), | |
"total_processing_time": sum(processing_times), | |
"average_processing_time": sum(processing_times) / len(processing_times), | |
"texts_per_second": len(request.texts) / sum(processing_times), | |
"geometry_enabled": request.enable_geometry, | |
"total_characters": sum(len(text) for text in request.texts), | |
"average_text_length": sum(len(text) for text in request.texts) / len(request.texts) | |
} | |
return BatchProcessingResponse( | |
success=True, | |
timestamp=start_time.isoformat(), | |
num_texts=len(request.texts), | |
processing_times=processing_times, | |
batch_summary=batch_summary, | |
individual_results=individual_results if not request.return_summary else None | |
) | |
except Exception as e: | |
logger.error(f"Error in batch processing: {e}") | |
return BatchProcessingResponse( | |
success=False, | |
timestamp=start_time.isoformat(), | |
num_texts=len(request.texts), | |
processing_times=[], | |
batch_summary={}, | |
error=str(e) | |
) | |
# Error handlers | |
async def http_exception_handler(request, exc): | |
return JSONResponse( | |
status_code=exc.status_code, | |
content={"error": exc.detail, "timestamp": datetime.now().isoformat()} | |
) | |
async def general_exception_handler(request, exc): | |
logger.error(f"Unhandled exception: {exc}") | |
return JSONResponse( | |
status_code=500, | |
content={"error": "Internal server error", "timestamp": datetime.now().isoformat()} | |
) | |
# OpenAPI customization for CustomGPT | |
async def custom_openapi(): | |
""" | |
Custom OpenAPI schema for CustomGPT integration | |
""" | |
from fastapi.openapi.utils import get_openapi | |
if app.openapi_schema: | |
return app.openapi_schema | |
openapi_schema = get_openapi( | |
title="GASM-LLM API", | |
version="1.0.0", | |
description="API for GASM-enhanced Large Language Model processing with geometric inference capabilities", | |
routes=app.routes, | |
) | |
# Add custom metadata for CustomGPT | |
openapi_schema["info"]["x-logo"] = { | |
"url": "https://huggingface.co/spaces/your-username/gasm-llm/resolve/main/logo.png" | |
} | |
app.openapi_schema = openapi_schema | |
return app.openapi_schema | |
if __name__ == "__main__": | |
import uvicorn | |
uvicorn.run( | |
"fastapi_endpoint:app", | |
host="0.0.0.0", | |
port=8000, | |
reload=True, | |
log_level="info" | |
) |