ndc8 commited on
Commit
4b4e9ed
Β·
1 Parent(s): 4f67c26

Refactor backend service to support Gemma 3n model and update requirements; remove obsolete test script and add new dependency tests

Browse files
Files changed (4) hide show
  1. backend_service.py +115 -48
  2. requirements.txt +9 -1
  3. test_app_structure.py +0 -39
  4. test_deps.py +37 -0
backend_service.py CHANGED
@@ -7,8 +7,8 @@ import httpx
7
  # Hugging Face Spaces: Only transformers backend is supported (no vLLM, no llama-cpp/gguf)
8
 
9
  """
10
- FastAPI Backend AI Service using Gemma-3n-E4B-it-GGUF
11
- Provides OpenAI-compatible chat completion endpoints powered by unsloth/gemma-3n-E4B-it-GGUF
12
  """
13
  import warnings
14
 
@@ -45,6 +45,8 @@ from transformers import AutoTokenizer, AutoModelForCausalLM
45
  # Transformers imports (now fallback for non-GGUF models)
46
  from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM, AutoConfig # type: ignore
47
  from transformers import BitsAndBytesConfig # type: ignore
 
 
48
  import torch
49
  # Configure logging
50
  logging.basicConfig(level=logging.INFO)
@@ -88,7 +90,7 @@ class ChatMessage(BaseModel):
88
  return v
89
 
90
  class ChatCompletionRequest(BaseModel):
91
- model: str = Field(default_factory=lambda: os.environ.get("AI_MODEL", "unsloth/gemma-3n-E4B-it-GGUF"), description="The model to use for completion")
92
  messages: List[ChatMessage] = Field(..., description="List of messages in the conversation")
93
  max_tokens: Optional[int] = Field(default=512, ge=1, le=2048, description="Maximum tokens to generate")
94
  temperature: Optional[float] = Field(default=0.7, ge=0.0, le=2.0, description="Sampling temperature")
@@ -137,11 +139,11 @@ class CompletionRequest(BaseModel):
137
 
138
 
139
  # Model can be configured via environment variable - defaults to Gemma 3n (transformers format)
140
- current_model = os.environ.get("AI_MODEL", "unsloth/gemma-3n-E4B-it-GGUF")
141
  vision_model = os.environ.get("VISION_MODEL", "Salesforce/blip-image-captioning-base")
142
 
143
  # Transformers model support
144
- tokenizer = None
145
  model = None
146
  image_text_pipeline = None # type: ignore
147
 
@@ -190,39 +192,58 @@ def has_images(messages: List[ChatMessage]) -> bool:
190
  @asynccontextmanager
191
  async def lifespan(app: FastAPI):
192
  """Application lifespan manager for startup and shutdown events"""
193
- global tokenizer, model, image_text_pipeline, current_model
194
  logger.info("πŸš€ Starting AI Backend Service (Hugging Face Spaces mode)...")
195
  try:
196
  logger.info(f"πŸ“₯ Loading model with transformers: {current_model}")
197
- tokenizer = AutoTokenizer.from_pretrained(current_model)
198
- # Hugging Face Spaces: Remove device_map and torch_dtype for CPU compatibility
199
- model = AutoModelForCausalLM.from_pretrained(
200
- current_model,
201
- low_cpu_mem_usage=True,
202
- trust_remote_code=True,
203
- )
204
- logger.info(f"βœ… Successfully loaded model and tokenizer: {current_model}")
205
- # Load image pipeline for multimodal support
206
- try:
207
- logger.info(f"πŸ–ΌοΈ Initializing image captioning pipeline with model: {vision_model}")
208
- image_text_pipeline = pipeline("image-to-text", model=vision_model)
209
- logger.info("βœ… Image captioning pipeline loaded successfully")
210
- except Exception as e:
211
- logger.warning(f"⚠️ Could not load image captioning pipeline: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
212
  image_text_pipeline = None
 
213
  except Exception as e:
214
  logger.error(f"❌ Failed to initialize model: {e}")
215
  raise RuntimeError(f"Service initialization failed: {e}")
216
  yield
217
  logger.info("πŸ”„ Shutting down AI Backend Service...")
218
- tokenizer = None
219
  model = None
220
  image_text_pipeline = None
221
 
222
  # Initialize FastAPI app
223
  app = FastAPI(
224
- title="AI Backend Service - Mistral Nemo",
225
- description="OpenAI-compatible chat completion API powered by unsloth/Mistral-Nemo-Instruct-2407",
226
  version="1.0.0",
227
  lifespan=lifespan
228
  )
@@ -239,7 +260,7 @@ app.add_middleware(
239
 
240
  def ensure_model_ready():
241
  """Check if transformers model is loaded and ready"""
242
- if tokenizer is None or model is None:
243
  raise HTTPException(status_code=503, detail="Service not ready - no model initialized (transformers)")
244
 
245
  def convert_messages_to_prompt(messages: List[ChatMessage]) -> str:
@@ -367,29 +388,75 @@ def convert_messages_to_gemma_prompt(messages: List[ChatMessage]) -> str:
367
  def generate_response_transformers(messages: List[ChatMessage], max_tokens: int = 512, temperature: float = 0.7, top_p: float = 0.95) -> str:
368
  """Generate response using transformers model with chat template."""
369
  try:
370
- # Convert messages to HuggingFace format for chat template
371
- chat_messages = []
372
- for m in messages:
373
- content_str = m.content if isinstance(m.content, str) else extract_text_and_images(m.content)[0]
374
- chat_messages.append({"role": m.role, "content": content_str})
375
-
376
- # Apply chat template and tokenize for Hugging Face Spaces CPU
377
- inputs = tokenizer.apply_chat_template(
378
- chat_messages,
379
- add_generation_prompt=True,
380
- tokenize=True,
381
- return_dict=True,
382
- return_tensors="pt",
383
- )
384
- # Pass input_ids and attention_mask directly (no .to(model.device))
385
- outputs = model.generate(
386
- input_ids=inputs["input_ids"],
387
- attention_mask=inputs.get("attention_mask"),
388
- max_new_tokens=max_tokens
389
- )
390
- # Decode only the newly generated tokens (exclude input)
391
- generated_text = tokenizer.decode(outputs[0][inputs["input_ids"].shape[-1]:], skip_special_tokens=True)
392
- return generated_text.strip()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
393
 
394
  except Exception as e:
395
  logger.error(f"Transformers generation failed: {e}")
 
7
  # Hugging Face Spaces: Only transformers backend is supported (no vLLM, no llama-cpp/gguf)
8
 
9
  """
10
+ FastAPI Backend AI Service using Gemma-3n-E4B-it
11
+ Provides OpenAI-compatible chat completion endpoints powered by google/gemma-3n-E4B-it
12
  """
13
  import warnings
14
 
 
45
  # Transformers imports (now fallback for non-GGUF models)
46
  from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM, AutoConfig # type: ignore
47
  from transformers import BitsAndBytesConfig # type: ignore
48
+ # Gemma 3n specific imports
49
+ from transformers import Gemma3nForConditionalGeneration, AutoProcessor # type: ignore
50
  import torch
51
  # Configure logging
52
  logging.basicConfig(level=logging.INFO)
 
90
  return v
91
 
92
  class ChatCompletionRequest(BaseModel):
93
+ model: str = Field(default_factory=lambda: os.environ.get("AI_MODEL", "google/gemma-3n-E4B-it"), description="The model to use for completion")
94
  messages: List[ChatMessage] = Field(..., description="List of messages in the conversation")
95
  max_tokens: Optional[int] = Field(default=512, ge=1, le=2048, description="Maximum tokens to generate")
96
  temperature: Optional[float] = Field(default=0.7, ge=0.0, le=2.0, description="Sampling temperature")
 
139
 
140
 
141
  # Model can be configured via environment variable - defaults to Gemma 3n (transformers format)
142
+ current_model = os.environ.get("AI_MODEL", "google/gemma-3n-E4B-it")
143
  vision_model = os.environ.get("VISION_MODEL", "Salesforce/blip-image-captioning-base")
144
 
145
  # Transformers model support
146
+ processor = None # For Gemma 3n we use AutoProcessor instead of just tokenizer
147
  model = None
148
  image_text_pipeline = None # type: ignore
149
 
 
192
  @asynccontextmanager
193
  async def lifespan(app: FastAPI):
194
  """Application lifespan manager for startup and shutdown events"""
195
+ global processor, model, image_text_pipeline, current_model
196
  logger.info("πŸš€ Starting AI Backend Service (Hugging Face Spaces mode)...")
197
  try:
198
  logger.info(f"πŸ“₯ Loading model with transformers: {current_model}")
199
+
200
+ # For Gemma 3n models, use the specific classes
201
+ if "gemma-3n" in current_model.lower():
202
+ processor = AutoProcessor.from_pretrained(current_model)
203
+ model = Gemma3nForConditionalGeneration.from_pretrained(
204
+ current_model,
205
+ low_cpu_mem_usage=True,
206
+ trust_remote_code=True,
207
+ torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
208
+ ).eval()
209
+ else:
210
+ # Fallback for other models
211
+ processor = AutoTokenizer.from_pretrained(current_model)
212
+ model = AutoModelForCausalLM.from_pretrained(
213
+ current_model,
214
+ low_cpu_mem_usage=True,
215
+ trust_remote_code=True,
216
+ )
217
+
218
+ logger.info(f"βœ… Successfully loaded model and processor: {current_model}")
219
+
220
+ # Gemma 3n is multimodal, so we don't need a separate image pipeline
221
+ if "gemma-3n" not in current_model.lower():
222
+ # Load image pipeline for multimodal support (only for non-Gemma-3n models)
223
+ try:
224
+ logger.info(f"πŸ–ΌοΈ Initializing image captioning pipeline with model: {vision_model}")
225
+ image_text_pipeline = pipeline("image-to-text", model=vision_model)
226
+ logger.info("βœ… Image captioning pipeline loaded successfully")
227
+ except Exception as e:
228
+ logger.warning(f"⚠️ Could not load image captioning pipeline: {e}")
229
+ image_text_pipeline = None
230
+ else:
231
+ logger.info("βœ… Gemma 3n has built-in multimodal support")
232
  image_text_pipeline = None
233
+
234
  except Exception as e:
235
  logger.error(f"❌ Failed to initialize model: {e}")
236
  raise RuntimeError(f"Service initialization failed: {e}")
237
  yield
238
  logger.info("πŸ”„ Shutting down AI Backend Service...")
239
+ processor = None
240
  model = None
241
  image_text_pipeline = None
242
 
243
  # Initialize FastAPI app
244
  app = FastAPI(
245
+ title="AI Backend Service - Gemma 3n",
246
+ description="OpenAI-compatible chat completion API powered by google/gemma-3n-E4B-it",
247
  version="1.0.0",
248
  lifespan=lifespan
249
  )
 
260
 
261
  def ensure_model_ready():
262
  """Check if transformers model is loaded and ready"""
263
+ if processor is None or model is None:
264
  raise HTTPException(status_code=503, detail="Service not ready - no model initialized (transformers)")
265
 
266
  def convert_messages_to_prompt(messages: List[ChatMessage]) -> str:
 
388
  def generate_response_transformers(messages: List[ChatMessage], max_tokens: int = 512, temperature: float = 0.7, top_p: float = 0.95) -> str:
389
  """Generate response using transformers model with chat template."""
390
  try:
391
+ # Check if we're using Gemma 3n
392
+ if "gemma-3n" in current_model.lower():
393
+ # Gemma 3n specific handling
394
+ # Convert messages to HuggingFace format for chat template
395
+ chat_messages = []
396
+ for m in messages:
397
+ # Gemma 3n supports multimodal, but for now we'll handle text only
398
+ if isinstance(m.content, str):
399
+ content = [{"type": "text", "text": m.content}]
400
+ else:
401
+ # Extract text content for now (image support can be added later)
402
+ text_content, _ = extract_text_and_images(m.content)
403
+ content = [{"type": "text", "text": text_content}]
404
+
405
+ chat_messages.append({"role": m.role, "content": content})
406
+
407
+ # Apply chat template using processor
408
+ inputs = processor.apply_chat_template(
409
+ chat_messages,
410
+ add_generation_prompt=True,
411
+ tokenize=True,
412
+ return_dict=True,
413
+ return_tensors="pt",
414
+ )
415
+
416
+ # Generate with Gemma 3n
417
+ input_len = inputs["input_ids"].shape[-1]
418
+ with torch.inference_mode():
419
+ generation = model.generate(
420
+ **inputs,
421
+ max_new_tokens=max_tokens,
422
+ temperature=temperature,
423
+ top_p=top_p,
424
+ do_sample=temperature > 0,
425
+ )
426
+ generation = generation[0][input_len:]
427
+
428
+ # Decode the response
429
+ generated_text = processor.decode(generation, skip_special_tokens=True)
430
+ return generated_text.strip()
431
+
432
+ else:
433
+ # Fallback for other models
434
+ # Convert messages to HuggingFace format for chat template
435
+ chat_messages = []
436
+ for m in messages:
437
+ content_str = m.content if isinstance(m.content, str) else extract_text_and_images(m.content)[0]
438
+ chat_messages.append({"role": m.role, "content": content_str})
439
+
440
+ # Apply chat template and tokenize
441
+ inputs = processor.apply_chat_template(
442
+ chat_messages,
443
+ add_generation_prompt=True,
444
+ tokenize=True,
445
+ return_dict=True,
446
+ return_tensors="pt",
447
+ )
448
+ # Generate response
449
+ outputs = model.generate(
450
+ input_ids=inputs["input_ids"],
451
+ attention_mask=inputs.get("attention_mask"),
452
+ max_new_tokens=max_tokens,
453
+ temperature=temperature,
454
+ top_p=top_p,
455
+ do_sample=temperature > 0,
456
+ )
457
+ # Decode only the newly generated tokens (exclude input)
458
+ generated_text = processor.decode(outputs[0][inputs["input_ids"].shape[-1]:], skip_special_tokens=True)
459
+ return generated_text.strip()
460
 
461
  except Exception as e:
462
  logger.error(f"Transformers generation failed: {e}")
requirements.txt CHANGED
@@ -3,11 +3,19 @@
3
  # Hugging Face Spaces requirements (transformers backend only)
4
  fastapi
5
  uvicorn
6
- transformers
7
  torch
8
  python-dotenv
9
  httpx
10
  requests
11
  Pillow
 
 
 
 
 
 
 
 
12
  # Optional: gradio for demo UI
13
  # gradio
 
3
  # Hugging Face Spaces requirements (transformers backend only)
4
  fastapi
5
  uvicorn
6
+ transformers>=4.53.0
7
  torch
8
  python-dotenv
9
  httpx
10
  requests
11
  Pillow
12
+
13
+ # Required dependencies for Gemma models
14
+ protobuf
15
+ tiktoken
16
+ sentencepiece>=0.2.0
17
+ tokenizers
18
+ regex
19
+
20
  # Optional: gradio for demo UI
21
  # gradio
test_app_structure.py DELETED
@@ -1,39 +0,0 @@
1
- #!/usr/bin/env python3
2
- """
3
- Test script to verify the FastAPI app can be imported and started
4
- """
5
-
6
- import sys
7
- import os
8
-
9
- # Add current directory to path
10
- sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
11
-
12
- try:
13
- # Test imports
14
- print("Testing imports...")
15
- from backend_service import app
16
- print("βœ… Successfully imported FastAPI app from backend_service")
17
-
18
- # Test app type
19
- from fastapi import FastAPI
20
- if isinstance(app, FastAPI):
21
- print("βœ… App is a valid FastAPI instance")
22
- else:
23
- print("❌ App is not a FastAPI instance")
24
- sys.exit(1)
25
-
26
- # Test app attributes
27
- print(f"βœ… App title: {app.title}")
28
- print(f"βœ… App version: {app.version}")
29
-
30
- print("\nπŸŽ‰ All tests passed! The app is ready for Hugging Face Spaces")
31
-
32
- except ImportError as e:
33
- print(f"❌ Import error: {e}")
34
- print("This is expected if you don't have all dependencies installed locally.")
35
- print("The Hugging Face Space will install them from requirements.txt")
36
-
37
- except Exception as e:
38
- print(f"❌ Unexpected error: {e}")
39
- sys.exit(1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
test_deps.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Test script to verify the transformers dependencies are working
4
+ """
5
+
6
+ def test_imports():
7
+ """Test that all required transformers imports work"""
8
+ try:
9
+ print("Testing transformers imports...")
10
+
11
+ from transformers import AutoProcessor, Gemma3nForConditionalGeneration
12
+ print("βœ… Gemma3nForConditionalGeneration import successful")
13
+
14
+ from transformers import AutoTokenizer, AutoModelForCausalLM
15
+ print("βœ… Standard transformers imports successful")
16
+
17
+ import torch
18
+ print("βœ… PyTorch import successful")
19
+
20
+ import sentencepiece
21
+ print("βœ… SentencePiece import successful")
22
+
23
+ import tiktoken
24
+ print("βœ… TikToken import successful")
25
+
26
+ import protobuf
27
+ print("βœ… Protobuf import successful")
28
+
29
+ print("\nπŸŽ‰ All imports successful! Ready for Hugging Face Spaces deployment")
30
+ return True
31
+
32
+ except ImportError as e:
33
+ print(f"❌ Import error: {e}")
34
+ return False
35
+
36
+ if __name__ == "__main__":
37
+ test_imports()