khurrameycon commited on
Commit
4196bc7
·
verified ·
1 Parent(s): 64f53a8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +128 -174
app.py CHANGED
@@ -94,31 +94,45 @@
94
 
95
  # return Response("No audio generated", status_code=400)
96
 
97
- from fastapi import FastAPI, Response, HTTPException, Request
98
- from fastapi.responses import JSONResponse
99
- from fastapi.staticfiles import StaticFiles
100
- from kokoro import KPipeline
101
  import os
102
- import numpy as np
103
- import torch
104
- from huggingface_hub import InferenceClient
105
- from pydantic import BaseModel
106
  import base64
107
  import logging
 
 
 
 
108
  from typing import Optional, ClassVar, List
109
- import uuid
 
 
 
110
 
111
  # Set up logging
112
  logging.basicConfig(level=logging.INFO)
113
  logger = logging.getLogger(__name__)
114
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
115
  class TextImageRequest(BaseModel):
116
  text: Optional[str] = None
117
  image_base64: Optional[str] = None
118
- voice: str = "af_heart" # Default voice that we know exists
119
  speed: float = 1.0
120
 
121
- # Annotate as a ClassVar so Pydantic ignores it as a field.
122
  AVAILABLE_VOICES: ClassVar[List[str]] = ["af_heart"]
123
 
124
  def validate_voice(self):
@@ -126,6 +140,7 @@ class TextImageRequest(BaseModel):
126
  return "af_heart"
127
  return self.voice
128
 
 
129
  class AudioResponse(BaseModel):
130
  status: str
131
  message: str
@@ -134,107 +149,72 @@ class ErrorResponse(BaseModel):
134
  error: str
135
  detail: Optional[str] = None
136
 
137
- # Initialize FastAPI app
138
- app = FastAPI(
139
- title="Text-to-Speech API with Vision Support",
140
- description="API for generating speech from text with optional image analysis",
141
- version="1.0.0"
142
- )
143
-
144
- # Create and mount static images directory so images are accessible via URL
145
- STATIC_DIR = "static_images"
146
- if not os.path.exists(STATIC_DIR):
147
- os.makedirs(STATIC_DIR)
148
- app.mount("/static", StaticFiles(directory=STATIC_DIR), name="static")
149
-
150
- def llm_chat_response(text, image_base64=None):
151
- """Get responses from LLM with text and optionally an image input."""
 
 
152
  try:
153
- HF_TOKEN = os.getenv("HF_TOKEN")
154
- logger.info("Checking HF_TOKEN...")
155
- if not HF_TOKEN:
156
- logger.error("HF_TOKEN not found in environment variables")
157
- raise HTTPException(status_code=500, detail="HF_TOKEN not configured")
158
-
159
- logger.info("Initializing InferenceClient...")
160
- client = InferenceClient(
161
- provider="hf-inference", # Using the correct provider as per sample
162
- api_key=HF_TOKEN
163
- )
164
-
165
- if image_base64:
166
- logger.info("Processing request with image")
167
- # Save the base64 image to the static folder
168
- filename = f"{uuid.uuid4()}.jpg"
169
- image_path = os.path.join(STATIC_DIR, filename)
170
- try:
171
- image_data = base64.b64decode(image_base64)
172
- except Exception as e:
173
- logger.error(f"Error decoding base64 image: {str(e)}")
174
- raise HTTPException(status_code=400, detail="Invalid base64 image data")
175
- with open(image_path, "wb") as f:
176
- f.write(image_data)
177
- # Construct image URL (assumes BASE_URL environment variable or defaults to localhost)
178
- base_url = os.getenv("BASE_URL", "http://localhost:8000")
179
- image_url = f"{base_url}/static/{filename}"
180
- prompt = text if text else "Describe this image in one sentence."
181
- # Construct message exactly as in the reference
182
- messages = [
183
- {
184
- "role": "user",
185
- "content": [
186
- {"type": "text", "text": prompt},
187
- {"type": "image_url", "image_url": {"url": image_url}}
188
- ]
189
- }
190
- ]
191
- else:
192
- logger.info("Processing text-only request")
193
- messages = [
194
- {
195
- "role": "user",
196
- "content": text + " Describe in one line only."
197
- }
198
  ]
199
-
200
- logger.info("Sending request to model...")
201
- logger.info(f"Message structure: {messages}")
202
-
 
203
  completion = client.chat.completions.create(
204
  model="meta-llama/Llama-3.2-11B-Vision-Instruct",
205
  messages=messages,
206
  max_tokens=500
207
  )
208
-
209
- logger.info("Received response from model")
210
- logger.info(f"Model response received: {completion}")
211
-
212
- try:
213
- response = completion.choices[0].message.content
214
- logger.info(f"Extracted response content: {response}")
215
- return response
216
- except Exception as e:
217
- logger.error(f"Error extracting message content: {str(e)}")
218
- try:
219
- if hasattr(completion.choices[0], "message") and hasattr(completion.choices[0].message, "content"):
220
- return completion.choices[0].message.content
221
- return completion.choices[0]["message"]["content"]
222
- except Exception as e2:
223
- logger.error(f"All extraction methods failed: {str(e2)}")
224
- return "I couldn't process that input. Please try again with a different query."
225
-
226
  except Exception as e:
227
- logger.error(f"Error in llm_chat_response: {str(e)}")
228
  raise HTTPException(status_code=500, detail=str(e))
229
 
230
- # Initialize the audio generation pipeline once at startup
231
  try:
232
  logger.info("Initializing KPipeline...")
233
  pipeline = KPipeline(lang_code='a')
234
  logger.info("KPipeline initialized successfully")
235
  except Exception as e:
236
  logger.error(f"Failed to initialize KPipeline: {str(e)}")
237
- # The app starts regardless but logs the error
238
 
239
  @app.post("/generate", responses={
240
  200: {"content": {"application/octet-stream": {}}},
@@ -243,95 +223,69 @@ except Exception as e:
243
  })
244
  async def generate_audio(request: TextImageRequest):
245
  """
246
- Generate audio from text and optionally analyze an image.
247
-
248
- - If text is provided, it is used as input.
249
- - If an image is provided (base64), it is saved and a URL is generated for processing.
250
- - The LLM response is then converted to speech.
251
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
252
  try:
253
- logger.info("Received audio generation request")
254
- user_text = request.text if request.text is not None else ""
255
- if not user_text and request.image_base64:
256
- user_text = "Describe what you see in the image"
257
- elif not user_text and not request.image_base64:
258
- logger.error("Neither text nor image provided in request")
259
- return JSONResponse(
260
- status_code=400,
261
- content={"error": "Request must include either text or image_base64"}
262
- )
263
-
264
- logger.info("Getting LLM response...")
265
- text_reply = llm_chat_response(user_text, request.image_base64)
266
- logger.info(f"LLM response: {text_reply}")
267
-
268
- validated_voice = request.validate_voice()
269
- if validated_voice != request.voice:
270
- logger.warning(f"Requested voice '{request.voice}' not available, using '{validated_voice}' instead")
271
-
272
- logger.info(f"Generating audio using voice={validated_voice}, speed={request.speed}")
273
- try:
274
- generator = pipeline(
275
- text_reply,
276
- voice=validated_voice,
277
- speed=request.speed,
278
- split_pattern=r'\n+'
279
- )
280
-
281
- for i, (gs, ps, audio) in enumerate(generator):
282
- logger.info(f"Audio generated successfully: segment {i}")
283
- # Convert PyTorch tensor to NumPy array
284
- audio_numpy = audio.cpu().numpy()
285
- # Clip values to range [-1, 1] and convert to 16-bit PCM
286
- audio_numpy = np.clip(audio_numpy, -1, 1)
287
- pcm_data = (audio_numpy * 32767).astype(np.int16)
288
- raw_audio = pcm_data.tobytes()
289
-
290
- return Response(
291
- content=raw_audio,
292
- media_type="application/octet-stream",
293
- headers={
294
- "Content-Disposition": 'attachment; filename="output.pcm"',
295
- "X-Sample-Rate": "24000",
296
- "X-Bits-Per-Sample": "16",
297
- "X-Endianness": "little"
298
- }
299
- )
300
-
301
- logger.error("No audio segments generated")
302
- return JSONResponse(
303
- status_code=400,
304
- content={"error": "No audio generated", "detail": "The pipeline did not produce any audio"}
305
- )
306
 
307
- except Exception as e:
308
- logger.error(f"Error generating audio: {str(e)}")
309
- return JSONResponse(
310
- status_code=500,
311
- content={"error": "Audio generation failed", "detail": str(e)}
 
 
 
 
312
  )
313
-
314
  except Exception as e:
315
- logger.error(f"Unexpected error in generate_audio endpoint: {str(e)}")
316
- return JSONResponse(
317
- status_code=500,
318
- content={"error": "Internal server error", "detail": str(e)}
319
- )
320
 
321
  @app.get("/")
322
  async def root():
323
- return {"message": "Welcome to the Text-to-Speech API with Vision Support. Use POST /generate with 'text' and optionally 'image_base64' for queries."}
324
 
325
  @app.exception_handler(404)
326
  async def not_found_handler(request: Request, exc):
327
- return JSONResponse(
328
- status_code=404,
329
- content={"error": "Endpoint not found. Please use POST /generate for queries."}
330
- )
331
 
332
  @app.exception_handler(405)
333
  async def method_not_allowed_handler(request: Request, exc):
334
- return JSONResponse(
335
- status_code=405,
336
- content={"error": "Method not allowed. Please check the API documentation."}
337
- )
 
94
 
95
  # return Response("No audio generated", status_code=400)
96
 
 
 
 
 
97
  import os
98
+ import uuid
 
 
 
99
  import base64
100
  import logging
101
+ from fastapi import FastAPI, HTTPException, Response, Request
102
+ from fastapi.responses import JSONResponse
103
+ from fastapi.staticfiles import StaticFiles
104
+ from pydantic import BaseModel
105
  from typing import Optional, ClassVar, List
106
+ from huggingface_hub import InferenceClient
107
+ import numpy as np
108
+ import torch
109
+ from kokoro import KPipeline # Assuming you have this pipeline for audio generation
110
 
111
  # Set up logging
112
  logging.basicConfig(level=logging.INFO)
113
  logger = logging.getLogger(__name__)
114
 
115
+ # Create FastAPI app
116
+ app = FastAPI(
117
+ title="Text-to-Speech API with Vision Support",
118
+ description="This API uses meta-llama/Llama-3.2-11B-Vision-Instruct, which requires an image input.",
119
+ version="1.0.0"
120
+ )
121
+
122
+ # Mount a static directory for serving saved images
123
+ STATIC_DIR = "static_images"
124
+ if not os.path.exists(STATIC_DIR):
125
+ os.makedirs(STATIC_DIR)
126
+ app.mount("/static", StaticFiles(directory=STATIC_DIR), name="static")
127
+
128
+ # Pydantic model for request
129
  class TextImageRequest(BaseModel):
130
  text: Optional[str] = None
131
  image_base64: Optional[str] = None
132
+ voice: str = "af_heart" # Default voice
133
  speed: float = 1.0
134
 
135
+ # Use ClassVar so that Pydantic doesn't treat this as a model field.
136
  AVAILABLE_VOICES: ClassVar[List[str]] = ["af_heart"]
137
 
138
  def validate_voice(self):
 
140
  return "af_heart"
141
  return self.voice
142
 
143
+ # (Optional) Pydantic models for responses
144
  class AudioResponse(BaseModel):
145
  status: str
146
  message: str
 
149
  error: str
150
  detail: Optional[str] = None
151
 
152
+ # Function to call the LLM model following the reference code exactly
153
+ def llm_chat_response(text: str, image_base64: str) -> str:
154
+ HF_TOKEN = os.getenv("HF_TOKEN")
155
+ logger.info("Checking HF_TOKEN...")
156
+ if not HF_TOKEN:
157
+ logger.error("HF_TOKEN not configured")
158
+ raise HTTPException(status_code=500, detail="HF_TOKEN not configured")
159
+
160
+ logger.info("Initializing InferenceClient...")
161
+ client = InferenceClient(
162
+ provider="hf-inference",
163
+ api_key=HF_TOKEN
164
+ )
165
+
166
+ # Save the base64-encoded image locally so it is accessible via a URL
167
+ filename = f"{uuid.uuid4()}.jpg"
168
+ image_path = os.path.join(STATIC_DIR, filename)
169
  try:
170
+ image_data = base64.b64decode(image_base64)
171
+ except Exception as e:
172
+ logger.error(f"Error decoding image: {str(e)}")
173
+ raise HTTPException(status_code=400, detail="Invalid base64 image data")
174
+
175
+ with open(image_path, "wb") as f:
176
+ f.write(image_data)
177
+
178
+ # Construct the public URL for the saved image.
179
+ # BASE_URL should be set to your public URL if not running locally.
180
+ base_url = os.getenv("BASE_URL", "http://localhost:8000")
181
+ image_url = f"{base_url}/static/{filename}"
182
+
183
+ # Build the message exactly as in the reference code.
184
+ # This model requires a list with two items: one for text and one for the image.
185
+ prompt = text if text else "Describe this image in one sentence."
186
+ messages = [
187
+ {
188
+ "role": "user",
189
+ "content": [
190
+ {"type": "text", "text": prompt},
191
+ {"type": "image_url", "image_url": {"url": image_url}}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
192
  ]
193
+ }
194
+ ]
195
+ logger.info(f"Message structure: {messages}")
196
+
197
+ try:
198
  completion = client.chat.completions.create(
199
  model="meta-llama/Llama-3.2-11B-Vision-Instruct",
200
  messages=messages,
201
  max_tokens=500
202
  )
203
+ response = completion.choices[0].message.content
204
+ logger.info(f"Extracted response: {response}")
205
+ return response
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
206
  except Exception as e:
207
+ logger.error(f"Error during model inference: {str(e)}")
208
  raise HTTPException(status_code=500, detail=str(e))
209
 
210
+ # Initialize audio generation pipeline (your audio conversion pipeline)
211
  try:
212
  logger.info("Initializing KPipeline...")
213
  pipeline = KPipeline(lang_code='a')
214
  logger.info("KPipeline initialized successfully")
215
  except Exception as e:
216
  logger.error(f"Failed to initialize KPipeline: {str(e)}")
217
+ # The API can still run, but audio generation will fail.
218
 
219
  @app.post("/generate", responses={
220
  200: {"content": {"application/octet-stream": {}}},
 
223
  })
224
  async def generate_audio(request: TextImageRequest):
225
  """
226
+ Generate audio from a multimodal (text+image) input.
227
+ This model does not support text-only inputs.
 
 
 
228
  """
229
+ logger.info("Received generation request")
230
+ # Ensure an image is provided because the model is multimodal.
231
+ if not request.image_base64:
232
+ raise HTTPException(status_code=400, detail="This model requires an image input.")
233
+
234
+ # Get the text prompt. If none is provided, use a default.
235
+ user_text = request.text if request.text else "Describe this image in one sentence."
236
+
237
+ # Get the LLM's response
238
+ logger.info("Calling the LLM model")
239
+ text_reply = llm_chat_response(user_text, request.image_base64)
240
+ logger.info(f"LLM response: {text_reply}")
241
+
242
+ # Validate voice parameter (if needed for audio generation)
243
+ validated_voice = request.validate_voice()
244
+ if validated_voice != request.voice:
245
+ logger.warning(f"Voice '{request.voice}' not available; using '{validated_voice}' instead")
246
+
247
+ # Convert the text reply to audio using your audio pipeline
248
+ logger.info(f"Generating audio using voice={validated_voice}, speed={request.speed}")
249
  try:
250
+ # Generate audio segments (assumes pipeline yields segments)
251
+ generator = pipeline(
252
+ text_reply,
253
+ voice=validated_voice,
254
+ speed=request.speed,
255
+ split_pattern=r'\n+'
256
+ )
257
+ for i, (gs, ps, audio) in enumerate(generator):
258
+ logger.info(f"Audio generated, segment {i}")
259
+ # Convert audio tensor to 16-bit PCM bytes
260
+ audio_numpy = audio.cpu().numpy()
261
+ audio_numpy = np.clip(audio_numpy, -1, 1)
262
+ pcm_data = (audio_numpy * 32767).astype(np.int16)
263
+ raw_audio = pcm_data.tobytes()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
264
 
265
+ return Response(
266
+ content=raw_audio,
267
+ media_type="application/octet-stream",
268
+ headers={
269
+ "Content-Disposition": 'attachment; filename="output.pcm"',
270
+ "X-Sample-Rate": "24000",
271
+ "X-Bits-Per-Sample": "16",
272
+ "X-Endianness": "little"
273
+ }
274
  )
275
+ raise HTTPException(status_code=400, detail="No audio segments generated.")
276
  except Exception as e:
277
+ logger.error(f"Error generating audio: {str(e)}")
278
+ raise HTTPException(status_code=500, detail=str(e))
 
 
 
279
 
280
  @app.get("/")
281
  async def root():
282
+ return {"message": "Welcome! Use POST /generate with text and image_base64."}
283
 
284
  @app.exception_handler(404)
285
  async def not_found_handler(request: Request, exc):
286
+ return JSONResponse(status_code=404, content={"error": "Endpoint not found."})
 
 
 
287
 
288
  @app.exception_handler(405)
289
  async def method_not_allowed_handler(request: Request, exc):
290
+ return JSONResponse(status_code=405, content={"error": "Method not allowed."})
291
+