AnkitShrestha commited on
Commit
5e9e76c
·
1 Parent(s): bb18f6f

Add batching using vllm

Browse files
Files changed (2) hide show
  1. main.py +44 -0
  2. requirements.txt +2 -1
main.py CHANGED
@@ -77,6 +77,9 @@ import logging
77
  from typing import List
78
  import time
79
  import numpy as np
 
 
 
80
 
81
  # Configure logging
82
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
@@ -97,6 +100,7 @@ app = FastAPI()
97
  # Global variables for model and processor
98
  model = None
99
  processor = None
 
100
 
101
  def load_model():
102
  """Load model and processor when needed"""
@@ -114,6 +118,15 @@ def load_model():
114
  processor = PaliGemmaProcessor.from_pretrained(model_id)
115
  logger.info("Model loaded successfully")
116
 
 
 
 
 
 
 
 
 
 
117
  def clean_memory():
118
  """Force garbage collection and clear CUDA cache"""
119
  gc.collect()
@@ -165,6 +178,37 @@ async def extract_text(background_tasks: BackgroundTasks, file: UploadFile = Fil
165
  logger.error(f"Error processing image: {str(e)}")
166
  return {"error": str(e)}
167
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
168
  @app.post("/batch_extract_text")
169
  async def batch_extract_text(batch_size:int, background_tasks: BackgroundTasks, files: List[UploadFile] = File(...)):
170
  """Extract text from multiple images with batching"""
 
77
  from typing import List
78
  import time
79
  import numpy as np
80
+ from vllm import LLM, SamplingParams
81
+ import torch._dynamo
82
+ torch._dynamo.config.suppress_errors = True
83
 
84
  # Configure logging
85
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
 
100
  # Global variables for model and processor
101
  model = None
102
  processor = None
103
+ llm = None
104
 
105
  def load_model():
106
  """Load model and processor when needed"""
 
118
  processor = PaliGemmaProcessor.from_pretrained(model_id)
119
  logger.info("Model loaded successfully")
120
 
121
+ def load_vllm_model():
122
+ global llm
123
+ if llm is None:
124
+ llm = LLM(
125
+ model="google/paligemma2-3b-mix-448",
126
+ trust_remote_code=True,
127
+ max_model_len=4096,
128
+ dtype="float16",
129
+ )
130
  def clean_memory():
131
  """Force garbage collection and clear CUDA cache"""
132
  gc.collect()
 
178
  logger.error(f"Error processing image: {str(e)}")
179
  return {"error": str(e)}
180
 
181
+ @app.post("/batch_extract_text_vllm")
182
+ async def batch_extract_text_vllm(background_tasks: BackgroundTasks, files: List[UploadFile] = File(...)):
183
+ try:
184
+ start_time = time.time()
185
+ load_vllm_model()
186
+ results = []
187
+ sampling_params = SamplingParams(temperature=0.0,max_tokens=32)
188
+ # Load images
189
+ images = []
190
+ for file in files:
191
+ image_data = await file.read()
192
+ img = Image.open(BytesIO(image_data)).convert("RGB")
193
+ images.append(img)
194
+ for image in images:
195
+ inputs = {
196
+ "prompt": "ocr",
197
+ "multi_modal_data": {
198
+ "image": image
199
+ },
200
+ }
201
+ outputs = llm.generate(inputs, sampling_params)
202
+ for o in outputs:
203
+ generated_text = o.outputs[0].text
204
+ results.append(generated_text)
205
+
206
+ logger.info(f"vLLM Batch processing completed in {time.time() - start_time:.2f} seconds")
207
+ return {"extracted_texts": results}
208
+ except Exception as e:
209
+ logger.error(f"Error in batch processing vLLM: {str(e)}")
210
+ return {"error": str(e)}
211
+
212
  @app.post("/batch_extract_text")
213
  async def batch_extract_text(batch_size:int, background_tasks: BackgroundTasks, files: List[UploadFile] = File(...)):
214
  """Extract text from multiple images with batching"""
requirements.txt CHANGED
@@ -7,4 +7,5 @@ transformers
7
  torch
8
  accelerate
9
  pillow
10
- python-multipart
 
 
7
  torch
8
  accelerate
9
  pillow
10
+ python-multipart
11
+ vllm