Spaces:
Sleeping
Sleeping
Added retrieval num chunks options
Browse files- .gitignore +1 -0
- app.py +24 -82
- handler.py +0 -14
- input_reader.py +0 -22
- rag.py +12 -4
.gitignore
CHANGED
@@ -3,3 +3,4 @@
|
|
3 |
.env
|
4 |
__pycache__
|
5 |
__pycache__/*
|
|
|
|
3 |
.env
|
4 |
__pycache__
|
5 |
__pycache__/*
|
6 |
+
__DELETE__*
|
app.py
CHANGED
@@ -65,6 +65,8 @@ def submit_input(input_, num_chunks, max_new_tokens, repetition_penalty, top_k,
|
|
65 |
"temperature": temperature
|
66 |
}
|
67 |
|
|
|
|
|
68 |
output, context, source = generate(input_, model_parameters)
|
69 |
sources_markup = ""
|
70 |
|
@@ -87,13 +89,7 @@ def clear():
|
|
87 |
None,
|
88 |
None,
|
89 |
None,
|
90 |
-
gr.
|
91 |
-
gr.Slider(value=MAX_NEW_TOKENS),
|
92 |
-
gr.Slider(value=1.0),
|
93 |
-
gr.Slider(value=50),
|
94 |
-
gr.Slider(value=0.99),
|
95 |
-
gr.Checkbox(value=False),
|
96 |
-
gr.Slider(value=0.35),
|
97 |
)
|
98 |
|
99 |
|
@@ -102,25 +98,12 @@ def gradio_app():
|
|
102 |
# App Description
|
103 |
# =====================================================================================================================================
|
104 |
with gr.Row():
|
105 |
-
with gr.Column():
|
106 |
-
|
107 |
-
|
108 |
-
# """# Demo de Retrieval-Augmented Generation per la Viquipèdia
|
109 |
-
# 🔍 **Retrieval-Augmented Generation** (RAG) és una tecnologia d'IA que permet interrogar un repositori de documents amb preguntes
|
110 |
-
# en llenguatge natural, i combina tècniques de recuperació d'informació avançades amb models generatius per redactar una resposta
|
111 |
-
# fent servir només la informació existent en els documents del repositori.
|
112 |
-
|
113 |
-
# 🎯 **Objectiu:** Aquest és un demostrador amb Viquipèdia i genera la resposta fent servir el model salamandra-7b-instruct.
|
114 |
-
|
115 |
-
# ⚠️ **Advertencies**: Aquesta versió és experimental. El contingut generat per aquest model no està supervisat i pot ser incorrecte.
|
116 |
-
# Si us plau, tingueu-ho en compte quan exploreu aquest recurs. El model en inferencia asociat a aquesta demo de desenvolupament no funciona continuament. Si vol fer proves,
|
117 |
-
# contacteu amb nosaltres a Langtech.
|
118 |
-
# """
|
119 |
-
)
|
120 |
|
121 |
-
|
122 |
-
# with gr.Row(equal_height=True):
|
123 |
with gr.Row(equal_height=False):
|
|
|
124 |
# User Input
|
125 |
# =====================================================================================================================================
|
126 |
with gr.Column(scale=2, variant="panel"):
|
@@ -131,69 +114,25 @@ def gradio_app():
|
|
131 |
placeholder="Qui va crear la guerra de les Galaxies ?",
|
132 |
)
|
133 |
|
134 |
-
|
135 |
-
# with gr.Column(variant="panel"):
|
136 |
with gr.Row(variant="default"):
|
137 |
-
# with gr.Row(variant="panel"):
|
138 |
clear_btn = Button("Clear",)
|
139 |
submit_btn = Button("Submit", variant="primary", interactive=False)
|
140 |
|
141 |
-
|
142 |
-
|
143 |
-
with gr.Accordion("Model parameters (not used)", open=False, visible=SHOW_MODEL_PARAMETERS_IN_UI):
|
144 |
-
num_chunks = Slider(
|
145 |
-
minimum=1,
|
146 |
-
maximum=6,
|
147 |
-
step=1,
|
148 |
-
value=5,
|
149 |
-
label="Number of chunks"
|
150 |
-
)
|
151 |
-
max_new_tokens = Slider(
|
152 |
-
minimum=50,
|
153 |
-
maximum=2000,
|
154 |
-
step=1,
|
155 |
-
value=MAX_NEW_TOKENS,
|
156 |
-
label="Max tokens"
|
157 |
-
)
|
158 |
-
repetition_penalty = Slider(
|
159 |
-
minimum=0.1,
|
160 |
-
maximum=2.0,
|
161 |
-
step=0.1,
|
162 |
-
value=1.0,
|
163 |
-
label="Repetition penalty"
|
164 |
-
)
|
165 |
-
top_k = Slider(
|
166 |
-
minimum=1,
|
167 |
-
maximum=100,
|
168 |
-
step=1,
|
169 |
-
value=50,
|
170 |
-
label="Top k"
|
171 |
-
)
|
172 |
-
top_p = Slider(
|
173 |
-
minimum=0.01,
|
174 |
-
maximum=0.99,
|
175 |
-
value=0.99,
|
176 |
-
label="Top p"
|
177 |
-
)
|
178 |
-
do_sample = Checkbox(
|
179 |
-
value=False,
|
180 |
-
label="Do sample"
|
181 |
-
)
|
182 |
-
temperature = Slider(
|
183 |
-
minimum=0.1,
|
184 |
-
maximum=1,
|
185 |
-
value=0.35,
|
186 |
-
label="Temperature"
|
187 |
-
)
|
188 |
|
189 |
-
parameters_compontents = [num_chunks, max_new_tokens, repetition_penalty, top_k, top_p, do_sample, temperature]
|
190 |
|
191 |
# Add Examples manually
|
192 |
-
gr.Examples(
|
193 |
-
examples=[
|
194 |
["Qui va crear la guerra de les Galaxies?"],
|
195 |
["Quin era el nom real de Voltaire?"],
|
196 |
-
["Què fan al BSC?"]
|
|
|
|
|
|
|
|
|
|
|
|
|
197 |
],
|
198 |
inputs=[input_], # only inputs
|
199 |
)
|
@@ -246,14 +185,16 @@ def gradio_app():
|
|
246 |
clear_btn.click(
|
247 |
fn=clear,
|
248 |
inputs=[],
|
249 |
-
outputs=[input_, output, source_context, context_evaluation]
|
250 |
-
|
251 |
-
|
|
|
252 |
)
|
253 |
|
254 |
submit_btn.click(
|
255 |
fn=submit_input,
|
256 |
-
inputs=[input_]+ parameters_compontents,
|
|
|
257 |
outputs=[output, source_context, context_evaluation],
|
258 |
api_name="get-results"
|
259 |
)
|
@@ -269,6 +210,7 @@ def gradio_app():
|
|
269 |
# fn=submit_input,
|
270 |
# )
|
271 |
|
|
|
272 |
demo.launch(show_api=True)
|
273 |
|
274 |
|
|
|
65 |
"temperature": temperature
|
66 |
}
|
67 |
|
68 |
+
print("Model parameters: ", model_parameters)
|
69 |
+
|
70 |
output, context, source = generate(input_, model_parameters)
|
71 |
sources_markup = ""
|
72 |
|
|
|
89 |
None,
|
90 |
None,
|
91 |
None,
|
92 |
+
gr.Number(value=5, label="Num. Retrieved Chunks", minimum=1, interactive=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
93 |
)
|
94 |
|
95 |
|
|
|
98 |
# App Description
|
99 |
# =====================================================================================================================================
|
100 |
with gr.Row():
|
101 |
+
with gr.Column():
|
102 |
+
gr.Markdown("""# Demo de Retrieval (only) Viquipèdia""")
|
103 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
104 |
|
|
|
|
|
105 |
with gr.Row(equal_height=False):
|
106 |
+
|
107 |
# User Input
|
108 |
# =====================================================================================================================================
|
109 |
with gr.Column(scale=2, variant="panel"):
|
|
|
114 |
placeholder="Qui va crear la guerra de les Galaxies ?",
|
115 |
)
|
116 |
|
|
|
|
|
117 |
with gr.Row(variant="default"):
|
|
|
118 |
clear_btn = Button("Clear",)
|
119 |
submit_btn = Button("Submit", variant="primary", interactive=False)
|
120 |
|
121 |
+
with gr.Row(variant="default"):
|
122 |
+
num_chunks = gr.Number(value=5, label="Num. Retrieved Chunks", minimum=1, interactive=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
123 |
|
|
|
124 |
|
125 |
# Add Examples manually
|
126 |
+
gr.Examples( examples=[
|
|
|
127 |
["Qui va crear la guerra de les Galaxies?"],
|
128 |
["Quin era el nom real de Voltaire?"],
|
129 |
+
["Què fan al BSC?"],
|
130 |
+
|
131 |
+
# No existèix aquesta entrada a la VDB
|
132 |
+
# https://ca.wikipedia.org/wiki/Imperi_Gal%C3%A0ctic
|
133 |
+
# ["Què és un Imperi Galàctic?"],
|
134 |
+
# ["Què és l'Imperi Galàctic d'Isaac Asimov?"],
|
135 |
+
# ["Què és l'Imperi Galàctic de la Guerra de les Galàxies?"]
|
136 |
],
|
137 |
inputs=[input_], # only inputs
|
138 |
)
|
|
|
185 |
clear_btn.click(
|
186 |
fn=clear,
|
187 |
inputs=[],
|
188 |
+
outputs=[input_, output, source_context, context_evaluation, num_chunks],
|
189 |
+
# outputs=[input_, output, source_context, context_evaluation] + parameters_compontents,
|
190 |
+
queue=False,
|
191 |
+
api_name=False
|
192 |
)
|
193 |
|
194 |
submit_btn.click(
|
195 |
fn=submit_input,
|
196 |
+
# inputs=[input_] + parameters_compontents,
|
197 |
+
inputs=[input_] + [num_chunks],
|
198 |
outputs=[output, source_context, context_evaluation],
|
199 |
api_name="get-results"
|
200 |
)
|
|
|
210 |
# fn=submit_input,
|
211 |
# )
|
212 |
|
213 |
+
# input_, output, source_context, context_evaluation, num_chunks = clear()
|
214 |
demo.launch(show_api=True)
|
215 |
|
216 |
|
handler.py
DELETED
@@ -1,14 +0,0 @@
|
|
1 |
-
import json
|
2 |
-
|
3 |
-
class ContentHandler():
|
4 |
-
content_type = "application/json"
|
5 |
-
accepts = "application/json"
|
6 |
-
|
7 |
-
def transform_input(self, prompt: str, model_kwargs: dict) -> bytes:
|
8 |
-
input_str = json.dumps({'inputs': prompt, 'parameters': model_kwargs})
|
9 |
-
return input_str.encode('utf-8')
|
10 |
-
|
11 |
-
def transform_output(self, output: bytes) -> str:
|
12 |
-
response_json = json.loads(output.read().decode("utf-8"))
|
13 |
-
return response_json[0]["generated_text"]
|
14 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
input_reader.py
DELETED
@@ -1,22 +0,0 @@
|
|
1 |
-
from typing import List
|
2 |
-
|
3 |
-
from llama_index.core.constants import DEFAULT_CHUNK_OVERLAP, DEFAULT_CHUNK_SIZE
|
4 |
-
from llama_index.core.readers import SimpleDirectoryReader
|
5 |
-
from llama_index.core.schema import Document
|
6 |
-
from llama_index.core import Settings
|
7 |
-
|
8 |
-
|
9 |
-
class InputReader:
|
10 |
-
def __init__(self, input_dir: str) -> None:
|
11 |
-
self.reader = SimpleDirectoryReader(input_dir=input_dir)
|
12 |
-
|
13 |
-
def parse_documents(
|
14 |
-
self,
|
15 |
-
show_progress: bool = True,
|
16 |
-
chunk_size: int = DEFAULT_CHUNK_SIZE,
|
17 |
-
chunk_overlap: int = DEFAULT_CHUNK_OVERLAP,
|
18 |
-
) -> List[Document]:
|
19 |
-
Settings.chunk_size = chunk_size
|
20 |
-
Settings.chunk_overlap = chunk_overlap
|
21 |
-
documents = self.reader.load_data(show_progress=show_progress)
|
22 |
-
return documents
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
rag.py
CHANGED
@@ -42,6 +42,7 @@ class RAG:
|
|
42 |
logging.info("RAG loaded!")
|
43 |
logging.info( self.vectore_store)
|
44 |
|
|
|
45 |
def rerank_contexts(self, instruction, contexts, number_of_contexts=1):
|
46 |
"""
|
47 |
Rerank the contexts based on their relevance to the given instruction.
|
@@ -86,21 +87,28 @@ class RAG:
|
|
86 |
|
87 |
logging.info("RETRIEVE DOCUMENTS")
|
88 |
logging.info(f"Instruction: {instruction}")
|
|
|
|
|
|
|
89 |
embedding = self.vectore_store._embed_query(instruction)
|
90 |
logging.info(f"Query embedding generated: {len(embedding)}")
|
91 |
-
|
92 |
-
|
93 |
-
|
|
|
|
|
94 |
logging.info(f"Documents retrieved: {len(documents_retrieved)}")
|
95 |
|
96 |
-
# documents_retrieved = self.vectore_store.similarity_search_with_score(instruction, k=self.rerank_number_contexts)
|
97 |
|
|
|
|
|
98 |
if self.rerank_model:
|
99 |
logging.info("RERANK DOCUMENTS")
|
100 |
documents_reranked = self.rerank_contexts(instruction, documents_retrieved, number_of_contexts=number_of_contexts)
|
101 |
else:
|
102 |
logging.info("NO RERANKING")
|
103 |
documents_reranked = documents_retrieved[:number_of_contexts]
|
|
|
104 |
|
105 |
return documents_reranked
|
106 |
|
|
|
42 |
logging.info("RAG loaded!")
|
43 |
logging.info( self.vectore_store)
|
44 |
|
45 |
+
|
46 |
def rerank_contexts(self, instruction, contexts, number_of_contexts=1):
|
47 |
"""
|
48 |
Rerank the contexts based on their relevance to the given instruction.
|
|
|
87 |
|
88 |
logging.info("RETRIEVE DOCUMENTS")
|
89 |
logging.info(f"Instruction: {instruction}")
|
90 |
+
|
91 |
+
# Embed the query
|
92 |
+
# ==============================================================================================================
|
93 |
embedding = self.vectore_store._embed_query(instruction)
|
94 |
logging.info(f"Query embedding generated: {len(embedding)}")
|
95 |
+
|
96 |
+
|
97 |
+
# Retrieve documents
|
98 |
+
# ==============================================================================================================
|
99 |
+
documents_retrieved = self.vectore_store.similarity_search_with_score_by_vector(embedding, k=number_of_contexts)
|
100 |
logging.info(f"Documents retrieved: {len(documents_retrieved)}")
|
101 |
|
|
|
102 |
|
103 |
+
# Reranking
|
104 |
+
# ==============================================================================================================
|
105 |
if self.rerank_model:
|
106 |
logging.info("RERANK DOCUMENTS")
|
107 |
documents_reranked = self.rerank_contexts(instruction, documents_retrieved, number_of_contexts=number_of_contexts)
|
108 |
else:
|
109 |
logging.info("NO RERANKING")
|
110 |
documents_reranked = documents_retrieved[:number_of_contexts]
|
111 |
+
# ==============================================================================================================
|
112 |
|
113 |
return documents_reranked
|
114 |
|