khurrameycon commited on
Commit
63d6fee
·
verified ·
1 Parent(s): 313833e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +135 -153
app.py CHANGED
@@ -95,190 +95,172 @@
95
  # return Response("No audio generated", status_code=400)
96
 
97
  import os
98
- import uuid
99
- import base64
100
  import logging
101
- from fastapi import FastAPI, HTTPException, Response, Request
 
 
102
  from fastapi.responses import JSONResponse
103
- from fastapi.staticfiles import StaticFiles
104
  from pydantic import BaseModel
105
- from typing import Optional, ClassVar, List
106
  from huggingface_hub import InferenceClient
107
- import numpy as np
108
- import torch
109
- from kokoro import KPipeline # Your audio generation pipeline
110
 
111
  # Set up logging
112
  logging.basicConfig(level=logging.INFO)
113
  logger = logging.getLogger(__name__)
114
 
115
- # Create FastAPI app
116
  app = FastAPI(
117
- title="Text-to-Speech API with Vision Support",
118
- description="This API uses meta-llama/Llama-3.2-11B-Vision-Instruct which requires an image input.",
119
  version="1.0.0"
120
  )
121
 
122
- # Mount a static directory for serving saved images
123
  STATIC_DIR = "static_images"
124
  if not os.path.exists(STATIC_DIR):
125
  os.makedirs(STATIC_DIR)
126
- app.mount("/static", StaticFiles(directory=STATIC_DIR), name="static")
127
-
128
- # Pydantic model for request
129
- class TextImageRequest(BaseModel):
130
- text: Optional[str] = None
131
- image_base64: Optional[str] = None
132
- voice: str = "af_heart" # Default voice
133
- speed: float = 1.0
134
-
135
- # Use ClassVar so that Pydantic doesn't treat this as a model field.
136
- AVAILABLE_VOICES: ClassVar[List[str]] = ["af_heart"]
137
 
138
- def validate_voice(self):
139
- if self.voice not in self.AVAILABLE_VOICES:
140
- return "af_heart"
141
- return self.voice
142
 
143
- # Pydantic model for error responses
144
- class ErrorResponse(BaseModel):
145
- error: str
146
- detail: Optional[str] = None
147
 
148
- def llm_chat_response(prompt: str, image_base64: str) -> str:
149
- HF_TOKEN = os.getenv("HF_TOKEN")
150
- logger.info("Checking HF_TOKEN...")
151
- if not HF_TOKEN:
152
- logger.error("HF_TOKEN not configured")
153
- raise HTTPException(status_code=500, detail="HF_TOKEN not configured")
154
-
155
- logger.info("Initializing InferenceClient...")
156
- client = InferenceClient(
157
- provider="hf-inference",
158
- api_key=HF_TOKEN
159
- )
160
-
161
- # Save the base64-encoded image locally
162
- filename = f"{uuid.uuid4()}.jpg"
163
- image_path = os.path.join(STATIC_DIR, filename)
164
  try:
165
- image_data = base64.b64decode(image_base64)
166
- except Exception as e:
167
- logger.error(f"Error decoding image: {str(e)}")
168
- raise HTTPException(status_code=400, detail="Invalid base64 image data")
169
-
170
- with open(image_path, "wb") as f:
171
- f.write(image_data)
172
-
173
- # Construct the public URL for the saved image.
174
- # Set BASE_URL to your public URL if needed.
175
- base_url = os.getenv("BASE_URL", "http://localhost:8000")
176
- image_url = f"{base_url}/static/{filename}"
177
-
178
- # Build the message payload exactly as in the reference:
179
- messages = [
180
- {
181
- "role": "user",
182
- "content": [
183
- {
184
- "type": "text",
185
- "text": prompt
186
- },
187
- {
188
- "type": "image_url",
189
- "image_url": {
190
- "url": image_url
191
- }
192
- }
193
- ]
194
- }
195
- ]
196
- logger.info(f"Message structure: {messages}")
197
-
198
- try:
199
- completion = client.chat.completions.create(
200
- model="meta-llama/Llama-3.2-11B-Vision-Instruct",
201
- messages=messages,
202
- max_tokens=500,
203
  )
204
- response = completion.choices[0].message.content
205
- logger.info(f"Extracted response: {response}")
206
- return response
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
207
  except Exception as e:
208
- logger.error(f"Error during model inference: {str(e)}")
209
  raise HTTPException(status_code=500, detail=str(e))
210
 
211
- # Initialize the audio generation pipeline (KPipeline)
212
- try:
213
- logger.info("Initializing KPipeline...")
214
- pipeline = KPipeline(lang_code='a')
215
- logger.info("KPipeline initialized successfully")
216
- except Exception as e:
217
- logger.error(f"Failed to initialize KPipeline: {str(e)}")
218
- # The API will run but audio generation will fail if the pipeline is not ready.
219
-
220
- @app.post("/generate", responses={
221
- 200: {"content": {"application/octet-stream": {}}},
222
- 400: {"model": ErrorResponse},
223
- 500: {"model": ErrorResponse}
224
- })
225
- async def generate_audio(request: TextImageRequest):
226
- """
227
- Generate audio from a multimodal (text+image) input.
228
- This model requires an image input.
229
- """
230
- logger.info("Received generation request")
231
-
232
- # The model requires an image; if missing, return an error.
233
- if not request.image_base64:
234
- raise HTTPException(status_code=400, detail="This model requires an image input.")
235
-
236
- prompt = request.text if request.text else "Describe this image in one sentence."
237
- logger.info("Calling the LLM model")
238
- text_reply = llm_chat_response(prompt, request.image_base64)
239
- logger.info(f"LLM response: {text_reply}")
240
-
241
- validated_voice = request.validate_voice()
242
- if validated_voice != request.voice:
243
- logger.warning(f"Voice '{request.voice}' not available; using '{validated_voice}' instead")
244
-
245
- # Convert the text reply to audio using the KPipeline.
246
- logger.info(f"Generating audio using voice={validated_voice}, speed={request.speed}")
247
  try:
248
- generator = pipeline(
249
- text_reply,
250
- voice=validated_voice,
251
- speed=request.speed,
252
- split_pattern=r'\n+'
253
- )
254
- for _, _, audio in generator:
255
- audio_numpy = audio.cpu().numpy()
256
- audio_numpy = np.clip(audio_numpy, -1, 1)
257
- pcm_data = (audio_numpy * 32767).astype(np.int16)
258
- raw_audio = pcm_data.tobytes()
259
- return Response(
260
- content=raw_audio,
261
- media_type="application/octet-stream",
262
- headers={
263
- "Content-Disposition": 'attachment; filename="output.pcm"',
264
- "X-Sample-Rate": "24000",
265
- "X-Bits-Per-Sample": "16",
266
- "X-Endianness": "little"
267
- }
268
- )
269
- raise HTTPException(status_code=400, detail="No audio segments generated.")
270
  except Exception as e:
271
- logger.error(f"Error generating audio: {str(e)}")
272
  raise HTTPException(status_code=500, detail=str(e))
273
 
274
  @app.get("/")
275
  async def root():
276
- return {"message": "Welcome to the Text-to-Speech API with Vision Support. Use POST /generate with text and image_base64."}
277
 
278
  @app.exception_handler(404)
279
- async def not_found_handler(request: Request, exc):
280
- return JSONResponse(status_code=404, content={"error": "Endpoint not found."})
 
 
 
281
 
282
  @app.exception_handler(405)
283
- async def method_not_allowed_handler(request: Request, exc):
284
- return JSONResponse(status_code=405, content={"error": "Method not allowed."})
 
 
 
 
 
95
  # return Response("No audio generated", status_code=400)
96
 
97
  import os
 
 
98
  import logging
99
+ import base64
100
+ from typing import Optional
101
+ from fastapi import FastAPI, HTTPException
102
  from fastapi.responses import JSONResponse
 
103
  from pydantic import BaseModel
 
104
  from huggingface_hub import InferenceClient
105
+ from requests.exceptions import HTTPError
106
+ import uuid
 
107
 
108
  # Set up logging
109
  logging.basicConfig(level=logging.INFO)
110
  logger = logging.getLogger(__name__)
111
 
112
+ # Initialize FastAPI app
113
  app = FastAPI(
114
+ title="LLM Chat API",
115
+ description="API for getting chat responses from Llama model (supports text and image input)",
116
  version="1.0.0"
117
  )
118
 
119
+ # Directory to save images
120
  STATIC_DIR = "static_images"
121
  if not os.path.exists(STATIC_DIR):
122
  os.makedirs(STATIC_DIR)
 
 
 
 
 
 
 
 
 
 
 
123
 
124
+ # Pydantic models
125
+ class ChatRequest(BaseModel):
126
+ text: str
127
+ image_url: Optional[str] = None # In this updated version, this field is expected to be a base64 encoded image
128
 
129
+ class ChatResponse(BaseModel):
130
+ response: str
131
+ status: str
 
132
 
133
+ def llm_chat_response(text: str, image_base64: Optional[str] = None) -> str:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
134
  try:
135
+ HF_TOKEN = os.getenv("HF_TOKEN")
136
+ logger.info("Checking HF_TOKEN...")
137
+ if not HF_TOKEN:
138
+ logger.error("HF_TOKEN not found in environment variables")
139
+ raise HTTPException(status_code=500, detail="HF_TOKEN not configured")
140
+
141
+ logger.info("Initializing InferenceClient...")
142
+ client = InferenceClient(
143
+ provider="hf-inference", # Updated provider
144
+ api_key=HF_TOKEN
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
145
  )
146
+
147
+ # Build the messages payload.
148
+ # For text-only queries, append a default instruction.
149
+ message_content = [{
150
+ "type": "text",
151
+ "text": text + ("" if image_base64 else " describe in one line only")
152
+ }]
153
+
154
+ if image_base64:
155
+ logger.info("Saving base64 encoded image to file...")
156
+ # Decode and save the image locally
157
+ filename = f"{uuid.uuid4()}.jpg"
158
+ image_path = os.path.join(STATIC_DIR, filename)
159
+ try:
160
+ image_data = base64.b64decode(image_base64)
161
+ except Exception as e:
162
+ logger.error(f"Error decoding image: {str(e)}")
163
+ raise HTTPException(status_code=400, detail="Invalid base64 image data")
164
+ with open(image_path, "wb") as f:
165
+ f.write(image_data)
166
+
167
+ # Construct public URL for the saved image.
168
+ # Set BASE_URL to your public URL if needed.
169
+ base_url = os.getenv("BASE_URL", "http://localhost:8000")
170
+ public_image_url = f"{base_url}/{STATIC_DIR}/{filename}"
171
+ logger.info(f"Using saved image URL: {public_image_url}")
172
+
173
+ message_content.append({
174
+ "type": "image_url",
175
+ "image_url": {"url": public_image_url}
176
+ })
177
+
178
+ messages = [{
179
+ "role": "user",
180
+ "content": message_content
181
+ }]
182
+
183
+ logger.info("Sending request to model...")
184
+ try:
185
+ completion = client.chat.completions.create(
186
+ model="meta-llama/Llama-3.2-11B-Vision-Instruct",
187
+ messages=messages,
188
+ max_tokens=500
189
+ )
190
+ except HTTPError as http_err:
191
+ logger.error(f"HTTP error occurred: {http_err.response.text}")
192
+ raise HTTPException(status_code=500, detail=http_err.response.text)
193
+
194
+ logger.info(f"Raw model response: {completion}")
195
+
196
+ if getattr(completion, "error", None):
197
+ error_details = completion.error
198
+ error_message = error_details.get("message", "Unknown error")
199
+ logger.error(f"Model returned error: {error_message}")
200
+ raise HTTPException(status_code=500, detail=f"Model returned error: {error_message}")
201
+
202
+ if not completion.choices or len(completion.choices) == 0:
203
+ logger.error("No choices returned from model.")
204
+ raise HTTPException(status_code=500, detail="Model returned no choices.")
205
+
206
+ # Extract the response message from the first choice.
207
+ choice = completion.choices[0]
208
+ response_message = None
209
+ if hasattr(choice, "message"):
210
+ response_message = choice.message
211
+ elif isinstance(choice, dict):
212
+ response_message = choice.get("message")
213
+
214
+ if not response_message:
215
+ logger.error(f"Response message is empty: {choice}")
216
+ raise HTTPException(status_code=500, detail="Model response did not include a message.")
217
+
218
+ content = None
219
+ if isinstance(response_message, dict):
220
+ content = response_message.get("content")
221
+ if content is None and hasattr(response_message, "content"):
222
+ content = response_message.content
223
+
224
+ if not content:
225
+ logger.error(f"Message content is missing: {response_message}")
226
+ raise HTTPException(status_code=500, detail="Model message did not include content.")
227
+
228
+ return content
229
+
230
  except Exception as e:
231
+ logger.error(f"Error in llm_chat_response: {str(e)}")
232
  raise HTTPException(status_code=500, detail=str(e))
233
 
234
+ @app.post("/chat", response_model=ChatResponse)
235
+ async def chat(request: ChatRequest):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
236
  try:
237
+ logger.info(f"Received chat request with text: {request.text}")
238
+ if request.image_url:
239
+ logger.info("Image data provided.")
240
+ response = llm_chat_response(request.text, request.image_url)
241
+ return ChatResponse(response=response, status="success")
242
+ except HTTPException as he:
243
+ logger.error(f"HTTP Exception in chat endpoint: {str(he)}")
244
+ raise he
 
 
 
 
 
 
 
 
 
 
 
 
 
 
245
  except Exception as e:
246
+ logger.error(f"Unexpected error in chat endpoint: {str(e)}")
247
  raise HTTPException(status_code=500, detail=str(e))
248
 
249
  @app.get("/")
250
  async def root():
251
+ return {"message": "Welcome to the LLM Chat API. Use POST /chat endpoint with 'text' and optionally 'image_url' (base64 encoded) for queries."}
252
 
253
  @app.exception_handler(404)
254
+ async def not_found_handler(request, exc):
255
+ return JSONResponse(
256
+ status_code=404,
257
+ content={"error": "Endpoint not found. Please use POST /chat for queries."}
258
+ )
259
 
260
  @app.exception_handler(405)
261
+ async def method_not_allowed_handler(request, exc):
262
+ return JSONResponse(
263
+ status_code=405,
264
+ content={"error": "Method not allowed. Please check the API documentation."}
265
+ )
266
+