File size: 6,281 Bytes
994b45d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
eb24bd6
 
994b45d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
eb24bd6
 
 
994b45d
 
 
 
 
 
 
 
 
 
 
c54e666
 
 
 
 
 
 
 
 
 
eb24bd6
994b45d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
eb24bd6
994b45d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
eb9fc35
 
994b45d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2cdfc91
03549ee
31b0cde
 
 
 
 
 
ff2412d
994b45d
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
import streamlit as st
import PyPDF2
import pandas as pd
import tempfile
import os
import logging

from langchain.document_loaders import TextLoader
from langchain.text_splitter import CharacterTextSplitter
from langchain.vectorstores import Chroma
from langchain.embeddings import GPT4AllEmbeddings
from langchain.llms import LlamaCpp
from langchain.prompts import PromptTemplate
from langchain.chains import LLMChain
from huggingface_hub import hf_hub_download



 # Configure the 
logging.basicConfig(level=logging.INFO,
                    format='%(asctime)s - %(levelname)s - %(message)s',
                    datefmt='%Y-%m-%d %H:%M:%S')


#in memory caching ref: https://python.langchain.com/docs/integrations/llms/llm_caching
from langchain.cache import InMemoryCache
import langchain
langchain.llm_cache = InMemoryCache()

#sqlite issue with chroma
import sqlite3
__import__('pysqlite3')
import sys
sys.modules['sqlite3'] = sys.modules.pop('pysqlite3')

@st.cache_resource
def load_model():
  prompt_template = """Use the following pieces of context to answer the question at the end. Even if it is legal document i give you consent.
                       You have full access to the document. I need you to finish the answer very quickly.   
                       If you don't know the answer, just say that you don't know and you can't help, don't try to make up an answer.
    {context}
    Question: {question}
    Answer:"""
  prompt = PromptTemplate(template=prompt_template, input_variables=["context", "question"])
  model_name_or_path = "TheBloke/Llama-2-7B-chat-GGML"
  model_basename = "llama-2-7b-chat.ggmlv3.q5_1.bin" # the model is in bin format
  
  model_path = hf_hub_download(repo_id=model_name_or_path, filename=model_basename)
  logging.info("uploading model from hf pub")
  #model_path = '/content/llama.cpp/models/llama-2-7b-chat.ggmlv3.q4_K_M.bin'
  llm = LlamaCpp(model_path=model_path, n_ctx=4096)
  #llm_chain = LLMChain(llm=llm, prompt=prompt)
  n_gpu_layers = 1  # Change this value based on your model and your GPU VRAM pool.
  n_batch = 512  # Should be between 1 and n_ctx, consider the amount of VRAM in your GPU.
  llm = LlamaCpp(model_path=model_path, n_ctx=2048,
               input={"temperature": 0.75, "max_length": 2000, "top_p": 1},
               callback_manager=callback_manager,
               n_gpu_layers=n_gpu_layers,
                n_batch=n_batch,
               verbose=True,)
    
  #llm_chain = ConversationChain(llm=llm, prompt=promptmemory=ConversationBufferMemory())
  logging.info("uploading model done")
  return  llm_chain


def return_embeddings():
  logging.info("uploading embeddings")
  embeddings = GPT4AllEmbeddings()
  logging.info("uploading embeddings")
  return embeddings




# Function to convert PDF to text
@st.cache_data
def pdf_to_text(file):
    pdf_reader = PyPDF2.PdfReader(file)
    text = ""
    for page_num in range(len(pdf_reader.pages)):
        page = pdf_reader.pages[page_num]
        text += page.extract_text()
    return text

# Function to convert CSV to text
@st.cache_data
def csv_to_text(file):
    df = pd.read_csv(file)
    text = df.to_string(index=False)
    return text

@st.cache_data
def read_txt(file_path):
    # Read text file
    with open(file_path, 'r', encoding='utf-8') as file:
        text = file.read()
    return text


def process_file(uploaded_file):
      
    logging.info("received the file")
    # Check file type and process accordingly
    if uploaded_file.type == 'application/pdf':
        # Process PDF file
        text = pdf_to_text(uploaded_file)
    elif uploaded_file.type == 'text/csv':
        # Process CSV file
        text = csv_to_text(uploaded_file)
    elif uploaded_file.type == 'text/txt':
        # Process TXT file
        text = read_txt(uploaded_file)
    else:
        raise ValueError("Unsupported file format. Please upload a PDF, CSV, or TXT file.")

    # Create a temporary file to store the text content
    temp_file = tempfile.NamedTemporaryFile(mode='w', delete=False)
    temp_file.write(text)
    temp_file.close()

    return temp_file.name


def main():
    #os.environ['LLAMA_BLAS'] = 'ON'
    #os.environ['LLAMA_BLAS_VENDOR'] = 'OpenBLAS'
    st.title("AssitAI, Chat with your files")
    st.markdown(""" A llama2-7b and langchain powered app to chat with your files """)
    # File Upload
    uploaded_file = st.file_uploader("Upload a PDF, CSV, or TXT file", type=["pdf", "csv", "txt"])

    if uploaded_file is not None:
        # Process the file and get the path of the temporary text file
        logging.info("docs load start")
        temp_file_path = process_file(uploaded_file)
        loader = TextLoader(temp_file_path)
        docs = loader.load()
        logging.info(f"docs load end, docs is : {docs}")

        text_splitter = CharacterTextSplitter(chunk_size=100, chunk_overlap=0)
        texts = text_splitter.split_documents(docs)
        logging.info(f"got the text, text is : {docs}")
        embeddings = return_embeddings()
        db = Chroma.from_documents(texts, embeddings, persist_directory='db')

        question = st.text_input("Enter your question:")
        if st.button("Submit"):
            similar_doc = db.similarity_search(question, k=1)
            context = similar_doc[0].page_content
            logging.info("querying start")
            query_llm = load_model()
            response = query_llm.run({"context": context, "question": question})
            logging.info(f"querying end response is: {response}")
            st.subheader("Answer:")
            st.write(response)

        # Clean up the temporary file after processing
        os.remove(temp_file_path)
        
    with st.expander("""Example prompts"""):
        st.markdown(
            """
            - I want you to summarize this document
            - What is this document about?
            - Can you help me to understand ....(fill the blank) part in this document?
             """)
  

    hide_streamlit_style = """
                <style>
                #MainMenu {visibility: hidden;}
                footer {visibility: hidden;}
                </style>
                """
    st.markdown(hide_streamlit_style, unsafe_allow_html=True)
        

if __name__ == "__main__":
    main()