ndc8 commited on
Commit
68f41f4
·
1 Parent(s): 83df634
Files changed (1) hide show
  1. backend_service.py +59 -82
backend_service.py CHANGED
@@ -7,6 +7,8 @@ Provides OpenAI-compatible chat completion endpoints
7
  import os
8
  os.environ.setdefault("HF_HOME", "/tmp/.cache/huggingface")
9
  os.environ.setdefault("TRANSFORMERS_CACHE", "/tmp/.cache/huggingface")
 
 
10
  import asyncio
11
  import logging
12
  import time
@@ -191,7 +193,16 @@ async def lifespan(app: FastAPI):
191
  # Initialize tokenizer for better text handling
192
  if transformers_available and AutoTokenizer:
193
  try:
194
- tokenizer = AutoTokenizer.from_pretrained(current_model) # type: ignore
 
 
 
 
 
 
 
 
 
195
  logger.info("✅ Tokenizer loaded successfully")
196
  except Exception as e:
197
  logger.warning(f"⚠️ Could not load tokenizer: {e}")
@@ -469,33 +480,49 @@ async def list_models():
469
 
470
  return ModelsResponse(data=models)
471
 
472
- @app.post("/v1/chat/completions")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
473
  async def create_chat_completion(
474
  request: ChatCompletionRequest,
475
  client: InferenceClient = Depends(get_inference_client)
476
- ):
477
- """Create a chat completion (OpenAI-compatible) with multimodal support"""
478
  try:
479
- # Validate request
480
  if not request.messages:
481
  raise HTTPException(status_code=400, detail="Messages cannot be empty")
482
-
483
- # Check if this is a multimodal request (contains images)
484
  is_multimodal = has_images(request.messages)
485
-
486
  if is_multimodal:
487
- # Handle multimodal request with image-text pipeline
488
  if not image_text_pipeline:
489
  raise HTTPException(status_code=503, detail="Image processing not available")
490
-
491
  response_text = await generate_multimodal_response(request.messages, request)
492
  else:
493
- # Handle text-only request with existing logic
494
  prompt = convert_messages_to_prompt(request.messages)
495
  logger.info(f"Generated prompt: {prompt[:200]}...")
496
-
497
  if request.stream:
498
- # Return streaming response
499
  return StreamingResponse(
500
  generate_streaming_response(client, prompt, request),
501
  media_type="text/plain",
@@ -504,37 +531,26 @@ async def create_chat_completion(
504
  "Connection": "keep-alive",
505
  "Content-Type": "text/plain; charset=utf-8"
506
  }
507
- )
508
- else:
509
- # Generate non-streaming response
510
- response_text = await asyncio.to_thread(
511
- generate_response_safe,
512
- client,
513
- prompt,
514
- request.max_tokens or 512,
515
- request.temperature or 0.7,
516
- request.top_p or 0.95
517
- )
518
-
519
- # Clean up the response
520
  response_text = response_text.strip() if response_text else "No response generated."
521
-
522
- # Create OpenAI-compatible response
523
- response = ChatCompletionResponse(
524
  id=f"chatcmpl-{int(time.time())}",
525
  created=int(time.time()),
526
  model=request.model,
527
- choices=[
528
- ChatCompletionChoice(
529
- index=0,
530
- message=ChatMessage(role="assistant", content=response_text),
531
- finish_reason="stop"
532
- )
533
- ]
534
  )
535
-
536
- return response
537
-
538
  except Exception as e:
539
  logger.error(f"Error in chat completion: {e}")
540
  raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
@@ -548,17 +564,14 @@ async def create_completion(
548
  try:
549
  if not request.prompt:
550
  raise HTTPException(status_code=400, detail="Prompt cannot be empty")
551
-
552
- # Generate response
553
  response_text = await asyncio.to_thread(
554
  generate_response_safe,
555
  client,
556
  request.prompt,
557
  request.max_tokens or 512,
558
  request.temperature or 0.7,
559
- 0.95 # default top_p
560
  )
561
-
562
  return {
563
  "id": f"cmpl-{int(time.time())}",
564
  "object": "text_completion",
@@ -570,57 +583,21 @@ async def create_completion(
570
  "finish_reason": "stop"
571
  }]
572
  }
573
-
574
  except Exception as e:
575
  logger.error(f"Error in completion: {e}")
576
  raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
577
 
578
  @app.post("/api/response")
579
- async def api_response(request: Request):
580
  """Endpoint to receive and send responses via API."""
581
  try:
582
  data = await request.json()
583
  message = data.get("message", "No message provided")
584
- response: dict[str, str] = {
585
  "status": "success",
586
  "received_message": message,
587
  "response_message": f"You sent: {message}"
588
- }
589
- return JSONResponse(content=response)
590
  except Exception as e:
591
  logger.error(f"Error processing API response: {e}")
592
  raise HTTPException(status_code=500, detail="Internal server error")
593
-
594
- @app.exception_handler(Exception)
595
- async def global_exception_handler(request: Any, exc: Exception) -> JSONResponse:
596
- """Global exception handler"""
597
- logger.error(f"Unhandled exception: {exc}")
598
- return JSONResponse(
599
- status_code=500,
600
- content={"detail": f"Internal server error: {str(exc)}"}
601
- )
602
-
603
- if __name__ == "__main__":
604
- import argparse
605
-
606
- parser = argparse.ArgumentParser(description="AI Backend Service")
607
- parser.add_argument("--host", default="0.0.0.0", help="Host to bind to")
608
- parser.add_argument("--port", type=int, default=8000, help="Port to bind to")
609
- parser.add_argument("--model", default=current_model, help="HuggingFace model to use")
610
- parser.add_argument("--reload", action="store_true", help="Enable auto-reload for development")
611
-
612
- args = parser.parse_args()
613
-
614
- if args.model != current_model:
615
- current_model = args.model
616
- logger.info(f"Using model: {current_model}")
617
-
618
- logger.info(f"🚀 Starting AI Backend Service on {args.host}:{args.port}")
619
-
620
- uvicorn.run(
621
- "backend_service:app",
622
- host=args.host,
623
- port=args.port,
624
- reload=args.reload,
625
- log_level="info"
626
- )
 
7
  import os
8
  os.environ.setdefault("HF_HOME", "/tmp/.cache/huggingface")
9
  os.environ.setdefault("TRANSFORMERS_CACHE", "/tmp/.cache/huggingface")
10
+ # Define Hugging Face auth token from environment
11
+ hf_token = os.environ.get("HF_TOKEN")
12
  import asyncio
13
  import logging
14
  import time
 
193
  # Initialize tokenizer for better text handling
194
  if transformers_available and AutoTokenizer:
195
  try:
196
+ # Load tokenizer, using auth token if provided
197
+ if hf_token:
198
+ tokenizer = AutoTokenizer.from_pretrained(
199
+ current_model,
200
+ use_auth_token=hf_token
201
+ ) # type: ignore
202
+ else:
203
+ tokenizer = AutoTokenizer.from_pretrained(
204
+ current_model
205
+ ) # type: ignore
206
  logger.info("✅ Tokenizer loaded successfully")
207
  except Exception as e:
208
  logger.warning(f"⚠️ Could not load tokenizer: {e}")
 
480
 
481
  return ModelsResponse(data=models)
482
 
483
+
484
+ # Clean up the response
485
+ response_text = response_text.strip() if response_text else "No response generated."
486
+
487
+ # Create OpenAI-compatible response
488
+ response = ChatCompletionResponse(
489
+ id=f"chatcmpl-{int(time.time())}",
490
+ created=int(time.time()),
491
+ model=request.model,
492
+ choices=[
493
+ ChatCompletionChoice(
494
+ index=0,
495
+ message=ChatMessage(role="assistant", content=response_text),
496
+ finish_reason="stop"
497
+ )
498
+ ]
499
+ )
500
+
501
+ return response
502
+
503
+ except Exception as e:
504
+ logger.error(f"Error in chat completion: {e}")
505
+ raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
506
+
507
+ @app.post("/api/response")
508
+ @app.post("/v1/chat/completions", response_model=ChatCompletionResponse)
509
  async def create_chat_completion(
510
  request: ChatCompletionRequest,
511
  client: InferenceClient = Depends(get_inference_client)
512
+ ) -> ChatCompletionResponse:
513
+ """Create a chat completion (OpenAI-compatible) with multimodal support."""
514
  try:
 
515
  if not request.messages:
516
  raise HTTPException(status_code=400, detail="Messages cannot be empty")
 
 
517
  is_multimodal = has_images(request.messages)
 
518
  if is_multimodal:
 
519
  if not image_text_pipeline:
520
  raise HTTPException(status_code=503, detail="Image processing not available")
 
521
  response_text = await generate_multimodal_response(request.messages, request)
522
  else:
 
523
  prompt = convert_messages_to_prompt(request.messages)
524
  logger.info(f"Generated prompt: {prompt[:200]}...")
 
525
  if request.stream:
 
526
  return StreamingResponse(
527
  generate_streaming_response(client, prompt, request),
528
  media_type="text/plain",
 
531
  "Connection": "keep-alive",
532
  "Content-Type": "text/plain; charset=utf-8"
533
  }
534
+ ) # type: ignore
535
+ response_text = await asyncio.to_thread(
536
+ generate_response_safe,
537
+ client,
538
+ prompt,
539
+ request.max_tokens or 512,
540
+ request.temperature or 0.7,
541
+ request.top_p or 0.95
542
+ )
 
 
 
 
543
  response_text = response_text.strip() if response_text else "No response generated."
544
+ return ChatCompletionResponse(
 
 
545
  id=f"chatcmpl-{int(time.time())}",
546
  created=int(time.time()),
547
  model=request.model,
548
+ choices=[ChatCompletionChoice(
549
+ index=0,
550
+ message=ChatMessage(role="assistant", content=response_text),
551
+ finish_reason="stop"
552
+ )]
 
 
553
  )
 
 
 
554
  except Exception as e:
555
  logger.error(f"Error in chat completion: {e}")
556
  raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
 
564
  try:
565
  if not request.prompt:
566
  raise HTTPException(status_code=400, detail="Prompt cannot be empty")
 
 
567
  response_text = await asyncio.to_thread(
568
  generate_response_safe,
569
  client,
570
  request.prompt,
571
  request.max_tokens or 512,
572
  request.temperature or 0.7,
573
+ 0.95
574
  )
 
575
  return {
576
  "id": f"cmpl-{int(time.time())}",
577
  "object": "text_completion",
 
583
  "finish_reason": "stop"
584
  }]
585
  }
 
586
  except Exception as e:
587
  logger.error(f"Error in completion: {e}")
588
  raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
589
 
590
  @app.post("/api/response")
591
+ async def api_response(request: Request) -> JSONResponse:
592
  """Endpoint to receive and send responses via API."""
593
  try:
594
  data = await request.json()
595
  message = data.get("message", "No message provided")
596
+ return JSONResponse(content={
597
  "status": "success",
598
  "received_message": message,
599
  "response_message": f"You sent: {message}"
600
+ })
 
601
  except Exception as e:
602
  logger.error(f"Error processing API response: {e}")
603
  raise HTTPException(status_code=500, detail="Internal server error")