ndc8 commited on
Commit
8d9c495
·
1 Parent(s): 3239c69
Files changed (1) hide show
  1. backend_service.py +52 -93
backend_service.py CHANGED
@@ -15,7 +15,6 @@ warnings.filterwarnings("ignore", message=".*rope_scaling.*")
15
  os.environ.setdefault("HF_HOME", "/tmp/.cache/huggingface")
16
  # Suppress advisory warnings from transformers (including deprecation warnings)
17
  os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "1"
18
- # Define Hugging Face auth token from environment
19
  hf_token = os.environ.get("HF_TOKEN")
20
  import asyncio
21
  import logging
@@ -28,10 +27,11 @@ from fastapi import FastAPI, HTTPException, Depends, Request
28
  from fastapi.responses import StreamingResponse, JSONResponse
29
  from fastapi.middleware.cors import CORSMiddleware
30
  from pydantic import BaseModel, Field, field_validator
31
- from huggingface_hub import InferenceClient
32
  import uvicorn
33
  import requests
34
  from PIL import Image
 
35
 
36
  # Transformers imports (now required)
37
  try:
@@ -128,12 +128,13 @@ class CompletionRequest(BaseModel):
128
  max_tokens: Optional[int] = Field(default=512, ge=1, le=2048)
129
  temperature: Optional[float] = Field(default=0.7, ge=0.0, le=2.0)
130
 
 
131
  # Global variables for model management
132
- inference_client: Optional[InferenceClient] = None
133
- image_text_pipeline = None # type: ignore
134
  current_model = "unsloth/DeepSeek-R1-0528-Qwen3-8B-GGUF"
135
  vision_model = "Salesforce/blip-image-captioning-base" # Working model for image captioning
136
  tokenizer = None
 
 
137
 
138
  # Image processing utilities
139
  async def download_image(url: str) -> Image.Image:
@@ -173,23 +174,22 @@ def has_images(messages: List[ChatMessage]) -> bool:
173
  return True
174
  return False
175
 
 
176
  @asynccontextmanager
177
  async def lifespan(app: FastAPI):
178
  """Application lifespan manager for startup and shutdown events"""
179
- global inference_client, tokenizer, image_text_pipeline
180
-
181
- # Startup
182
  logger.info("🚀 Starting AI Backend Service...")
183
  try:
184
- # Initialize HuggingFace Inference Client for text generation
185
- inference_client = InferenceClient(model=current_model)
186
- logger.info(f"✅ Initialized inference client with model: {current_model}")
187
-
188
- # Initialize image-text-to-text pipeline
189
  if transformers_available and pipeline:
190
  try:
191
  logger.info(f"🖼️ Initializing image captioning pipeline with model: {vision_model}")
192
- image_text_pipeline = pipeline("image-to-text", model=vision_model) # Use image-to-text task
193
  logger.info("✅ Image captioning pipeline loaded successfully")
194
  except Exception as e:
195
  logger.warning(f"⚠️ Could not load image captioning pipeline: {e}")
@@ -197,37 +197,13 @@ async def lifespan(app: FastAPI):
197
  else:
198
  logger.warning("⚠️ Transformers not available, image processing disabled")
199
  image_text_pipeline = None
200
-
201
- # Initialize tokenizer for better text handling
202
- if transformers_available and AutoTokenizer:
203
- try:
204
- # Load tokenizer, using auth token if provided
205
- if hf_token:
206
- tokenizer = AutoTokenizer.from_pretrained(
207
- current_model,
208
- token=hf_token
209
- ) # type: ignore
210
- else:
211
- tokenizer = AutoTokenizer.from_pretrained(
212
- current_model
213
- ) # type: ignore
214
- logger.info("✅ Tokenizer loaded successfully")
215
- except Exception as e:
216
- logger.warning(f"⚠️ Could not load tokenizer: {e}")
217
- tokenizer = None
218
- else:
219
- logger.info("⚠️ Tokenizer initialization skipped")
220
-
221
  except Exception as e:
222
- logger.error(f"❌ Failed to initialize inference client: {e}")
223
  raise RuntimeError(f"Service initialization failed: {e}")
224
-
225
  yield
226
-
227
- # Shutdown
228
  logger.info("🔄 Shutting down AI Backend Service...")
229
- inference_client = None
230
  tokenizer = None
 
231
  image_text_pipeline = None
232
 
233
  # Initialize FastAPI app
@@ -247,11 +223,10 @@ app.add_middleware(
247
  allow_headers=["*"],
248
  )
249
 
250
- def get_inference_client() -> InferenceClient:
251
- """Dependency to get the inference client"""
252
- if inference_client is None:
253
- raise HTTPException(status_code=503, detail="Service not ready - inference client not initialized")
254
- return inference_client
255
 
256
  def convert_messages_to_prompt(messages: List[ChatMessage]) -> str:
257
  """Convert OpenAI messages format to a single prompt string"""
@@ -341,36 +316,30 @@ async def generate_multimodal_response(
341
  logger.error(f"Error in multimodal generation: {e}")
342
  return f"I'm having trouble processing the image. Error: {str(e)}"
343
 
344
- def generate_response_safe(client: InferenceClient, prompt: str, max_tokens: int, temperature: float, top_p: float) -> str:
345
- """Safely generate response from the model with fallback methods"""
 
 
346
  try:
347
- # Method 1: Try text_generation with new parameters
348
- response_text = client.text_generation(
349
- prompt=prompt,
350
- max_new_tokens=max_tokens,
351
- temperature=temperature,
352
- top_p=top_p,
353
- return_full_text=False,
354
- stop=["Human:", "System:"] # Use stop instead of stop_sequences
 
 
355
  )
356
- return response_text.strip() if response_text else "I apologize, but I couldn't generate a response."
357
-
 
 
 
358
  except Exception as e:
359
- logger.warning(f"text_generation failed: {e}")
360
-
361
- # Method 2: Try with minimal parameters
362
- try:
363
- response_text = client.text_generation(
364
- prompt=prompt,
365
- max_new_tokens=max_tokens,
366
- temperature=temperature,
367
- return_full_text=False
368
- )
369
- return response_text.strip() if response_text else "I apologize, but I couldn't generate a response."
370
-
371
- except Exception as e2:
372
- logger.error(f"All generation methods failed: {e2}")
373
- return "I apologize, but I'm having trouble generating a response right now. Please try again."
374
 
375
  async def generate_streaming_response(
376
  client: InferenceClient,
@@ -491,10 +460,10 @@ async def list_models():
491
 
492
  # ...existing code...
493
 
 
494
  @app.post("/v1/chat/completions", response_model=ChatCompletionResponse)
495
  async def create_chat_completion(
496
- request: ChatCompletionRequest,
497
- client: InferenceClient = Depends(get_inference_client)
498
  ) -> ChatCompletionResponse:
499
  """Create a chat completion (OpenAI-compatible) with multimodal support."""
500
  try:
@@ -506,22 +475,10 @@ async def create_chat_completion(
506
  raise HTTPException(status_code=503, detail="Image processing not available")
507
  response_text = await generate_multimodal_response(request.messages, request)
508
  else:
509
- prompt = convert_messages_to_prompt(request.messages)
510
- logger.info(f"Generated prompt: {prompt[:200]}...")
511
- if request.stream:
512
- return StreamingResponse(
513
- generate_streaming_response(client, prompt, request),
514
- media_type="text/plain",
515
- headers={
516
- "Cache-Control": "no-cache",
517
- "Connection": "keep-alive",
518
- "Content-Type": "text/plain; charset=utf-8"
519
- }
520
- ) # type: ignore
521
  response_text = await asyncio.to_thread(
522
- generate_response_safe,
523
- client,
524
- prompt,
525
  request.max_tokens or 512,
526
  request.temperature or 0.7,
527
  request.top_p or 0.95
@@ -542,19 +499,21 @@ async def create_chat_completion(
542
  raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
543
 
544
 
 
545
  @app.post("/v1/completions")
546
  async def create_completion(
547
- request: CompletionRequest,
548
- client: InferenceClient = Depends(get_inference_client)
549
  ) -> Dict[str, Any]:
550
  """Create a text completion (OpenAI-compatible)"""
551
  try:
552
  if not request.prompt:
553
  raise HTTPException(status_code=400, detail="Prompt cannot be empty")
 
 
 
554
  response_text = await asyncio.to_thread(
555
- generate_response_safe,
556
- client,
557
- request.prompt,
558
  request.max_tokens or 512,
559
  request.temperature or 0.7,
560
  0.95
 
15
  os.environ.setdefault("HF_HOME", "/tmp/.cache/huggingface")
16
  # Suppress advisory warnings from transformers (including deprecation warnings)
17
  os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "1"
 
18
  hf_token = os.environ.get("HF_TOKEN")
19
  import asyncio
20
  import logging
 
27
  from fastapi.responses import StreamingResponse, JSONResponse
28
  from fastapi.middleware.cors import CORSMiddleware
29
  from pydantic import BaseModel, Field, field_validator
30
+
31
  import uvicorn
32
  import requests
33
  from PIL import Image
34
+ from transformers import AutoTokenizer, AutoModelForCausalLM
35
 
36
  # Transformers imports (now required)
37
  try:
 
128
  max_tokens: Optional[int] = Field(default=512, ge=1, le=2048)
129
  temperature: Optional[float] = Field(default=0.7, ge=0.0, le=2.0)
130
 
131
+
132
  # Global variables for model management
 
 
133
  current_model = "unsloth/DeepSeek-R1-0528-Qwen3-8B-GGUF"
134
  vision_model = "Salesforce/blip-image-captioning-base" # Working model for image captioning
135
  tokenizer = None
136
+ model = None
137
+ image_text_pipeline = None # type: ignore
138
 
139
  # Image processing utilities
140
  async def download_image(url: str) -> Image.Image:
 
174
  return True
175
  return False
176
 
177
+
178
  @asynccontextmanager
179
  async def lifespan(app: FastAPI):
180
  """Application lifespan manager for startup and shutdown events"""
181
+ global tokenizer, model, image_text_pipeline
 
 
182
  logger.info("🚀 Starting AI Backend Service...")
183
  try:
184
+ # Load local tokenizer and model
185
+ tokenizer = AutoTokenizer.from_pretrained(current_model)
186
+ model = AutoModelForCausalLM.from_pretrained(current_model)
187
+ logger.info(f"✅ Loaded local model and tokenizer: {current_model}")
188
+ # Optionally, load image pipeline as before
189
  if transformers_available and pipeline:
190
  try:
191
  logger.info(f"🖼️ Initializing image captioning pipeline with model: {vision_model}")
192
+ image_text_pipeline = pipeline("image-to-text", model=vision_model)
193
  logger.info("✅ Image captioning pipeline loaded successfully")
194
  except Exception as e:
195
  logger.warning(f"⚠️ Could not load image captioning pipeline: {e}")
 
197
  else:
198
  logger.warning("⚠️ Transformers not available, image processing disabled")
199
  image_text_pipeline = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
200
  except Exception as e:
201
+ logger.error(f"❌ Failed to initialize local model: {e}")
202
  raise RuntimeError(f"Service initialization failed: {e}")
 
203
  yield
 
 
204
  logger.info("🔄 Shutting down AI Backend Service...")
 
205
  tokenizer = None
206
+ model = None
207
  image_text_pipeline = None
208
 
209
  # Initialize FastAPI app
 
223
  allow_headers=["*"],
224
  )
225
 
226
+
227
+ def ensure_model_ready():
228
+ if tokenizer is None or model is None:
229
+ raise HTTPException(status_code=503, detail="Service not ready - model not initialized")
 
230
 
231
  def convert_messages_to_prompt(messages: List[ChatMessage]) -> str:
232
  """Convert OpenAI messages format to a single prompt string"""
 
316
  logger.error(f"Error in multimodal generation: {e}")
317
  return f"I'm having trouble processing the image. Error: {str(e)}"
318
 
319
+
320
+ def generate_response_local(messages: List[ChatMessage], max_tokens: int = 512, temperature: float = 0.7, top_p: float = 0.95) -> str:
321
+ """Generate response using local model and tokenizer with chat template."""
322
+ ensure_model_ready()
323
  try:
324
+ # Convert messages to OpenAI format for chat template
325
+ chat_messages = []
326
+ for m in messages:
327
+ chat_messages.append({"role": m.role, "content": m.content if isinstance(m.content, str) else extract_text_and_images(m.content)[0]})
328
+ inputs = tokenizer.apply_chat_template(
329
+ chat_messages,
330
+ add_generation_prompt=True,
331
+ tokenize=True,
332
+ return_dict=True,
333
+ return_tensors="pt",
334
  )
335
+ inputs = inputs.to(model.device)
336
+ outputs = model.generate(**inputs, max_new_tokens=max_tokens, do_sample=True, temperature=temperature, top_p=top_p)
337
+ # Only decode the newly generated tokens
338
+ generated = tokenizer.decode(outputs[0][inputs["input_ids"].shape[-1]:], skip_special_tokens=True)
339
+ return generated.strip()
340
  except Exception as e:
341
+ logger.error(f"Local generation failed: {e}")
342
+ return "I apologize, but I'm having trouble generating a response right now. Please try again."
 
 
 
 
 
 
 
 
 
 
 
 
 
343
 
344
  async def generate_streaming_response(
345
  client: InferenceClient,
 
460
 
461
  # ...existing code...
462
 
463
+
464
  @app.post("/v1/chat/completions", response_model=ChatCompletionResponse)
465
  async def create_chat_completion(
466
+ request: ChatCompletionRequest
 
467
  ) -> ChatCompletionResponse:
468
  """Create a chat completion (OpenAI-compatible) with multimodal support."""
469
  try:
 
475
  raise HTTPException(status_code=503, detail="Image processing not available")
476
  response_text = await generate_multimodal_response(request.messages, request)
477
  else:
478
+ logger.info(f"Generating local response for messages: {request.messages}")
 
 
 
 
 
 
 
 
 
 
 
479
  response_text = await asyncio.to_thread(
480
+ generate_response_local,
481
+ request.messages,
 
482
  request.max_tokens or 512,
483
  request.temperature or 0.7,
484
  request.top_p or 0.95
 
499
  raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
500
 
501
 
502
+
503
  @app.post("/v1/completions")
504
  async def create_completion(
505
+ request: CompletionRequest
 
506
  ) -> Dict[str, Any]:
507
  """Create a text completion (OpenAI-compatible)"""
508
  try:
509
  if not request.prompt:
510
  raise HTTPException(status_code=400, detail="Prompt cannot be empty")
511
+ ensure_model_ready()
512
+ # Use the prompt as a single user message
513
+ messages = [ChatMessage(role="user", content=request.prompt)]
514
  response_text = await asyncio.to_thread(
515
+ generate_response_local,
516
+ messages,
 
517
  request.max_tokens or 512,
518
  request.temperature or 0.7,
519
  0.95