nurasaki commited on
Commit
a880965
·
1 Parent(s): d51834f

Added retrieval num chunks options

Browse files
Files changed (5) hide show
  1. .gitignore +1 -0
  2. app.py +24 -82
  3. handler.py +0 -14
  4. input_reader.py +0 -22
  5. rag.py +12 -4
.gitignore CHANGED
@@ -3,3 +3,4 @@
3
  .env
4
  __pycache__
5
  __pycache__/*
 
 
3
  .env
4
  __pycache__
5
  __pycache__/*
6
+ __DELETE__*
app.py CHANGED
@@ -65,6 +65,8 @@ def submit_input(input_, num_chunks, max_new_tokens, repetition_penalty, top_k,
65
  "temperature": temperature
66
  }
67
 
 
 
68
  output, context, source = generate(input_, model_parameters)
69
  sources_markup = ""
70
 
@@ -87,13 +89,7 @@ def clear():
87
  None,
88
  None,
89
  None,
90
- gr.Slider(value=2.0),
91
- gr.Slider(value=MAX_NEW_TOKENS),
92
- gr.Slider(value=1.0),
93
- gr.Slider(value=50),
94
- gr.Slider(value=0.99),
95
- gr.Checkbox(value=False),
96
- gr.Slider(value=0.35),
97
  )
98
 
99
 
@@ -102,25 +98,12 @@ def gradio_app():
102
  # App Description
103
  # =====================================================================================================================================
104
  with gr.Row():
105
- with gr.Column():
106
-
107
- gr.Markdown(
108
- # """# Demo de Retrieval-Augmented Generation per la Viquipèdia
109
- # 🔍 **Retrieval-Augmented Generation** (RAG) és una tecnologia d'IA que permet interrogar un repositori de documents amb preguntes
110
- # en llenguatge natural, i combina tècniques de recuperació d'informació avançades amb models generatius per redactar una resposta
111
- # fent servir només la informació existent en els documents del repositori.
112
-
113
- # 🎯 **Objectiu:** Aquest és un demostrador amb Viquipèdia i genera la resposta fent servir el model salamandra-7b-instruct.
114
-
115
- # ⚠️ **Advertencies**: Aquesta versió és experimental. El contingut generat per aquest model no està supervisat i pot ser incorrecte.
116
- # Si us plau, tingueu-ho en compte quan exploreu aquest recurs. El model en inferencia asociat a aquesta demo de desenvolupament no funciona continuament. Si vol fer proves,
117
- # contacteu amb nosaltres a Langtech.
118
- # """
119
- )
120
 
121
-
122
- # with gr.Row(equal_height=True):
123
  with gr.Row(equal_height=False):
 
124
  # User Input
125
  # =====================================================================================================================================
126
  with gr.Column(scale=2, variant="panel"):
@@ -131,69 +114,25 @@ def gradio_app():
131
  placeholder="Qui va crear la guerra de les Galaxies ?",
132
  )
133
 
134
-
135
- # with gr.Column(variant="panel"):
136
  with gr.Row(variant="default"):
137
- # with gr.Row(variant="panel"):
138
  clear_btn = Button("Clear",)
139
  submit_btn = Button("Submit", variant="primary", interactive=False)
140
 
141
- # with gr.Row(variant="panel"):
142
- with gr.Row(variant="default"):
143
- with gr.Accordion("Model parameters (not used)", open=False, visible=SHOW_MODEL_PARAMETERS_IN_UI):
144
- num_chunks = Slider(
145
- minimum=1,
146
- maximum=6,
147
- step=1,
148
- value=5,
149
- label="Number of chunks"
150
- )
151
- max_new_tokens = Slider(
152
- minimum=50,
153
- maximum=2000,
154
- step=1,
155
- value=MAX_NEW_TOKENS,
156
- label="Max tokens"
157
- )
158
- repetition_penalty = Slider(
159
- minimum=0.1,
160
- maximum=2.0,
161
- step=0.1,
162
- value=1.0,
163
- label="Repetition penalty"
164
- )
165
- top_k = Slider(
166
- minimum=1,
167
- maximum=100,
168
- step=1,
169
- value=50,
170
- label="Top k"
171
- )
172
- top_p = Slider(
173
- minimum=0.01,
174
- maximum=0.99,
175
- value=0.99,
176
- label="Top p"
177
- )
178
- do_sample = Checkbox(
179
- value=False,
180
- label="Do sample"
181
- )
182
- temperature = Slider(
183
- minimum=0.1,
184
- maximum=1,
185
- value=0.35,
186
- label="Temperature"
187
- )
188
 
189
- parameters_compontents = [num_chunks, max_new_tokens, repetition_penalty, top_k, top_p, do_sample, temperature]
190
 
191
  # Add Examples manually
192
- gr.Examples(
193
- examples=[
194
  ["Qui va crear la guerra de les Galaxies?"],
195
  ["Quin era el nom real de Voltaire?"],
196
- ["Què fan al BSC?"]
 
 
 
 
 
 
197
  ],
198
  inputs=[input_], # only inputs
199
  )
@@ -246,14 +185,16 @@ def gradio_app():
246
  clear_btn.click(
247
  fn=clear,
248
  inputs=[],
249
- outputs=[input_, output, source_context, context_evaluation] + parameters_compontents,
250
- queue=False,
251
- api_name=False
 
252
  )
253
 
254
  submit_btn.click(
255
  fn=submit_input,
256
- inputs=[input_]+ parameters_compontents,
 
257
  outputs=[output, source_context, context_evaluation],
258
  api_name="get-results"
259
  )
@@ -269,6 +210,7 @@ def gradio_app():
269
  # fn=submit_input,
270
  # )
271
 
 
272
  demo.launch(show_api=True)
273
 
274
 
 
65
  "temperature": temperature
66
  }
67
 
68
+ print("Model parameters: ", model_parameters)
69
+
70
  output, context, source = generate(input_, model_parameters)
71
  sources_markup = ""
72
 
 
89
  None,
90
  None,
91
  None,
92
+ gr.Number(value=5, label="Num. Retrieved Chunks", minimum=1, interactive=True)
 
 
 
 
 
 
93
  )
94
 
95
 
 
98
  # App Description
99
  # =====================================================================================================================================
100
  with gr.Row():
101
+ with gr.Column():
102
+ gr.Markdown("""# Demo de Retrieval (only) Viquipèdia""")
103
+
 
 
 
 
 
 
 
 
 
 
 
 
104
 
 
 
105
  with gr.Row(equal_height=False):
106
+
107
  # User Input
108
  # =====================================================================================================================================
109
  with gr.Column(scale=2, variant="panel"):
 
114
  placeholder="Qui va crear la guerra de les Galaxies ?",
115
  )
116
 
 
 
117
  with gr.Row(variant="default"):
 
118
  clear_btn = Button("Clear",)
119
  submit_btn = Button("Submit", variant="primary", interactive=False)
120
 
121
+ with gr.Row(variant="default"):
122
+ num_chunks = gr.Number(value=5, label="Num. Retrieved Chunks", minimum=1, interactive=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
123
 
 
124
 
125
  # Add Examples manually
126
+ gr.Examples( examples=[
 
127
  ["Qui va crear la guerra de les Galaxies?"],
128
  ["Quin era el nom real de Voltaire?"],
129
+ ["Què fan al BSC?"],
130
+
131
+ # No existèix aquesta entrada a la VDB
132
+ # https://ca.wikipedia.org/wiki/Imperi_Gal%C3%A0ctic
133
+ # ["Què és un Imperi Galàctic?"],
134
+ # ["Què és l'Imperi Galàctic d'Isaac Asimov?"],
135
+ # ["Què és l'Imperi Galàctic de la Guerra de les Galàxies?"]
136
  ],
137
  inputs=[input_], # only inputs
138
  )
 
185
  clear_btn.click(
186
  fn=clear,
187
  inputs=[],
188
+ outputs=[input_, output, source_context, context_evaluation, num_chunks],
189
+ # outputs=[input_, output, source_context, context_evaluation] + parameters_compontents,
190
+ queue=False,
191
+ api_name=False
192
  )
193
 
194
  submit_btn.click(
195
  fn=submit_input,
196
+ # inputs=[input_] + parameters_compontents,
197
+ inputs=[input_] + [num_chunks],
198
  outputs=[output, source_context, context_evaluation],
199
  api_name="get-results"
200
  )
 
210
  # fn=submit_input,
211
  # )
212
 
213
+ # input_, output, source_context, context_evaluation, num_chunks = clear()
214
  demo.launch(show_api=True)
215
 
216
 
handler.py DELETED
@@ -1,14 +0,0 @@
1
- import json
2
-
3
- class ContentHandler():
4
- content_type = "application/json"
5
- accepts = "application/json"
6
-
7
- def transform_input(self, prompt: str, model_kwargs: dict) -> bytes:
8
- input_str = json.dumps({'inputs': prompt, 'parameters': model_kwargs})
9
- return input_str.encode('utf-8')
10
-
11
- def transform_output(self, output: bytes) -> str:
12
- response_json = json.loads(output.read().decode("utf-8"))
13
- return response_json[0]["generated_text"]
14
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
input_reader.py DELETED
@@ -1,22 +0,0 @@
1
- from typing import List
2
-
3
- from llama_index.core.constants import DEFAULT_CHUNK_OVERLAP, DEFAULT_CHUNK_SIZE
4
- from llama_index.core.readers import SimpleDirectoryReader
5
- from llama_index.core.schema import Document
6
- from llama_index.core import Settings
7
-
8
-
9
- class InputReader:
10
- def __init__(self, input_dir: str) -> None:
11
- self.reader = SimpleDirectoryReader(input_dir=input_dir)
12
-
13
- def parse_documents(
14
- self,
15
- show_progress: bool = True,
16
- chunk_size: int = DEFAULT_CHUNK_SIZE,
17
- chunk_overlap: int = DEFAULT_CHUNK_OVERLAP,
18
- ) -> List[Document]:
19
- Settings.chunk_size = chunk_size
20
- Settings.chunk_overlap = chunk_overlap
21
- documents = self.reader.load_data(show_progress=show_progress)
22
- return documents
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
rag.py CHANGED
@@ -42,6 +42,7 @@ class RAG:
42
  logging.info("RAG loaded!")
43
  logging.info( self.vectore_store)
44
 
 
45
  def rerank_contexts(self, instruction, contexts, number_of_contexts=1):
46
  """
47
  Rerank the contexts based on their relevance to the given instruction.
@@ -86,21 +87,28 @@ class RAG:
86
 
87
  logging.info("RETRIEVE DOCUMENTS")
88
  logging.info(f"Instruction: {instruction}")
 
 
 
89
  embedding = self.vectore_store._embed_query(instruction)
90
  logging.info(f"Query embedding generated: {len(embedding)}")
91
- documents_retrieved = self.vectore_store.similarity_search_with_score_by_vector(
92
- embedding,
93
- k=self.rerank_number_contexts)
 
 
94
  logging.info(f"Documents retrieved: {len(documents_retrieved)}")
95
 
96
- # documents_retrieved = self.vectore_store.similarity_search_with_score(instruction, k=self.rerank_number_contexts)
97
 
 
 
98
  if self.rerank_model:
99
  logging.info("RERANK DOCUMENTS")
100
  documents_reranked = self.rerank_contexts(instruction, documents_retrieved, number_of_contexts=number_of_contexts)
101
  else:
102
  logging.info("NO RERANKING")
103
  documents_reranked = documents_retrieved[:number_of_contexts]
 
104
 
105
  return documents_reranked
106
 
 
42
  logging.info("RAG loaded!")
43
  logging.info( self.vectore_store)
44
 
45
+
46
  def rerank_contexts(self, instruction, contexts, number_of_contexts=1):
47
  """
48
  Rerank the contexts based on their relevance to the given instruction.
 
87
 
88
  logging.info("RETRIEVE DOCUMENTS")
89
  logging.info(f"Instruction: {instruction}")
90
+
91
+ # Embed the query
92
+ # ==============================================================================================================
93
  embedding = self.vectore_store._embed_query(instruction)
94
  logging.info(f"Query embedding generated: {len(embedding)}")
95
+
96
+
97
+ # Retrieve documents
98
+ # ==============================================================================================================
99
+ documents_retrieved = self.vectore_store.similarity_search_with_score_by_vector(embedding, k=number_of_contexts)
100
  logging.info(f"Documents retrieved: {len(documents_retrieved)}")
101
 
 
102
 
103
+ # Reranking
104
+ # ==============================================================================================================
105
  if self.rerank_model:
106
  logging.info("RERANK DOCUMENTS")
107
  documents_reranked = self.rerank_contexts(instruction, documents_retrieved, number_of_contexts=number_of_contexts)
108
  else:
109
  logging.info("NO RERANKING")
110
  documents_reranked = documents_retrieved[:number_of_contexts]
111
+ # ==============================================================================================================
112
 
113
  return documents_reranked
114