File size: 3,688 Bytes
ed3a930
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import streamlit as st
import os
import pandas as pd
from langchain.document_loaders import DataFrameLoader
#import tiktoken
from langchain.vectorstores import Chroma
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.embeddings.sentence_transformer import SentenceTransformerEmbeddings
from langchain_core.messages import HumanMessage, SystemMessage
from langchain_openai import ChatOpenAI
from transformers import AutoTokenizer, AutoModelForCausalLM

# Function to load and process data
def load_data(file_path):
    df = pd.read_csv(file_path)
    return df

# Function to load documents from DataFrame
def load_documents(df, content_column):
    docs = DataFrameLoader(df, page_content_column=content_column).load()
    return docs

# Function to tokenize documents
# def tokenize_documents(docs):
#     encoder = tiktoken.get_encoding("cl100k_base")
#     tokens_per_docs = [len(encoder.encode(doc.page_content)) for doc in docs]
#     total_tokens = sum(tokens_per_docs)
#     cost_per_1000_tokens = 0.0001
#     cost = (total_tokens / 1000) * cost_per_1000_tokens
#     return tokens_per_docs, cost

# Function to create vector database
def create_vector_db(docs):
    text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
    texts = text_splitter.split_documents(docs)
    embedding_function = SentenceTransformerEmbeddings(model_name="all-MiniLM-L6-v2")
    vectordb = Chroma.from_documents(docs, embedding_function)
    vectordb.persist()
    vectordb = None
    vectordb = Chroma(persist_directory=vectordb, embedding_function=embedding_function)
    return vectordb

# Function to augment prompt
def augment_prompt(query, vectordb):
    results = vectordb.similarity_search(query, k=3)
    source_knowledge = "\n".join([x.page_content for x in results])
    augmented_prompt = f"""Using the contexts below, answer the query. If some information is not provided within
    the contexts below, do not include, and if the query cannot be answered with the below information, say "I don't know".

    Contexts:
    {source_knowledge}

    Query: {query}"""
    return augmented_prompt

# Function to handle chat
def chat_with_ai(query, vectordb,openai_api_key):
    chat = ChatOpenAI(model_name="gpt-3.5-turbo",openai_api_key=openai_api_key)
    augmented_query = augment_prompt(query, vectordb)
    prompt = HumanMessage(content=augmented_query)
    messages = [
        SystemMessage(content="You are a helpful assistant."),
        prompt
    ]
    res = chat(messages)
    return res.content

# Streamlit UI
st.title("Document Processing and AI Chat with LangChain")

# File upload
uploaded_file = st.file_uploader("Choose a CSV file", type="csv")

if uploaded_file is not None:
    # Load and process data
    df = load_data(uploaded_file)
    st.write("Data loaded successfully!")
    
    # Load documents
    docs = load_documents(df, 'page_content')
    st.write(f"Loaded {len(docs)} documents")
    
    # Tokenize documents
    tokens_per_docs, cost = tokenize_documents(docs)
    st.write(f"Total tokens: {sum(tokens_per_docs)}")
    st.write(f"Estimated cost: ${cost:.4f}")
    
    # Create vector database
    vectordb = create_vector_db(docs)
    st.write("Vector database created and persisted successfully!")
    
    # Query input
    query = st.text_input("Enter your query", "Recommend a company to work as a data scientist in the health sector")
    
    if st.button("Get Answer"):
        # Chat with AI
        openai_api_key = os.getenv("OPENAI_API_KEY")
        response = chat_with_ai(query, vectordb, openai_api_key)
        st.write("Response from AI:")
        st.write(response)