Spaces:
Sleeping
Sleeping
import requests | |
import io | |
import re | |
import numpy as np | |
import faiss | |
import torch | |
import time | |
import streamlit as st | |
from pypdf import PdfReader | |
from rank_bm25 import BM25Okapi | |
from sentence_transformers import SentenceTransformer | |
from accelerate import Accelerator | |
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline | |
from bert_score import score | |
def download_pdf(url): | |
"""Downloads a PDF from a URL and returns its content as bytes.""" | |
try: | |
response = requests.get(url, stream=True) | |
response.raise_for_status() | |
return response.content | |
except requests.exceptions.RequestException as e: | |
st.error(f"Error downloading PDF from {url}: {e}") | |
return None | |
def extract_text_from_pdf(pdf_bytes): | |
"""Extracts text from a PDF byte stream.""" | |
try: | |
pdf_file = io.BytesIO(pdf_bytes) | |
reader = PdfReader(pdf_file) | |
text = "" | |
for page in reader.pages: | |
text += page.extract_text() or "" #Handle None return. | |
return text | |
except Exception as e: | |
st.error(f"Error extracting text from PDF: {e}") | |
return None | |
def preprocess_text(text): | |
"""Cleans text while retaining financial symbols and ensuring proper formatting.""" | |
if not text: | |
return "" | |
# Define allowed financial symbols | |
financial_symbols = r"\$\€\₹\£\¥\₩\₽\₮\₦\₲" | |
# Allow numbers, letters, spaces, financial symbols, common punctuation (.,%/-) | |
text = re.sub(fr"[^\w\s{financial_symbols}.,%/₹$€¥£-]", "", text) | |
# Normalize spaces | |
text = re.sub(r'\s+', ' ', text).strip() | |
return text | |
def load_financial_pdfs(pdf_urls): | |
"""Downloads and extracts text from a list of PDF URLs.""" | |
all_data = [] | |
for url in pdf_urls: | |
pdf_bytes = download_pdf(url) | |
if pdf_bytes: | |
text = extract_text_from_pdf(pdf_bytes) | |
if text: | |
preprocessed_text = preprocess_text(text) | |
all_data.append(preprocessed_text) | |
return all_data | |
# Example Usage (Replace with actual PDF URLs) | |
pdf_urls = [ | |
"https://www.latentview.com/wp-content/uploads/2023/07/LatentView-Annual-Report-2022-23.pdf", | |
"https://www.latentview.com/wp-content/uploads/2024/08/LatentView-Annual-Report-2023-24.pdf", | |
] | |
all_data = load_financial_pdfs(pdf_urls) | |
def chunk_text(text, chunk_size=700, overlap_size=150): | |
"""Chunks text without breaking words in the middle (corrected overlap).""" | |
chunks = [] | |
start = 0 | |
text_length = len(text) | |
while start < text_length: | |
end = min(start + chunk_size, text_length) | |
# Ensure we do not split words | |
if end < text_length and text[end].isalnum(): | |
last_space = text.rfind(" ", start, end) # Find last space within the chunk | |
if last_space != -1: # If a space is found, adjust the end | |
end = last_space | |
chunk = text[start:end].strip() | |
if chunk: # Avoid empty chunks | |
chunks.append(chunk) | |
if end == text_length: | |
break | |
# Corrected overlap calculation | |
overlap_start = max(0, end - overlap_size) | |
if overlap_start < end: # Prevent infinite loop if overlap_start is equal to end. | |
last_overlap_space = text.rfind(" ", 0, overlap_start) | |
if last_overlap_space != -1 and last_overlap_space > start: | |
start = last_overlap_space + 1 | |
else: | |
start = end # If no space found, start at the last end. | |
else: | |
start = end | |
return chunks | |
chunks = [] | |
for data in all_data: | |
chunks.extend(chunk_text(data)) | |
embedding_model = SentenceTransformer("BAAI/bge-large-en") | |
# embedding_model = SentenceTransformer('multi-qa-mpnet-base-dot-v1') | |
embeddings = embedding_model.encode(chunks) | |
index = faiss.IndexFlatL2(embeddings.shape[1]) | |
index.add(embeddings) | |
def bm25_retrieval(query, documents, top_k=3): | |
tokenized_docs = [doc.split() for doc in documents] | |
bm25 = BM25Okapi(tokenized_docs) | |
return [documents[i] for i in np.argsort(bm25.get_scores(query.split()))[::-1][:top_k]] | |
def adaptive_retrieval(query, index, chunks, top_k=3, bm25_weight=0.5): | |
query_embedding = embedding_model.encode([query], convert_to_numpy=True, dtype=np.float16) | |
_, indices = index.search(query_embedding, top_k) | |
vector_results = [chunks[i] for i in indices[0]] | |
bm25_results = bm25_retrieval(query, chunks, top_k) | |
return list(set(vector_results + bm25_results)) | |
def rerank(query, results): | |
query_embedding = embedding_model.encode([query], convert_to_numpy=True) | |
result_embeddings = embedding_model.encode(results, convert_to_numpy=True) | |
similarities = np.dot(result_embeddings, query_embedding.T).flatten() | |
return [results[i] for i in np.argsort(similarities)[::-1]], similarities | |
#Chunk merging. | |
def merge_chunks(retrieved_chunks, overlap_size=100): | |
"""Merges overlapping chunks properly by detecting the actual overlap.""" | |
merged_chunks = [] | |
buffer = retrieved_chunks[0] if retrieved_chunks else "" | |
for i in range(1, len(retrieved_chunks)): | |
chunk = retrieved_chunks[i] | |
# Find actual overlap | |
overlap_start = buffer[-overlap_size:] # Get the last `overlap_size` chars of the previous chunk | |
overlap_index = chunk.find(overlap_start) # Find where this part appears in the new chunk | |
if overlap_index != -1: | |
# Merge only the non-overlapping part | |
buffer += chunk[overlap_index + overlap_size:] | |
else: | |
# Store completed merged chunk and start a new one | |
merged_chunks.append(buffer) | |
buffer = chunk | |
if buffer: | |
merged_chunks.append(buffer) | |
return merged_chunks | |
# def calculate_confidence(query, context, similarities): | |
# return np.mean(similarities) # Averaged similarity scores | |
def calculate_confidence(query, answer): | |
P, R, F1 = score([answer], [query], lang="en", verbose=False) | |
return F1.item() | |
# Load SLM | |
accelerator = Accelerator() | |
accelerator.free_memory() | |
MODEL_NAME = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B" | |
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) | |
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, device_map="auto", cache_dir="./my_models") | |
model = accelerator.prepare(model) | |
generator = pipeline("text-generation", model=model, tokenizer=tokenizer) | |
def generate_response(query, context): | |
prompt = f"""Your task is to analyze the given Context and answer the Question concisely in plain English. | |
**Guidelines:** | |
- Do NOT include </think> tag, just provide the final answer only. | |
- Provide a direct, factual answer based strictly on the Context. | |
- Avoid generating Python code, solutions, or any irrelevant information. | |
Context: {context} | |
Question: {query} | |
Answer: | |
""" | |
response = generator(prompt, max_new_tokens=150, num_return_sequences=1)[0]['generated_text'] #example 100 max new tokens | |
print(response) | |
answer = response.split("Answer:")[1].strip() | |
return answer | |
import gradio as gr | |
# Your existing functions should be defined before using them | |
# adaptive_retrieval, merge_chunks, rerank, generate_response, calculate_confidence | |
def inference_pipeline(query): | |
retrieved_chunks = adaptive_retrieval(query, index, chunks) | |
merged_chunks = merge_chunks(retrieved_chunks, 150) | |
reranked_chunks, similarities = rerank(query, merged_chunks) | |
context = " ".join(reranked_chunks[:3]) # Take top 3 most relevant | |
response = generate_response(query, context) | |
confidence = calculate_confidence(query, context, similarities) | |
return response, confidence | |
# Define the Gradio UI | |
ui = gr.Interface( | |
fn=inference_pipeline, | |
inputs=gr.Textbox(label="Enter your financial question"), | |
outputs=[ | |
gr.Textbox(label="Generated Response"), | |
gr.Number(label="Confidence Score"), | |
], | |
title="Financial Q&A Assistant", | |
description="Ask financial questions and get AI-powered responses with confidence scores.", | |
) | |
# Launch the UI | |
ui.launch(share=True) # share=True allows public access |