File size: 9,307 Bytes
c034a74
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ff704b5
2bb03a8
63d6fee
 
 
4196bc7
 
 
63d6fee
 
2bb03a8
 
 
 
ff704b5
63d6fee
4196bc7
63d6fee
 
4196bc7
 
 
63d6fee
4196bc7
 
 
60e1507
63d6fee
 
 
 
c034a74
63d6fee
 
 
2bb03a8
63d6fee
a318fb7
63d6fee
 
 
 
 
 
 
 
 
 
60e1507
63d6fee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a318fb7
63d6fee
2bb03a8
908288f
63d6fee
 
2bb03a8
63d6fee
 
 
 
 
 
 
 
2bb03a8
63d6fee
4196bc7
2bb03a8
 
 
63d6fee
2bb03a8
 
63d6fee
 
 
 
 
2bb03a8
 
63d6fee
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
# from fastapi import FastAPI, Response
# from fastapi.responses import FileResponse
# from kokoro import KPipeline
# import soundfile as sf
# import os
# import numpy as np
# import torch 
# from huggingface_hub import InferenceClient

# def llm_chat_response(text):
#     HF_TOKEN = os.getenv("HF_TOKEN")
#     client = InferenceClient(api_key=HF_TOKEN)
#     messages = [
# 	{
# 		"role": "user",
# 		"content": [
# 			{
# 				"type": "text",
# 				"text": text + str('describe in one line only')
# 			} #,
# 			# {
# 			# 	"type": "image_url",
# 			# 	"image_url": {
# 			# 		"url": "https://cdn.britannica.com/61/93061-050-99147DCE/Statue-of-Liberty-Island-New-York-Bay.jpg"
# 			# 	}
# 			# }
#             ]
# 	}
#     ]

#     response_from_llama = client.chat.completions.create(
#     model="meta-llama/Llama-3.2-11B-Vision-Instruct", 
# 	messages=messages, 
# 	max_tokens=500)

#     return response_from_llama.choices[0].message['content']

# app = FastAPI()

# # Initialize pipeline once at startup
# pipeline = KPipeline(lang_code='a')

# @app.post("/generate")
# async def generate_audio(text: str, voice: str = "af_heart", speed: float = 1.0):
    
#     text_reply = llm_chat_response(text)
    
#     # Generate audio
#     generator = pipeline(
#         text_reply, 
#         voice=voice,
#         speed=speed,
#         split_pattern=r'\n+'
#     )
    
#     # # Save first segment only for demo
#     # for i, (gs, ps, audio) in enumerate(generator):
#     #     sf.write(f"output_{i}.wav", audio, 24000)
#     #     return FileResponse(
#     #         f"output_{i}.wav",
#     #         media_type="audio/wav",
#     #         filename="output.wav"
#     #     )
    
#     # return Response("No audio generated", status_code=400)


#     # Process only the first segment for demo
#     for i, (gs, ps, audio) in enumerate(generator):

#         # Convert PyTorch tensor to NumPy array
#         audio_numpy = audio.cpu().numpy()
#         # Convert to 16-bit PCM
        
#         # Ensure the audio is in the range [-1, 1]
#         audio_numpy = np.clip(audio_numpy, -1, 1)
#         # Convert to 16-bit signed integers
#         pcm_data = (audio_numpy * 32767).astype(np.int16)
        
#         # Convert to bytes (automatically uses row-major order)
#         raw_audio = pcm_data.tobytes()
        
#         # Return PCM data with minimal necessary headers
#         return Response(
#             content=raw_audio,
#             media_type="application/octet-stream",
#             headers={
#                 "Content-Disposition": f'attachment; filename="output.pcm"',
#                 "X-Sample-Rate": "24000",
#                 "X-Bits-Per-Sample": "16",
#                 "X-Endianness": "little"
#             }
#         )
    
#     return Response("No audio generated", status_code=400)

import os
import logging
import base64
from typing import Optional
from fastapi import FastAPI, HTTPException
from fastapi.responses import JSONResponse
from pydantic import BaseModel
from huggingface_hub import InferenceClient
from requests.exceptions import HTTPError
import uuid

# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Initialize FastAPI app
app = FastAPI(
    title="LLM Chat API",
    description="API for getting chat responses from Llama model (supports text and image input)",
    version="1.0.0"
)

# Directory to save images
STATIC_DIR = "static_images"
if not os.path.exists(STATIC_DIR):
    os.makedirs(STATIC_DIR)

# Pydantic models
class ChatRequest(BaseModel):
    text: str
    image_url: Optional[str] = None  # In this updated version, this field is expected to be a base64 encoded image

class ChatResponse(BaseModel):
    response: str
    status: str

def llm_chat_response(text: str, image_base64: Optional[str] = None) -> str:
    try:
        HF_TOKEN = os.getenv("HF_TOKEN")
        logger.info("Checking HF_TOKEN...")
        if not HF_TOKEN:
            logger.error("HF_TOKEN not found in environment variables")
            raise HTTPException(status_code=500, detail="HF_TOKEN not configured")
        
        logger.info("Initializing InferenceClient...")
        client = InferenceClient(
            provider="hf-inference",  # Updated provider
            api_key=HF_TOKEN
        )
        
        # Build the messages payload.
        # For text-only queries, append a default instruction.
        message_content = [{
            "type": "text",
            "text": text + ("" if image_base64 else " describe in one line only")
        }]
        
        if image_base64:
            logger.info("Saving base64 encoded image to file...")
            # Decode and save the image locally
            filename = f"{uuid.uuid4()}.jpg"
            image_path = os.path.join(STATIC_DIR, filename)
            try:
                image_data = base64.b64decode(image_base64)
            except Exception as e:
                logger.error(f"Error decoding image: {str(e)}")
                raise HTTPException(status_code=400, detail="Invalid base64 image data")
            with open(image_path, "wb") as f:
                f.write(image_data)
            
            # Construct public URL for the saved image.
            # Set BASE_URL to your public URL if needed.
            base_url = os.getenv("BASE_URL", "http://localhost:8000")
            public_image_url = f"{base_url}/{STATIC_DIR}/{filename}"
            logger.info(f"Using saved image URL: {public_image_url}")
            
            message_content.append({
                "type": "image_url",
                "image_url": {"url": public_image_url}
            })
        
        messages = [{
            "role": "user",
            "content": message_content
        }]
        
        logger.info("Sending request to model...")
        try:
            completion = client.chat.completions.create(
                model="meta-llama/Llama-3.2-11B-Vision-Instruct",
                messages=messages,
                max_tokens=500
            )
        except HTTPError as http_err:
            logger.error(f"HTTP error occurred: {http_err.response.text}")
            raise HTTPException(status_code=500, detail=http_err.response.text)
        
        logger.info(f"Raw model response: {completion}")
        
        if getattr(completion, "error", None):
            error_details = completion.error
            error_message = error_details.get("message", "Unknown error")
            logger.error(f"Model returned error: {error_message}")
            raise HTTPException(status_code=500, detail=f"Model returned error: {error_message}")
        
        if not completion.choices or len(completion.choices) == 0:
            logger.error("No choices returned from model.")
            raise HTTPException(status_code=500, detail="Model returned no choices.")
        
        # Extract the response message from the first choice.
        choice = completion.choices[0]
        response_message = None
        if hasattr(choice, "message"):
            response_message = choice.message
        elif isinstance(choice, dict):
            response_message = choice.get("message")
        
        if not response_message:
            logger.error(f"Response message is empty: {choice}")
            raise HTTPException(status_code=500, detail="Model response did not include a message.")
        
        content = None
        if isinstance(response_message, dict):
            content = response_message.get("content")
        if content is None and hasattr(response_message, "content"):
            content = response_message.content
        
        if not content:
            logger.error(f"Message content is missing: {response_message}")
            raise HTTPException(status_code=500, detail="Model message did not include content.")
        
        return content

    except Exception as e:
        logger.error(f"Error in llm_chat_response: {str(e)}")
        raise HTTPException(status_code=500, detail=str(e))

@app.post("/chat", response_model=ChatResponse)
async def chat(request: ChatRequest):
    try:
        logger.info(f"Received chat request with text: {request.text}")
        if request.image_url:
            logger.info("Image data provided.")
        response = llm_chat_response(request.text, request.image_url)
        return ChatResponse(response=response, status="success")
    except HTTPException as he:
        logger.error(f"HTTP Exception in chat endpoint: {str(he)}")
        raise he
    except Exception as e:
        logger.error(f"Unexpected error in chat endpoint: {str(e)}")
        raise HTTPException(status_code=500, detail=str(e))

@app.get("/")
async def root():
    return {"message": "Welcome to the LLM Chat API. Use POST /chat endpoint with 'text' and optionally 'image_url' (base64 encoded) for queries."}

@app.exception_handler(404)
async def not_found_handler(request, exc):
    return JSONResponse(
        status_code=404,
        content={"error": "Endpoint not found. Please use POST /chat for queries."}
    )

@app.exception_handler(405)
async def method_not_allowed_handler(request, exc):
    return JSONResponse(
        status_code=405,
        content={"error": "Method not allowed. Please check the API documentation."}
    )