Spaces:
Sleeping
Sleeping
# main.py | |
from fastapi import FastAPI, File, UploadFile | |
from transformers import PaliGemmaProcessor, PaliGemmaForConditionalGeneration | |
from transformers.image_utils import load_image | |
import torch | |
from io import BytesIO | |
import os | |
from dotenv import load_dotenv | |
from PIL import Image | |
from huggingface_hub import login | |
# Load environment variables | |
load_dotenv() | |
# Set the cache directory to a writable path | |
os.environ["TORCHINDUCTOR_CACHE_DIR"] = "/tmp/torch_inductor_cache" | |
token = os.getenv("huggingface_ankit") | |
# Login to the Hugging Face Hub | |
login(token) | |
app = FastAPI() | |
model_id = "google/paligemma2-3b-mix-448" | |
model = PaliGemmaForConditionalGeneration.from_pretrained(model_id).to('cuda') | |
processor = PaliGemmaProcessor.from_pretrained(model_id) | |
def predict(image): | |
prompt = "<image> ocr" | |
model_inputs = processor(text=prompt, images=image, return_tensors="pt").to('cuda') | |
input_len = model_inputs["input_ids"].shape[-1] | |
with torch.inference_mode(): | |
generation = model.generate(**model_inputs, max_new_tokens=200) | |
torch.cuda.empty_cache() | |
decoded = processor.decode(generation[0], skip_special_tokens=True) #[len(prompt):].lstrip("\n") | |
return decoded | |
async def extract_text(file: UploadFile = File(...)): | |
image = Image.open(BytesIO(await file.read())).convert("RGB") # Ensure it's a valid PIL image | |
text = predict(image) | |
return {"extracted_text": text} | |
async def batch_extract_text(files: list[UploadFile] = File(...)): | |
if len(files) > 20: | |
return {"error": "A maximum of 20 images can be processed at a time."} | |
images = [Image.open(BytesIO(await file.read())).convert("RGB") for file in files] | |
prompts = ["OCR"] * len(images) | |
model_inputs = processor(text=prompts, images=images, return_tensors="pt").to(torch.bfloat16).to(model.device) | |
input_len = model_inputs["input_ids"].shape[-1] | |
with torch.inference_mode(): | |
generations = model.generate(**model_inputs, max_new_tokens=200, do_sample=False) | |
torch.cuda.empty_cache() | |
extracted_texts = [processor.decode(generations[i], skip_special_tokens=True) for i in range(len(images))] | |
return {"extracted_texts": extracted_texts} | |
if __name__ == "__main__": | |
import uvicorn | |
uvicorn.run(app, host="0.0.0.0", port=7860) |