Sebbe33 commited on
Commit
b3e38d5
·
verified ·
1 Parent(s): 2169c1a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +34 -38
app.py CHANGED
@@ -5,76 +5,72 @@ from dotenv import load_dotenv
5
  # LangChain imports for retrieval and generation
6
  from langchain.document_loaders import WebBaseLoader
7
  from langchain.text_splitter import CharacterTextSplitter
8
- from langchain.embeddings import OpenAIEmbeddings
9
  from langchain.vectorstores import FAISS
10
  from langchain.chains import RetrievalQA
11
- from langchain.llms import OpenAI
12
 
13
- # Load environment variables (e.g., OPENAI_API_KEY)
 
 
 
 
14
  load_dotenv()
 
 
 
 
 
 
 
15
 
16
- # Global variable to store our QA chain.
17
  qa_chain = None
18
 
19
  @cl.on_chat_start
20
  async def start_chat():
21
  """
22
- When the chat starts, load the document using WebBaseLoader, split it into chunks,
23
- create embeddings, build a vector store, and finally initialize a RetrievalQA chain.
24
- This chain will serve as the backend for our RAG system.
 
25
  """
26
  global qa_chain
27
 
28
- # URL to crawl (German Wikipedia page on Künstliche Intelligenz)
29
  url = "https://de.wikipedia.org/wiki/K%C3%BCnstliche_Intelligenz"
30
-
31
- # Retrieve the document from the webpage
32
  loader = WebBaseLoader(url)
33
- documents = loader.load() # returns a list of Document objects
34
 
35
- # Split the document into manageable chunks for better retrieval
36
  text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
37
  docs = text_splitter.split_documents(documents)
 
 
 
38
 
39
- # Create embeddings (make sure your OPENAI_API_KEY is set in your environment)
40
- embeddings = OpenAIEmbeddings()
41
-
42
- # Build a vector store from the documents using FAISS
43
  vectorstore = FAISS.from_documents(docs, embeddings)
44
-
45
- # Configure the retriever: retrieve the top 3 most relevant chunks
46
  retriever = vectorstore.as_retriever(search_kwargs={"k": 3})
47
-
48
- # Set up the language model (using OpenAI LLM here) with desired parameters
49
- llm = OpenAI(temperature=0)
50
-
51
- # Create a RetrievalQA chain that first retrieves relevant context and then generates an answer.
52
  qa_chain = RetrievalQA.from_chain_type(llm=llm, chain_type="stuff", retriever=retriever)
53
-
54
  await cl.Message(
55
- content="✅ Document loaded and processed successfully! "
56
- "You can now ask me questions about 'Künstliche Intelligenz'."
57
  ).send()
58
 
59
  @cl.on_message
60
- async def process_question(message: cl.Message):
61
  """
62
- When a message is received, use the QA chain to process the query. The chain:
63
- 1. Retrieves relevant document chunks.
64
- 2. Augments your query with the retrieved context.
65
- 3. Generates an answer via the language model.
66
  """
67
  global qa_chain
68
-
69
  if qa_chain is None:
70
- await cl.Message(content="❌ The document has not been loaded yet.").send()
71
  return
72
 
73
- # Get the user's query
74
  query = message.content.strip()
75
-
76
- # Process the query using the RetrievalQA chain
77
  result = qa_chain.run(query)
78
-
79
- # Send the answer back to the user
80
  await cl.Message(content=result).send()
 
5
  # LangChain imports for retrieval and generation
6
  from langchain.document_loaders import WebBaseLoader
7
  from langchain.text_splitter import CharacterTextSplitter
 
8
  from langchain.vectorstores import FAISS
9
  from langchain.chains import RetrievalQA
 
10
 
11
+ # Google Generative AI integrations
12
+ from langchain_google_genai import GoogleGenerativeAI # For LLM generation
13
+ from langchain_google_genai.embeddings import GoogleGenerativeAIEmbeddings # For embeddings
14
+
15
+ # Load environment variables (GEMINI_API_KEY should be defined)
16
  load_dotenv()
17
+ GEMINI_API_KEY = os.getenv("GEMINI_API_KEY")
18
+ if not GEMINI_API_KEY:
19
+ raise ValueError("GEMINI_API_KEY not found in .env file")
20
+
21
+ # Configure the LLM using Google’s Gemini model.
22
+ # You can change the model name if needed (e.g., "gemini-pro", "gemini-1.5-flash-latest", etc.)
23
+ llm = GoogleGenerativeAI(model="gemini-1.5-flash-latest", google_api_key=GEMINI_API_KEY)
24
 
25
+ # Global variable for the RetrievalQA chain
26
  qa_chain = None
27
 
28
  @cl.on_chat_start
29
  async def start_chat():
30
  """
31
+ On chat start, this function loads a document from the provided URL using WebBaseLoader,
32
+ splits it into chunks for retrieval, creates embeddings with Google’s embedding model,
33
+ and builds a vector store (using FAISS). Finally, it creates a RetrievalQA chain that
34
+ will retrieve relevant document sections and generate answers using the Gemini LLM.
35
  """
36
  global qa_chain
37
 
38
+ # URL to crawl (German Wikipedia page on "Künstliche Intelligenz")
39
  url = "https://de.wikipedia.org/wiki/K%C3%BCnstliche_Intelligenz"
 
 
40
  loader = WebBaseLoader(url)
41
+ documents = loader.load() # Returns a list of Document objects
42
 
43
+ # Split the document into chunks for effective retrieval
44
  text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
45
  docs = text_splitter.split_documents(documents)
46
+
47
+ # Create embeddings using Google Generative AI embeddings
48
+ embeddings = GoogleGenerativeAIEmbeddings(model="models/text-embedding-004", google_api_key=GEMINI_API_KEY)
49
 
50
+ # Build a FAISS vector store for efficient similarity search
 
 
 
51
  vectorstore = FAISS.from_documents(docs, embeddings)
 
 
52
  retriever = vectorstore.as_retriever(search_kwargs={"k": 3})
53
+
54
+ # Build the RetrievalQA chain that augments queries with the retrieved context
 
 
 
55
  qa_chain = RetrievalQA.from_chain_type(llm=llm, chain_type="stuff", retriever=retriever)
56
+
57
  await cl.Message(
58
+ content="✅ Document loaded and processed successfully! You can now ask questions about 'Künstliche Intelligenz'."
 
59
  ).send()
60
 
61
  @cl.on_message
62
+ async def process_message(message: cl.Message):
63
  """
64
+ When a user message arrives, this function uses the RetrievalQA chain to retrieve relevant
65
+ context from the processed document, augment the user query, and generate an answer using
66
+ the Gemini-based LLM.
 
67
  """
68
  global qa_chain
 
69
  if qa_chain is None:
70
+ await cl.Message(content="❌ The document is still being loaded. Please wait a moment.").send()
71
  return
72
 
73
+ # Retrieve user query and generate the answer using the chain
74
  query = message.content.strip()
 
 
75
  result = qa_chain.run(query)
 
 
76
  await cl.Message(content=result).send()