File size: 3,990 Bytes
8090f7b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
01ee054
8090f7b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b4ab0cc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8090f7b
b4ab0cc
 
c32636f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import os
import requests
import streamlit as st
from io import BytesIO
from PyPDF2 import PdfReader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.vectorstores import FAISS
from transformers import pipeline
import torch

st.set_page_config(page_title="RAG-based PDF Chat", layout="centered", page_icon="πŸ“„")

@st.cache_resource
def load_summarization_pipeline():
    try:
        summarizer = pipeline("summarization", model="facebook/bart-large-cnn", device=0 if torch.cuda.is_available() else -1)
        return summarizer
    except Exception as e:
        st.error(f"Failed to load the summarization model: {e}")
        return None

summarizer = load_summarization_pipeline()

PDF_FOLDERS = {
    "Folder 1": ["https://huggingface.co/spaces/ZeeAI1/LawTest3/tree/main/documents1"]
}

def fetch_pdf_text_from_folders(pdf_folders):
    all_text = ""
    for folder_name, urls in pdf_folders.items():
        folder_text = f"\n[Folder: {folder_name}]\n"
        for url in urls:
            try:
                response = requests.get(url)
                response.raise_for_status()
                pdf_file = BytesIO(response.content)
                pdf_reader = PdfReader(pdf_file)
                for page in pdf_reader.pages:
                    page_text = page.extract_text()
                    if page_text:
                        folder_text += page_text
            except Exception as e:
                st.error(f"Error fetching PDF from {url}: {e}")
        all_text += folder_text
    return all_text

@st.cache_data
def get_text_chunks(text):
    text_splitter = RecursiveCharacterTextSplitter(chunk_size=2000, chunk_overlap=200)
    return text_splitter.split_text(text)

@st.cache_resource
def load_embedding_function():
    try:
        return HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
    except Exception as e:
        st.error(f"Failed to load embedding model: {e}")
        return None

embedding_function = load_embedding_function()

@st.cache_resource
def load_or_create_vector_store(text_chunks):
    if not text_chunks:
        st.error("No valid text chunks found.")
        return None
    try:
        return FAISS.from_texts(text_chunks, embedding=embedding_function)
    except Exception as e:
        st.error(f"Failed to create or load vector store: {e}")
        return None

def generate_summary_with_huggingface(query, retrieved_text):
    summarization_input = f"{query}\n\nRelated information:\n{retrieved_text}"[:1024]
    try:
        summary = summarizer(summarization_input, max_length=500, min_length=50, do_sample=False)
        return summary[0]["summary_text"]
    except Exception as e:
        st.error(f"Failed to generate summary: {e}")
        return "Error generating summary."

def user_input(user_question, vector_store):
    if vector_store is None:
        return "Vector store is empty."
    try:
        docs = vector_store.similarity_search(user_question)
        context_text = " ".join([doc.page_content for doc in docs])
        return generate_summary_with_huggingface(user_question, context_text)
    except Exception as e:
        st.error(f"Error in similarity search: {e}")
        return "Error in similarity search."

def main():
    st.title("πŸ“„ Gen AI Lawyers Guide")
    raw_text = fetch_pdf_text_from_folders(PDF_FOLDERS)
    text_chunks = get_text_chunks(raw_text)
    vector_store = load_or_create_vector_store(text_chunks)

    user_question = st.text_input("Ask a Question:", placeholder="Type your question here...")

    if st.button("Get Response"):
        if not user_question:
            st.warning("Please enter a question before submitting.")
        else:
            with st.spinner("Generating response..."):
                answer = user_input(user_question, vector_store)
                st.markdown(f"**πŸ€– AI:** {answer}")

if __name__ == "__main__":
    main()