File size: 3,782 Bytes
1bac175
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
# rag_system.py

from langchain_community.vectorstores import Chroma
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.llms import Together
from langchain.prompts import PromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough
import os

class RAGSystem:
    def __init__(self, vector_db_path="vector_db", model_name="meta-llama/Llama-3-8b-chat-hf", embedding_model_name="sentence-transformers/all-mpnet-base-v2", api_key=None):
        self.vector_db_path = vector_db_path
        self.embedding_model_name = embedding_model_name
        self.model_name = model_name
        self.api_key = api_key or os.getenv("TOGETHER_API_KEY")

        self.embedding_model = None
        self.vectorstore = None
        self.llm = None
        self.prompt_template = None

    def load_vectorstore(self):
        self.embedding_model = HuggingFaceEmbeddings(model_name=self.embedding_model_name)
        self.vectorstore = Chroma(persist_directory=self.vector_db_path, embedding_function=self.embedding_model)
        print(f"✅ Vectorstore loaded from: {self.vector_db_path}")
        return self.vectorstore

    def load_llm(self):
        if not self.api_key:
            raise ValueError("❌ API key not found. Please set TOGETHER_API_KEY.")
        self.llm = Together(
            model=self.model_name,
            temperature=0.2,
            max_tokens=512,
            top_p=0.95,
            together_api_key=self.api_key
        )
        print(f"✅ LLM loaded from Together: {self.model_name}")
        return self.llm

    def get_prompt_template(self):
        self.prompt_template = PromptTemplate(
            input_variables=["context", "question", "sources"],
            template="""
            You are a cybersecurity compliance assistant specializing in Saudi Arabian regulations.
            Use ONLY the provided official documents (NCA Cybersecurity Framework, YESSER standards, SCYWF, or ECC controls) to answer.
            If the answer cannot be found in the provided context, respond with:
            "The answer is not available in the ECC guide." (Arabic: "الإجابة غير متوفرة في دليل ECC")
            Instructions:
            - Provide factual, formal, and concise answers only from the context.
            - Do NOT add conversational phrases (e.g., "Hello, I'm happy to help you").
            - If the question explicitly asks for a summary, present a short bullet-point summary.
            - Merge related points without repetition.
            - Always add the sources at the end of the answer in the format: "Sources: {sources}".
            - Do NOT mention being an AI model.
            - If the context contains no relevant data, state clearly it is not available.

            Context:
            {context}

            Question:
            {question}

            Answer:
            """
        )
        return self.prompt_template

    def ask_question(self, user_input):
        retriever = self.vectorstore.as_retriever(search_kwargs={"k": 5})
        docs = retriever.invoke(user_input)  # تستخدم invoke بدلاً من get_relevant_documents

        if not docs:
            return "The answer is not available in the ECC guide."

        context = "\n\n".join([d.page_content for d in docs])
        raw_sources = [
            f"source={d.metadata.get('source','?')};page={d.metadata.get('page_label', d.metadata.get('page','?'))}"
            for d in docs
        ]
        sources = " | ".join(set(raw_sources))

        answer_prompt = self.prompt_template.format(context=context, question=user_input, sources=sources)
        answer = self.llm(answer_prompt)

        return answer