Tim Luka Horstmann
commited on
Commit
·
7ee4aae
1
Parent(s):
0e9cc30
Rate limiting
Browse files- app.py +29 -9
- requirements.txt +2 -1
app.py
CHANGED
@@ -3,7 +3,7 @@ import json
|
|
3 |
import time
|
4 |
import numpy as np
|
5 |
from sentence_transformers import SentenceTransformer
|
6 |
-
from fastapi import FastAPI, HTTPException, BackgroundTasks
|
7 |
from fastapi.responses import StreamingResponse, Response
|
8 |
from fastapi.middleware.cors import CORSMiddleware
|
9 |
from pydantic import BaseModel
|
@@ -18,6 +18,9 @@ from google import genai
|
|
18 |
from google.genai import types
|
19 |
import httpx
|
20 |
from elevenlabs import ElevenLabs, VoiceSettings
|
|
|
|
|
|
|
21 |
|
22 |
# Set up logging
|
23 |
logging.basicConfig(level=logging.INFO)
|
@@ -25,6 +28,18 @@ logger = logging.getLogger(__name__)
|
|
25 |
|
26 |
app = FastAPI()
|
27 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
28 |
# Add CORS middleware to handle cross-origin requests
|
29 |
app.add_middleware(
|
30 |
CORSMiddleware,
|
@@ -331,20 +346,22 @@ def get_ram_usage():
|
|
331 |
}
|
332 |
|
333 |
@app.post("/api/predict")
|
334 |
-
|
335 |
-
|
336 |
-
|
|
|
337 |
return StreamingResponse(stream_response(query, history), media_type="text/event-stream")
|
338 |
|
339 |
@app.post("/api/tts")
|
340 |
-
|
|
|
341 |
"""Convert text to speech using ElevenLabs API"""
|
342 |
if not elevenlabs_client:
|
343 |
raise HTTPException(status_code=503, detail="TTS service not available")
|
344 |
|
345 |
try:
|
346 |
# Clean the text for TTS (remove markdown and special characters)
|
347 |
-
clean_text =
|
348 |
|
349 |
if not clean_text:
|
350 |
raise HTTPException(status_code=400, detail="No text provided for TTS")
|
@@ -381,11 +398,13 @@ async def text_to_speech(request: TTSRequest):
|
|
381 |
raise HTTPException(status_code=500, detail=f"TTS conversion failed: {str(e)}")
|
382 |
|
383 |
@app.get("/health")
|
384 |
-
|
|
|
385 |
return {"status": "healthy"}
|
386 |
|
387 |
@app.get("/model_info")
|
388 |
-
|
|
|
389 |
base_info = {
|
390 |
"embedding_model": sentence_transformer_model,
|
391 |
"faiss_index_size": len(cv_chunks),
|
@@ -411,7 +430,8 @@ async def model_info():
|
|
411 |
return base_info
|
412 |
|
413 |
@app.get("/ram_usage")
|
414 |
-
|
|
|
415 |
"""Endpoint to get current RAM usage."""
|
416 |
try:
|
417 |
ram_stats = get_ram_usage()
|
|
|
3 |
import time
|
4 |
import numpy as np
|
5 |
from sentence_transformers import SentenceTransformer
|
6 |
+
from fastapi import FastAPI, HTTPException, BackgroundTasks, Request
|
7 |
from fastapi.responses import StreamingResponse, Response
|
8 |
from fastapi.middleware.cors import CORSMiddleware
|
9 |
from pydantic import BaseModel
|
|
|
18 |
from google.genai import types
|
19 |
import httpx
|
20 |
from elevenlabs import ElevenLabs, VoiceSettings
|
21 |
+
from slowapi import Limiter, _rate_limit_exceeded_handler
|
22 |
+
from slowapi.util import get_remote_address
|
23 |
+
from slowapi.errors import RateLimitExceeded
|
24 |
|
25 |
# Set up logging
|
26 |
logging.basicConfig(level=logging.INFO)
|
|
|
28 |
|
29 |
app = FastAPI()
|
30 |
|
31 |
+
# Initialize rate limiter
|
32 |
+
limiter = Limiter(key_func=get_remote_address)
|
33 |
+
app.state.limiter = limiter
|
34 |
+
|
35 |
+
# Custom rate limit exceeded handler with logging
|
36 |
+
async def custom_rate_limit_handler(request: Request, exc: RateLimitExceeded):
|
37 |
+
client_ip = get_remote_address(request)
|
38 |
+
logger.warning(f"Rate limit exceeded for IP {client_ip} on endpoint {request.url.path}")
|
39 |
+
return await _rate_limit_exceeded_handler(request, exc)
|
40 |
+
|
41 |
+
app.add_exception_handler(RateLimitExceeded, custom_rate_limit_handler)
|
42 |
+
|
43 |
# Add CORS middleware to handle cross-origin requests
|
44 |
app.add_middleware(
|
45 |
CORSMiddleware,
|
|
|
346 |
}
|
347 |
|
348 |
@app.post("/api/predict")
|
349 |
+
@limiter.limit("5/minute") # Allow 10 chat requests per minute per IP
|
350 |
+
async def predict(request: Request, query_request: QueryRequest):
|
351 |
+
query = query_request.query
|
352 |
+
history = query_request.history
|
353 |
return StreamingResponse(stream_response(query, history), media_type="text/event-stream")
|
354 |
|
355 |
@app.post("/api/tts")
|
356 |
+
@limiter.limit("5/minute") # Allow 5 TTS requests per minute per IP (more restrictive as TTS is more expensive)
|
357 |
+
async def text_to_speech(request: Request, tts_request: TTSRequest):
|
358 |
"""Convert text to speech using ElevenLabs API"""
|
359 |
if not elevenlabs_client:
|
360 |
raise HTTPException(status_code=503, detail="TTS service not available")
|
361 |
|
362 |
try:
|
363 |
# Clean the text for TTS (remove markdown and special characters)
|
364 |
+
clean_text = tts_request.text.replace("**", "").replace("*", "").replace("\n", " ").strip()
|
365 |
|
366 |
if not clean_text:
|
367 |
raise HTTPException(status_code=400, detail="No text provided for TTS")
|
|
|
398 |
raise HTTPException(status_code=500, detail=f"TTS conversion failed: {str(e)}")
|
399 |
|
400 |
@app.get("/health")
|
401 |
+
@limiter.limit("30/minute") # Allow frequent health checks
|
402 |
+
async def health_check(request: Request):
|
403 |
return {"status": "healthy"}
|
404 |
|
405 |
@app.get("/model_info")
|
406 |
+
@limiter.limit("10/minute") # Limit model info requests
|
407 |
+
async def model_info(request: Request):
|
408 |
base_info = {
|
409 |
"embedding_model": sentence_transformer_model,
|
410 |
"faiss_index_size": len(cv_chunks),
|
|
|
430 |
return base_info
|
431 |
|
432 |
@app.get("/ram_usage")
|
433 |
+
@limiter.limit("20/minute") # Allow moderate monitoring requests
|
434 |
+
async def ram_usage(request: Request):
|
435 |
"""Endpoint to get current RAM usage."""
|
436 |
try:
|
437 |
ram_stats = get_ram_usage()
|
requirements.txt
CHANGED
@@ -10,4 +10,5 @@ google-genai
|
|
10 |
asyncio
|
11 |
elevenlabs
|
12 |
httpx
|
13 |
-
llama-cpp-python==0.2.85
|
|
|
|
10 |
asyncio
|
11 |
elevenlabs
|
12 |
httpx
|
13 |
+
llama-cpp-python==0.2.85
|
14 |
+
slowapi==0.1.9
|