sagar008's picture
Update vector_store.py
ac89d45 verified
raw
history blame
5.38 kB
# vector_store.py - Updated for new Pinecone package
import os
from pinecone import Pinecone, ServerlessSpec # Changed import
from langchain.vectorstores import Pinecone as LangchainPinecone
from langchain.embeddings.base import Embeddings
from langchain.text_splitter import RecursiveCharacterTextSplitter
import numpy as np
from typing import List, Dict, Any
class InLegalBERTEmbeddings(Embeddings):
"""Custom LangChain embeddings wrapper for InLegalBERT"""
def __init__(self, model):
self.model = model
def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""Embed a list of documents"""
return self.model.encode(texts).tolist()
def embed_query(self, text: str) -> List[float]:
"""Embed a single query"""
return self.model.encode([text])[0].tolist()
class LegalDocumentVectorStore:
"""Manages vector storage for legal documents"""
def __init__(self):
self.index_name = 'legal-documents'
self.dimension = 768 # InLegalBERT dimension
self._initialized = False
self.clause_tagger = None
self.pc = None # Pinecone client
def _initialize_pinecone(self):
"""Initialize Pinecone connection with new API"""
if self._initialized:
return
PINECONE_API_KEY = os.getenv('PINECONE_API_KEY')
if not PINECONE_API_KEY:
raise ValueError("PINECONE_API_KEY environment variable not set")
# Use new Pinecone API
self.pc = Pinecone(api_key=PINECONE_API_KEY)
# Create index if doesn't exist
existing_indexes = [index_info["name"] for index_info in self.pc.list_indexes()]
if self.index_name not in existing_indexes:
self.pc.create_index(
name=self.index_name,
dimension=self.dimension,
metric='cosine',
spec=ServerlessSpec(cloud='aws', region='us-east-1')
)
print(f"βœ… Created Pinecone index: {self.index_name}")
self._initialized = True
def save_document_embeddings(self, document_text: str, document_id: str,
analysis_results: Dict[str, Any], clause_tagger) -> bool:
"""Save document embeddings using InLegalBERT model"""
try:
self._initialize_pinecone()
# Use the clause tagger's InLegalBERT model
legal_embeddings = InLegalBERTEmbeddings(clause_tagger.embedding_model)
# Split document into chunks
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=1000,
chunk_overlap=200,
separators=["\n\n", "\n", ".", "!", "?", ",", " ", ""]
)
chunks = text_splitter.split_text(document_text)
# Prepare metadata with analysis results
metadatas = []
for i, chunk in enumerate(chunks):
metadata = {
'document_id': document_id,
'chunk_index': i,
'total_chunks': len(chunks),
'source': 'legal_document',
'has_key_clauses': len(analysis_results.get('key_clauses', [])) > 0,
'risk_count': len(analysis_results.get('risky_terms', [])),
'embedding_model': 'InLegalBERT',
'timestamp': str(np.datetime64('now'))
}
metadatas.append(metadata)
# Create vector store with new API
index = self.pc.Index(self.index_name)
vectorstore = LangchainPinecone(
index=index,
embedding=legal_embeddings,
text_key="text"
)
# Add documents to Pinecone
vectorstore.add_texts(
texts=chunks,
metadatas=metadatas,
ids=[f"{document_id}_chunk_{i}" for i in range(len(chunks))]
)
print(f"βœ… Saved {len(chunks)} chunks using InLegalBERT embeddings for document {document_id}")
return True
except Exception as e:
print(f"❌ Error saving to Pinecone: {e}")
return False
def get_retriever(self, clause_tagger, document_id: str = None):
"""Get retriever for chat functionality"""
try:
self._initialize_pinecone()
legal_embeddings = InLegalBERTEmbeddings(clause_tagger.embedding_model)
index = self.pc.Index(self.index_name)
vectorstore = LangchainPinecone(
index=index,
embedding=legal_embeddings,
text_key="text"
)
# Create retriever with optional document filtering
search_kwargs = {'k': 5}
if document_id:
search_kwargs['filter'] = {'document_id': document_id}
return vectorstore.as_retriever(search_kwargs=search_kwargs)
except Exception as e:
print(f"❌ Error creating retriever: {e}")
return None
# Global instance
vector_store = LegalDocumentVectorStore()