File size: 2,191 Bytes
c4331f2
0dfba83
a3b1498
 
 
0dfba83
c4331f2
 
 
 
 
 
 
0dfba83
c4331f2
 
0dfba83
c4331f2
 
 
 
 
 
 
 
 
a3b1498
 
c4331f2
 
 
0c69aa1
a3b1498
 
 
 
 
 
 
 
0c69aa1
 
a3b1498
 
 
0c69aa1
a3b1498
0c69aa1
a3b1498
0dfba83
c4331f2
 
 
 
a3b1498
c4331f2
 
 
 
a3b1498
c4331f2
 
 
a3b1498
 
 
 
c4331f2
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import chainlit as cl
from langchain.retrievers import ParentDocumentRetriever
from langchain.schema.runnable import RunnableConfig
from langchain.storage import LocalFileStore
from langchain.storage._lc_store import create_kv_docstore
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.vectorstores.chroma import Chroma
from langchain_google_genai import (
    GoogleGenerativeAI,
    GoogleGenerativeAIEmbeddings,
    HarmBlockThreshold,
    HarmCategory,
)

import config
from prompts import prompt
from utils import PostMessageHandler, format_docs

model = GoogleGenerativeAI(
    model=config.GOOGLE_CHAT_MODEL,
    google_api_key=config.GOOGLE_API_KEY,
    safety_settings={
        HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE,
    },
)  # type: ignore

embeddings_model = GoogleGenerativeAIEmbeddings(
    model=config.GOOGLE_EMBEDDING_MODEL
)  # type: ignore


## retriever
child_splitter = RecursiveCharacterTextSplitter(chunk_size=500, separators=["\n"])

# The vectorstore to use to index the child chunks
vectorstore = Chroma(
    persist_directory=config.STORAGE_PATH + "vectorstore",
    collection_name="full_documents",
    embedding_function=embeddings_model,
)

# The storage layer for the parent documents
fs = LocalFileStore(config.STORAGE_PATH + "docstore")
store = create_kv_docstore(fs)

retriever = ParentDocumentRetriever(
    vectorstore=vectorstore,
    docstore=store,
    child_splitter=child_splitter,
)


@cl.on_chat_start
async def on_chat_start():
    cl.user_session.set("retriever", retriever)


@cl.on_message
async def on_message(message: cl.Message):
    chain = prompt | model
    msg = cl.Message(content="")

    async with cl.Step(type="run", name="QA Assistant"):
        question = message.content
        context = format_docs(retriever.get_relevant_documents(question))
        async for chunk in chain.astream(
            input={"context": context, "question": question},
            config=RunnableConfig(
                callbacks=[cl.LangchainCallbackHandler(), PostMessageHandler(msg)]
            ),
        ):
            await msg.stream_token(chunk)

    await msg.send()