JustusI commited on
Commit
61caa6e
·
verified ·
1 Parent(s): 2285b10

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +33 -63
app.py CHANGED
@@ -10,52 +10,38 @@ from langchain_core.messages import HumanMessage, SystemMessage
10
  from langchain_openai import ChatOpenAI
11
  from transformers import AutoTokenizer, AutoModelForCausalLM
12
 
13
- # Function to load and process data
14
- def load_data(file_path):
15
- df = pd.read_csv(file_path)
16
- return df
17
-
18
- # Function to load documents from DataFrame
19
- def load_documents(df, content_column):
20
- docs = DataFrameLoader(df, page_content_column=content_column).load()
21
- return docs
22
-
23
- # Function to tokenize documents
24
- # def tokenize_documents(docs):
25
- # encoder = tiktoken.get_encoding("cl100k_base")
26
- # tokens_per_docs = [len(encoder.encode(doc.page_content)) for doc in docs]
27
- # total_tokens = sum(tokens_per_docs)
28
- # cost_per_1000_tokens = 0.0001
29
- # cost = (total_tokens / 1000) * cost_per_1000_tokens
30
- # return tokens_per_docs, cost
31
-
32
- # Function to create vector database
33
- def create_vector_db(docs):
34
- text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
35
- texts = text_splitter.split_documents(docs)
36
  embedding_function = SentenceTransformerEmbeddings(model_name="all-MiniLM-L6-v2")
37
- vectordb = Chroma.from_documents(docs, embedding_function,persist_directory='./chroma_db')
38
- vectordb.persist()
39
- vectordb = None
40
- vectordb = Chroma(persist_directory=vectordb, embedding_function=embedding_function)
 
41
  return vectordb
42
 
43
  # Function to augment prompt
44
  def augment_prompt(query, vectordb):
45
  results = vectordb.similarity_search(query, k=3)
46
  source_knowledge = "\n".join([x.page_content for x in results])
47
- augmented_prompt = f"""Using the contexts below, answer the query. If some information is not provided within
48
- the contexts below, do not include, and if the query cannot be answered with the below information, say "I don't know".
 
49
 
50
- Contexts:
51
  {source_knowledge}
52
 
53
- Query: {query}"""
 
54
  return augmented_prompt
55
 
56
- # Function to handle chat
57
- def chat_with_ai(query, vectordb,openai_api_key):
58
- chat = ChatOpenAI(model_name="gpt-3.5-turbo",openai_api_key=openai_api_key)
59
  augmented_query = augment_prompt(query, vectordb)
60
  prompt = HumanMessage(content=augmented_query)
61
  messages = [
@@ -68,33 +54,17 @@ def chat_with_ai(query, vectordb,openai_api_key):
68
  # Streamlit UI
69
  st.title("Document Processing and AI Chat with LangChain")
70
 
71
- # File upload
72
- uploaded_file = st.file_uploader("Choose a CSV file", type="csv")
 
 
73
 
74
- if uploaded_file is not None:
75
- # Load and process data
76
- df = load_data(uploaded_file)
77
- st.write("Data loaded successfully!")
78
-
79
- # Load documents
80
- docs = load_documents(df, 'page_content')
81
- st.write(f"Loaded {len(docs)} documents")
82
-
83
- # Tokenize documents
84
- # tokens_per_docs, cost = tokenize_documents(docs)
85
- # st.write(f"Total tokens: {sum(tokens_per_docs)}")
86
- # st.write(f"Estimated cost: ${cost:.4f}")
87
-
88
- # Create vector database
89
- vectordb = create_vector_db(docs)
90
- st.write("Vector database created and persisted successfully!")
91
-
92
- # Query input
93
- query = st.text_input("Enter your query", "Recommend a company to work as a data scientist in the health sector")
94
-
95
- if st.button("Get Answer"):
96
- # Chat with AI
97
- openai_api_key = os.getenv("OPENAI_API_KEY")
98
- response = chat_with_ai(query, vectordb, openai_api_key)
99
- st.write("Response from AI:")
100
- st.write(response)
 
10
  from langchain_openai import ChatOpenAI
11
  from transformers import AutoTokenizer, AutoModelForCausalLM
12
 
13
+ # Function to load vector database
14
+ def load_vector_db(zip_file_path, extract_path):
15
+ with st.spinner("Loading vector store..."):
16
+ with zipfile.ZipFile(zip_file_path, 'r') as zip_ref:
17
+ zip_ref.extractall(extract_path)
18
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  embedding_function = SentenceTransformerEmbeddings(model_name="all-MiniLM-L6-v2")
20
+ vectordb = Chroma(
21
+ persist_directory=extract_path,
22
+ embedding_function=embedding_function
23
+ )
24
+ st.success("Vector store loaded")
25
  return vectordb
26
 
27
  # Function to augment prompt
28
  def augment_prompt(query, vectordb):
29
  results = vectordb.similarity_search(query, k=3)
30
  source_knowledge = "\n".join([x.page_content for x in results])
31
+ augmented_prompt = f"""
32
+ You are an AI assistant. Use the context provided below to answer the question as comprehensively as possible.
33
+ If the answer is not contained within the context, respond with "I don't know".
34
 
35
+ Context:
36
  {source_knowledge}
37
 
38
+ Question: {query}
39
+ """
40
  return augmented_prompt
41
 
42
+ # Function to handle chat with OpenAI
43
+ def chat_with_openai(query, vectordb, openai_api_key):
44
+ chat = ChatOpenAI(model_name="gpt-3.5-turbo", openai_api_key=openai_api_key)
45
  augmented_query = augment_prompt(query, vectordb)
46
  prompt = HumanMessage(content=augmented_query)
47
  messages = [
 
54
  # Streamlit UI
55
  st.title("Document Processing and AI Chat with LangChain")
56
 
57
+ # Load vector database
58
+ zip_file_path = "chroma_db_compressed_.zip"
59
+ extract_path = "./chroma_db_extracted"
60
+ vectordb = load_vector_db(zip_file_path, extract_path)
61
 
62
+ # Query input
63
+ query = st.text_input("Enter your query", "Recommend a company to work as a data scientist in the health sector")
64
+
65
+ if st.button("Get Answer"):
66
+ # Chat with OpenAI
67
+ openai_api_key = st.secrets["OPENAI_API_KEY"]
68
+ response = chat_with_openai(query, vectordb, openai_api_key)
69
+ st.write("Response from AI:")
70
+ st.write(response)