Tim Luka Horstmann commited on
Commit
1292878
·
1 Parent(s): ea17465

Add ElevenLabs TTS integration

Browse files
Files changed (3) hide show
  1. app.py +86 -12
  2. requirements.txt +3 -1
  3. test_gemini_integration.py +0 -120
app.py CHANGED
@@ -4,7 +4,8 @@ import time
4
  import numpy as np
5
  from sentence_transformers import SentenceTransformer
6
  from fastapi import FastAPI, HTTPException, BackgroundTasks
7
- from fastapi.responses import StreamingResponse
 
8
  from pydantic import BaseModel
9
  from llama_cpp import Llama
10
  from huggingface_hub import login, hf_hub_download
@@ -15,6 +16,8 @@ import asyncio
15
  import psutil # Added for RAM tracking
16
  from google import genai
17
  from google.genai import types
 
 
18
 
19
  # Set up logging
20
  logging.basicConfig(level=logging.INFO)
@@ -22,6 +25,15 @@ logger = logging.getLogger(__name__)
22
 
23
  app = FastAPI()
24
 
 
 
 
 
 
 
 
 
 
25
  # Global lock for model access
26
  model_lock = asyncio.Lock()
27
 
@@ -51,6 +63,18 @@ else:
51
  gemini_client = None
52
  logger.info("Using local model (Gemini disabled)")
53
 
 
 
 
 
 
 
 
 
 
 
 
 
54
  # Define FAQs
55
  faqs = [
56
  {"question": "What is your name?", "answer": "My name is Tim Luka Horstmann."},
@@ -287,7 +311,10 @@ async def stream_response_local(query, history):
287
 
288
  class QueryRequest(BaseModel):
289
  query: str
290
- history: list[dict]
 
 
 
291
 
292
  # RAM Usage Tracking Function
293
  def get_ram_usage():
@@ -309,32 +336,79 @@ async def predict(request: QueryRequest):
309
  history = request.history
310
  return StreamingResponse(stream_response(query, history), media_type="text/event-stream")
311
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
312
  @app.get("/health")
313
  async def health_check():
314
  return {"status": "healthy"}
315
 
316
  @app.get("/model_info")
317
  async def model_info():
 
 
 
 
 
 
 
318
  if USE_GEMINI:
319
- return {
320
  "model_type": "gemini",
321
  "model_name": gemini_model,
322
  "provider": "Google Gemini API",
323
- "embedding_model": sentence_transformer_model,
324
- "faiss_index_size": len(cv_chunks),
325
- "faiss_index_dim": cv_embeddings.shape[1],
326
- }
327
  else:
328
- return {
329
  "model_type": "local",
330
  "model_name": filename,
331
  "repo_id": repo_id,
332
  "model_size": "1.7B",
333
  "quantization": "Q4_K_M",
334
- "embedding_model": sentence_transformer_model,
335
- "faiss_index_size": len(cv_chunks),
336
- "faiss_index_dim": cv_embeddings.shape[1],
337
- }
338
 
339
  @app.get("/ram_usage")
340
  async def ram_usage():
 
4
  import numpy as np
5
  from sentence_transformers import SentenceTransformer
6
  from fastapi import FastAPI, HTTPException, BackgroundTasks
7
+ from fastapi.responses import StreamingResponse, Response
8
+ from fastapi.middleware.cors import CORSMiddleware
9
  from pydantic import BaseModel
10
  from llama_cpp import Llama
11
  from huggingface_hub import login, hf_hub_download
 
16
  import psutil # Added for RAM tracking
17
  from google import genai
18
  from google.genai import types
19
+ import httpx
20
+ from elevenlabs import ElevenLabs, VoiceSettings
21
 
22
  # Set up logging
23
  logging.basicConfig(level=logging.INFO)
 
25
 
26
  app = FastAPI()
27
 
28
+ # Add CORS middleware to handle cross-origin requests
29
+ app.add_middleware(
30
+ CORSMiddleware,
31
+ allow_origins=["*"], # In production, specify your domain
32
+ allow_credentials=True,
33
+ allow_methods=["*"],
34
+ allow_headers=["*"],
35
+ )
36
+
37
  # Global lock for model access
38
  model_lock = asyncio.Lock()
39
 
 
63
  gemini_client = None
64
  logger.info("Using local model (Gemini disabled)")
65
 
66
+ # ElevenLabs Configuration
67
+ elevenlabs_api_key = os.getenv("ELEVENLABS_API_KEY")
68
+ if elevenlabs_api_key:
69
+ elevenlabs_client = ElevenLabs(api_key=elevenlabs_api_key)
70
+ # You can set a specific voice ID here or use the default voice
71
+ # Get your voice ID from ElevenLabs dashboard after cloning your voice
72
+ tts_voice_id = os.getenv("ELEVENLABS_VOICE_ID", "21m00Tcm4TlvDq8ikWAM") # Default voice, replace with your cloned voice ID
73
+ logger.info("ElevenLabs TTS client initialized")
74
+ else:
75
+ elevenlabs_client = None
76
+ logger.info("ElevenLabs TTS disabled (no API key provided)")
77
+
78
  # Define FAQs
79
  faqs = [
80
  {"question": "What is your name?", "answer": "My name is Tim Luka Horstmann."},
 
311
 
312
  class QueryRequest(BaseModel):
313
  query: str
314
+ history: list
315
+
316
+ class TTSRequest(BaseModel):
317
+ text: str[dict]
318
 
319
  # RAM Usage Tracking Function
320
  def get_ram_usage():
 
336
  history = request.history
337
  return StreamingResponse(stream_response(query, history), media_type="text/event-stream")
338
 
339
+ @app.post("/api/tts")
340
+ async def text_to_speech(request: TTSRequest):
341
+ """Convert text to speech using ElevenLabs API"""
342
+ if not elevenlabs_client:
343
+ raise HTTPException(status_code=503, detail="TTS service not available")
344
+
345
+ try:
346
+ # Clean the text for TTS (remove markdown and special characters)
347
+ clean_text = request.text.replace("**", "").replace("*", "").replace("\n", " ").strip()
348
+
349
+ if not clean_text:
350
+ raise HTTPException(status_code=400, detail="No text provided for TTS")
351
+
352
+ if len(clean_text) > 1000: # Limit text length to avoid long processing times
353
+ clean_text = clean_text[:1000] + "..."
354
+
355
+ # Generate speech
356
+ response = elevenlabs_client.text_to_speech.convert(
357
+ voice_id=tts_voice_id,
358
+ text=clean_text,
359
+ voice_settings=VoiceSettings(
360
+ stability=0.5,
361
+ similarity_boost=0.8,
362
+ style=0.2,
363
+ use_speaker_boost=True
364
+ )
365
+ )
366
+
367
+ # Convert generator to bytes
368
+ audio_bytes = b"".join(response)
369
+
370
+ return Response(
371
+ content=audio_bytes,
372
+ media_type="audio/mpeg",
373
+ headers={
374
+ "Content-Disposition": "inline; filename=tts_audio.mp3",
375
+ "Cache-Control": "no-cache"
376
+ }
377
+ )
378
+
379
+ except Exception as e:
380
+ logger.error(f"TTS error: {str(e)}")
381
+ raise HTTPException(status_code=500, detail=f"TTS conversion failed: {str(e)}")
382
+
383
  @app.get("/health")
384
  async def health_check():
385
  return {"status": "healthy"}
386
 
387
  @app.get("/model_info")
388
  async def model_info():
389
+ base_info = {
390
+ "embedding_model": sentence_transformer_model,
391
+ "faiss_index_size": len(cv_chunks),
392
+ "faiss_index_dim": cv_embeddings.shape[1],
393
+ "tts_available": elevenlabs_client is not None,
394
+ }
395
+
396
  if USE_GEMINI:
397
+ base_info.update({
398
  "model_type": "gemini",
399
  "model_name": gemini_model,
400
  "provider": "Google Gemini API",
401
+ })
 
 
 
402
  else:
403
+ base_info.update({
404
  "model_type": "local",
405
  "model_name": filename,
406
  "repo_id": repo_id,
407
  "model_size": "1.7B",
408
  "quantization": "Q4_K_M",
409
+ })
410
+
411
+ return base_info
 
412
 
413
  @app.get("/ram_usage")
414
  async def ram_usage():
requirements.txt CHANGED
@@ -8,4 +8,6 @@ huggingface_hub==0.30.1
8
  faiss-cpu==1.8.0
9
  asyncio
10
  psutil
11
- google-genai
 
 
 
8
  faiss-cpu==1.8.0
9
  asyncio
10
  psutil
11
+ google-genai
12
+ elevenlabs==1.1.3
13
+ httpx==0.25.0
test_gemini_integration.py DELETED
@@ -1,120 +0,0 @@
1
- #!/usr/bin/env python3
2
- """
3
- Test script for Gemini API integration
4
- """
5
-
6
- import os
7
- import asyncio
8
- from datetime import datetime
9
-
10
- # Mock the dependencies for testing
11
- class MockClient:
12
- def __init__(self, api_key):
13
- self.api_key = api_key
14
-
15
- class models:
16
- @staticmethod
17
- def generate_content_stream(model, contents, config):
18
- # Mock streaming response
19
- class MockChunk:
20
- text = "Hello! I'm Tim Luka Horstmann, a Computer Scientist currently pursuing my MSc in Data and AI at Institut Polytechnique de Paris."
21
-
22
- yield MockChunk()
23
-
24
- class MockTypes:
25
- class Content:
26
- def __init__(self, role, parts):
27
- self.role = role
28
- self.parts = parts
29
-
30
- class Part:
31
- def __init__(self, text):
32
- self.text = text
33
-
34
- @classmethod
35
- def from_text(cls, text):
36
- return cls(text)
37
-
38
- class GenerateContentConfig:
39
- def __init__(self, temperature, top_p, max_output_tokens):
40
- self.temperature = temperature
41
- self.top_p = top_p
42
- self.max_output_tokens = max_output_tokens
43
-
44
- # Test function similar to our Gemini implementation
45
- async def test_gemini_integration():
46
- """Test the Gemini integration logic"""
47
-
48
- # Mock environment variables
49
- USE_GEMINI = True
50
- gemini_api_key = "test_api_key"
51
- gemini_model = "gemini-2.5-flash-preview-05-20"
52
-
53
- # Mock full CV text
54
- full_cv_text = "Tim Luka Horstmann is a Computer Scientist pursuing MSc in Data and AI at Institut Polytechnique de Paris."
55
-
56
- # Initialize mock client
57
- gemini_client = MockClient(api_key=gemini_api_key)
58
- types = MockTypes()
59
-
60
- # Test query and history
61
- query = "What is your education?"
62
- history = []
63
-
64
- print(f"Testing Gemini integration...")
65
- print(f"USE_GEMINI: {USE_GEMINI}")
66
- print(f"Query: {query}")
67
-
68
- # Simulate the Gemini function logic
69
- current_date = datetime.now().strftime("%Y-%m-%d")
70
-
71
- system_prompt = (
72
- "You are Tim Luka Horstmann, a Computer Scientist. A user is asking you a question. Respond as yourself, using the first person, in a friendly and concise manner. "
73
- "For questions about your CV, base your answer *exclusively* on the provided CV information below and do not add any details not explicitly stated. "
74
- "For casual questions not covered by the CV, respond naturally but limit answers to general truths about yourself (e.g., your current location is Paris, France, or your field is AI) "
75
- "and say 'I don't have specific details to share about that' if pressed for specifics beyond the CV or FAQs. Do not invent facts, experiences, or opinions not supported by the CV or FAQs. "
76
- f"Today's date is {current_date}. "
77
- f"CV: {full_cv_text}"
78
- )
79
-
80
- # Build messages for Gemini (no system role - embed instructions in first user message)
81
- messages = []
82
-
83
- # Add conversation history
84
- for msg in history:
85
- role = "user" if msg["role"] == "user" else "model"
86
- messages.append(types.Content(role=role, parts=[types.Part.from_text(text=msg["content"])]))
87
-
88
- # Add current query with system prompt embedded
89
- if not history: # If no history, include system prompt with the first message
90
- combined_query = f"{system_prompt}\n\nUser question: {query}"
91
- else:
92
- combined_query = query
93
-
94
- messages.append(types.Content(role="user", parts=[types.Part.from_text(text=combined_query)]))
95
-
96
- print(f"System prompt length: {len(system_prompt)}")
97
- print(f"Number of messages: {len(messages)}")
98
-
99
- # Mock the streaming response
100
- response = gemini_client.models.generate_content_stream(
101
- model=gemini_model,
102
- contents=messages,
103
- config=types.GenerateContentConfig(
104
- temperature=0.3,
105
- top_p=0.7,
106
- max_output_tokens=512,
107
- )
108
- )
109
-
110
- print("Streaming response:")
111
- for chunk in response:
112
- if chunk.text:
113
- print(f"Chunk: {chunk.text}")
114
-
115
- print("✅ Gemini integration test completed successfully!")
116
-
117
- return True
118
-
119
- if __name__ == "__main__":
120
- asyncio.run(test_gemini_integration())