meshsl commited on
Commit
203c55a
·
verified ·
1 Parent(s): 63de98e

Update rag_system.py

Browse files
Files changed (1) hide show
  1. rag_system.py +20 -4
rag_system.py CHANGED
@@ -67,13 +67,26 @@ class RAGSystem:
67
  )
68
  return self.prompt_template
69
 
 
 
 
 
70
  def ask_question(self, user_input):
71
- retriever = self.vectorstore.as_retriever(search_kwargs={"k": 5})
72
- docs = retriever.invoke(user_input) # تستخدم invoke بدلاً من get_relevant_documents
 
 
 
 
 
 
 
 
73
 
74
  if not docs:
75
  return "The answer is not available in the ECC guide."
76
 
 
77
  context = "\n\n".join([d.page_content for d in docs])
78
  raw_sources = [
79
  f"source={d.metadata.get('source','?')};page={d.metadata.get('page_label', d.metadata.get('page','?'))}"
@@ -81,7 +94,10 @@ class RAGSystem:
81
  ]
82
  sources = " | ".join(set(raw_sources))
83
 
84
- answer_prompt = self.prompt_template.format(context=context, question=user_input, sources=sources)
 
 
 
85
  answer = self.llm(answer_prompt)
86
 
87
- return answer
 
67
  )
68
  return self.prompt_template
69
 
70
+ GREETINGS = [
71
+ "hi", "hello", "hey", "good morning", "good afternoon", "good evening"
72
+ ]
73
+
74
  def ask_question(self, user_input):
75
+ # التعامل مع التحية
76
+ if user_input.lower() in [g.lower() for g in self.GREETINGS]:
77
+ return (
78
+ "Hi! Please ask a question related to any of the following Saudi cybersecurity documents: "
79
+ "ECC (Essential Cybersecurity Controls), SCYWF (Saudi Cybersecurity Workforce Framework), or OSMACC (Organizations’ Social Media Accounts Cybersecurity Controls)."
80
+ )
81
+
82
+ # استرجاع الوثائق ذات الصلة
83
+ retriever = self.vectorstore.as_retriever(search_kwargs={"k": 3})
84
+ docs = retriever.invoke(user_input)
85
 
86
  if not docs:
87
  return "The answer is not available in the ECC guide."
88
 
89
+ # بناء السياق والمصادر
90
  context = "\n\n".join([d.page_content for d in docs])
91
  raw_sources = [
92
  f"source={d.metadata.get('source','?')};page={d.metadata.get('page_label', d.metadata.get('page','?'))}"
 
94
  ]
95
  sources = " | ".join(set(raw_sources))
96
 
97
+ # إعداد السؤال والإجابة
98
+ answer_prompt = self.prompt_template.format(
99
+ context=context, question=user_input, sources=sources
100
+ )
101
  answer = self.llm(answer_prompt)
102
 
103
+ return answer