sagar008 commited on
Commit
137a4e7
·
verified ·
1 Parent(s): a5a31ff

Update clause_tagger.py

Browse files
Files changed (1) hide show
  1. clause_tagger.py +51 -5
clause_tagger.py CHANGED
@@ -1,3 +1,4 @@
 
1
  from typing import List, Dict, Any
2
  from sentence_transformers import SentenceTransformer
3
  import numpy as np
@@ -14,8 +15,16 @@ class ClauseTagger:
14
  """Initialize embedding model and load clause references"""
15
  if self.embedding_model is None:
16
  print("🧠 Loading embedding model for clause tagging...")
17
- # Use a legal-domain model for better clause understanding
18
- self.embedding_model = SentenceTransformer('law-ai/InLegalBERT')
 
 
 
 
 
 
 
 
19
  print("✅ Embedding model loaded")
20
 
21
  # Load clause references
@@ -53,7 +62,7 @@ class ClauseTagger:
53
  return clauses
54
 
55
  async def tag_clauses(self, chunks: List[str]) -> List[Dict[str, Any]]:
56
- """Tag clauses in document chunks"""
57
  if not self.clause_reference:
58
  return []
59
 
@@ -75,7 +84,7 @@ class ClauseTagger:
75
  )[0][0]
76
 
77
  # Only include matches above threshold
78
- if similarity > 0.7: # Adjust threshold as needed
79
  tagged_clauses.append({
80
  'clause_id': clause['id'],
81
  'clause_type': clause['type'],
@@ -88,4 +97,41 @@ class ClauseTagger:
88
 
89
  # Sort by similarity score and return top matches
90
  tagged_clauses.sort(key=lambda x: x['similarity_score'], reverse=True)
91
- return tagged_clauses[:20] # Return top 20 matches
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # clause_tagger.py
2
  from typing import List, Dict, Any
3
  from sentence_transformers import SentenceTransformer
4
  import numpy as np
 
15
  """Initialize embedding model and load clause references"""
16
  if self.embedding_model is None:
17
  print("🧠 Loading embedding model for clause tagging...")
18
+
19
+ # Set cache directory explicitly for HF Spaces
20
+ cache_folder = "/tmp/sentence_transformers_cache"
21
+ os.makedirs(cache_folder, exist_ok=True)
22
+
23
+ # Use a legal-domain model with explicit cache directory
24
+ self.embedding_model = SentenceTransformer(
25
+ 'law-ai/InLegalBERT',
26
+ cache_folder=cache_folder
27
+ )
28
  print("✅ Embedding model loaded")
29
 
30
  # Load clause references
 
62
  return clauses
63
 
64
  async def tag_clauses(self, chunks: List[str]) -> List[Dict[str, Any]]:
65
+ """Tag clauses in document chunks - GENERATES NEW EMBEDDINGS"""
66
  if not self.clause_reference:
67
  return []
68
 
 
84
  )[0][0]
85
 
86
  # Only include matches above threshold
87
+ if similarity > 0.7:
88
  tagged_clauses.append({
89
  'clause_id': clause['id'],
90
  'clause_type': clause['type'],
 
97
 
98
  # Sort by similarity score and return top matches
99
  tagged_clauses.sort(key=lambda x: x['similarity_score'], reverse=True)
100
+ return tagged_clauses[:20]
101
+
102
+ async def tag_clauses_with_embeddings(self, chunk_data: List[Dict]) -> List[Dict[str, Any]]:
103
+ """Tag clauses using pre-computed embeddings - OPTIMIZED VERSION"""
104
+ if not self.clause_reference:
105
+ return []
106
+
107
+ print(f"🏷️ Tagging clauses using pre-computed embeddings for {len(chunk_data)} chunks...")
108
+
109
+ tagged_clauses = []
110
+
111
+ for chunk_idx, chunk_info in enumerate(chunk_data):
112
+ chunk_embedding = chunk_info["embedding"]
113
+
114
+ if chunk_embedding is None:
115
+ continue
116
+
117
+ # Find best matching clauses using pre-computed embedding
118
+ for clause in self.clause_reference:
119
+ similarity = cosine_similarity(
120
+ [chunk_embedding],
121
+ [clause['embedding']]
122
+ )[0][0]
123
+
124
+ if similarity > 0.7:
125
+ tagged_clauses.append({
126
+ 'clause_id': clause['id'],
127
+ 'clause_type': clause['type'],
128
+ 'clause_category': clause['category'],
129
+ 'matched_text': chunk_info["text"][:200] + '...' if len(chunk_info["text"]) > 200 else chunk_info["text"],
130
+ 'similarity_score': float(similarity),
131
+ 'chunk_index': chunk_idx,
132
+ 'reference_text': clause['text']
133
+ })
134
+
135
+ tagged_clauses.sort(key=lambda x: x['similarity_score'], reverse=True)
136
+ return tagged_clauses[:6]
137
+