File size: 5,379 Bytes
ac89d45
e96a966
ac89d45
e96a966
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ac89d45
e96a966
 
ac89d45
e96a966
 
 
 
 
 
 
 
ac89d45
 
e96a966
 
ac89d45
 
 
e96a966
 
ac89d45
 
e96a966
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ac89d45
 
e96a966
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ac89d45
e96a966
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
# vector_store.py - Updated for new Pinecone package
import os
from pinecone import Pinecone, ServerlessSpec  # Changed import
from langchain.vectorstores import Pinecone as LangchainPinecone
from langchain.embeddings.base import Embeddings
from langchain.text_splitter import RecursiveCharacterTextSplitter
import numpy as np
from typing import List, Dict, Any

class InLegalBERTEmbeddings(Embeddings):
    """Custom LangChain embeddings wrapper for InLegalBERT"""
    
    def __init__(self, model):
        self.model = model
    
    def embed_documents(self, texts: List[str]) -> List[List[float]]:
        """Embed a list of documents"""
        return self.model.encode(texts).tolist()
    
    def embed_query(self, text: str) -> List[float]:
        """Embed a single query"""
        return self.model.encode([text])[0].tolist()

class LegalDocumentVectorStore:
    """Manages vector storage for legal documents"""
    
    def __init__(self):
        self.index_name = 'legal-documents'
        self.dimension = 768  # InLegalBERT dimension
        self._initialized = False
        self.clause_tagger = None
        self.pc = None  # Pinecone client
    
    def _initialize_pinecone(self):
        """Initialize Pinecone connection with new API"""
        if self._initialized:
            return
            
        PINECONE_API_KEY = os.getenv('PINECONE_API_KEY')
        
        if not PINECONE_API_KEY:
            raise ValueError("PINECONE_API_KEY environment variable not set")
        
        # Use new Pinecone API
        self.pc = Pinecone(api_key=PINECONE_API_KEY)
        
        # Create index if doesn't exist
        existing_indexes = [index_info["name"] for index_info in self.pc.list_indexes()]
        if self.index_name not in existing_indexes:
            self.pc.create_index(
                name=self.index_name,
                dimension=self.dimension,
                metric='cosine',
                spec=ServerlessSpec(cloud='aws', region='us-east-1')
            )
            print(f"βœ… Created Pinecone index: {self.index_name}")
        
        self._initialized = True
    
    def save_document_embeddings(self, document_text: str, document_id: str, 
                               analysis_results: Dict[str, Any], clause_tagger) -> bool:
        """Save document embeddings using InLegalBERT model"""
        try:
            self._initialize_pinecone()
            
            # Use the clause tagger's InLegalBERT model
            legal_embeddings = InLegalBERTEmbeddings(clause_tagger.embedding_model)
            
            # Split document into chunks
            text_splitter = RecursiveCharacterTextSplitter(
                chunk_size=1000,
                chunk_overlap=200,
                separators=["\n\n", "\n", ".", "!", "?", ",", " ", ""]
            )
            
            chunks = text_splitter.split_text(document_text)
            
            # Prepare metadata with analysis results
            metadatas = []
            for i, chunk in enumerate(chunks):
                metadata = {
                    'document_id': document_id,
                    'chunk_index': i,
                    'total_chunks': len(chunks),
                    'source': 'legal_document',
                    'has_key_clauses': len(analysis_results.get('key_clauses', [])) > 0,
                    'risk_count': len(analysis_results.get('risky_terms', [])),
                    'embedding_model': 'InLegalBERT',
                    'timestamp': str(np.datetime64('now'))
                }
                metadatas.append(metadata)
            
            # Create vector store with new API
            index = self.pc.Index(self.index_name)
            vectorstore = LangchainPinecone(
                index=index,
                embedding=legal_embeddings,
                text_key="text"
            )
            
            # Add documents to Pinecone
            vectorstore.add_texts(
                texts=chunks,
                metadatas=metadatas,
                ids=[f"{document_id}_chunk_{i}" for i in range(len(chunks))]
            )
            
            print(f"βœ… Saved {len(chunks)} chunks using InLegalBERT embeddings for document {document_id}")
            return True
            
        except Exception as e:
            print(f"❌ Error saving to Pinecone: {e}")
            return False
    
    def get_retriever(self, clause_tagger, document_id: str = None):
        """Get retriever for chat functionality"""
        try:
            self._initialize_pinecone()
            
            legal_embeddings = InLegalBERTEmbeddings(clause_tagger.embedding_model)
            index = self.pc.Index(self.index_name)
            
            vectorstore = LangchainPinecone(
                index=index,
                embedding=legal_embeddings,
                text_key="text"
            )
            
            # Create retriever with optional document filtering
            search_kwargs = {'k': 5}
            if document_id:
                search_kwargs['filter'] = {'document_id': document_id}
            
            return vectorstore.as_retriever(search_kwargs=search_kwargs)
            
        except Exception as e:
            print(f"❌ Error creating retriever: {e}")
            return None

# Global instance
vector_store = LegalDocumentVectorStore()