JustusI commited on
Commit
ed3a930
·
verified ·
1 Parent(s): 71619fd

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +100 -0
app.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import os
3
+ import pandas as pd
4
+ from langchain.document_loaders import DataFrameLoader
5
+ #import tiktoken
6
+ from langchain.vectorstores import Chroma
7
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
8
+ from langchain_community.embeddings.sentence_transformer import SentenceTransformerEmbeddings
9
+ from langchain_core.messages import HumanMessage, SystemMessage
10
+ from langchain_openai import ChatOpenAI
11
+ from transformers import AutoTokenizer, AutoModelForCausalLM
12
+
13
+ # Function to load and process data
14
+ def load_data(file_path):
15
+ df = pd.read_csv(file_path)
16
+ return df
17
+
18
+ # Function to load documents from DataFrame
19
+ def load_documents(df, content_column):
20
+ docs = DataFrameLoader(df, page_content_column=content_column).load()
21
+ return docs
22
+
23
+ # Function to tokenize documents
24
+ # def tokenize_documents(docs):
25
+ # encoder = tiktoken.get_encoding("cl100k_base")
26
+ # tokens_per_docs = [len(encoder.encode(doc.page_content)) for doc in docs]
27
+ # total_tokens = sum(tokens_per_docs)
28
+ # cost_per_1000_tokens = 0.0001
29
+ # cost = (total_tokens / 1000) * cost_per_1000_tokens
30
+ # return tokens_per_docs, cost
31
+
32
+ # Function to create vector database
33
+ def create_vector_db(docs):
34
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
35
+ texts = text_splitter.split_documents(docs)
36
+ embedding_function = SentenceTransformerEmbeddings(model_name="all-MiniLM-L6-v2")
37
+ vectordb = Chroma.from_documents(docs, embedding_function)
38
+ vectordb.persist()
39
+ vectordb = None
40
+ vectordb = Chroma(persist_directory=vectordb, embedding_function=embedding_function)
41
+ return vectordb
42
+
43
+ # Function to augment prompt
44
+ def augment_prompt(query, vectordb):
45
+ results = vectordb.similarity_search(query, k=3)
46
+ source_knowledge = "\n".join([x.page_content for x in results])
47
+ augmented_prompt = f"""Using the contexts below, answer the query. If some information is not provided within
48
+ the contexts below, do not include, and if the query cannot be answered with the below information, say "I don't know".
49
+
50
+ Contexts:
51
+ {source_knowledge}
52
+
53
+ Query: {query}"""
54
+ return augmented_prompt
55
+
56
+ # Function to handle chat
57
+ def chat_with_ai(query, vectordb,openai_api_key):
58
+ chat = ChatOpenAI(model_name="gpt-3.5-turbo",openai_api_key=openai_api_key)
59
+ augmented_query = augment_prompt(query, vectordb)
60
+ prompt = HumanMessage(content=augmented_query)
61
+ messages = [
62
+ SystemMessage(content="You are a helpful assistant."),
63
+ prompt
64
+ ]
65
+ res = chat(messages)
66
+ return res.content
67
+
68
+ # Streamlit UI
69
+ st.title("Document Processing and AI Chat with LangChain")
70
+
71
+ # File upload
72
+ uploaded_file = st.file_uploader("Choose a CSV file", type="csv")
73
+
74
+ if uploaded_file is not None:
75
+ # Load and process data
76
+ df = load_data(uploaded_file)
77
+ st.write("Data loaded successfully!")
78
+
79
+ # Load documents
80
+ docs = load_documents(df, 'page_content')
81
+ st.write(f"Loaded {len(docs)} documents")
82
+
83
+ # Tokenize documents
84
+ tokens_per_docs, cost = tokenize_documents(docs)
85
+ st.write(f"Total tokens: {sum(tokens_per_docs)}")
86
+ st.write(f"Estimated cost: ${cost:.4f}")
87
+
88
+ # Create vector database
89
+ vectordb = create_vector_db(docs)
90
+ st.write("Vector database created and persisted successfully!")
91
+
92
+ # Query input
93
+ query = st.text_input("Enter your query", "Recommend a company to work as a data scientist in the health sector")
94
+
95
+ if st.button("Get Answer"):
96
+ # Chat with AI
97
+ openai_api_key = os.getenv("OPENAI_API_KEY")
98
+ response = chat_with_ai(query, vectordb, openai_api_key)
99
+ st.write("Response from AI:")
100
+ st.write(response)