File size: 6,224 Bytes
c4eb084
 
 
 
 
 
 
a085452
 
c4eb084
 
 
a085452
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c4eb084
a085452
c4eb084
 
a085452
 
 
 
c4eb084
 
a085452
 
 
 
 
 
 
 
c4eb084
 
 
a085452
 
c4eb084
a085452
 
c4eb084
a085452
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c4eb084
 
 
 
a085452
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c4eb084
 
 
 
 
 
a085452
c4eb084
 
a085452
c4eb084
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a085452
c4eb084
 
 
 
 
 
a085452
c4eb084
 
 
 
 
 
 
 
 
 
 
 
 
a085452
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
import time
from typing import List, Dict, Any
from transformers import pipeline, AutoTokenizer
import os

class DocumentSummarizer:
    def __init__(self):
        self.bart_pipeline = None
        self.legal_pipeline = None
        self.tokenizer = None
        
    async def initialize(self):
        """Initialize both summarization models"""
        print(" Loading BART summarizer...")
        start_time = time.time()
        
        # Initialize reliable BART model first
        self.bart_pipeline = pipeline("summarization", model="facebook/bart-large-cnn")
        self.tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large-cnn")
        print(f" BART model loaded in {time.time() - start_time:.2f}s")
        
        # Try to load legal model
        print(" Loading legal summarizer...")
        hf_token = os.getenv("HF_TOKEN")
        try:
            legal_start = time.time()
            self.legal_pipeline = pipeline(
                "summarization", 
                model="VincentMuriuki/legal-summarizer",
                token=hf_token
            )
            print(f" Legal model loaded in {time.time() - legal_start:.2f}s")
        except Exception as e:
            print(f"⚠️ Legal model failed, using BART only: {e}")
            self.legal_pipeline = None
    
    async def batch_summarize(self, chunks: List[str]) -> Dict[str, Any]:
        """Choose strategy based on available models"""
        if self.legal_pipeline:
            return await self.hybrid_summarize(chunks)
        else:
            return await self.bart_only_summarize(chunks)
    
    async def hybrid_summarize(self, chunks: List[str]) -> Dict[str, Any]:
        """Two-stage summarization: BART → Legal-specific"""
        if not chunks:
            return {"actual_summary": "", "short_summary": ""}
        
        print(f"Stage 1: Initial summarization with BART ({len(chunks)} chunks)...")
        stage1_start = time.time()
        
        # Stage 1: Facebook BART for clean, reliable summarization
        initial_summaries = self.bart_pipeline(
            chunks,
            max_length=150,  # Slightly longer for more detail
            min_length=30,
            do_sample=False,
            num_beams=2,
            truncation=True
        )
        
        initial_summary = " ".join([s["summary_text"] for s in initial_summaries])
        stage1_time = time.time() - stage1_start
        print(f" Stage 1 completed in {stage1_time:.2f}s")
        
        # Stage 2: Vincent's legal model for domain refinement
        print(" Stage 2: Legal refinement with specialized model...")
        stage2_start = time.time()
        
        # Break the initial summary into smaller chunks if needed
        if len(initial_summary) > 3000:
            # Use simple chunking since we don't have chunker here
            words = initial_summary.split()
            refined_chunks = []
            chunk_size = 800  # words per chunk
            for i in range(0, len(words), chunk_size):
                chunk = " ".join(words[i:i + chunk_size])
                refined_chunks.append(chunk)
        else:
            refined_chunks = [initial_summary]
        
        final_summaries = self.legal_pipeline(
            refined_chunks,
            max_length=128,
            min_length=24,
            do_sample=False,
            num_beams=1,
            truncation=True
        )
        
        final_summary = " ".join([s["summary_text"] for s in final_summaries])
        stage2_time = time.time() - stage2_start
        print(f" Stage 2 completed in {stage2_time:.2f}s")
        
        total_time = stage1_time + stage2_time
        
        return {
            "actual_summary": final_summary,
            "short_summary": final_summary,
            "initial_bart_summary": initial_summary,  # For comparison
            "processing_method": "hybrid_bart_to_legal",
            "time_taken": f"{total_time:.2f}s",
            "stage1_time": f"{stage1_time:.2f}s",
            "stage2_time": f"{stage2_time:.2f}s"
        }
    
    async def bart_only_summarize(self, chunks: List[str]) -> Dict[str, Any]:
        """Fallback to BART-only summarization"""
        if not chunks:
            return {"actual_summary": "", "short_summary": ""}
            
        print(f" BART-only summarization ({len(chunks)} chunks)...")
        start_time = time.time()
        
        outputs = self.bart_pipeline(
            chunks,
            max_length=128,
            min_length=24,
            do_sample=False,
            num_beams=2,
            truncation=True,
        )
        
        summaries = [output["summary_text"] for output in outputs]
        combined_summary = " ".join(summaries)
        
        # Create short summary if combined is too long
        short_summary = combined_summary
        if len(combined_summary) > 2000:
            short_outputs = self.bart_pipeline(
                [combined_summary],
                max_length=96,
                min_length=16,
                do_sample=False,
                num_beams=1,
                truncation=True,
            )
            short_summary = short_outputs[0]["summary_text"]
        
        processing_time = time.time() - start_time
        
        return {
            "actual_summary": combined_summary,
            "short_summary": short_summary,
            "individual_summaries": summaries,
            "processing_method": "bart_only",
            "time_taken": f"{processing_time:.2f}s"
        }
    
    def summarize_texts_sync(self, texts: List[str], max_length: int, min_length: int) -> Dict[str, Any]:
        """Synchronous batch summarization for standalone endpoint"""
        start_time = time.time()
        outputs = self.bart_pipeline(  # Use BART for reliability
            texts,
            max_length=max_length,
            min_length=min_length,
            do_sample=False,
            num_beams=1,
            truncation=True,
        )
        summaries = [output["summary_text"] for output in outputs]
        return {
            "summaries": summaries,
            "count": len(summaries),
            "time_taken": f"{time.time() - start_time:.2f}s"
        }