LegalDoc / stage_4.py
Hyma Roshini Gompa
Add Streamlit app files
0e553a7
raw
history blame
14.2 kB
import streamlit as st
import shelve
import docx2txt
import PyPDF2
import time # Used to simulate typing effect
import nltk
import re
import os
import time # already imported in your code
import requests
from dotenv import load_dotenv
import torch
from sentence_transformers import SentenceTransformer, util
nltk.download('punkt')
import hashlib
from nltk import sent_tokenize
nltk.download('punkt_tab')
from transformers import LEDTokenizer, LEDForConditionalGeneration
from transformers import pipeline
import asyncio
import sys
# Fix for RuntimeError: no running event loop on Windows
if sys.platform.startswith("win"):
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
st.set_page_config(page_title="Legal Document Summarizer", layout="wide")
st.title("πŸ“„ Legal Document Summarizer (stage 4 )")
USER_AVATAR = "πŸ‘€"
BOT_AVATAR = "πŸ€–"
# Load chat history
def load_chat_history():
with shelve.open("chat_history") as db:
return db.get("messages", [])
# Save chat history
def save_chat_history(messages):
with shelve.open("chat_history") as db:
db["messages"] = messages
# Function to limit text preview to 500 words
def limit_text(text, word_limit=500):
words = text.split()
return " ".join(words[:word_limit]) + ("..." if len(words) > word_limit else "")
# CLEAN AND NORMALIZE TEXT
def clean_text(text):
# Remove newlines and extra spaces
text = text.replace('\r\n', ' ').replace('\n', ' ')
text = re.sub(r'\s+', ' ', text)
# Remove page number markers like "Page 1 of 10"
text = re.sub(r'Page\s+\d+\s+of\s+\d+', '', text, flags=re.IGNORECASE)
# Remove long dashed or underscored lines
text = re.sub(r'[_]{5,}', '', text) # Lines with underscores: _____
text = re.sub(r'[-]{5,}', '', text) # Lines with hyphens: -----
# Remove long dotted separators
text = re.sub(r'[.]{4,}', '', text) # Dots like "......" or ".............."
# Trim final leading/trailing whitespace
text = text.strip()
return text
#######################################################################################################################
# LOADING MODELS FOR DIVIDING TEXT INTO SECTIONS
# Load token from .env file
load_dotenv()
HF_API_TOKEN = os.getenv("HF_API_TOKEN")
# Load once at the top (cache for performance)
@st.cache_resource
def load_local_zero_shot_classifier():
return pipeline("zero-shot-classification", model="typeform/distilbert-base-uncased-mnli")
local_classifier = load_local_zero_shot_classifier()
SECTION_LABELS = ["Facts", "Arguments", "Judgment", "Other"]
def classify_chunk(text):
result = local_classifier(text, candidate_labels=SECTION_LABELS)
return result["labels"][0]
# NEW: NLP-based sectioning using zero-shot classification
def section_by_zero_shot(text):
sections = {"Facts": "", "Arguments": "", "Judgment": "", "Other": ""}
sentences = sent_tokenize(text)
chunk = ""
for i, sent in enumerate(sentences):
chunk += sent + " "
if (i + 1) % 3 == 0 or i == len(sentences) - 1:
label = classify_chunk(chunk.strip())
print(f"πŸ”Ž Chunk: {chunk[:60]}...\nπŸ”– Predicted Label: {label}")
# πŸ‘‡ Normalize label (title case and fallback)
label = label.capitalize()
if label not in sections:
label = "Other"
sections[label] += chunk + "\n"
chunk = ""
return sections
#######################################################################################################################
# EXTRACTING TEXT FROM UPLOADED FILES
# Function to extract text from uploaded file
def extract_text(file):
if file.name.endswith(".pdf"):
reader = PyPDF2.PdfReader(file)
full_text = "\n".join(page.extract_text() or "" for page in reader.pages)
elif file.name.endswith(".docx"):
full_text = docx2txt.process(file)
elif file.name.endswith(".txt"):
full_text = file.read().decode("utf-8")
else:
return "Unsupported file type."
return full_text # Full text is needed for summarization
#######################################################################################################################
# EXTRACTIVE AND ABSTRACTIVE SUMMARIZATION
@st.cache_resource
def load_legalbert():
return SentenceTransformer("nlpaueb/legal-bert-base-uncased")
legalbert_model = load_legalbert()
@st.cache_resource
def load_led():
tokenizer = LEDTokenizer.from_pretrained("allenai/led-base-16384")
model = LEDForConditionalGeneration.from_pretrained("allenai/led-base-16384")
return tokenizer, model
tokenizer_led, model_led = load_led()
def legalbert_extractive_summary(text, top_ratio=0.2):
sentences = sent_tokenize(text)
top_k = max(3, int(len(sentences) * top_ratio))
if len(sentences) <= top_k:
return text
# Embeddings & scoring
sentence_embeddings = legalbert_model.encode(sentences, convert_to_tensor=True)
doc_embedding = torch.mean(sentence_embeddings, dim=0)
cosine_scores = util.pytorch_cos_sim(doc_embedding, sentence_embeddings)[0]
top_results = torch.topk(cosine_scores, k=top_k)
# Preserve original order
selected_sentences = [sentences[i] for i in sorted(top_results.indices.tolist())]
return " ".join(selected_sentences)
# Add LED Abstractive Summarization
def led_abstractive_summary(text, max_length=512, min_length=100):
inputs = tokenizer_led(
text, return_tensors="pt", padding="max_length",
truncation=True, max_length=4096
)
global_attention_mask = torch.zeros_like(inputs["input_ids"])
global_attention_mask[:, 0] = 1
outputs = model_led.generate(
inputs["input_ids"],
attention_mask=inputs["attention_mask"],
global_attention_mask=global_attention_mask,
max_length=max_length,
min_length=min_length,
num_beams=4, # Use beam search
repetition_penalty=2.0, # Penalize repetition
length_penalty=1.0,
early_stopping=True,
no_repeat_ngram_size=4 # Prevent repeated phrases
)
return tokenizer_led.decode(outputs[0], skip_special_tokens=True)
def led_abstractive_summary_chunked(text, max_tokens=3000):
sentences = sent_tokenize(text)
current_chunk = ""
chunks = []
for sent in sentences:
if len(tokenizer_led(current_chunk + sent)["input_ids"]) > max_tokens:
chunks.append(current_chunk)
current_chunk = sent
else:
current_chunk += " " + sent
if current_chunk:
chunks.append(current_chunk)
summaries = []
for chunk in chunks:
summaries.append(led_abstractive_summary(chunk)) # Call your LED summary function here
return " ".join(summaries)
def hybrid_summary_hierarchical(text, top_ratio=0.8):
cleaned_text = clean_text(text)
sections = section_by_zero_shot(cleaned_text)
structured_summary = {} # <-- hierarchical summary here
for name, content in sections.items():
if content.strip():
# Extractive summary
extractive = legalbert_extractive_summary(content, top_ratio)
# Abstractive summary
abstractive = led_abstractive_summary_chunked(extractive)
# Store in dictionary (hierarchical structure)
structured_summary[name] = {
"extractive": extractive,
"abstractive": abstractive
}
return structured_summary
#######################################################################################################################
# STREAMLIT APP INTERFACE CODE
# Initialize or load chat history
if "messages" not in st.session_state:
st.session_state.messages = load_chat_history()
# Initialize last_uploaded if not set
if "last_uploaded" not in st.session_state:
st.session_state.last_uploaded = None
# Sidebar with a button to delete chat history
with st.sidebar:
st.subheader("βš™οΈ Options")
if st.button("Delete Chat History"):
st.session_state.messages = []
st.session_state.last_uploaded = None
save_chat_history([])
# Display chat messages with a typing effect
def display_with_typing_effect(text, speed=0.005):
placeholder = st.empty()
displayed_text = ""
for char in text:
displayed_text += char
placeholder.markdown(displayed_text)
time.sleep(speed)
return displayed_text
# Show existing chat messages
for message in st.session_state.messages:
avatar = USER_AVATAR if message["role"] == "user" else BOT_AVATAR
with st.chat_message(message["role"], avatar=avatar):
st.markdown(message["content"])
# Standard chat input field
prompt = st.chat_input("Type a message...")
# Place uploader before the chat so it's always visible
with st.container():
st.subheader("πŸ“Ž Upload a Legal Document")
uploaded_file = st.file_uploader("Upload a file (PDF, DOCX, TXT)", type=["pdf", "docx", "txt"])
reprocess_btn = st.button("πŸ”„ Reprocess Last Uploaded File")
# Hashing logic
def get_file_hash(file):
file.seek(0)
content = file.read()
file.seek(0)
return hashlib.md5(content).hexdigest()
##############################################################################################################
user_role = st.sidebar.selectbox(
"🎭 Select Your Role for Custom Summary",
["General", "Judge", "Lawyer", "Student"]
)
def role_based_filter(section, summary, role):
if role == "General":
return summary
filtered_summary = {
"extractive": "",
"abstractive": ""
}
if role == "Judge" and section in ["Judgment", "Facts"]:
filtered_summary = summary
elif role == "Lawyer" and section in ["Arguments", "Facts"]:
filtered_summary = summary
elif role == "Student" and section in ["Facts"]:
filtered_summary = summary
return filtered_summary
if uploaded_file:
file_hash = get_file_hash(uploaded_file)
# Check if file is new OR reprocess is triggered
if file_hash != st.session_state.get("last_uploaded_hash") or reprocess_btn:
start_time = time.time() # Start the timer
raw_text = extract_text(uploaded_file)
summary_dict = hybrid_summary_hierarchical(raw_text)
st.session_state.messages.append({
"role": "user",
"content": f"πŸ“€ Uploaded **{uploaded_file.name}**"
})
# Start building preview
preview_text = f"🧾 **Hybrid Summary of {uploaded_file.name}:**\n\n"
for section in ["Facts", "Arguments", "Judgment", "Other"]:
if section in summary_dict:
filtered = role_based_filter(section, summary_dict[section], user_role)
extractive = filtered.get("extractive", "").strip()
abstractive = filtered.get("abstractive", "").strip()
if not extractive and not abstractive:
continue # Skip if empty after filtering
preview_text += f"### πŸ“˜ {section} Section\n"
preview_text += f"πŸ“Œ **Extractive Summary:**\n{extractive if extractive else '_No content extracted._'}\n\n"
preview_text += f"πŸ” **Abstractive Summary:**\n{abstractive if abstractive else '_No summary generated._'}\n\n"
# Display in chat
with st.chat_message("assistant", avatar=BOT_AVATAR):
display_with_typing_effect(clean_text(preview_text), speed=0)
# Show processing time after the summary
processing_time = round(time.time() - start_time, 2)
st.session_state["last_response_time"] = processing_time
if "last_response_time" in st.session_state:
st.info(f"⏱️ Response generated in **{st.session_state['last_response_time']} seconds**.")
st.session_state.messages.append({
"role": "assistant",
"content": clean_text(preview_text)
})
# Save this file hash only if it’s a new upload (avoid overwriting during reprocess)
if not reprocess_btn:
st.session_state.last_uploaded_hash = file_hash
save_chat_history(st.session_state.messages)
st.rerun()
# Handle chat input and return hybrid summary
if prompt:
raw_text = prompt
start_time = time.time()
summary_dict = hybrid_summary_hierarchical(raw_text)
st.session_state.messages.append({
"role": "user",
"content": prompt
})
# Start building preview
preview_text = f"🧾 **Hybrid Summary of {uploaded_file.name}:**\n\n"
for section in ["Facts", "Arguments", "Judgment", "Other"]:
if section in summary_dict:
filtered = role_based_filter(section, summary_dict[section], user_role)
extractive = filtered.get("extractive", "").strip()
abstractive = filtered.get("abstractive", "").strip()
if not extractive and not abstractive:
continue # Skip if empty after filtering
preview_text += f"### πŸ“˜ {section} Section\n"
preview_text += f"πŸ“Œ **Extractive Summary:**\n{extractive if extractive else '_No content extracted._'}\n\n"
preview_text += f"πŸ” **Abstractive Summary:**\n{abstractive if abstractive else '_No summary generated._'}\n\n"
# Display in chat
with st.chat_message("assistant", avatar=BOT_AVATAR):
display_with_typing_effect(clean_text(preview_text), speed=0)
# Show processing time after the summary
processing_time = round(time.time() - start_time, 2)
st.session_state["last_response_time"] = processing_time
if "last_response_time" in st.session_state:
st.info(f"⏱️ Response generated in **{st.session_state['last_response_time']} seconds**.")
st.session_state.messages.append({
"role": "assistant",
"content": clean_text(preview_text)
})
save_chat_history(st.session_state.messages)
st.rerun()