himel06 commited on
Commit
23c91c6
·
verified ·
1 Parent(s): 22234b1

Create BanglaRAG/bangla_rag_pipeline.py

Browse files
Files changed (1) hide show
  1. BanglaRAG/bangla_rag_pipeline.py +303 -0
BanglaRAG/bangla_rag_pipeline.py ADDED
@@ -0,0 +1,303 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from transformers import (
4
+ AutoTokenizer,
5
+ AutoModelForCausalLM,
6
+ pipeline,
7
+ GenerationConfig,
8
+ BitsAndBytesConfig,
9
+ )
10
+ from langchain_core.prompts import PromptTemplate
11
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
12
+ from langchain_community.embeddings import HuggingFaceEmbeddings
13
+ from langchain_community.vectorstores import Chroma
14
+ from langchain_community.llms.huggingface_pipeline import HuggingFacePipeline
15
+ from langchain_core.runnables import RunnableParallel, RunnablePassthrough
16
+ from langchain_core.output_parsers import StrOutputParser
17
+ from langchain_core.prompts import ChatPromptTemplate, HumanMessagePromptTemplate
18
+ from rich import print as rprint
19
+ from rich.panel import Panel
20
+ from tqdm import tqdm
21
+ import warnings
22
+ import re
23
+
24
+ warnings.filterwarnings("ignore")
25
+
26
+ class BanglaRAGChain:
27
+ """
28
+ Bangla Retrieval-Augmented Generation (RAG) Chain for question answering.
29
+
30
+ This class uses a HuggingFace/local language model for text generation, a Chroma vector database for
31
+ document retrieval, and a custom prompt template to create a RAG chain that can generate
32
+ responses to user queries in Bengali.
33
+ """
34
+
35
+ def __init__(self):
36
+ """Initializes the BanglaRAGChain with default parameters."""
37
+ self._device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
38
+ self.chat_model_id = None
39
+ self.embed_model_id = None
40
+ self.k = 4
41
+ self.max_new_tokens = 1024
42
+ self.chunk_size = 500
43
+ self.chunk_overlap = 150
44
+ self.text_path = ""
45
+ self.quantization = None
46
+ self.temperature = 0.9
47
+ self.top_p = 0.6
48
+ self.top_k = 50
49
+ self._text_content = None
50
+ self.hf_token = None
51
+
52
+ self.tokenizer = None
53
+ self.chat_model = None
54
+ self._llm = None
55
+ self._retriever = None
56
+ self._db = None
57
+ self._documents = []
58
+ self._chain = None
59
+
60
+ def load(
61
+ self,
62
+ chat_model_id,
63
+ embed_model_id,
64
+ text_path,
65
+ quantization,
66
+ k=4,
67
+ top_k=2,
68
+ top_p=0.6,
69
+ max_new_tokens=1024,
70
+ temperature=0.6,
71
+ chunk_size=500,
72
+ chunk_overlap=150,
73
+ hf_token=None,
74
+ ):
75
+ """
76
+ Loads the required models and data for the RAG chain.
77
+
78
+ Args:
79
+ chat_model_id (str): The Hugging Face model ID for the chat model.
80
+ embed_model_id (str): The Hugging Face model ID for the embedding model.
81
+ text_path (str): Path to the text file to be indexed.
82
+ quantization (bool): Whether to quantize the model or not.
83
+ k (int): The number of documents to retrieve.
84
+ top_k (int): The top_k parameter for the generation configuration.
85
+ top_p (float): The top_p parameter for the generation configuration.
86
+ max_new_tokens (int): The maximum number of new tokens to generate.
87
+ temperature (float): The temperature parameter for the generation configuration.
88
+ chunk_size (int): The chunk size for text splitting.
89
+ chunk_overlap (int): The chunk overlap for text splitting.
90
+ hf_token (str): The Hugging Face token for authentication.
91
+ """
92
+ self.chat_model_id = chat_model_id
93
+ self.embed_model_id = embed_model_id
94
+ self.k = k
95
+ self.top_k = top_k
96
+ self.top_p = top_p
97
+ self.temperature = temperature
98
+ self.chunk_size = chunk_size
99
+ self.chunk_overlap = chunk_overlap
100
+ self.text_path = text_path
101
+ self.quantization = quantization
102
+ self.max_new_tokens = max_new_tokens
103
+ self.hf_token = hf_token
104
+
105
+ if self.hf_token is not None:
106
+ os.environ["HF_TOKEN"] = str(self.hf_token)
107
+
108
+ rprint(Panel("[bold green]Loading chat models...", expand=False))
109
+ self._load_models()
110
+
111
+ rprint(Panel("[bold green]Creating document...", expand=False))
112
+ self._create_document()
113
+
114
+ rprint(Panel("[bold green]Updating Chroma database...", expand=False))
115
+ self._update_chroma_db()
116
+
117
+ rprint(Panel("[bold green]Initializing retriever...", expand=False))
118
+ self._get_retriever()
119
+
120
+ rprint(Panel("[bold green]Initializing LLM...", expand=False))
121
+ self._get_llm()
122
+ rprint(Panel("[bold green]Creating chain...", expand=False))
123
+ self._create_chain()
124
+
125
+ def _load_models(self):
126
+ """Loads the chat model and tokenizer."""
127
+ try:
128
+ self.tokenizer = AutoTokenizer.from_pretrained(self.chat_model_id)
129
+ bnb_config = None
130
+ if self.quantization:
131
+ bnb_config = BitsAndBytesConfig(
132
+ load_in_4bit=True,
133
+ bnb_4bit_use_double_quant=True,
134
+ bnb_4bit_quant_type="nf4",
135
+ bnb_4bit_compute_dtype=torch.float16,
136
+ )
137
+ rprint(Panel("[bold green]Applying 4bit quantization...", expand=False))
138
+ self.chat_model = AutoModelForCausalLM.from_pretrained(
139
+ self.chat_model_id,
140
+ torch_dtype=torch.float16,
141
+ low_cpu_mem_usage=True,
142
+ quantization_config=bnb_config,
143
+ device_map="auto",
144
+ # cache_dir=CACHE_DIR, # Removed cache_dir to use default caching
145
+ )
146
+ rprint(Panel("[bold green]Applied 4bit quantization successfully", expand=False))
147
+
148
+ else:
149
+ self.chat_model = AutoModelForCausalLM.from_pretrained(
150
+ self.chat_model_id,
151
+ torch_dtype=torch.float16,
152
+ low_cpu_mem_usage=True,
153
+ device_map="auto",
154
+ # cache_dir=CACHE_DIR, # Removed cache_dir to use default caching
155
+ )
156
+ rprint(Panel("[bold green]Chat Model loaded successfully!", expand=False))
157
+ except Exception as e:
158
+ rprint(Panel(f"[red]Error loading chat model: {e}", expand=False))
159
+
160
+ def _create_document(self):
161
+ """Splits the input text into chunks using RecursiveCharacterTextSplitter."""
162
+ try:
163
+ with open(self.text_path, "r", encoding="utf-8") as file:
164
+ self._text_content = file.read()
165
+ character_splitter = RecursiveCharacterTextSplitter(
166
+ separators=["!", "?", "।"],
167
+ chunk_size=self.chunk_size,
168
+ chunk_overlap=self.chunk_overlap,
169
+ )
170
+ self._documents = list(
171
+ tqdm(
172
+ character_splitter.split_text(self._text_content),
173
+ desc="Chunking text",
174
+ )
175
+ )
176
+ print(f"Number of chunks: {len(self._documents)}")
177
+ if False:
178
+ for i, chunk in enumerate(self._documents):
179
+ if i > 5:
180
+ break
181
+ print(f"Chunk {i}: {chunk}")
182
+ rprint(Panel("[bold green]Document created successfully!", expand=False))
183
+ except Exception as e:
184
+ rprint(Panel(f"[red]Chunking failed: {e}", expand=False))
185
+
186
+ def _update_chroma_db(self):
187
+ """Updates the Chroma vector database with the text chunks."""
188
+ try:
189
+ try:
190
+ rprint(Panel(f"[bold green]Loading embedding model...",expand=False))
191
+ model_kwargs = {"device": self._device}
192
+ embeddings = HuggingFaceEmbeddings(
193
+ model_name=self.embed_model_id, model_kwargs=model_kwargs
194
+ )
195
+ rprint(Panel(f"[bold green]Loaded embedding model successfully!", expand=False))
196
+ except Exception as e:
197
+ rprint(Panel("f[red]embedding model loading failed: {e}", expand=False))
198
+
199
+
200
+ self._db = Chroma.from_texts(texts=self._documents, embedding=embeddings)
201
+ rprint(
202
+ Panel("[bold green]Chroma database updated successfully!", expand=False)
203
+ )
204
+ except Exception as e:
205
+ rprint(Panel(f"[red]Vector DB initialization failed: {e}", expand=False))
206
+
207
+ def _create_chain(self):
208
+ """Creates the retrieval-augmented generation (RAG) chain."""
209
+ template = """Below is an instruction in Bengali language that describes a task, paired with an input also in Bengali language that provides further context. Write a response in Bengali that appropriately completes the request.
210
+
211
+ ### Instruction:
212
+ {question}
213
+
214
+ ### Input:
215
+ {context}
216
+
217
+ ### Response:
218
+ """
219
+ prompt_template = ChatPromptTemplate(
220
+ input_variables=["question", "context"],
221
+ output_parser=None,
222
+ partial_variables={},
223
+ messages=[
224
+ HumanMessagePromptTemplate(
225
+ prompt=PromptTemplate(
226
+ input_variables=["question", "context"],
227
+ output_parser=None,
228
+ partial_variables={},
229
+ template=template,
230
+ template_format="f-string",
231
+ validate_template=True,
232
+ ),
233
+ additional_kwargs={},
234
+ )
235
+ ],
236
+ )
237
+
238
+ try:
239
+ rag_chain_from_docs = (
240
+ RunnablePassthrough.assign(
241
+ context=lambda x: self._format_docs(x["context"])
242
+ )
243
+ | prompt_template
244
+ | self._llm
245
+ | StrOutputParser()
246
+ )
247
+
248
+ rag_chain_with_source = RunnableParallel(
249
+ {"context": self._retriever, "question": RunnablePassthrough()}
250
+ ).assign(answer=rag_chain_from_docs)
251
+
252
+ self._chain = rag_chain_with_source
253
+ rprint(Panel("[bold green]Chain created successfully!", expand=False))
254
+ except Exception as e:
255
+ rprint(Panel(f"[red]Chain creation failed: {e}", expand=False))
256
+
257
+ def _get_retriever(self):
258
+ """Creates a retriever for the vector database."""
259
+ self._retriever = self._db.as_retriever(search_kwargs={"k": self.k})
260
+
261
+ def _get_llm(self):
262
+ """Initializes the language model using the Hugging Face pipeline."""
263
+ try:
264
+ pipe = pipeline(
265
+ "text-generation",
266
+ model=self.chat_model,
267
+ tokenizer=self.tokenizer,
268
+ device=self._device,
269
+ max_new_tokens=self.max_new_tokens,
270
+ pad_token_id=self.tokenizer.eos_token_id,
271
+ do_sample=True,
272
+ temperature=self.temperature,
273
+ top_p=self.top_p,
274
+ top_k=self.top_k,
275
+ repetition_penalty=1.2,
276
+ torch_dtype=torch.float16,
277
+ )
278
+
279
+ self._llm = HuggingFacePipeline(pipeline=pipe)
280
+ rprint(Panel("[bold green]LLM initialized successfully!", expand=False))
281
+ except Exception as e:
282
+ rprint(Panel(f"[red]LLM initialization failed: {e}", expand=False))
283
+
284
+ def _format_docs(self, docs):
285
+ """Formats the retrieved documents for the prompt."""
286
+ formatted_docs = "\n".join([re.sub(r"\s+", " ", doc) for doc in docs])
287
+ return formatted_docs
288
+
289
+ def query(self, prompt: str) -> str:
290
+ """
291
+ Queries the RAG chain with a given prompt.
292
+
293
+ Args:
294
+ prompt (str): The input prompt to query the RAG chain.
295
+
296
+ Returns:
297
+ str: The generated response from the RAG chain.
298
+ """
299
+ return self._chain.invoke({"question": prompt})
300
+
301
+ def __call__(self, prompt: str) -> str:
302
+ """Alias for the query method."""
303
+ return self.query(prompt)