Spaces:
Sleeping
Sleeping
import streamlit as st | |
import os | |
import pandas as pd | |
from langchain.document_loaders import DataFrameLoader | |
#import tiktoken | |
from langchain.vectorstores import Chroma | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from langchain_community.embeddings.sentence_transformer import SentenceTransformerEmbeddings | |
from langchain_core.messages import HumanMessage, SystemMessage | |
from langchain_openai import ChatOpenAI | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
# Function to load and process data | |
def load_data(file_path): | |
df = pd.read_csv(file_path) | |
return df | |
# Function to load documents from DataFrame | |
def load_documents(df, content_column): | |
docs = DataFrameLoader(df, page_content_column=content_column).load() | |
return docs | |
# Function to tokenize documents | |
# def tokenize_documents(docs): | |
# encoder = tiktoken.get_encoding("cl100k_base") | |
# tokens_per_docs = [len(encoder.encode(doc.page_content)) for doc in docs] | |
# total_tokens = sum(tokens_per_docs) | |
# cost_per_1000_tokens = 0.0001 | |
# cost = (total_tokens / 1000) * cost_per_1000_tokens | |
# return tokens_per_docs, cost | |
# Function to create vector database | |
def create_vector_db(docs): | |
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=0) | |
texts = text_splitter.split_documents(docs) | |
embedding_function = SentenceTransformerEmbeddings(model_name="all-MiniLM-L6-v2") | |
vectordb = Chroma.from_documents(docs, embedding_function) | |
vectordb.persist() | |
vectordb = None | |
vectordb = Chroma(persist_directory=vectordb, embedding_function=embedding_function) | |
return vectordb | |
# Function to augment prompt | |
def augment_prompt(query, vectordb): | |
results = vectordb.similarity_search(query, k=3) | |
source_knowledge = "\n".join([x.page_content for x in results]) | |
augmented_prompt = f"""Using the contexts below, answer the query. If some information is not provided within | |
the contexts below, do not include, and if the query cannot be answered with the below information, say "I don't know". | |
Contexts: | |
{source_knowledge} | |
Query: {query}""" | |
return augmented_prompt | |
# Function to handle chat | |
def chat_with_ai(query, vectordb,openai_api_key): | |
chat = ChatOpenAI(model_name="gpt-3.5-turbo",openai_api_key=openai_api_key) | |
augmented_query = augment_prompt(query, vectordb) | |
prompt = HumanMessage(content=augmented_query) | |
messages = [ | |
SystemMessage(content="You are a helpful assistant."), | |
prompt | |
] | |
res = chat(messages) | |
return res.content | |
# Streamlit UI | |
st.title("Document Processing and AI Chat with LangChain") | |
# File upload | |
uploaded_file = st.file_uploader("Choose a CSV file", type="csv") | |
if uploaded_file is not None: | |
# Load and process data | |
df = load_data(uploaded_file) | |
st.write("Data loaded successfully!") | |
# Load documents | |
docs = load_documents(df, 'page_content') | |
st.write(f"Loaded {len(docs)} documents") | |
# Tokenize documents | |
tokens_per_docs, cost = tokenize_documents(docs) | |
st.write(f"Total tokens: {sum(tokens_per_docs)}") | |
st.write(f"Estimated cost: ${cost:.4f}") | |
# Create vector database | |
vectordb = create_vector_db(docs) | |
st.write("Vector database created and persisted successfully!") | |
# Query input | |
query = st.text_input("Enter your query", "Recommend a company to work as a data scientist in the health sector") | |
if st.button("Get Answer"): | |
# Chat with AI | |
openai_api_key = os.getenv("OPENAI_API_KEY") | |
response = chat_with_ai(query, vectordb, openai_api_key) | |
st.write("Response from AI:") | |
st.write(response) | |