meshsl commited on
Commit
1bac175
·
verified ·
1 Parent(s): b818c49

Update rag_system.py

Browse files
Files changed (1) hide show
  1. rag_system.py +87 -0
rag_system.py CHANGED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # rag_system.py
2
+
3
+ from langchain_community.vectorstores import Chroma
4
+ from langchain_community.embeddings import HuggingFaceEmbeddings
5
+ from langchain_community.llms import Together
6
+ from langchain.prompts import PromptTemplate
7
+ from langchain_core.output_parsers import StrOutputParser
8
+ from langchain_core.runnables import RunnablePassthrough
9
+ import os
10
+
11
+ class RAGSystem:
12
+ def __init__(self, vector_db_path="vector_db", model_name="meta-llama/Llama-3-8b-chat-hf", embedding_model_name="sentence-transformers/all-mpnet-base-v2", api_key=None):
13
+ self.vector_db_path = vector_db_path
14
+ self.embedding_model_name = embedding_model_name
15
+ self.model_name = model_name
16
+ self.api_key = api_key or os.getenv("TOGETHER_API_KEY")
17
+
18
+ self.embedding_model = None
19
+ self.vectorstore = None
20
+ self.llm = None
21
+ self.prompt_template = None
22
+
23
+ def load_vectorstore(self):
24
+ self.embedding_model = HuggingFaceEmbeddings(model_name=self.embedding_model_name)
25
+ self.vectorstore = Chroma(persist_directory=self.vector_db_path, embedding_function=self.embedding_model)
26
+ print(f"✅ Vectorstore loaded from: {self.vector_db_path}")
27
+ return self.vectorstore
28
+
29
+ def load_llm(self):
30
+ if not self.api_key:
31
+ raise ValueError("❌ API key not found. Please set TOGETHER_API_KEY.")
32
+ self.llm = Together(
33
+ model=self.model_name,
34
+ temperature=0.2,
35
+ max_tokens=512,
36
+ top_p=0.95,
37
+ together_api_key=self.api_key
38
+ )
39
+ print(f"✅ LLM loaded from Together: {self.model_name}")
40
+ return self.llm
41
+
42
+ def get_prompt_template(self):
43
+ self.prompt_template = PromptTemplate(
44
+ input_variables=["context", "question", "sources"],
45
+ template="""
46
+ You are a cybersecurity compliance assistant specializing in Saudi Arabian regulations.
47
+ Use ONLY the provided official documents (NCA Cybersecurity Framework, YESSER standards, SCYWF, or ECC controls) to answer.
48
+ If the answer cannot be found in the provided context, respond with:
49
+ "The answer is not available in the ECC guide." (Arabic: "الإجابة غير متوفرة في دليل ECC")
50
+ Instructions:
51
+ - Provide factual, formal, and concise answers only from the context.
52
+ - Do NOT add conversational phrases (e.g., "Hello, I'm happy to help you").
53
+ - If the question explicitly asks for a summary, present a short bullet-point summary.
54
+ - Merge related points without repetition.
55
+ - Always add the sources at the end of the answer in the format: "Sources: {sources}".
56
+ - Do NOT mention being an AI model.
57
+ - If the context contains no relevant data, state clearly it is not available.
58
+
59
+ Context:
60
+ {context}
61
+
62
+ Question:
63
+ {question}
64
+
65
+ Answer:
66
+ """
67
+ )
68
+ return self.prompt_template
69
+
70
+ def ask_question(self, user_input):
71
+ retriever = self.vectorstore.as_retriever(search_kwargs={"k": 5})
72
+ docs = retriever.invoke(user_input) # تستخدم invoke بدلاً من get_relevant_documents
73
+
74
+ if not docs:
75
+ return "The answer is not available in the ECC guide."
76
+
77
+ context = "\n\n".join([d.page_content for d in docs])
78
+ raw_sources = [
79
+ f"source={d.metadata.get('source','?')};page={d.metadata.get('page_label', d.metadata.get('page','?'))}"
80
+ for d in docs
81
+ ]
82
+ sources = " | ".join(set(raw_sources))
83
+
84
+ answer_prompt = self.prompt_template.format(context=context, question=user_input, sources=sources)
85
+ answer = self.llm(answer_prompt)
86
+
87
+ return answer