Spaces:
Sleeping
Sleeping
from typing import List, Tuple | |
from concurrent.futures import ThreadPoolExecutor | |
from pymongo import UpdateOne | |
from pymongo.collection import Collection | |
import math | |
def get_embedding(text: str, openai_client, model="text-embedding-ada-002") -> list[float]: | |
"""Get embeddings for given text using OpenAI API""" | |
text = text.replace("\n", " ") | |
resp = openai_client.embeddings.create( | |
input=[text], | |
model=model | |
) | |
return resp.data[0].embedding | |
def process_batch(docs: List[dict], field_name: str, embedding_field: str, openai_client) -> List[Tuple[str, list]]: | |
"""Process a batch of documents to generate embeddings""" | |
results = [] | |
for doc in docs: | |
# Skip if embedding already exists | |
if embedding_field in doc: | |
continue | |
text = doc[field_name] | |
if isinstance(text, str): | |
embedding = get_embedding(text, openai_client) | |
results.append((doc["_id"], embedding)) | |
return results | |
def parallel_generate_embeddings( | |
collection: Collection, | |
cursor, | |
field_name: str, | |
embedding_field: str, | |
openai_client, | |
total_docs: int, | |
batch_size: int = 20, | |
callback=None | |
) -> int: | |
"""Generate embeddings in parallel using ThreadPoolExecutor with cursor-based batching | |
Args: | |
collection: MongoDB collection | |
cursor: MongoDB cursor for document iteration | |
field_name: Field containing text to embed | |
embedding_field: Field to store embeddings | |
openai_client: OpenAI client instance | |
total_docs: Total number of documents to process | |
batch_size: Size of batches for parallel processing | |
callback: Optional callback function for progress updates | |
Returns: | |
Number of documents processed | |
""" | |
if total_docs == 0: | |
return 0 | |
processed = 0 | |
# Initial progress update | |
if callback: | |
callback(0, 0, total_docs) | |
# Process documents in batches using cursor | |
with ThreadPoolExecutor(max_workers=20) as executor: | |
batch = [] | |
futures = [] | |
# Iterate through cursor and build batches | |
for doc in cursor: | |
batch.append(doc) | |
if len(batch) >= batch_size: | |
# Submit batch for processing | |
future = executor.submit(process_batch, batch.copy(), field_name, embedding_field, openai_client) | |
futures.append(future) | |
batch = [] # Clear batch for next round | |
# Process completed futures to free up memory | |
completed_futures = [f for f in futures if f.done()] | |
for future in completed_futures: | |
results = future.result() | |
if results: | |
# Batch update MongoDB | |
bulk_ops = [ | |
UpdateOne({"_id": doc_id}, {"$set": {embedding_field: embedding}}) | |
for doc_id, embedding in results | |
] | |
if bulk_ops: | |
collection.bulk_write(bulk_ops) | |
processed += len(bulk_ops) | |
# Update progress | |
if callback: | |
progress = (processed / total_docs) * 100 | |
callback(progress, processed, total_docs) | |
futures = [f for f in futures if not f.done()] | |
# Process any remaining documents in the last batch | |
if batch: | |
future = executor.submit(process_batch, batch, field_name, embedding_field, openai_client) | |
futures.append(future) | |
# Wait for remaining futures to complete | |
for future in futures: | |
results = future.result() | |
if results: | |
bulk_ops = [ | |
UpdateOne({"_id": doc_id}, {"$set": {embedding_field: embedding}}) | |
for doc_id, embedding in results | |
] | |
if bulk_ops: | |
collection.bulk_write(bulk_ops) | |
processed += len(bulk_ops) | |
# Final progress update | |
if callback: | |
progress = (processed / total_docs) * 100 | |
callback(progress, processed, total_docs) | |
return processed | |