sagar008 commited on
Commit
e96a966
Β·
verified Β·
1 Parent(s): c7129a7

Create vector_store.py

Browse files
Files changed (1) hide show
  1. vector_store.py +141 -0
vector_store.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # vector_store.py
2
+ """
3
+ Vector store integration for legal document embeddings using InLegalBERT and Pinecone
4
+ """
5
+ import os
6
+ import pinecone
7
+ from langchain.vectorstores import Pinecone as LangchainPinecone
8
+ from langchain.embeddings.base import Embeddings
9
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
10
+ import numpy as np
11
+ from typing import List, Dict, Any
12
+
13
+ class InLegalBERTEmbeddings(Embeddings):
14
+ """Custom LangChain embeddings wrapper for InLegalBERT"""
15
+
16
+ def __init__(self, model):
17
+ self.model = model
18
+
19
+ def embed_documents(self, texts: List[str]) -> List[List[float]]:
20
+ """Embed a list of documents"""
21
+ return self.model.encode(texts).tolist()
22
+
23
+ def embed_query(self, text: str) -> List[float]:
24
+ """Embed a single query"""
25
+ return self.model.encode([text])[0].tolist()
26
+
27
+ class LegalDocumentVectorStore:
28
+ """Manages vector storage for legal documents"""
29
+
30
+ def __init__(self):
31
+ self.index_name = 'legal-documents'
32
+ self.dimension = 768 # InLegalBERT dimension
33
+ self._initialized = False
34
+ self.clause_tagger = None
35
+
36
+ def _initialize_pinecone(self):
37
+ """Initialize Pinecone connection"""
38
+ if self._initialized:
39
+ return
40
+
41
+ PINECONE_API_KEY = os.getenv('PINECONE_API_KEY')
42
+ PINECONE_ENV = os.getenv('PINECONE_ENV', 'us-west1-gcp')
43
+
44
+ if not PINECONE_API_KEY:
45
+ raise ValueError("PINECONE_API_KEY environment variable not set")
46
+
47
+ pinecone.init(api_key=PINECONE_API_KEY, environment=PINECONE_ENV)
48
+
49
+ # Create index if doesn't exist
50
+ if self.index_name not in pinecone.list_indexes():
51
+ pinecone.create_index(
52
+ name=self.index_name,
53
+ dimension=self.dimension,
54
+ metric='cosine'
55
+ )
56
+ print(f"βœ… Created Pinecone index: {self.index_name}")
57
+
58
+ self._initialized = True
59
+
60
+ def save_document_embeddings(self, document_text: str, document_id: str,
61
+ analysis_results: Dict[str, Any], clause_tagger) -> bool:
62
+ """Save document embeddings using InLegalBERT model"""
63
+ try:
64
+ self._initialize_pinecone()
65
+
66
+ # Use the clause tagger's InLegalBERT model
67
+ legal_embeddings = InLegalBERTEmbeddings(clause_tagger.embedding_model)
68
+
69
+ # Split document into chunks
70
+ text_splitter = RecursiveCharacterTextSplitter(
71
+ chunk_size=1000,
72
+ chunk_overlap=200,
73
+ separators=["\n\n", "\n", ".", "!", "?", ",", " ", ""]
74
+ )
75
+
76
+ chunks = text_splitter.split_text(document_text)
77
+
78
+ # Prepare metadata with analysis results
79
+ metadatas = []
80
+ for i, chunk in enumerate(chunks):
81
+ metadata = {
82
+ 'document_id': document_id,
83
+ 'chunk_index': i,
84
+ 'total_chunks': len(chunks),
85
+ 'source': 'legal_document',
86
+ 'has_key_clauses': len(analysis_results.get('key_clauses', [])) > 0,
87
+ 'risk_count': len(analysis_results.get('risky_terms', [])),
88
+ 'embedding_model': 'InLegalBERT',
89
+ 'timestamp': str(np.datetime64('now'))
90
+ }
91
+ metadatas.append(metadata)
92
+
93
+ # Create vector store
94
+ index = pinecone.Index(self.index_name)
95
+ vectorstore = LangchainPinecone(
96
+ index=index,
97
+ embedding=legal_embeddings,
98
+ text_key="text"
99
+ )
100
+
101
+ # Add documents to Pinecone
102
+ vectorstore.add_texts(
103
+ texts=chunks,
104
+ metadatas=metadatas,
105
+ ids=[f"{document_id}_chunk_{i}" for i in range(len(chunks))]
106
+ )
107
+
108
+ print(f"βœ… Saved {len(chunks)} chunks using InLegalBERT embeddings for document {document_id}")
109
+ return True
110
+
111
+ except Exception as e:
112
+ print(f"❌ Error saving to Pinecone: {e}")
113
+ return False
114
+
115
+ def get_retriever(self, clause_tagger, document_id: str = None):
116
+ """Get retriever for chat functionality"""
117
+ try:
118
+ self._initialize_pinecone()
119
+
120
+ legal_embeddings = InLegalBERTEmbeddings(clause_tagger.embedding_model)
121
+ index = pinecone.Index(self.index_name)
122
+
123
+ vectorstore = LangchainPinecone(
124
+ index=index,
125
+ embedding=legal_embeddings,
126
+ text_key="text"
127
+ )
128
+
129
+ # Create retriever with optional document filtering
130
+ search_kwargs = {'k': 5}
131
+ if document_id:
132
+ search_kwargs['filter'] = {'document_id': document_id}
133
+
134
+ return vectorstore.as_retriever(search_kwargs=search_kwargs)
135
+
136
+ except Exception as e:
137
+ print(f"❌ Error creating retriever: {e}")
138
+ return None
139
+
140
+ # Global instance
141
+ vector_store = LegalDocumentVectorStore()