File size: 18,761 Bytes
f4c0f01
1f683db
 
 
 
72544b8
1f683db
 
a8285e4
 
 
 
1f683db
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a8285e4
 
 
 
 
 
 
 
1f683db
a8285e4
 
 
 
 
1f683db
a8285e4
 
 
 
 
 
 
 
 
 
1f683db
a8285e4
 
1f683db
a8285e4
 
 
1f683db
a8285e4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1f683db
a8285e4
 
 
 
 
 
1f683db
 
a8285e4
 
 
 
 
1f683db
 
a8285e4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1f683db
a8285e4
 
 
 
 
 
 
1f683db
a8285e4
 
 
 
 
 
 
 
 
 
 
 
 
1f683db
a8285e4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1f683db
a8285e4
 
1f683db
a8285e4
 
 
 
 
 
 
 
1f683db
 
a8285e4
 
1f683db
a8285e4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1f683db
a8285e4
 
 
1f683db
a8285e4
 
 
 
 
1f683db
a8285e4
 
1f683db
a8285e4
 
1f683db
a8285e4
 
1f683db
a8285e4
 
1f683db
a8285e4
 
 
1f683db
a8285e4
 
 
 
 
 
 
 
 
 
 
1f683db
a8285e4
1f683db
a8285e4
 
1f683db
a8285e4
 
 
 
1f683db
a8285e4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1f683db
a8285e4
 
 
1f683db
a8285e4
 
 
 
 
 
 
1f683db
a8285e4
1f683db
a8285e4
 
 
 
1f683db
a8285e4
 
1f683db
 
 
a8285e4
1f683db
a8285e4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
732ba20
a8285e4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
732ba20
a8285e4
1f683db
a8285e4
 
 
 
 
 
 
 
 
 
 
 
5224f4e
8f83e1c
1f683db
a8285e4
1f683db
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
import os
import re
import json
from tqdm import tqdm
from pathlib import Path
import spaces
import gradio as gr

# WARNING: Don't import torch, cuda, or GPU-related modules at the top level
# They must ONLY be imported inside functions decorated with @spaces.GPU

# Helper functions that don't use GPU
def safe_tokenize(text):
    """Pure regex tokenizer with no NLTK dependency"""
    if not text:
        return []
    # Replace punctuation with spaces around them
    text = re.sub(r'([.,!?;:()\[\]{}"\'/\\])', r' \1 ', text)
    # Split on whitespace and filter empty strings
    return [token for token in re.split(r'\s+', text.lower()) if token]

def detect_language(text):
    """Detect if text is primarily Arabic or English"""
    # Simple heuristic: count Arabic characters
    arabic_chars = re.findall(r'[\u0600-\u06FF]', text)
    is_arabic = len(arabic_chars) > len(text) * 0.5
    return "arabic" if is_arabic else "english"

# Comprehensive evaluation dataset
comprehensive_evaluation_data = [
    # === Overview ===
    {
        "query": "ما هي رؤية السعودية 2030؟",
        "reference": "رؤية السعودية 2030 هي خطة استراتيجية تهدف إلى تنويع الاقتصاد السعودي وتقليل الاعتماد على النفط مع تطوير قطاعات مختلفة مثل الصحة والتعليم والسياحة.",
        "category": "overview",
        "language": "arabic"
    },
    {
        "query": "What is Saudi Vision 2030?",
        "reference": "Saudi Vision 2030 is a strategic framework aiming to diversify Saudi Arabia's economy and reduce dependence on oil, while developing sectors like health, education, and tourism.",
        "category": "overview",
        "language": "english"
    },
    
    # === Economic Goals ===
    {
        "query": "ما هي الأهداف الاقتصادية لرؤية 2030؟",
        "reference": "تشمل الأهداف الاقتصادية زيادة مساهمة القطاع الخاص إلى 65%، وزيادة الصادرات غير النفطية إلى 50% من الناتج المحلي غير النفطي، وخفض البطالة إلى 7%.",
        "category": "economic",
        "language": "arabic"
    },
    {
        "query": "What are the economic goals of Vision 2030?",
        "reference": "The economic goals of Vision 2030 include increasing private sector contribution from 40% to 65% of GDP, raising non-oil exports from 16% to 50%, reducing unemployment from 11.6% to 7%.",
        "category": "economic",
        "language": "english"
    },
    
    # === Social Goals ===
    {
        "query": "كيف تعزز رؤية 2030 الإرث الثقافي السعودي؟",
        "reference": "تتضمن رؤية 2030 الحفاظ على الهوية الوطنية، تسجيل مواقع أثرية في اليونسكو، وتعزيز الفعاليات الثقافية.",
        "category": "social",
        "language": "arabic"
    },
    {
        "query": "How does Vision 2030 aim to improve quality of life?",
        "reference": "Vision 2030 plans to enhance quality of life by expanding sports facilities, promoting cultural activities, and boosting tourism and entertainment sectors.",
        "category": "social",
        "language": "english"
    }
]

# RAG Service class
class Vision2030Service:
    def __init__(self):
        self.initialized = False
        self.model = None
        self.tokenizer = None
        self.vector_store = None
        self.conversation_history = []
        
    @spaces.GPU
    def initialize(self):
        """Initialize the system - ALL GPU operations must happen here"""
        if self.initialized:
            return True
            
        try:
            # Import all GPU-dependent libraries only inside this function
            import torch
            import PyPDF2
            from transformers import AutoTokenizer, AutoModelForCausalLM
            from sentence_transformers import SentenceTransformer
            from langchain.text_splitter import RecursiveCharacterTextSplitter
            from langchain_community.vectorstores import FAISS
            from langchain.schema import Document
            from langchain.embeddings import HuggingFaceEmbeddings
            
            # Define paths for PDF files
            pdf_files = ["saudi_vision203.pdf", "saudi_vision2030_ar.pdf"]
            
            # Process PDFs and create vector store
            vector_store_dir = "vector_stores"
            os.makedirs(vector_store_dir, exist_ok=True)
            
            if os.path.exists(os.path.join(vector_store_dir, "index.faiss")):
                print("Loading existing vector store...")
                embedding_function = HuggingFaceEmbeddings(
                    model_name="sentence-transformers/paraphrase-multilingual-mpnet-base-v2"
                )
                self.vector_store = FAISS.load_local(vector_store_dir, embedding_function)
            else:
                print("Creating new vector store...")
                # Process PDFs
                documents = []
                for pdf_path in pdf_files:
                    if not os.path.exists(pdf_path):
                        print(f"Warning: {pdf_path} does not exist")
                        continue
                        
                    print(f"Processing {pdf_path}...")
                    text = ""
                    with open(pdf_path, 'rb') as file:
                        reader = PyPDF2.PdfReader(file)
                        for page in reader.pages:
                            page_text = page.extract_text()
                            if page_text:
                                text += page_text + "\n\n"
                    
                    if text.strip():
                        doc = Document(
                            page_content=text,
                            metadata={"source": pdf_path, "filename": os.path.basename(pdf_path)}
                        )
                        documents.append(doc)
                
                if not documents:
                    raise ValueError("No documents were processed successfully.")
                
                # Split into chunks
                text_splitter = RecursiveCharacterTextSplitter(
                    chunk_size=500,
                    chunk_overlap=50,
                    separators=["\n\n", "\n", ".", "!", "?", ",", " ", ""]
                )
                
                chunks = []
                for doc in documents:
                    doc_chunks = text_splitter.split_text(doc.page_content)
                    chunks.extend([
                        Document(page_content=chunk, metadata=doc.metadata)
                        for chunk in doc_chunks
                    ])
                
                # Create vector store
                embedding_function = HuggingFaceEmbeddings(
                    model_name="sentence-transformers/paraphrase-multilingual-mpnet-base-v2"
                )
                self.vector_store = FAISS.from_documents(chunks, embedding_function)
                self.vector_store.save_local(vector_store_dir)
            
            # Load model
            model_name = "ALLaM-AI/ALLaM-7B-Instruct-preview"
            self.tokenizer = AutoTokenizer.from_pretrained(
                model_name,
                trust_remote_code=True,
                use_fast=False
            )
            
            self.model = AutoModelForCausalLM.from_pretrained(
                model_name,
                torch_dtype=torch.bfloat16,
                trust_remote_code=True,
                device_map="auto",
            )
            
            self.initialized = True
            return True
            
        except Exception as e:
            import traceback
            print(f"Initialization error: {e}")
            print(traceback.format_exc())
            return False
    
    @spaces.GPU
    def retrieve_context(self, query, top_k=5):
        """Retrieve contexts from vector store"""
        # Import must be inside the function to avoid CUDA init in main process
        
        if not self.initialized:
            return []
            
        try:
            results = self.vector_store.similarity_search_with_score(query, k=top_k)
            
            contexts = []
            for doc, score in results:
                contexts.append({
                    "content": doc.page_content,
                    "source": doc.metadata.get("source", "Unknown"),
                    "relevance_score": score
                })
            
            return contexts
        except Exception as e:
            print(f"Error retrieving context: {e}")
            return []
    
    @spaces.GPU
    def generate_response(self, query, contexts, language="auto"):
        """Generate response using the model"""
        # Import must be inside the function to avoid CUDA init in main process
        import torch
        
        if not self.initialized or self.model is None or self.tokenizer is None:
            return "I'm still initializing. Please try again in a moment."
        
        try:
            # Auto-detect language if not specified
            if language == "auto":
                language = detect_language(query)
            
            # Format the prompt based on language
            if language == "arabic":
                instruction = (
                    "أنت مساعد افتراضي يهتم برؤية السعودية 2030. استخدم المعلومات التالية للإجابة على السؤال. "
                    "إذا لم تعرف الإجابة، فقل بأمانة إنك لا تعرف."
                )
            else:  # english
                instruction = (
                    "You are a virtual assistant for Saudi Vision 2030. Use the following information to answer the question. "
                    "If you don't know the answer, honestly say you don't know."
                )
            
            # Combine retrieved contexts
            context_text = "\n\n".join([f"Document: {ctx['content']}" for ctx in contexts])
            
            # Format the prompt for ALLaM instruction format
            prompt = f"""<s>[INST] {instruction}

Context:
{context_text}

Question: {query} [/INST]</s>"""
            
            # Generate response
            inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)
            
            outputs = self.model.generate(
                inputs.input_ids,
                attention_mask=inputs.attention_mask,
                max_new_tokens=512,
                temperature=0.7,
                top_p=0.9,
                do_sample=True,
                repetition_penalty=1.1
            )
            
            # Decode the response
            full_output = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
            
            # Extract just the answer part (after the instruction)
            response = full_output.split("[/INST]")[-1].strip()
            
            # If response is empty for some reason, return the full output
            if not response:
                response = full_output
                
            return response
                
        except Exception as e:
            import traceback
            print(f"Error generating response: {e}")
            print(traceback.format_exc())
            return f"Sorry, I encountered an error while generating a response."
    
    @spaces.GPU
    def answer_question(self, query):
        """Process a user query and return a response with sources"""
        if not self.initialized:
            if not self.initialize():
                return "System initialization failed. Please check the logs.", []
        
        try:
            # Add user query to conversation history
            self.conversation_history.append({"role": "user", "content": query})
            
            # Get the full conversation context
            conversation_context = "\n".join([
                f"{'User' if msg['role'] == 'user' else 'Assistant'}: {msg['content']}"
                for msg in self.conversation_history[-6:]  # Keep last 3 turns
            ])
            
            # Enhance query with conversation context 
            enhanced_query = f"{conversation_context}\n{query}"
            
            # Retrieve relevant contexts
            contexts = self.retrieve_context(enhanced_query, top_k=5)
            
            # Generate response
            response = self.generate_response(query, contexts)
            
            # Add response to conversation history
            self.conversation_history.append({"role": "assistant", "content": response})
            
            # Get sources
            sources = [ctx.get("source", "Unknown") for ctx in contexts]
            unique_sources = list(set(sources))
            
            return response, unique_sources
        except Exception as e:
            import traceback
            print(f"Error answering question: {e}")
            print(traceback.format_exc())
            return f"Sorry, I encountered an error: {str(e)}", []
    
    def reset_conversation(self):
        """Reset the conversation history"""
        self.conversation_history = []
        return "Conversation has been reset."

# Main function with Gradio UI
def main():
    # Create the Vision 2030 service
    service = Vision2030Service()
    
    # Build the Gradio interface
    with gr.Blocks(title="Vision 2030 Assistant") as demo:
        gr.Markdown("# Vision 2030 Assistant")
        gr.Markdown("Ask questions about Saudi Vision 2030 in English or Arabic")
        
        with gr.Tab("Chat"):
            chatbot = gr.Chatbot()
            msg = gr.Textbox(label="Your question", placeholder="Ask about Vision 2030...")
            clear = gr.Button("Clear History")
            
            @spaces.GPU
            def respond(message, history):
                if not message:
                    return history, ""
                
                response, sources = service.answer_question(message)
                sources_text = ", ".join(sources) if sources else "No specific sources"
                
                # Format the response to include sources
                full_response = f"{response}\n\nSources: {sources_text}"
                
                return history + [[message, full_response]], ""
            
            def reset_chat():
                service.reset_conversation()
                return [], "Conversation history has been reset."
            
            msg.submit(respond, [msg, chatbot], [chatbot, msg])
            clear.click(reset_chat, None, [chatbot, msg])
        
        with gr.Tab("System Status"):
            init_btn = gr.Button("Initialize System")
            status_box = gr.Textbox(label="Status", value="System not initialized")
            
            @spaces.GPU
            def initialize_system():
                success = service.initialize()
                if success:
                    return "System initialized successfully!"
                else:
                    return "System initialization failed. Check logs for details."
            
            init_btn.click(initialize_system, None, status_box)
            
            # PDF Check section
            gr.Markdown("### PDF Status")
            pdf_btn = gr.Button("Check PDF Files")
            pdf_status = gr.Textbox(label="PDF Files")
            
            def check_pdfs():
                result = []
                for pdf_file in ["saudi_vision203.pdf", "saudi_vision2030_ar.pdf"]:
                    if os.path.exists(pdf_file):
                        size = os.path.getsize(pdf_file) / (1024 * 1024)  # Size in MB
                        result.append(f"{pdf_file}: Found ({size:.2f} MB)")
                    else:
                        result.append(f"{pdf_file}: Not found")
                return "\n".join(result)
            
            pdf_btn.click(check_pdfs, None, pdf_status)
            
            # System check section
            gr.Markdown("### Dependencies")
            sys_btn = gr.Button("Check Dependencies")
            sys_status = gr.Textbox(label="Dependencies Status")
            
            @spaces.GPU
            def check_dependencies():
                result = []
                
                # Safe imports inside GPU-decorated function
                try:
                    import torch
                    result.append(f"✓ PyTorch: {torch.__version__}")
                except ImportError:
                    result.append("✗ PyTorch: Not installed")
                
                try:
                    import transformers
                    result.append(f"✓ Transformers: {transformers.__version__}")
                except ImportError:
                    result.append("✗ Transformers: Not installed")
                
                try:
                    import sentencepiece
                    result.append("✓ SentencePiece: Installed")
                except ImportError:
                    result.append("✗ SentencePiece: Not installed")
                
                try:
                    import accelerate
                    result.append(f"✓ Accelerate: {accelerate.__version__}")
                except ImportError:
                    result.append("✗ Accelerate: Not installed")
                
                try:
                    import langchain
                    result.append(f"✓ LangChain: {langchain.__version__}")
                except ImportError:
                    result.append("✗ LangChain: Not installed")
                
                try:
                    import langchain_community
                    result.append(f"✓ LangChain Community: {langchain_community.__version__}")
                except ImportError:
                    result.append("✗ LangChain Community: Not installed")
                
                return "\n".join(result)
            
            sys_btn.click(check_dependencies, None, sys_status)
        
        with gr.Tab("Sample Questions"):
            gr.Markdown("### Sample Questions to Try")
            
            sample_questions = []
            
            for item in comprehensive_evaluation_data:
                sample_questions.append(item["query"])
            
            questions_md = "\n".join([f"- {q}" for q in sample_questions])
            gr.Markdown(questions_md)
    
    return demo

if __name__ == "__main__":
    demo = main()
    demo.queue()
    demo.launch()