khurrameycon commited on
Commit
b2195c3
·
verified ·
1 Parent(s): 63d6fee

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +169 -71
app.py CHANGED
@@ -94,43 +94,68 @@
94
 
95
  # return Response("No audio generated", status_code=400)
96
 
 
 
 
 
97
  import os
98
- import logging
 
 
 
99
  import base64
 
 
 
100
  from typing import Optional
101
- from fastapi import FastAPI, HTTPException
102
- from fastapi.responses import JSONResponse
103
- from pydantic import BaseModel
104
- from huggingface_hub import InferenceClient
105
- from requests.exceptions import HTTPError
106
  import uuid
 
107
 
108
  # Set up logging
109
  logging.basicConfig(level=logging.INFO)
110
  logger = logging.getLogger(__name__)
111
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
112
  # Initialize FastAPI app
113
  app = FastAPI(
114
- title="LLM Chat API",
115
- description="API for getting chat responses from Llama model (supports text and image input)",
116
  version="1.0.0"
117
  )
118
 
119
- # Directory to save images
120
- STATIC_DIR = "static_images"
121
- if not os.path.exists(STATIC_DIR):
122
- os.makedirs(STATIC_DIR)
123
-
124
- # Pydantic models
125
- class ChatRequest(BaseModel):
126
- text: str
127
- image_url: Optional[str] = None # In this updated version, this field is expected to be a base64 encoded image
128
-
129
- class ChatResponse(BaseModel):
130
- response: str
131
- status: str
 
 
 
 
132
 
133
- def llm_chat_response(text: str, image_base64: Optional[str] = None) -> str:
 
134
  try:
135
  HF_TOKEN = os.getenv("HF_TOKEN")
136
  logger.info("Checking HF_TOKEN...")
@@ -140,41 +165,32 @@ def llm_chat_response(text: str, image_base64: Optional[str] = None) -> str:
140
 
141
  logger.info("Initializing InferenceClient...")
142
  client = InferenceClient(
143
- provider="hf-inference", # Updated provider
144
  api_key=HF_TOKEN
145
  )
146
 
147
- # Build the messages payload.
148
- # For text-only queries, append a default instruction.
149
  message_content = [{
150
  "type": "text",
151
  "text": text + ("" if image_base64 else " describe in one line only")
152
  }]
153
 
154
  if image_base64:
155
- logger.info("Saving base64 encoded image to file...")
156
- # Decode and save the image locally
157
- filename = f"{uuid.uuid4()}.jpg"
158
- image_path = os.path.join(STATIC_DIR, filename)
159
- try:
160
- image_data = base64.b64decode(image_base64)
161
- except Exception as e:
162
- logger.error(f"Error decoding image: {str(e)}")
163
- raise HTTPException(status_code=400, detail="Invalid base64 image data")
164
- with open(image_path, "wb") as f:
165
- f.write(image_data)
166
 
167
- # Construct public URL for the saved image.
168
- # Set BASE_URL to your public URL if needed.
169
- base_url = os.getenv("BASE_URL", "http://localhost:8000")
170
- public_image_url = f"{base_url}/{STATIC_DIR}/{filename}"
171
- logger.info(f"Using saved image URL: {public_image_url}")
172
 
 
173
  message_content.append({
174
  "type": "image_url",
175
- "image_url": {"url": public_image_url}
176
  })
177
 
 
178
  messages = [{
179
  "role": "user",
180
  "content": message_content
@@ -187,23 +203,19 @@ def llm_chat_response(text: str, image_base64: Optional[str] = None) -> str:
187
  messages=messages,
188
  max_tokens=500
189
  )
190
- except HTTPError as http_err:
191
- logger.error(f"HTTP error occurred: {http_err.response.text}")
192
- raise HTTPException(status_code=500, detail=http_err.response.text)
 
193
 
194
- logger.info(f"Raw model response: {completion}")
195
-
196
- if getattr(completion, "error", None):
197
- error_details = completion.error
198
- error_message = error_details.get("message", "Unknown error")
199
- logger.error(f"Model returned error: {error_message}")
200
- raise HTTPException(status_code=500, detail=f"Model returned error: {error_message}")
201
 
 
202
  if not completion.choices or len(completion.choices) == 0:
203
  logger.error("No choices returned from model.")
204
  raise HTTPException(status_code=500, detail="Model returned no choices.")
205
 
206
- # Extract the response message from the first choice.
207
  choice = completion.choices[0]
208
  response_message = None
209
  if hasattr(choice, "message"):
@@ -226,35 +238,122 @@ def llm_chat_response(text: str, image_base64: Optional[str] = None) -> str:
226
  raise HTTPException(status_code=500, detail="Model message did not include content.")
227
 
228
  return content
229
-
230
  except Exception as e:
231
  logger.error(f"Error in llm_chat_response: {str(e)}")
232
- raise HTTPException(status_code=500, detail=str(e))
 
 
 
 
 
 
 
 
 
 
233
 
234
- @app.post("/chat", response_model=ChatResponse)
235
- async def chat(request: ChatRequest):
 
 
 
 
 
 
 
236
  try:
237
- logger.info(f"Received chat request with text: {request.text}")
238
- if request.image_url:
239
- logger.info("Image data provided.")
240
- response = llm_chat_response(request.text, request.image_url)
241
- return ChatResponse(response=response, status="success")
242
- except HTTPException as he:
243
- logger.error(f"HTTP Exception in chat endpoint: {str(he)}")
244
- raise he
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
245
  except Exception as e:
246
- logger.error(f"Unexpected error in chat endpoint: {str(e)}")
247
- raise HTTPException(status_code=500, detail=str(e))
 
 
 
248
 
249
  @app.get("/")
250
  async def root():
251
- return {"message": "Welcome to the LLM Chat API. Use POST /chat endpoint with 'text' and optionally 'image_url' (base64 encoded) for queries."}
 
 
 
 
 
 
252
 
253
  @app.exception_handler(404)
254
  async def not_found_handler(request, exc):
255
  return JSONResponse(
256
  status_code=404,
257
- content={"error": "Endpoint not found. Please use POST /chat for queries."}
258
  )
259
 
260
  @app.exception_handler(405)
@@ -262,5 +361,4 @@ async def method_not_allowed_handler(request, exc):
262
  return JSONResponse(
263
  status_code=405,
264
  content={"error": "Method not allowed. Please check the API documentation."}
265
- )
266
-
 
94
 
95
  # return Response("No audio generated", status_code=400)
96
 
97
+ from fastapi import FastAPI, Response, HTTPException
98
+ from fastapi.responses import FileResponse, JSONResponse
99
+ from kokoro import KPipeline
100
+ import soundfile as sf
101
  import os
102
+ import numpy as np
103
+ import torch
104
+ from huggingface_hub import InferenceClient
105
+ from pydantic import BaseModel
106
  import base64
107
+ from io import BytesIO
108
+ from PIL import Image
109
+ import logging
110
  from typing import Optional
 
 
 
 
 
111
  import uuid
112
+ import pathlib
113
 
114
  # Set up logging
115
  logging.basicConfig(level=logging.INFO)
116
  logger = logging.getLogger(__name__)
117
 
118
+ # Create a directory for temporary image storage
119
+ TEMP_DIR = pathlib.Path("./temp_images")
120
+ TEMP_DIR.mkdir(exist_ok=True)
121
+
122
+ class TextImageRequest(BaseModel):
123
+ text: Optional[str] = None
124
+ image_base64: Optional[str] = None
125
+ voice: str = "af_heart"
126
+ speed: float = 1.0
127
+
128
+ class AudioResponse(BaseModel):
129
+ status: str
130
+ message: str
131
+
132
  # Initialize FastAPI app
133
  app = FastAPI(
134
+ title="Text-to-Speech API with Vision Support",
135
+ description="API for generating speech from text with optional image analysis",
136
  version="1.0.0"
137
  )
138
 
139
+ def save_base64_image(image_base64):
140
+ """Save base64 image to a temporary file and return the file path"""
141
+ try:
142
+ # Generate a unique filename
143
+ filename = f"{uuid.uuid4()}.jpg"
144
+ filepath = TEMP_DIR / filename
145
+
146
+ # Decode and save the image
147
+ image_data = base64.b64decode(image_base64)
148
+ with open(filepath, "wb") as f:
149
+ f.write(image_data)
150
+
151
+ # Return the file URL (using file:// protocol)
152
+ return f"file://{filepath.absolute()}"
153
+ except Exception as e:
154
+ logger.error(f"Error saving base64 image: {str(e)}")
155
+ raise HTTPException(status_code=400, detail=f"Invalid base64 image data: {str(e)}")
156
 
157
+ def llm_chat_response(text, image_base64=None):
158
+ """Function to get responses from LLM with text and optionally image input."""
159
  try:
160
  HF_TOKEN = os.getenv("HF_TOKEN")
161
  logger.info("Checking HF_TOKEN...")
 
165
 
166
  logger.info("Initializing InferenceClient...")
167
  client = InferenceClient(
168
+ provider="sambanova", # Using sambanova as in your working example
169
  api_key=HF_TOKEN
170
  )
171
 
172
+ # Build the messages payload using the format from your working example
 
173
  message_content = [{
174
  "type": "text",
175
  "text": text + ("" if image_base64 else " describe in one line only")
176
  }]
177
 
178
  if image_base64:
179
+ logger.info("Processing base64 image...")
180
+ # Save the base64 image to a file and get the file URL
181
+ image_url = save_base64_image(image_base64)
182
+ logger.info(f"Image saved at: {image_url}")
 
 
 
 
 
 
 
183
 
184
+ # Create data URI
185
+ data_uri = f"data:image/jpeg;base64,{image_base64}"
 
 
 
186
 
187
+ # Add image to message content
188
  message_content.append({
189
  "type": "image_url",
190
+ "image_url": {"url": data_uri}
191
  })
192
 
193
+ # Construct the messages array exactly as in your working example
194
  messages = [{
195
  "role": "user",
196
  "content": message_content
 
203
  messages=messages,
204
  max_tokens=500
205
  )
206
+ except Exception as http_err:
207
+ # Log HTTP errors from the request
208
+ logger.error(f"HTTP error occurred: {str(http_err)}")
209
+ raise HTTPException(status_code=500, detail=str(http_err))
210
 
211
+ logger.info(f"Raw model response received")
 
 
 
 
 
 
212
 
213
+ # Extract the response using the same method as your working code
214
  if not completion.choices or len(completion.choices) == 0:
215
  logger.error("No choices returned from model.")
216
  raise HTTPException(status_code=500, detail="Model returned no choices.")
217
 
218
+ # Extract the response message from the first choice
219
  choice = completion.choices[0]
220
  response_message = None
221
  if hasattr(choice, "message"):
 
238
  raise HTTPException(status_code=500, detail="Model message did not include content.")
239
 
240
  return content
241
+
242
  except Exception as e:
243
  logger.error(f"Error in llm_chat_response: {str(e)}")
244
+ # Fallback response in case of error
245
+ return "I couldn't process that input. Please try again with a different image or text query."
246
+
247
+ # Initialize pipeline once at startup
248
+ try:
249
+ logger.info("Initializing KPipeline...")
250
+ pipeline = KPipeline(lang_code='a')
251
+ logger.info("KPipeline initialized successfully")
252
+ except Exception as e:
253
+ logger.error(f"Failed to initialize KPipeline: {str(e)}")
254
+ # We'll let the app start anyway, but log the error
255
 
256
+ @app.post("/generate")
257
+ async def generate_audio(request: TextImageRequest):
258
+ """
259
+ Generate audio from text and optionally analyze an image.
260
+
261
+ - If text is provided, uses that as input
262
+ - If image is provided, analyzes the image
263
+ - Converts the LLM response to speech using the specified voice and speed
264
+ """
265
  try:
266
+ logger.info(f"Received audio generation request")
267
+
268
+ # If no text is provided but image is provided, use default prompt
269
+ user_text = request.text if request.text is not None else ""
270
+ if not user_text and request.image_base64:
271
+ user_text = "Describe what you see in the image"
272
+ elif not user_text and not request.image_base64:
273
+ logger.error("Neither text nor image provided in request")
274
+ return JSONResponse(
275
+ status_code=400,
276
+ content={"error": "Request must include either text or image_base64"}
277
+ )
278
+
279
+ # Generate response using text and image if provided
280
+ logger.info("Getting LLM response...")
281
+ text_reply = llm_chat_response(user_text, request.image_base64)
282
+ logger.info(f"LLM response: {text_reply}")
283
+
284
+ # Generate audio
285
+ logger.info(f"Generating audio using voice={request.voice}, speed={request.speed}")
286
+ try:
287
+ generator = pipeline(
288
+ text_reply,
289
+ voice=request.voice,
290
+ speed=request.speed,
291
+ split_pattern=r'\n+'
292
+ )
293
+
294
+ # Process only the first segment for demo
295
+ for i, (gs, ps, audio) in enumerate(generator):
296
+ logger.info(f"Audio generated successfully: segment {i}")
297
+
298
+ # Convert PyTorch tensor to NumPy array
299
+ audio_numpy = audio.cpu().numpy()
300
+
301
+ # Convert to 16-bit PCM
302
+ # Ensure the audio is in the range [-1, 1]
303
+ audio_numpy = np.clip(audio_numpy, -1, 1)
304
+ # Convert to 16-bit signed integers
305
+ pcm_data = (audio_numpy * 32767).astype(np.int16)
306
+
307
+ # Convert to bytes (automatically uses row-major order)
308
+ raw_audio = pcm_data.tobytes()
309
+
310
+ # Return PCM data with minimal necessary headers
311
+ return Response(
312
+ content=raw_audio,
313
+ media_type="application/octet-stream",
314
+ headers={
315
+ "Content-Disposition": f'attachment; filename="output.pcm"',
316
+ "X-Sample-Rate": "24000",
317
+ "X-Bits-Per-Sample": "16",
318
+ "X-Endianness": "little"
319
+ }
320
+ )
321
+
322
+ logger.error("No audio segments generated")
323
+ return JSONResponse(
324
+ status_code=400,
325
+ content={"error": "No audio generated", "detail": "The pipeline did not produce any audio"}
326
+ )
327
+
328
+ except Exception as e:
329
+ logger.error(f"Error generating audio: {str(e)}")
330
+ return JSONResponse(
331
+ status_code=500,
332
+ content={"error": "Audio generation failed", "detail": str(e)}
333
+ )
334
+
335
  except Exception as e:
336
+ logger.error(f"Unexpected error in generate_audio endpoint: {str(e)}")
337
+ return JSONResponse(
338
+ status_code=500,
339
+ content={"error": "Internal server error", "detail": str(e)}
340
+ )
341
 
342
  @app.get("/")
343
  async def root():
344
+ return {"message": "Welcome to the Text-to-Speech API with Vision Support. Use POST /generate endpoint with 'text' and optionally 'image_base64' for queries."}
345
+
346
+ # Cleanup function to periodically remove old temporary images
347
+ @app.on_event("startup")
348
+ async def startup_event():
349
+ # You could add scheduled tasks here to clean up old images
350
+ pass
351
 
352
  @app.exception_handler(404)
353
  async def not_found_handler(request, exc):
354
  return JSONResponse(
355
  status_code=404,
356
+ content={"error": "Endpoint not found. Please use POST /generate for queries."}
357
  )
358
 
359
  @app.exception_handler(405)
 
361
  return JSONResponse(
362
  status_code=405,
363
  content={"error": "Method not allowed. Please check the API documentation."}
364
+ )