Abid Ali Awan commited on
Commit
9b3bd46
·
1 Parent(s): 355b607

fix the issues with the app and optimized it.

Browse files
Files changed (1) hide show
  1. main.py +231 -404
main.py CHANGED
@@ -1,81 +1,47 @@
 
1
  import os
 
2
  import zipfile
3
- from typing import Dict, List, Optional, Union
 
4
 
 
 
5
  import gradio as gr
6
  from groq import Groq
7
  from langchain.text_splitter import RecursiveCharacterTextSplitter
 
8
  from langchain_core.output_parsers import StrOutputParser
9
  from langchain_core.prompts import PromptTemplate
10
  from langchain_core.runnables import RunnablePassthrough
 
11
  from langchain_groq import ChatGroq
12
  from langchain_huggingface import HuggingFaceEmbeddings
13
- from langchain_core.vectorstores import InMemoryVectorStore
14
- # Retrieve API key for Groq from the environment variables
15
- GROQ_API_KEY = os.environ.get("GROQ_API_KEY")
16
-
17
- # Initialize the Groq client
18
- client = Groq(api_key=GROQ_API_KEY)
19
-
20
- # Initialize the LLM
21
- llm = ChatGroq(model="meta-llama/llama-4-scout-17b-16e-instruct", api_key=GROQ_API_KEY)
22
 
23
- # Initialize the embedding model
24
- embed_model = HuggingFaceEmbeddings(model_name="mixedbread-ai/mxbai-embed-large-v1")
25
-
26
- # General constants for the UI
27
- TITLE = """<h1 align="center">✨ Llama 4 RAG Application</h1>"""
28
  AVATAR_IMAGES = (
29
  None,
30
  "https://huggingface.co/datasets/huggingface/brand-assets/resolve/main/hf-logo.png",
31
  )
32
 
33
- # List of supported text extensions (alphabetically sorted)
34
- TEXT_EXTENSIONS = [
35
- ".bat",
36
- ".c",
37
- ".cfg",
38
- ".conf",
39
- ".cpp",
40
- ".cs",
41
- ".css",
42
- ".docx",
43
- ".go",
44
- ".h",
45
- ".html",
46
- ".ini",
47
- ".java",
48
- ".js",
49
- ".json",
50
- ".jsx",
51
- ".md",
52
- ".php",
53
- ".ps1",
54
- ".py",
55
- ".rb",
56
- ".rs",
57
- ".sh",
58
- ".toml",
59
- ".ts",
60
- ".tsx",
61
- ".txt",
62
- ".xml",
63
- ".yaml",
64
- ".yml",
65
- ]
66
-
67
- # Global variables
68
- EXTRACTED_FILES = {}
69
- VECTORSTORE = None
70
- RAG_CHAIN = None
71
-
72
- # Initialize the text splitter
73
  text_splitter = RecursiveCharacterTextSplitter(
74
- chunk_size=1000, chunk_overlap=100, separators=["\n\n", "\n"]
 
 
75
  )
76
 
77
- # Define the RAG prompt template
78
- template = """You are an expert assistant tasked with answering questions based on the provided documents.
79
  Use only the given context to generate your answer.
80
  If the answer cannot be found in the context, clearly state that you do not know.
81
  Be detailed and precise in your response, but avoid mentioning or referencing the context itself.
@@ -87,424 +53,285 @@ Question:
87
  {question}
88
 
89
  Answer:"""
 
90
 
91
- # Create the PromptTemplate
92
- rag_prompt = PromptTemplate.from_template(template)
93
 
 
 
 
 
94
 
95
- def extract_text_from_zip(zip_file_path: str) -> Dict[str, str]:
96
- """
97
- Extract text content from files in a ZIP archive.
98
 
99
- Parameters:
100
- zip_file_path (str): Path to the ZIP file.
101
 
102
- Returns:
103
- Dict[str, str]: Dictionary mapping filenames to their text content.
104
- """
105
- text_contents = {}
106
 
107
- with zipfile.ZipFile(zip_file_path, "r") as zip_ref:
108
- for file_info in zip_ref.infolist():
109
- # Skip directories
110
- if file_info.filename.endswith("/"):
111
- continue
112
 
113
- # Skip binary files and focus on text files
114
- file_ext = os.path.splitext(file_info.filename)[1].lower()
 
115
 
116
- if file_ext in TEXT_EXTENSIONS:
117
- try:
118
- with zip_ref.open(file_info) as file:
119
- content = file.read().decode("utf-8", errors="replace")
120
- text_contents[file_info.filename] = content
121
- except Exception as e:
122
- text_contents[file_info.filename] = (
123
- f"Error extracting file: {str(e)}"
124
- )
 
 
 
 
 
 
 
 
 
125
 
126
- return text_contents
 
 
 
 
127
 
 
128
 
129
- def extract_text_from_single_file(file_path: str) -> Dict[str, str]:
130
- """
131
- Extract text content from a single file.
132
 
133
- Parameters:
134
- file_path (str): Path to the file.
 
 
 
 
 
 
 
 
 
 
135
 
136
- Returns:
137
- Dict[str, str]: Dictionary mapping filename to its text content.
138
- """
139
- text_contents = {}
140
- filename = os.path.basename(file_path)
141
- file_ext = os.path.splitext(filename)[1].lower()
142
 
143
- if file_ext in TEXT_EXTENSIONS:
144
- try:
145
- with open(file_path, "r", encoding="utf-8", errors="replace") as file:
146
- content = file.read()
147
- text_contents[filename] = content
148
- except Exception as e:
149
- text_contents[filename] = f"Error reading file: {str(e)}"
150
 
151
- return text_contents
152
 
153
 
154
  def upload_files(
155
- files: Optional[List[str]], chatbot: List[Union[dict, gr.ChatMessage]]
156
  ):
157
- """
158
- Process uploaded files (ZIP or single text files): extract text content and append a message to the chat.
159
-
160
- Parameters:
161
- files (Optional[List[str]]): List of file paths.
162
- chatbot (List[Union[dict, gr.ChatMessage]]): The conversation history.
163
-
164
- Returns:
165
- List[Union[dict, gr.ChatMessage]]: Updated conversation history.
166
- """
167
- global EXTRACTED_FILES, VECTORSTORE, RAG_CHAIN
168
-
169
- # Handle multiple file uploads
170
- if len(files) > 1:
171
- total_files_processed = 0
172
- total_files_extracted = 0
173
- file_types = set()
174
-
175
- # Process each file
176
- for file in files:
177
- filename = os.path.basename(file)
178
- file_ext = os.path.splitext(filename)[1].lower()
179
-
180
- # Process based on file type
181
- if file_ext == ".zip":
182
- extracted_files = extract_text_from_zip(file)
183
- file_types.add("zip")
184
- else:
185
- extracted_files = extract_text_from_single_file(file)
186
- file_types.add("text")
187
 
188
- if extracted_files:
189
- total_files_extracted += len(extracted_files)
190
- # Store the extracted content in the global variable
191
- EXTRACTED_FILES[filename] = extracted_files
192
 
193
- total_files_processed += 1
 
 
 
194
 
195
- # Create a summary message for multiple files
196
- file_types_str = (
197
- "files"
198
- if len(file_types) > 1
199
- else ("ZIP files" if "zip" in file_types else "text files")
200
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
201
 
202
- # Create a list of uploaded file names
203
- file_list = "\n".join([f"- {os.path.basename(file)}" for file in files])
204
 
 
205
  chatbot.append(
206
  gr.ChatMessage(
207
- role="user",
208
- content=f"<p>📚 Multiple {file_types_str} uploaded ({total_files_processed} files)</p><p>Extracted {total_files_extracted} text file(s) in total</p><p>Uploaded files:</p><pre>{file_list}</pre>",
209
- )
210
- )
211
-
212
- # Handle single file upload
213
- elif len(files) == 1:
214
- file = files[0]
215
- filename = os.path.basename(file)
216
- file_ext = os.path.splitext(filename)[1].lower()
217
-
218
- # Process based on file type
219
- if file_ext == ".zip":
220
- extracted_files = extract_text_from_zip(file)
221
- file_type_msg = "📦 ZIP file"
222
- else:
223
- extracted_files = extract_text_from_single_file(file)
224
- file_type_msg = "📄 File"
225
-
226
- if not extracted_files:
227
- chatbot.append(
228
- gr.ChatMessage(
229
- role="user",
230
- content=f"<p>{file_type_msg} uploaded: {filename}, but no text content was found or the file format is not supported.</p>",
231
- )
232
- )
233
- else:
234
- file_list = "\n".join([f"- {name}" for name in extracted_files.keys()])
235
- chatbot.append(
236
- gr.ChatMessage(
237
- role="user",
238
- content=f"<p>{file_type_msg} uploaded: {filename}</p><p>Extracted {len(extracted_files)} text file(s):</p><pre>{file_list}</pre>",
239
- )
240
  )
241
-
242
- # Store the extracted content in the global variable
243
- EXTRACTED_FILES[filename] = extracted_files
244
-
245
- # Process the extracted files and create vector embeddings
246
- if EXTRACTED_FILES:
247
- # Prepare documents for processing
248
- all_texts = []
249
- for filename, files in EXTRACTED_FILES.items():
250
- for file_path, content in files.items():
251
- all_texts.append(
252
- {"page_content": content, "metadata": {"source": file_path}}
253
- )
254
-
255
- # Create document objects
256
- from langchain_core.documents import Document
257
-
258
- documents = [
259
- Document(page_content=item["page_content"], metadata=item["metadata"])
260
- for item in all_texts
261
- ]
262
-
263
- # Split the documents into chunks
264
- chunks = text_splitter.split_documents(documents)
265
-
266
- # Create the vector store
267
- VECTORSTORE = InMemoryVectorStore.from_documents(
268
- documents=chunks,
269
- embedding=embed_model,
270
- )
271
-
272
- # Create the retriever
273
- retriever = VECTORSTORE.as_retriever()
274
-
275
- # Create the RAG chain
276
- RAG_CHAIN = (
277
- {"context": retriever, "question": RunnablePassthrough()}
278
- | rag_prompt
279
- | llm
280
- | StrOutputParser()
281
  )
 
282
 
283
- # Add a confirmation message
 
 
284
  chatbot.append(
285
  gr.ChatMessage(
286
- role="assistant",
287
- content="Documents processed and indexed. You can now ask questions about the content.",
288
  )
289
  )
 
290
 
291
- return chatbot
292
-
 
 
 
 
 
 
 
 
 
 
 
 
293
 
294
- def user(text_prompt: str, chatbot: List[gr.ChatMessage]):
295
- """
296
- Append a new user text message to the chat history.
 
 
 
 
 
 
 
297
 
298
- Parameters:
299
- text_prompt (str): The input text provided by the user.
300
- chatbot (List[gr.ChatMessage]): The existing conversation history.
301
 
302
- Returns:
303
- Tuple[str, List[gr.ChatMessage]]: A tuple of an empty string (clearing the prompt)
304
- and the updated conversation history.
305
- """
306
- if text_prompt:
307
  chatbot.append(gr.ChatMessage(role="user", content=text_prompt))
308
  return "", chatbot
309
 
310
 
311
- def get_message_content(msg):
312
- """
313
- Retrieve the content of a message that can be either a dictionary or a gr.ChatMessage.
314
-
315
- Parameters:
316
- msg (Union[dict, gr.ChatMessage]): The message object.
317
-
318
- Returns:
319
- str: The textual content of the message.
320
- """
321
- if isinstance(msg, dict):
322
- return msg.get("content", "")
323
- return msg.content
324
-
325
-
326
- def process_query(chatbot: List[Union[dict, gr.ChatMessage]]):
327
- """
328
- Process the user's query using the RAG pipeline.
329
-
330
- Parameters:
331
- chatbot (List[Union[dict, gr.ChatMessage]]): The conversation history.
332
-
333
- Returns:
334
- List[Union[dict, gr.ChatMessage]]: The updated conversation history with the response.
335
- """
336
- global RAG_CHAIN
337
-
338
- if len(chatbot) == 0:
339
- chatbot.append(
340
- gr.ChatMessage(
341
- role="assistant",
342
- content="Please enter a question or upload documents to start the conversation.",
343
- )
344
- )
345
- return chatbot
346
-
347
- # Get the last user message as the prompt
348
- user_messages = [
349
- msg
350
- for msg in chatbot
351
- if (isinstance(msg, dict) and msg.get("role") == "user")
352
- or (hasattr(msg, "role") and msg.role == "user")
353
- ]
354
-
355
- if not user_messages:
356
  chatbot.append(
357
- gr.ChatMessage(
358
- role="assistant",
359
- content="Please enter a question to start the conversation.",
360
- )
361
  )
362
  return chatbot
363
 
364
- last_user_msg = user_messages[-1]
365
- prompt = get_message_content(last_user_msg)
366
-
367
- # Skip if the last message was about uploading a file
368
- if (
369
- "📦 ZIP file uploaded:" in prompt
370
- or "📄 File uploaded:" in prompt
371
- or "📚 Multiple files uploaded" in prompt
372
- ):
373
- return chatbot
374
-
375
- # Check if RAG chain is available
376
- if RAG_CHAIN is None:
377
  chatbot.append(
378
- gr.ChatMessage(
379
- role="assistant",
380
- content="Please upload documents first to enable question answering.",
381
- )
382
  )
383
  return chatbot
384
 
385
- # Append a placeholder for the assistant's response
386
  chatbot.append(gr.ChatMessage(role="assistant", content="Thinking..."))
387
 
388
  try:
389
- # Process the query through the RAG chain
390
- response = RAG_CHAIN.invoke(prompt)
391
-
392
- # Update the placeholder with the actual response
393
  chatbot[-1].content = response
394
  except Exception as e:
395
- # Handle any errors
396
- chatbot[-1].content = f"Error processing your query: {str(e)}"
397
 
398
  return chatbot
399
 
400
 
401
- def reset_app(chatbot):
402
- """
403
- Reset the app by clearing the chat context and removing any uploaded files.
404
-
405
- Parameters:
406
- chatbot (List[Union[dict, gr.ChatMessage]]): The conversation history.
407
-
408
- Returns:
409
- List[Union[dict, gr.ChatMessage]]: A fresh conversation history.
410
- """
411
- global EXTRACTED_FILES, VECTORSTORE, RAG_CHAIN
412
-
413
- # Clear the global variables
414
- EXTRACTED_FILES = {}
415
- VECTORSTORE = None
416
- RAG_CHAIN = None
417
-
418
- # Reset the chatbot with a welcome message
419
  return [
420
  gr.ChatMessage(
421
- role="assistant",
422
- content="App has been reset. You can start a new conversation or upload new documents.",
423
  )
424
  ]
425
 
426
 
427
- # Define the Gradio UI components
428
- chatbot_component = gr.Chatbot(
429
- label="Llama 4 RAG",
430
- type="messages",
431
- bubble_full_width=False,
432
- avatar_images=AVATAR_IMAGES,
433
- scale=2,
434
- height=350,
435
- )
436
- text_prompt_component = gr.Textbox(
437
- placeholder="Ask a question about your documents...",
438
- show_label=False,
439
- autofocus=True,
440
- scale=28,
441
- )
442
- upload_files_button_component = gr.UploadButton(
443
- label="Upload",
444
- file_count="multiple",
445
- file_types=[".zip", ".docx"] + TEXT_EXTENSIONS,
446
- scale=1,
447
- min_width=80,
448
- )
449
- send_button_component = gr.Button(
450
- value="Send", variant="primary", scale=1, min_width=80
451
- )
452
- reset_button_component = gr.Button(value="Reset", variant="stop", scale=1, min_width=80)
453
-
454
- # Define input lists for button chaining
455
- user_inputs = [text_prompt_component, chatbot_component]
456
 
457
- with gr.Blocks(theme=gr.themes.Ocean()) as demo:
458
  gr.HTML(TITLE)
459
- with gr.Column():
460
- chatbot_component.render()
461
- with gr.Row(equal_height=True):
462
- text_prompt_component.render()
463
- send_button_component.render()
464
- upload_files_button_component.render()
465
- reset_button_component.render()
466
-
467
- # When the Send button is clicked, first process the user text then process the query
468
- send_button_component.click(
469
- fn=user,
470
- inputs=user_inputs,
471
- outputs=[text_prompt_component, chatbot_component],
472
- queue=False,
473
- ).then(
474
- fn=process_query,
475
- inputs=[chatbot_component],
476
- outputs=[chatbot_component],
477
- api_name="process_query",
478
  )
479
 
480
- # Allow submission using the Enter key
481
- text_prompt_component.submit(
482
- fn=user,
483
- inputs=user_inputs,
484
- outputs=[text_prompt_component, chatbot_component],
485
- queue=False,
486
- ).then(
487
- fn=process_query,
488
- inputs=[chatbot_component],
489
- outputs=[chatbot_component],
490
- api_name="process_query_submit",
491
- )
 
 
 
 
 
 
 
 
 
 
 
492
 
493
- # Handle file uploads
494
- upload_files_button_component.upload(
495
- fn=upload_files,
496
- inputs=[upload_files_button_component, chatbot_component],
497
- outputs=[chatbot_component],
498
  queue=False,
499
- )
500
 
501
- # Handle Reset button clicks
502
- reset_button_component.click(
503
- fn=reset_app,
504
- inputs=[chatbot_component],
505
- outputs=[chatbot_component],
506
  queue=False,
 
 
 
 
507
  )
 
508
 
509
- # Launch the demo interface
510
  demo.queue().launch()
 
1
+ # ========== Standard Library ==========
2
  import os
3
+ import tempfile
4
  import zipfile
5
+ from typing import List, Optional, Tuple, Union
6
+ import collections
7
 
8
+
9
+ # ========== Third-Party Libraries ==========
10
  import gradio as gr
11
  from groq import Groq
12
  from langchain.text_splitter import RecursiveCharacterTextSplitter
13
+ from langchain_community.document_loaders import DirectoryLoader, UnstructuredFileLoader
14
  from langchain_core.output_parsers import StrOutputParser
15
  from langchain_core.prompts import PromptTemplate
16
  from langchain_core.runnables import RunnablePassthrough
17
+ from langchain_core.vectorstores import InMemoryVectorStore
18
  from langchain_groq import ChatGroq
19
  from langchain_huggingface import HuggingFaceEmbeddings
 
 
 
 
 
 
 
 
 
20
 
21
+ # ========== Configs ==========
22
+ TITLE = """<h1 align="center">🗨️🦙 Llama 4 Docx Chatter</h1>"""
 
 
 
23
  AVATAR_IMAGES = (
24
  None,
25
  "https://huggingface.co/datasets/huggingface/brand-assets/resolve/main/hf-logo.png",
26
  )
27
 
28
+ # Acceptable file extensions
29
+ TEXT_EXTENSIONS = [".docx", ".zip"]
30
+
31
+ # ========== Models & Clients ==========
32
+ GROQ_API_KEY = os.getenv("GROQ_API_KEY")
33
+ client = Groq(api_key=GROQ_API_KEY)
34
+ llm = ChatGroq(model="meta-llama/llama-4-scout-17b-16e-instruct", api_key=GROQ_API_KEY)
35
+ embed_model = HuggingFaceEmbeddings(model_name="mixedbread-ai/mxbai-embed-large-v1")
36
+
37
+ # ========== Core Components ==========
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
  text_splitter = RecursiveCharacterTextSplitter(
39
+ chunk_size=1000,
40
+ chunk_overlap=100,
41
+ separators=["\n\n", "\n"],
42
  )
43
 
44
+ rag_template = """You are an expert assistant tasked with answering questions based on the provided documents.
 
45
  Use only the given context to generate your answer.
46
  If the answer cannot be found in the context, clearly state that you do not know.
47
  Be detailed and precise in your response, but avoid mentioning or referencing the context itself.
 
53
  {question}
54
 
55
  Answer:"""
56
+ rag_prompt = PromptTemplate.from_template(rag_template)
57
 
 
 
58
 
59
+ # ========== App State ==========
60
+ class AppState:
61
+ vectorstore: Optional[InMemoryVectorStore] = None
62
+ rag_chain = None
63
 
 
 
 
64
 
65
+ state = AppState()
 
66
 
67
+ # ========== Utility Functions ==========
 
 
 
68
 
 
 
 
 
 
69
 
70
+ def load_documents_from_files(files: List[str]) -> List:
71
+ """Load documents from uploaded files directly without moving."""
72
+ all_documents = []
73
 
74
+ # Temporary directory if ZIP needs extraction
75
+ with tempfile.TemporaryDirectory() as temp_dir:
76
+ for file_path in files:
77
+ ext = os.path.splitext(file_path)[1].lower()
78
+
79
+ if ext == ".zip":
80
+ # Extract ZIP inside temp_dir
81
+ with zipfile.ZipFile(file_path, "r") as zip_ref:
82
+ zip_ref.extractall(temp_dir)
83
+
84
+ # Load all docx from extracted zip
85
+ loader = DirectoryLoader(
86
+ path=temp_dir,
87
+ glob="**/*.docx",
88
+ use_multithreading=True,
89
+ )
90
+ docs = loader.load()
91
+ all_documents.extend(docs)
92
 
93
+ elif ext == ".docx":
94
+ # Load single docx directly
95
+ loader = UnstructuredFileLoader(file_path)
96
+ docs = loader.load()
97
+ all_documents.extend(docs)
98
 
99
+ return all_documents
100
 
 
 
 
101
 
102
+ def get_last_user_message(chatbot: List[Union[gr.ChatMessage, dict]]) -> Optional[str]:
103
+ """Get last user prompt."""
104
+ for message in reversed(chatbot):
105
+ content = (
106
+ message.get("content") if isinstance(message, dict) else message.content
107
+ )
108
+ if (
109
+ message.get("role") if isinstance(message, dict) else message.role
110
+ ) == "user":
111
+ return content
112
+ return None
113
+
114
 
115
+ # ========== Main Logic ==========
 
 
 
 
 
116
 
 
 
 
 
 
 
 
117
 
 
118
 
119
 
120
  def upload_files(
121
+ files: Optional[List[str]], chatbot: List[Union[gr.ChatMessage, dict]]
122
  ):
123
+ """Handle file upload - .docx or .zip containing docx."""
124
+ if not files:
125
+ return chatbot
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
126
 
127
+ file_summaries = [] # <-- Collect formatted file/folder info
128
+ documents = []
 
 
129
 
130
+ with tempfile.TemporaryDirectory() as temp_dir:
131
+ for file_path in files:
132
+ filename = os.path.basename(file_path)
133
+ ext = os.path.splitext(file_path)[1].lower()
134
 
135
+ if ext == ".zip":
136
+ file_summaries.append(f"📦 **{filename}** (ZIP file) contains:")
137
+ try:
138
+ with zipfile.ZipFile(file_path, "r") as zip_ref:
139
+ zip_ref.extractall(temp_dir)
140
+ zip_contents = zip_ref.namelist()
141
+
142
+ # Group files by folder
143
+ folder_map = collections.defaultdict(list)
144
+ for item in zip_contents:
145
+ if item.endswith("/"):
146
+ continue # skip folder entries themselves
147
+ folder = os.path.dirname(item)
148
+ file_name = os.path.basename(item)
149
+ folder_map[folder].append(file_name)
150
+
151
+ # Format nicely
152
+ for folder, files_in_folder in folder_map.items():
153
+ if folder:
154
+ file_summaries.append(f"📂 {folder}/")
155
+ else:
156
+ file_summaries.append(f"📄 (root)")
157
+ for f in files_in_folder:
158
+ file_summaries.append(f" - {f}")
159
+
160
+ # Load docx files extracted from ZIP
161
+ loader = DirectoryLoader(
162
+ path=temp_dir,
163
+ glob="**/*.docx",
164
+ use_multithreading=True,
165
+ )
166
+ docs = loader.load()
167
+ documents.extend(docs)
168
+
169
+ except zipfile.BadZipFile:
170
+ chatbot.append(
171
+ gr.ChatMessage(
172
+ role="assistant",
173
+ content=f"❌ Failed to open ZIP file: {filename}",
174
+ )
175
+ )
176
+
177
+ elif ext == ".docx":
178
+ file_summaries.append(f"📄 **{filename}**")
179
+ loader = UnstructuredFileLoader(file_path)
180
+ docs = loader.load()
181
+ documents.extend(docs)
182
 
183
+ else:
184
+ file_summaries.append(f" Unsupported file type: {filename}")
185
 
186
+ if not documents:
187
  chatbot.append(
188
  gr.ChatMessage(
189
+ role="assistant", content="No valid .docx files found in upload."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
190
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
191
  )
192
+ return chatbot
193
 
194
+ # Split documents
195
+ chunks = text_splitter.split_documents(documents)
196
+ if not chunks:
197
  chatbot.append(
198
  gr.ChatMessage(
199
+ role="assistant", content="Failed to split documents into chunks."
 
200
  )
201
  )
202
+ return chatbot
203
 
204
+ # Create Vectorstore
205
+ state.vectorstore = InMemoryVectorStore.from_documents(
206
+ documents=chunks,
207
+ embedding=embed_model,
208
+ )
209
+ retriever = state.vectorstore.as_retriever()
210
+
211
+ # Build RAG Chain
212
+ state.rag_chain = (
213
+ {"context": retriever, "question": RunnablePassthrough()}
214
+ | rag_prompt
215
+ | llm
216
+ | StrOutputParser()
217
+ )
218
 
219
+ # Final display
220
+ chatbot.append(
221
+ gr.ChatMessage(
222
+ role="assistant",
223
+ content="**Uploaded Files:**\n"
224
+ + "\n".join(file_summaries)
225
+ + "\n\n✅ Ready to chat!",
226
+ )
227
+ )
228
+ return chatbot
229
 
 
 
 
230
 
231
+ def user_message(
232
+ text_prompt: str, chatbot: List[Union[gr.ChatMessage, dict]]
233
+ ) -> Tuple[str, List[Union[gr.ChatMessage, dict]]]:
234
+ """Add user's text input to conversation."""
235
+ if text_prompt.strip():
236
  chatbot.append(gr.ChatMessage(role="user", content=text_prompt))
237
  return "", chatbot
238
 
239
 
240
+ def process_query(
241
+ chatbot: List[Union[gr.ChatMessage, dict]],
242
+ ) -> List[Union[gr.ChatMessage, dict]]:
243
+ """Process user's query through RAG pipeline."""
244
+ prompt = get_last_user_message(chatbot)
245
+ if not prompt:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
246
  chatbot.append(
247
+ gr.ChatMessage(role="assistant", content="Please type a question first.")
 
 
 
248
  )
249
  return chatbot
250
 
251
+ if state.rag_chain is None:
 
 
 
 
 
 
 
 
 
 
 
 
252
  chatbot.append(
253
+ gr.ChatMessage(role="assistant", content="Please upload documents first.")
 
 
 
254
  )
255
  return chatbot
256
 
 
257
  chatbot.append(gr.ChatMessage(role="assistant", content="Thinking..."))
258
 
259
  try:
260
+ response = state.rag_chain.invoke(prompt)
 
 
 
261
  chatbot[-1].content = response
262
  except Exception as e:
263
+ chatbot[-1].content = f"Error: {str(e)}"
 
264
 
265
  return chatbot
266
 
267
 
268
+ def reset_app(
269
+ chatbot: List[Union[gr.ChatMessage, dict]],
270
+ ) -> List[Union[gr.ChatMessage, dict]]:
271
+ """Reset application state."""
272
+ state.vectorstore = None
273
+ state.rag_chain = None
 
 
 
 
 
 
 
 
 
 
 
 
274
  return [
275
  gr.ChatMessage(
276
+ role="assistant", content="App reset! Upload new documents to start."
 
277
  )
278
  ]
279
 
280
 
281
+ # ========== UI Layout ==========
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
282
 
283
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
284
  gr.HTML(TITLE)
285
+ chatbot = gr.Chatbot(
286
+ label="Llama 4 RAG",
287
+ type="messages",
288
+ bubble_full_width=False,
289
+ avatar_images=AVATAR_IMAGES,
290
+ scale=2,
291
+ height=350,
 
 
 
 
 
 
 
 
 
 
 
 
292
  )
293
 
294
+ with gr.Row(equal_height=True):
295
+ text_prompt = gr.Textbox(
296
+ placeholder="Ask a question...", show_label=False, autofocus=True, scale=28
297
+ )
298
+ send_button = gr.Button(
299
+ value="Send",
300
+ variant="primary",
301
+ scale=1,
302
+ min_width=80,
303
+ )
304
+ upload_button = gr.UploadButton(
305
+ label="Upload",
306
+ file_count="multiple",
307
+ file_types=TEXT_EXTENSIONS,
308
+ scale=1,
309
+ min_width=80,
310
+ )
311
+ reset_button = gr.Button(
312
+ value="Reset",
313
+ variant="stop",
314
+ scale=1,
315
+ min_width=80,
316
+ )
317
 
318
+ send_button.click(
319
+ fn=user_message,
320
+ inputs=[text_prompt, chatbot],
321
+ outputs=[text_prompt, chatbot],
 
322
  queue=False,
323
+ ).then(fn=process_query, inputs=[chatbot], outputs=[chatbot])
324
 
325
+ text_prompt.submit(
326
+ fn=user_message,
327
+ inputs=[text_prompt, chatbot],
328
+ outputs=[text_prompt, chatbot],
 
329
  queue=False,
330
+ ).then(fn=process_query, inputs=[chatbot], outputs=[chatbot])
331
+
332
+ upload_button.upload(
333
+ fn=upload_files, inputs=[upload_button, chatbot], outputs=[chatbot], queue=False
334
  )
335
+ reset_button.click(fn=reset_app, inputs=[chatbot], outputs=[chatbot], queue=False)
336
 
 
337
  demo.queue().launch()