File size: 5,686 Bytes
137a4e7 c1c33c3 137a4e7 c1c33c3 137a4e7 c1c33c3 137a4e7 c1c33c3 137a4e7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 |
# clause_tagger.py
from typing import List, Dict, Any
from sentence_transformers import SentenceTransformer
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity
import csv
import os
class ClauseTagger:
def __init__(self):
self.embedding_model = None
self.clause_reference = []
async def initialize(self):
"""Initialize embedding model and load clause references"""
if self.embedding_model is None:
print("π§ Loading embedding model for clause tagging...")
# Set cache directory explicitly for HF Spaces
cache_folder = "/tmp/sentence_transformers_cache"
os.makedirs(cache_folder, exist_ok=True)
# Use a legal-domain model with explicit cache directory
self.embedding_model = SentenceTransformer(
'law-ai/InLegalBERT',
cache_folder=cache_folder
)
print("β
Embedding model loaded")
# Load clause references
self.clause_reference = self._load_clause_reference()
if self.clause_reference:
# Pre-embed clause references
clause_texts = [clause['text'] for clause in self.clause_reference]
clause_embeddings = self.embedding_model.encode(clause_texts)
for i, clause in enumerate(self.clause_reference):
clause['embedding'] = clause_embeddings[i]
print(f"π Loaded and embedded {len(self.clause_reference)} clause references")
def _load_clause_reference(self) -> List[Dict[str, Any]]:
"""Load clause reference data"""
clause_file = "clause_refrence.csv" # Your existing file
if not os.path.exists(clause_file):
print(f"β οΈ Clause reference file not found: {clause_file}")
return []
clauses = []
try:
with open(clause_file, 'r', encoding='utf-8') as f:
reader = csv.DictReader(f)
for row in reader:
clauses.append({
'id': row.get('id', ''),
'type': row.get('type', ''),
'text': row.get('text', ''),
'category': row.get('category', 'general')
})
except Exception as e:
print(f"β Error loading clause reference: {e}")
return []
return clauses
async def tag_clauses(self, chunks: List[str]) -> List[Dict[str, Any]]:
"""Tag clauses in document chunks - GENERATES NEW EMBEDDINGS"""
if not self.clause_reference:
return []
print(f"π·οΈ Tagging clauses in {len(chunks)} chunks...")
# Embed all chunks
chunk_embeddings = self.embedding_model.encode(chunks)
tagged_clauses = []
for chunk_idx, chunk in enumerate(chunks):
chunk_embedding = chunk_embeddings[chunk_idx]
# Find best matching clauses for this chunk
for clause in self.clause_reference:
similarity = cosine_similarity(
[chunk_embedding],
[clause['embedding']]
)[0][0]
# Only include matches above threshold
if similarity > 0.7:
tagged_clauses.append({
'clause_id': clause['id'],
'clause_type': clause['type'],
'clause_category': clause['category'],
'matched_text': chunk[:200] + '...' if len(chunk) > 200 else chunk,
'similarity_score': float(similarity),
'chunk_index': chunk_idx,
'reference_text': clause['text']
})
# Sort by similarity score and return top matches
tagged_clauses.sort(key=lambda x: x['similarity_score'], reverse=True)
return tagged_clauses[:20]
async def tag_clauses_with_embeddings(self, chunk_data: List[Dict]) -> List[Dict[str, Any]]:
"""Tag clauses using pre-computed embeddings - OPTIMIZED VERSION"""
if not self.clause_reference:
return []
print(f"π·οΈ Tagging clauses using pre-computed embeddings for {len(chunk_data)} chunks...")
tagged_clauses = []
for chunk_idx, chunk_info in enumerate(chunk_data):
chunk_embedding = chunk_info["embedding"]
if chunk_embedding is None:
continue
# Find best matching clauses using pre-computed embedding
for clause in self.clause_reference:
similarity = cosine_similarity(
[chunk_embedding],
[clause['embedding']]
)[0][0]
if similarity > 0.7:
tagged_clauses.append({
'clause_id': clause['id'],
'clause_type': clause['type'],
'clause_category': clause['category'],
'matched_text': chunk_info["text"][:200] + '...' if len(chunk_info["text"]) > 200 else chunk_info["text"],
'similarity_score': float(similarity),
'chunk_index': chunk_idx,
'reference_text': clause['text']
})
tagged_clauses.sort(key=lambda x: x['similarity_score'], reverse=True)
return tagged_clauses[:6]
|