rockerritesh commited on
Commit
95f4d99
·
verified ·
1 Parent(s): 13ed0ee

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +205 -35
main.py CHANGED
@@ -1,65 +1,235 @@
1
- # main.py
2
- from fastapi import FastAPI, File, UploadFile
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  from transformers import PaliGemmaProcessor, PaliGemmaForConditionalGeneration
4
- from transformers.image_utils import load_image
5
  import torch
6
  from io import BytesIO
7
  import os
8
  from dotenv import load_dotenv
9
  from PIL import Image
10
-
11
  from huggingface_hub import login
 
 
 
 
 
 
 
 
12
 
13
  # Load environment variables
14
  load_dotenv()
15
 
16
  # Set the cache directory to a writable path
17
  os.environ["TORCHINDUCTOR_CACHE_DIR"] = "/tmp/torch_inductor_cache"
18
-
19
  token = os.getenv("huggingface_ankit")
 
20
  # Login to the Hugging Face Hub
21
  login(token)
22
 
23
  app = FastAPI()
24
 
25
- model_id = "google/paligemma2-3b-mix-448"
26
- model = PaliGemmaForConditionalGeneration.from_pretrained(model_id).to('cuda')
27
- processor = PaliGemmaProcessor.from_pretrained(model_id)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
29
  def predict(image):
 
 
 
 
30
  prompt = "<image> ocr"
31
- model_inputs = processor(text=prompt, images=image, return_tensors="pt").to('cuda')
32
- input_len = model_inputs["input_ids"].shape[-1]
 
 
 
 
33
  with torch.inference_mode():
34
  generation = model.generate(**model_inputs, max_new_tokens=200)
35
- torch.cuda.empty_cache()
36
- decoded = processor.decode(generation[0], skip_special_tokens=True) #[len(prompt):].lstrip("\n")
 
 
 
 
 
 
37
  return decoded
38
 
39
  @app.post("/extract_text")
40
- async def extract_text(file: UploadFile = File(...)):
41
- image = Image.open(BytesIO(await file.read())).convert("RGB") # Ensure it's a valid PIL image
42
- text = predict(image)
43
- return {"extracted_text": text}
 
 
 
 
 
 
 
 
 
 
 
44
 
45
  @app.post("/batch_extract_text")
46
- async def batch_extract_text(files: list[UploadFile] = File(...)):
47
- # if len(files) > 20:
48
- # return {"error": "A maximum of 20 images can be processed at a time."}
49
-
50
- images = [Image.open(BytesIO(await file.read())).convert("RGB") for file in files]
51
- prompts = ["OCR"] * len(images)
52
-
53
- model_inputs = processor(text=prompts, images=images, return_tensors="pt").to(torch.bfloat16).to(model.device)
54
- input_len = model_inputs["input_ids"].shape[-1]
55
-
56
- with torch.inference_mode():
57
- generations = model.generate(**model_inputs, max_new_tokens=200, do_sample=False)
58
- torch.cuda.empty_cache()
59
- extracted_texts = [processor.decode(generations[i], skip_special_tokens=True) for i in range(len(images))]
60
-
61
- return {"extracted_texts": extracted_texts}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
 
63
- if __name__ == "__main__":
64
- import uvicorn
65
- uvicorn.run(app, host="0.0.0.0", port=7860)
 
 
 
 
 
 
1
+ # # main.py
2
+ # from fastapi import FastAPI, File, UploadFile
3
+ # from transformers import PaliGemmaProcessor, PaliGemmaForConditionalGeneration
4
+ # from transformers.image_utils import load_image
5
+ # import torch
6
+ # from io import BytesIO
7
+ # import os
8
+ # from dotenv import load_dotenv
9
+ # from PIL import Image
10
+
11
+ # from huggingface_hub import login
12
+
13
+ # # Load environment variables
14
+ # load_dotenv()
15
+
16
+ # # Set the cache directory to a writable path
17
+ # os.environ["TORCHINDUCTOR_CACHE_DIR"] = "/tmp/torch_inductor_cache"
18
+
19
+ # token = os.getenv("huggingface_ankit")
20
+ # # Login to the Hugging Face Hub
21
+ # login(token)
22
+
23
+ # app = FastAPI()
24
+
25
+ # model_id = "google/paligemma2-3b-mix-448"
26
+ # model = PaliGemmaForConditionalGeneration.from_pretrained(model_id).to('cuda')
27
+ # processor = PaliGemmaProcessor.from_pretrained(model_id)
28
+
29
+ # def predict(image):
30
+ # prompt = "<image> ocr"
31
+ # model_inputs = processor(text=prompt, images=image, return_tensors="pt").to('cuda')
32
+ # input_len = model_inputs["input_ids"].shape[-1]
33
+ # with torch.inference_mode():
34
+ # generation = model.generate(**model_inputs, max_new_tokens=200)
35
+ # torch.cuda.empty_cache()
36
+ # decoded = processor.decode(generation[0], skip_special_tokens=True) #[len(prompt):].lstrip("\n")
37
+ # return decoded
38
+
39
+ # @app.post("/extract_text")
40
+ # async def extract_text(file: UploadFile = File(...)):
41
+ # image = Image.open(BytesIO(await file.read())).convert("RGB") # Ensure it's a valid PIL image
42
+ # text = predict(image)
43
+ # return {"extracted_text": text}
44
+
45
+ # @app.post("/batch_extract_text")
46
+ # async def batch_extract_text(files: list[UploadFile] = File(...)):
47
+ # # if len(files) > 20:
48
+ # # return {"error": "A maximum of 20 images can be processed at a time."}
49
+
50
+ # images = [Image.open(BytesIO(await file.read())).convert("RGB") for file in files]
51
+ # prompts = ["OCR"] * len(images)
52
+
53
+ # model_inputs = processor(text=prompts, images=images, return_tensors="pt").to(torch.bfloat16).to(model.device)
54
+ # input_len = model_inputs["input_ids"].shape[-1]
55
+
56
+ # with torch.inference_mode():
57
+ # generations = model.generate(**model_inputs, max_new_tokens=200, do_sample=False)
58
+ # torch.cuda.empty_cache()
59
+ # extracted_texts = [processor.decode(generations[i], skip_special_tokens=True) for i in range(len(images))]
60
+
61
+ # return {"extracted_texts": extracted_texts}
62
+
63
+ # if __name__ == "__main__":
64
+ # import uvicorn
65
+ # uvicorn.run(app, host="0.0.0.0", port=7860)
66
+
67
+
68
+ from fastapi import FastAPI, File, UploadFile, BackgroundTasks
69
  from transformers import PaliGemmaProcessor, PaliGemmaForConditionalGeneration
 
70
  import torch
71
  from io import BytesIO
72
  import os
73
  from dotenv import load_dotenv
74
  from PIL import Image
 
75
  from huggingface_hub import login
76
+ import gc
77
+ import logging
78
+ from typing import List
79
+ import time
80
+
81
+ # Configure logging
82
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
83
+ logger = logging.getLogger(__name__)
84
 
85
  # Load environment variables
86
  load_dotenv()
87
 
88
  # Set the cache directory to a writable path
89
  os.environ["TORCHINDUCTOR_CACHE_DIR"] = "/tmp/torch_inductor_cache"
 
90
  token = os.getenv("huggingface_ankit")
91
+
92
  # Login to the Hugging Face Hub
93
  login(token)
94
 
95
  app = FastAPI()
96
 
97
+ # Global variables for model and processor
98
+ model = None
99
+ processor = None
100
+
101
+ def load_model():
102
+ """Load model and processor when needed"""
103
+ global model, processor
104
+ if model is None:
105
+ model_id = "google/paligemma2-3b-mix-448"
106
+ logger.info(f"Loading model {model_id}")
107
+
108
+ # Load model with memory-efficient settings
109
+ model = PaliGemmaForConditionalGeneration.from_pretrained(
110
+ model_id,
111
+ device_map="auto",
112
+ torch_dtype=torch.bfloat16 # Use lower precision for memory efficiency
113
+ )
114
+ processor = PaliGemmaProcessor.from_pretrained(model_id)
115
+ logger.info("Model loaded successfully")
116
+
117
+ def clean_memory():
118
+ """Force garbage collection and clear CUDA cache"""
119
+ gc.collect()
120
+ if torch.cuda.is_available():
121
+ torch.cuda.empty_cache()
122
+ logger.info("Memory cleaned")
123
 
124
  def predict(image):
125
+ """Process a single image"""
126
+ load_model() # Ensure model is loaded
127
+
128
+ # Process input
129
  prompt = "<image> ocr"
130
+ model_inputs = processor(text=prompt, images=image, return_tensors="pt")
131
+
132
+ # Move to appropriate device
133
+ model_inputs = {k: v.to(model.device) for k, v in model_inputs.items()}
134
+
135
+ # Generate with memory optimization
136
  with torch.inference_mode():
137
  generation = model.generate(**model_inputs, max_new_tokens=200)
138
+
139
+ # Decode output
140
+ decoded = processor.decode(generation[0], skip_special_tokens=True)
141
+
142
+ # Clean up intermediates
143
+ del model_inputs, generation
144
+ clean_memory()
145
+
146
  return decoded
147
 
148
  @app.post("/extract_text")
149
+ async def extract_text(file: UploadFile = File(...), background_tasks: BackgroundTasks):
150
+ """Extract text from a single image"""
151
+ try:
152
+ start_time = time.time()
153
+ image = Image.open(BytesIO(await file.read())).convert("RGB")
154
+ text = predict(image)
155
+
156
+ # Schedule cleanup after response
157
+ background_tasks.add_task(clean_memory)
158
+
159
+ logger.info(f"Processing completed in {time.time() - start_time:.2f} seconds")
160
+ return {"extracted_text": text}
161
+ except Exception as e:
162
+ logger.error(f"Error processing image: {str(e)}")
163
+ return {"error": str(e)}
164
 
165
  @app.post("/batch_extract_text")
166
+ async def batch_extract_text(files: List[UploadFile] = File(...), background_tasks: BackgroundTasks):
167
+ """Extract text from multiple images with batching"""
168
+ try:
169
+ start_time = time.time()
170
+
171
+ # Limit batch size for memory management
172
+ max_batch_size = 5 # Adjust based on your GPU memory
173
+
174
+ if len(files) > 20:
175
+ return {"error": "A maximum of 20 images can be processed at a time."}
176
+
177
+ load_model() # Ensure model is loaded
178
+
179
+ all_results = []
180
+
181
+ # Process in smaller batches
182
+ for i in range(0, len(files), max_batch_size):
183
+ batch_files = files[i:i+max_batch_size]
184
+
185
+ # Load images
186
+ images = []
187
+ for file in batch_files:
188
+ image_data = await file.read()
189
+ img = Image.open(BytesIO(image_data)).convert("RGB")
190
+ images.append(img)
191
+
192
+ # Create batch inputs
193
+ prompts = ["<image> ocr"] * len(images)
194
+ model_inputs = processor(text=prompts, images=images, return_tensors="pt")
195
+
196
+ # Move to appropriate device
197
+ model_inputs = {k: v.to(model.device) for k, v in model_inputs.items()}
198
+
199
+ # Generate with memory optimization
200
+ with torch.inference_mode():
201
+ generations = model.generate(**model_inputs, max_new_tokens=200, do_sample=False)
202
+
203
+ # Decode outputs
204
+ batch_results = [processor.decode(generations[i], skip_special_tokens=True) for i in range(len(images))]
205
+ all_results.extend(batch_results)
206
+
207
+ # Clean up batch resources
208
+ del model_inputs, generations, images
209
+ clean_memory()
210
+
211
+ # Schedule cleanup after response
212
+ background_tasks.add_task(clean_memory)
213
+
214
+ logger.info(f"Batch processing completed in {time.time() - start_time:.2f} seconds")
215
+ return {"extracted_texts": all_results}
216
+ except Exception as e:
217
+ logger.error(f"Error in batch processing: {str(e)}")
218
+ return {"error": str(e)}
219
+
220
+ # Health check endpoint
221
+ @app.get("/health")
222
+ async def health_check():
223
+ return {"status": "healthy"}
224
+
225
+ # if __name__ == "__main__":
226
+ # import uvicorn
227
 
228
+ # # Start the server with proper worker configuration
229
+ # uvicorn.run(
230
+ # app,
231
+ # host="0.0.0.0",
232
+ # port=7860,
233
+ # log_level="info",
234
+ # workers=1 # Multiple workers can cause GPU memory issues
235
+ # )