File size: 3,242 Bytes
3f48d1f
 
 
 
 
 
 
 
 
 
b3e38d5
 
 
 
 
3f48d1f
b3e38d5
 
 
 
 
 
 
3f48d1f
b3e38d5
3f48d1f
 
 
 
 
b3e38d5
 
 
 
3f48d1f
 
 
b3e38d5
3f48d1f
 
b3e38d5
3f48d1f
b3e38d5
3f48d1f
 
b3e38d5
 
 
3f48d1f
b3e38d5
3f48d1f
 
b3e38d5
 
3f48d1f
b3e38d5
3f48d1f
b3e38d5
3f48d1f
 
 
b3e38d5
3f48d1f
b3e38d5
 
 
3f48d1f
 
 
b3e38d5
3f48d1f
 
b3e38d5
3f48d1f
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import os
import chainlit as cl
from dotenv import load_dotenv

# LangChain imports for retrieval and generation
from langchain.document_loaders import WebBaseLoader
from langchain.text_splitter import CharacterTextSplitter
from langchain.vectorstores import FAISS
from langchain.chains import RetrievalQA

# Google Generative AI integrations
from langchain_google_genai import GoogleGenerativeAI  # For LLM generation
from langchain_google_genai.embeddings import GoogleGenerativeAIEmbeddings  # For embeddings

# Load environment variables (GEMINI_API_KEY should be defined)
load_dotenv()
GEMINI_API_KEY = os.getenv("GEMINI_API_KEY")
if not GEMINI_API_KEY:
    raise ValueError("GEMINI_API_KEY not found in .env file")

# Configure the LLM using Google’s Gemini model.
# You can change the model name if needed (e.g., "gemini-pro", "gemini-1.5-flash-latest", etc.)
llm = GoogleGenerativeAI(model="gemini-1.5-flash-latest", google_api_key=GEMINI_API_KEY)

# Global variable for the RetrievalQA chain
qa_chain = None

@cl.on_chat_start
async def start_chat():
    """
    On chat start, this function loads a document from the provided URL using WebBaseLoader,
    splits it into chunks for retrieval, creates embeddings with Google’s embedding model,
    and builds a vector store (using FAISS). Finally, it creates a RetrievalQA chain that
    will retrieve relevant document sections and generate answers using the Gemini LLM.
    """
    global qa_chain

    # URL to crawl (German Wikipedia page on "Künstliche Intelligenz")
    url = "https://de.wikipedia.org/wiki/K%C3%BCnstliche_Intelligenz"
    loader = WebBaseLoader(url)
    documents = loader.load()  # Returns a list of Document objects

    # Split the document into chunks for effective retrieval
    text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
    docs = text_splitter.split_documents(documents)

    # Create embeddings using Google Generative AI embeddings
    embeddings = GoogleGenerativeAIEmbeddings(model="models/text-embedding-004", google_api_key=GEMINI_API_KEY)
    
    # Build a FAISS vector store for efficient similarity search
    vectorstore = FAISS.from_documents(docs, embeddings)
    retriever = vectorstore.as_retriever(search_kwargs={"k": 3})

    # Build the RetrievalQA chain that augments queries with the retrieved context
    qa_chain = RetrievalQA.from_chain_type(llm=llm, chain_type="stuff", retriever=retriever)

    await cl.Message(
        content="✅ Document loaded and processed successfully! You can now ask questions about 'Künstliche Intelligenz'."
    ).send()

@cl.on_message
async def process_message(message: cl.Message):
    """
    When a user message arrives, this function uses the RetrievalQA chain to retrieve relevant
    context from the processed document, augment the user query, and generate an answer using
    the Gemini-based LLM.
    """
    global qa_chain
    if qa_chain is None:
        await cl.Message(content="❌ The document is still being loaded. Please wait a moment.").send()
        return

    # Retrieve user query and generate the answer using the chain
    query = message.content.strip()
    result = qa_chain.run(query)
    await cl.Message(content=result).send()