File size: 5,228 Bytes
e96a966
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
# vector_store.py
"""
Vector store integration for legal document embeddings using InLegalBERT and Pinecone
"""
import os
import pinecone
from langchain.vectorstores import Pinecone as LangchainPinecone
from langchain.embeddings.base import Embeddings
from langchain.text_splitter import RecursiveCharacterTextSplitter
import numpy as np
from typing import List, Dict, Any

class InLegalBERTEmbeddings(Embeddings):
    """Custom LangChain embeddings wrapper for InLegalBERT"""
    
    def __init__(self, model):
        self.model = model
    
    def embed_documents(self, texts: List[str]) -> List[List[float]]:
        """Embed a list of documents"""
        return self.model.encode(texts).tolist()
    
    def embed_query(self, text: str) -> List[float]:
        """Embed a single query"""
        return self.model.encode([text])[0].tolist()

class LegalDocumentVectorStore:
    """Manages vector storage for legal documents"""
    
    def __init__(self):
        self.index_name = 'legal-documents'
        self.dimension = 768  # InLegalBERT dimension
        self._initialized = False
        self.clause_tagger = None
    
    def _initialize_pinecone(self):
        """Initialize Pinecone connection"""
        if self._initialized:
            return
            
        PINECONE_API_KEY = os.getenv('PINECONE_API_KEY')
        PINECONE_ENV = os.getenv('PINECONE_ENV', 'us-west1-gcp')
        
        if not PINECONE_API_KEY:
            raise ValueError("PINECONE_API_KEY environment variable not set")
        
        pinecone.init(api_key=PINECONE_API_KEY, environment=PINECONE_ENV)
        
        # Create index if doesn't exist
        if self.index_name not in pinecone.list_indexes():
            pinecone.create_index(
                name=self.index_name,
                dimension=self.dimension,
                metric='cosine'
            )
            print(f"βœ… Created Pinecone index: {self.index_name}")
        
        self._initialized = True
    
    def save_document_embeddings(self, document_text: str, document_id: str, 
                               analysis_results: Dict[str, Any], clause_tagger) -> bool:
        """Save document embeddings using InLegalBERT model"""
        try:
            self._initialize_pinecone()
            
            # Use the clause tagger's InLegalBERT model
            legal_embeddings = InLegalBERTEmbeddings(clause_tagger.embedding_model)
            
            # Split document into chunks
            text_splitter = RecursiveCharacterTextSplitter(
                chunk_size=1000,
                chunk_overlap=200,
                separators=["\n\n", "\n", ".", "!", "?", ",", " ", ""]
            )
            
            chunks = text_splitter.split_text(document_text)
            
            # Prepare metadata with analysis results
            metadatas = []
            for i, chunk in enumerate(chunks):
                metadata = {
                    'document_id': document_id,
                    'chunk_index': i,
                    'total_chunks': len(chunks),
                    'source': 'legal_document',
                    'has_key_clauses': len(analysis_results.get('key_clauses', [])) > 0,
                    'risk_count': len(analysis_results.get('risky_terms', [])),
                    'embedding_model': 'InLegalBERT',
                    'timestamp': str(np.datetime64('now'))
                }
                metadatas.append(metadata)
            
            # Create vector store
            index = pinecone.Index(self.index_name)
            vectorstore = LangchainPinecone(
                index=index,
                embedding=legal_embeddings,
                text_key="text"
            )
            
            # Add documents to Pinecone
            vectorstore.add_texts(
                texts=chunks,
                metadatas=metadatas,
                ids=[f"{document_id}_chunk_{i}" for i in range(len(chunks))]
            )
            
            print(f"βœ… Saved {len(chunks)} chunks using InLegalBERT embeddings for document {document_id}")
            return True
            
        except Exception as e:
            print(f"❌ Error saving to Pinecone: {e}")
            return False
    
    def get_retriever(self, clause_tagger, document_id: str = None):
        """Get retriever for chat functionality"""
        try:
            self._initialize_pinecone()
            
            legal_embeddings = InLegalBERTEmbeddings(clause_tagger.embedding_model)
            index = pinecone.Index(self.index_name)
            
            vectorstore = LangchainPinecone(
                index=index,
                embedding=legal_embeddings,
                text_key="text"
            )
            
            # Create retriever with optional document filtering
            search_kwargs = {'k': 5}
            if document_id:
                search_kwargs['filter'] = {'document_id': document_id}
            
            return vectorstore.as_retriever(search_kwargs=search_kwargs)
            
        except Exception as e:
            print(f"❌ Error creating retriever: {e}")
            return None

# Global instance
vector_store = LegalDocumentVectorStore()