Tim Luka Horstmann commited on
Commit
7ee4aae
·
1 Parent(s): 0e9cc30

Rate limiting

Browse files
Files changed (2) hide show
  1. app.py +29 -9
  2. requirements.txt +2 -1
app.py CHANGED
@@ -3,7 +3,7 @@ import json
3
  import time
4
  import numpy as np
5
  from sentence_transformers import SentenceTransformer
6
- from fastapi import FastAPI, HTTPException, BackgroundTasks
7
  from fastapi.responses import StreamingResponse, Response
8
  from fastapi.middleware.cors import CORSMiddleware
9
  from pydantic import BaseModel
@@ -18,6 +18,9 @@ from google import genai
18
  from google.genai import types
19
  import httpx
20
  from elevenlabs import ElevenLabs, VoiceSettings
 
 
 
21
 
22
  # Set up logging
23
  logging.basicConfig(level=logging.INFO)
@@ -25,6 +28,18 @@ logger = logging.getLogger(__name__)
25
 
26
  app = FastAPI()
27
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  # Add CORS middleware to handle cross-origin requests
29
  app.add_middleware(
30
  CORSMiddleware,
@@ -331,20 +346,22 @@ def get_ram_usage():
331
  }
332
 
333
  @app.post("/api/predict")
334
- async def predict(request: QueryRequest):
335
- query = request.query
336
- history = request.history
 
337
  return StreamingResponse(stream_response(query, history), media_type="text/event-stream")
338
 
339
  @app.post("/api/tts")
340
- async def text_to_speech(request: TTSRequest):
 
341
  """Convert text to speech using ElevenLabs API"""
342
  if not elevenlabs_client:
343
  raise HTTPException(status_code=503, detail="TTS service not available")
344
 
345
  try:
346
  # Clean the text for TTS (remove markdown and special characters)
347
- clean_text = request.text.replace("**", "").replace("*", "").replace("\n", " ").strip()
348
 
349
  if not clean_text:
350
  raise HTTPException(status_code=400, detail="No text provided for TTS")
@@ -381,11 +398,13 @@ async def text_to_speech(request: TTSRequest):
381
  raise HTTPException(status_code=500, detail=f"TTS conversion failed: {str(e)}")
382
 
383
  @app.get("/health")
384
- async def health_check():
 
385
  return {"status": "healthy"}
386
 
387
  @app.get("/model_info")
388
- async def model_info():
 
389
  base_info = {
390
  "embedding_model": sentence_transformer_model,
391
  "faiss_index_size": len(cv_chunks),
@@ -411,7 +430,8 @@ async def model_info():
411
  return base_info
412
 
413
  @app.get("/ram_usage")
414
- async def ram_usage():
 
415
  """Endpoint to get current RAM usage."""
416
  try:
417
  ram_stats = get_ram_usage()
 
3
  import time
4
  import numpy as np
5
  from sentence_transformers import SentenceTransformer
6
+ from fastapi import FastAPI, HTTPException, BackgroundTasks, Request
7
  from fastapi.responses import StreamingResponse, Response
8
  from fastapi.middleware.cors import CORSMiddleware
9
  from pydantic import BaseModel
 
18
  from google.genai import types
19
  import httpx
20
  from elevenlabs import ElevenLabs, VoiceSettings
21
+ from slowapi import Limiter, _rate_limit_exceeded_handler
22
+ from slowapi.util import get_remote_address
23
+ from slowapi.errors import RateLimitExceeded
24
 
25
  # Set up logging
26
  logging.basicConfig(level=logging.INFO)
 
28
 
29
  app = FastAPI()
30
 
31
+ # Initialize rate limiter
32
+ limiter = Limiter(key_func=get_remote_address)
33
+ app.state.limiter = limiter
34
+
35
+ # Custom rate limit exceeded handler with logging
36
+ async def custom_rate_limit_handler(request: Request, exc: RateLimitExceeded):
37
+ client_ip = get_remote_address(request)
38
+ logger.warning(f"Rate limit exceeded for IP {client_ip} on endpoint {request.url.path}")
39
+ return await _rate_limit_exceeded_handler(request, exc)
40
+
41
+ app.add_exception_handler(RateLimitExceeded, custom_rate_limit_handler)
42
+
43
  # Add CORS middleware to handle cross-origin requests
44
  app.add_middleware(
45
  CORSMiddleware,
 
346
  }
347
 
348
  @app.post("/api/predict")
349
+ @limiter.limit("5/minute") # Allow 10 chat requests per minute per IP
350
+ async def predict(request: Request, query_request: QueryRequest):
351
+ query = query_request.query
352
+ history = query_request.history
353
  return StreamingResponse(stream_response(query, history), media_type="text/event-stream")
354
 
355
  @app.post("/api/tts")
356
+ @limiter.limit("5/minute") # Allow 5 TTS requests per minute per IP (more restrictive as TTS is more expensive)
357
+ async def text_to_speech(request: Request, tts_request: TTSRequest):
358
  """Convert text to speech using ElevenLabs API"""
359
  if not elevenlabs_client:
360
  raise HTTPException(status_code=503, detail="TTS service not available")
361
 
362
  try:
363
  # Clean the text for TTS (remove markdown and special characters)
364
+ clean_text = tts_request.text.replace("**", "").replace("*", "").replace("\n", " ").strip()
365
 
366
  if not clean_text:
367
  raise HTTPException(status_code=400, detail="No text provided for TTS")
 
398
  raise HTTPException(status_code=500, detail=f"TTS conversion failed: {str(e)}")
399
 
400
  @app.get("/health")
401
+ @limiter.limit("30/minute") # Allow frequent health checks
402
+ async def health_check(request: Request):
403
  return {"status": "healthy"}
404
 
405
  @app.get("/model_info")
406
+ @limiter.limit("10/minute") # Limit model info requests
407
+ async def model_info(request: Request):
408
  base_info = {
409
  "embedding_model": sentence_transformer_model,
410
  "faiss_index_size": len(cv_chunks),
 
430
  return base_info
431
 
432
  @app.get("/ram_usage")
433
+ @limiter.limit("20/minute") # Allow moderate monitoring requests
434
+ async def ram_usage(request: Request):
435
  """Endpoint to get current RAM usage."""
436
  try:
437
  ram_stats = get_ram_usage()
requirements.txt CHANGED
@@ -10,4 +10,5 @@ google-genai
10
  asyncio
11
  elevenlabs
12
  httpx
13
- llama-cpp-python==0.2.85
 
 
10
  asyncio
11
  elevenlabs
12
  httpx
13
+ llama-cpp-python==0.2.85
14
+ slowapi==0.1.9