File size: 5,379 Bytes
ac89d45 e96a966 ac89d45 e96a966 ac89d45 e96a966 ac89d45 e96a966 ac89d45 e96a966 ac89d45 e96a966 ac89d45 e96a966 ac89d45 e96a966 ac89d45 e96a966 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 |
# vector_store.py - Updated for new Pinecone package
import os
from pinecone import Pinecone, ServerlessSpec # Changed import
from langchain.vectorstores import Pinecone as LangchainPinecone
from langchain.embeddings.base import Embeddings
from langchain.text_splitter import RecursiveCharacterTextSplitter
import numpy as np
from typing import List, Dict, Any
class InLegalBERTEmbeddings(Embeddings):
"""Custom LangChain embeddings wrapper for InLegalBERT"""
def __init__(self, model):
self.model = model
def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""Embed a list of documents"""
return self.model.encode(texts).tolist()
def embed_query(self, text: str) -> List[float]:
"""Embed a single query"""
return self.model.encode([text])[0].tolist()
class LegalDocumentVectorStore:
"""Manages vector storage for legal documents"""
def __init__(self):
self.index_name = 'legal-documents'
self.dimension = 768 # InLegalBERT dimension
self._initialized = False
self.clause_tagger = None
self.pc = None # Pinecone client
def _initialize_pinecone(self):
"""Initialize Pinecone connection with new API"""
if self._initialized:
return
PINECONE_API_KEY = os.getenv('PINECONE_API_KEY')
if not PINECONE_API_KEY:
raise ValueError("PINECONE_API_KEY environment variable not set")
# Use new Pinecone API
self.pc = Pinecone(api_key=PINECONE_API_KEY)
# Create index if doesn't exist
existing_indexes = [index_info["name"] for index_info in self.pc.list_indexes()]
if self.index_name not in existing_indexes:
self.pc.create_index(
name=self.index_name,
dimension=self.dimension,
metric='cosine',
spec=ServerlessSpec(cloud='aws', region='us-east-1')
)
print(f"β
Created Pinecone index: {self.index_name}")
self._initialized = True
def save_document_embeddings(self, document_text: str, document_id: str,
analysis_results: Dict[str, Any], clause_tagger) -> bool:
"""Save document embeddings using InLegalBERT model"""
try:
self._initialize_pinecone()
# Use the clause tagger's InLegalBERT model
legal_embeddings = InLegalBERTEmbeddings(clause_tagger.embedding_model)
# Split document into chunks
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=1000,
chunk_overlap=200,
separators=["\n\n", "\n", ".", "!", "?", ",", " ", ""]
)
chunks = text_splitter.split_text(document_text)
# Prepare metadata with analysis results
metadatas = []
for i, chunk in enumerate(chunks):
metadata = {
'document_id': document_id,
'chunk_index': i,
'total_chunks': len(chunks),
'source': 'legal_document',
'has_key_clauses': len(analysis_results.get('key_clauses', [])) > 0,
'risk_count': len(analysis_results.get('risky_terms', [])),
'embedding_model': 'InLegalBERT',
'timestamp': str(np.datetime64('now'))
}
metadatas.append(metadata)
# Create vector store with new API
index = self.pc.Index(self.index_name)
vectorstore = LangchainPinecone(
index=index,
embedding=legal_embeddings,
text_key="text"
)
# Add documents to Pinecone
vectorstore.add_texts(
texts=chunks,
metadatas=metadatas,
ids=[f"{document_id}_chunk_{i}" for i in range(len(chunks))]
)
print(f"β
Saved {len(chunks)} chunks using InLegalBERT embeddings for document {document_id}")
return True
except Exception as e:
print(f"β Error saving to Pinecone: {e}")
return False
def get_retriever(self, clause_tagger, document_id: str = None):
"""Get retriever for chat functionality"""
try:
self._initialize_pinecone()
legal_embeddings = InLegalBERTEmbeddings(clause_tagger.embedding_model)
index = self.pc.Index(self.index_name)
vectorstore = LangchainPinecone(
index=index,
embedding=legal_embeddings,
text_key="text"
)
# Create retriever with optional document filtering
search_kwargs = {'k': 5}
if document_id:
search_kwargs['filter'] = {'document_id': document_id}
return vectorstore.as_retriever(search_kwargs=search_kwargs)
except Exception as e:
print(f"β Error creating retriever: {e}")
return None
# Global instance
vector_store = LegalDocumentVectorStore()
|