File size: 5,686 Bytes
137a4e7
c1c33c3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
137a4e7
 
 
 
 
 
 
 
 
 
c1c33c3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
137a4e7
c1c33c3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
137a4e7
c1c33c3
 
 
 
 
 
 
 
 
 
 
 
137a4e7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
# clause_tagger.py
from typing import List, Dict, Any
from sentence_transformers import SentenceTransformer
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity
import csv
import os

class ClauseTagger:
    def __init__(self):
        self.embedding_model = None
        self.clause_reference = []
        
    async def initialize(self):
        """Initialize embedding model and load clause references"""
        if self.embedding_model is None:
            print("🧠 Loading embedding model for clause tagging...")
            
            # Set cache directory explicitly for HF Spaces
            cache_folder = "/tmp/sentence_transformers_cache"
            os.makedirs(cache_folder, exist_ok=True)
            
            # Use a legal-domain model with explicit cache directory
            self.embedding_model = SentenceTransformer(
                'law-ai/InLegalBERT',
                cache_folder=cache_folder
            )
            print("βœ… Embedding model loaded")
        
        # Load clause references
        self.clause_reference = self._load_clause_reference()
        if self.clause_reference:
            # Pre-embed clause references
            clause_texts = [clause['text'] for clause in self.clause_reference]
            clause_embeddings = self.embedding_model.encode(clause_texts)
            for i, clause in enumerate(self.clause_reference):
                clause['embedding'] = clause_embeddings[i]
            print(f"πŸ“‹ Loaded and embedded {len(self.clause_reference)} clause references")
    
    def _load_clause_reference(self) -> List[Dict[str, Any]]:
        """Load clause reference data"""
        clause_file = "clause_refrence.csv"  # Your existing file
        if not os.path.exists(clause_file):
            print(f"⚠️ Clause reference file not found: {clause_file}")
            return []
        
        clauses = []
        try:
            with open(clause_file, 'r', encoding='utf-8') as f:
                reader = csv.DictReader(f)
                for row in reader:
                    clauses.append({
                        'id': row.get('id', ''),
                        'type': row.get('type', ''),
                        'text': row.get('text', ''),
                        'category': row.get('category', 'general')
                    })
        except Exception as e:
            print(f"❌ Error loading clause reference: {e}")
            return []
        
        return clauses
    
    async def tag_clauses(self, chunks: List[str]) -> List[Dict[str, Any]]:
        """Tag clauses in document chunks - GENERATES NEW EMBEDDINGS"""
        if not self.clause_reference:
            return []
        
        print(f"🏷️ Tagging clauses in {len(chunks)} chunks...")
        
        # Embed all chunks
        chunk_embeddings = self.embedding_model.encode(chunks)
        
        tagged_clauses = []
        
        for chunk_idx, chunk in enumerate(chunks):
            chunk_embedding = chunk_embeddings[chunk_idx]
            
            # Find best matching clauses for this chunk
            for clause in self.clause_reference:
                similarity = cosine_similarity(
                    [chunk_embedding], 
                    [clause['embedding']]
                )[0][0]
                
                # Only include matches above threshold
                if similarity > 0.7:
                    tagged_clauses.append({
                        'clause_id': clause['id'],
                        'clause_type': clause['type'],
                        'clause_category': clause['category'],
                        'matched_text': chunk[:200] + '...' if len(chunk) > 200 else chunk,
                        'similarity_score': float(similarity),
                        'chunk_index': chunk_idx,
                        'reference_text': clause['text']
                    })
        
        # Sort by similarity score and return top matches
        tagged_clauses.sort(key=lambda x: x['similarity_score'], reverse=True)
        return tagged_clauses[:20]

    async def tag_clauses_with_embeddings(self, chunk_data: List[Dict]) -> List[Dict[str, Any]]:
        """Tag clauses using pre-computed embeddings - OPTIMIZED VERSION"""
        if not self.clause_reference:
            return []
            
        print(f"🏷️ Tagging clauses using pre-computed embeddings for {len(chunk_data)} chunks...")
        
        tagged_clauses = []
        
        for chunk_idx, chunk_info in enumerate(chunk_data):
            chunk_embedding = chunk_info["embedding"]
            
            if chunk_embedding is None:
                continue
                
            # Find best matching clauses using pre-computed embedding
            for clause in self.clause_reference:
                similarity = cosine_similarity(
                    [chunk_embedding], 
                    [clause['embedding']]
                )[0][0]
                
                if similarity > 0.7:
                    tagged_clauses.append({
                        'clause_id': clause['id'],
                        'clause_type': clause['type'],
                        'clause_category': clause['category'],
                        'matched_text': chunk_info["text"][:200] + '...' if len(chunk_info["text"]) > 200 else chunk_info["text"],
                        'similarity_score': float(similarity),
                        'chunk_index': chunk_idx,
                        'reference_text': clause['text']
                    })
        
        tagged_clauses.sort(key=lambda x: x['similarity_score'], reverse=True)
        return tagged_clauses[:6]