Update rag_system.py
Browse files- rag_system.py +87 -0
rag_system.py
CHANGED
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# rag_system.py
|
2 |
+
|
3 |
+
from langchain_community.vectorstores import Chroma
|
4 |
+
from langchain_community.embeddings import HuggingFaceEmbeddings
|
5 |
+
from langchain_community.llms import Together
|
6 |
+
from langchain.prompts import PromptTemplate
|
7 |
+
from langchain_core.output_parsers import StrOutputParser
|
8 |
+
from langchain_core.runnables import RunnablePassthrough
|
9 |
+
import os
|
10 |
+
|
11 |
+
class RAGSystem:
|
12 |
+
def __init__(self, vector_db_path="vector_db", model_name="meta-llama/Llama-3-8b-chat-hf", embedding_model_name="sentence-transformers/all-mpnet-base-v2", api_key=None):
|
13 |
+
self.vector_db_path = vector_db_path
|
14 |
+
self.embedding_model_name = embedding_model_name
|
15 |
+
self.model_name = model_name
|
16 |
+
self.api_key = api_key or os.getenv("TOGETHER_API_KEY")
|
17 |
+
|
18 |
+
self.embedding_model = None
|
19 |
+
self.vectorstore = None
|
20 |
+
self.llm = None
|
21 |
+
self.prompt_template = None
|
22 |
+
|
23 |
+
def load_vectorstore(self):
|
24 |
+
self.embedding_model = HuggingFaceEmbeddings(model_name=self.embedding_model_name)
|
25 |
+
self.vectorstore = Chroma(persist_directory=self.vector_db_path, embedding_function=self.embedding_model)
|
26 |
+
print(f"✅ Vectorstore loaded from: {self.vector_db_path}")
|
27 |
+
return self.vectorstore
|
28 |
+
|
29 |
+
def load_llm(self):
|
30 |
+
if not self.api_key:
|
31 |
+
raise ValueError("❌ API key not found. Please set TOGETHER_API_KEY.")
|
32 |
+
self.llm = Together(
|
33 |
+
model=self.model_name,
|
34 |
+
temperature=0.2,
|
35 |
+
max_tokens=512,
|
36 |
+
top_p=0.95,
|
37 |
+
together_api_key=self.api_key
|
38 |
+
)
|
39 |
+
print(f"✅ LLM loaded from Together: {self.model_name}")
|
40 |
+
return self.llm
|
41 |
+
|
42 |
+
def get_prompt_template(self):
|
43 |
+
self.prompt_template = PromptTemplate(
|
44 |
+
input_variables=["context", "question", "sources"],
|
45 |
+
template="""
|
46 |
+
You are a cybersecurity compliance assistant specializing in Saudi Arabian regulations.
|
47 |
+
Use ONLY the provided official documents (NCA Cybersecurity Framework, YESSER standards, SCYWF, or ECC controls) to answer.
|
48 |
+
If the answer cannot be found in the provided context, respond with:
|
49 |
+
"The answer is not available in the ECC guide." (Arabic: "الإجابة غير متوفرة في دليل ECC")
|
50 |
+
Instructions:
|
51 |
+
- Provide factual, formal, and concise answers only from the context.
|
52 |
+
- Do NOT add conversational phrases (e.g., "Hello, I'm happy to help you").
|
53 |
+
- If the question explicitly asks for a summary, present a short bullet-point summary.
|
54 |
+
- Merge related points without repetition.
|
55 |
+
- Always add the sources at the end of the answer in the format: "Sources: {sources}".
|
56 |
+
- Do NOT mention being an AI model.
|
57 |
+
- If the context contains no relevant data, state clearly it is not available.
|
58 |
+
|
59 |
+
Context:
|
60 |
+
{context}
|
61 |
+
|
62 |
+
Question:
|
63 |
+
{question}
|
64 |
+
|
65 |
+
Answer:
|
66 |
+
"""
|
67 |
+
)
|
68 |
+
return self.prompt_template
|
69 |
+
|
70 |
+
def ask_question(self, user_input):
|
71 |
+
retriever = self.vectorstore.as_retriever(search_kwargs={"k": 5})
|
72 |
+
docs = retriever.invoke(user_input) # تستخدم invoke بدلاً من get_relevant_documents
|
73 |
+
|
74 |
+
if not docs:
|
75 |
+
return "The answer is not available in the ECC guide."
|
76 |
+
|
77 |
+
context = "\n\n".join([d.page_content for d in docs])
|
78 |
+
raw_sources = [
|
79 |
+
f"source={d.metadata.get('source','?')};page={d.metadata.get('page_label', d.metadata.get('page','?'))}"
|
80 |
+
for d in docs
|
81 |
+
]
|
82 |
+
sources = " | ".join(set(raw_sources))
|
83 |
+
|
84 |
+
answer_prompt = self.prompt_template.format(context=context, question=user_input, sources=sources)
|
85 |
+
answer = self.llm(answer_prompt)
|
86 |
+
|
87 |
+
return answer
|