Spaces:
Running
on
Zero
Running
on
Zero
from colpali_engine.models import ColPali | |
from colpali_engine.models import ColPaliProcessor | |
from colpali_engine.utils.processing_utils import BaseVisualRetrieverProcessor | |
from colpali_engine.utils.torch_utils import ListDataset, get_torch_device | |
from torch.utils.data import DataLoader | |
import torch | |
from typing import List, cast | |
import matplotlib.pyplot as plt | |
#from colpali_engine.models import ColQwen2_5, ColQwen2_5_Processor | |
from colpali_engine.models import ColIdefics3, ColIdefics3Processor | |
from tqdm import tqdm | |
from PIL import Image | |
import os | |
import spaces | |
#this part is for local runs | |
torch.cuda.empty_cache() | |
#get model name from .env variable & set directory & processor dir as the model names! | |
import dotenv | |
# Load the .env file | |
dotenv_file = dotenv.find_dotenv() | |
dotenv.load_dotenv(dotenv_file) | |
model_name = 'vidore/colpali-v1.3' #"vidore/colSmol-256M" | |
device = get_torch_device("cuda") #try using cpu instead of cuda? | |
#switch to locally downloading models & loading locally rather than from hf | |
# | |
current_working_directory = os.getcwd() | |
save_directory = model_name # Directory to save the specific model name | |
save_directory = os.path.join(current_working_directory, save_directory) | |
processor_directory = model_name+'_processor' # Directory to save the processor | |
processor_directory = os.path.join(current_working_directory, processor_directory) | |
if not os.path.exists(save_directory): #download if directory not created/model not loaded | |
# Directory does not exist; create it | |
if "colSmol-256M" in model_name: #if colsmol | |
model = ColIdefics3.from_pretrained( | |
model_name, | |
torch_dtype=torch.bfloat16, | |
device_map=device, | |
#attn_implementation="flash_attention_2", | |
).eval() | |
processor = cast(ColIdefics3Processor, ColIdefics3Processor.from_pretrained(model_name)) | |
else: #if colpali v1.3 etc | |
model = ColPali.from_pretrained( | |
model_name, | |
torch_dtype=torch.bfloat16, | |
device_map=device, | |
#attn_implementation="flash_attention_2", | |
).eval() | |
processor = cast(ColPaliProcessor, ColPaliProcessor.from_pretrained(model_name)) | |
os.makedirs(save_directory) | |
print(f"Directory '{save_directory}' created.") | |
model.save_pretrained(save_directory) | |
os.makedirs(processor_directory) | |
processor.save_pretrained(processor_directory) | |
else: | |
if "colSmol-256M" in model_name: | |
model = ColIdefics3.from_pretrained(save_directory) | |
processor = ColIdefics3Processor.from_pretrained(processor_directory, use_fast=True) | |
else: | |
model = ColPali.from_pretrained(save_directory) | |
processor = ColPaliProcessor.from_pretrained(processor_directory, use_fast=True) | |
class ColpaliManager: | |
def __init__(self, device = "cuda", model_name = model_name): #need to hot potato/use diff gpus between colpali & ollama | |
print(f"Initializing ColpaliManager with device {device} and model {model_name}") | |
# self.device = get_torch_device(device) | |
# self.model = ColPali.from_pretrained( | |
# model_name, | |
# torch_dtype=torch.bfloat16, | |
# device_map=self.device, | |
# ).eval() | |
# self.processor = cast(ColPaliProcessor, ColPaliProcessor.from_pretrained(model_name)) | |
def get_images(self, paths: list[str]) -> List[Image.Image]: | |
model.to("cuda") | |
return [Image.open(path) for path in paths] | |
def process_images(self, image_paths:list[str], batch_size=5): | |
model.to("cuda") | |
print(f"Processing {len(image_paths)} image_paths") | |
images = self.get_images(image_paths) | |
dataloader = DataLoader( | |
dataset=ListDataset[str](images), | |
batch_size=batch_size, | |
shuffle=False, | |
collate_fn=lambda x: processor.process_images(x), | |
) | |
ds: List[torch.Tensor] = [] | |
for batch_doc in tqdm(dataloader): | |
with torch.no_grad(): | |
batch_doc = {k: v.to(model.device) for k, v in batch_doc.items()} | |
embeddings_doc = model(**batch_doc) | |
ds.extend(list(torch.unbind(embeddings_doc.to(device)))) | |
ds_np = [d.float().cpu().numpy() for d in ds] | |
return ds_np | |
def process_text(self, texts: list[str]): | |
#current_working_directory = os.getcwd() | |
#save_directory = model_name # Directory to save the specific model name | |
#save_directory = os.path.join(current_working_directory, save_directory) | |
#processor_directory = model_name+'_processor' # Directory to save the processor | |
#processor_directory = os.path.join(current_working_directory, processor_directory) | |
if not os.path.exists(save_directory): #download if directory not created/model not loaded | |
#MUST USE colpali v1.3/1.2 etc, CANNOT USE SMOLCOLPALI! for queries AS NOT RELIABLE! | |
""" | |
model = ColPali.from_pretrained( | |
model_name, | |
torch_dtype=torch.bfloat16, | |
device_map=device, | |
attn_implementation="flash_attention_2", | |
).eval() | |
processor = cast(ColPaliProcessor, ColPaliProcessor.from_pretrained(model_name)) | |
os.makedirs(save_directory) | |
print(f"Directory '{save_directory}' created.") | |
model.save_pretrained(save_directory) | |
os.makedirs(processor_directory) | |
processor.save_pretrained(processor_directory) | |
else: | |
model = ColPali.from_pretrained(save_directory) | |
processor = ColPaliProcessor.from_pretrained(processor_directory, use_fast=True) | |
""" | |
model.to("cuda") #ensure this is commented out so ollama/multimodal llm can use gpu! (nah wrong, need to enable so that it can process multiple) | |
print(f"Processing {len(texts)} texts") | |
dataloader = DataLoader( | |
dataset=ListDataset[str](texts), | |
batch_size=5, #OG is 5, try reducing batch size to maximise gpu use | |
shuffle=False, | |
collate_fn=lambda x: processor.process_queries(x), | |
) | |
qs: List[torch.Tensor] = [] | |
for batch_query in dataloader: | |
with torch.no_grad(): | |
batch_query = {k: v.to(model.device) for k, v in batch_query.items()} | |
embeddings_query = model(**batch_query) | |
qs.extend(list(torch.unbind(embeddings_query.to(device)))) | |
qs_np = [q.float().cpu().numpy() for q in qs] | |
model.to("cpu") # Moves all model parameters and buffers to the CPU, freeing up gpu for ollama call after this process text call! (THIS WORKS!) | |
return qs_np | |
plt.close("all") | |