Spaces:
Runtime error
Runtime error
Jens Grivolla
commited on
Commit
·
f005840
1
Parent(s):
99162ec
make sys prompt configurable
Browse files
app.py
CHANGED
|
@@ -22,9 +22,9 @@ rag = RAG(
|
|
| 22 |
)
|
| 23 |
|
| 24 |
|
| 25 |
-
def generate(prompt, model_parameters):
|
| 26 |
try:
|
| 27 |
-
output, context, source = rag.get_response(prompt, model_parameters)
|
| 28 |
return output, context, source
|
| 29 |
except HTTPError as err:
|
| 30 |
if err.code == 400:
|
|
@@ -37,7 +37,7 @@ def generate(prompt, model_parameters):
|
|
| 37 |
)
|
| 38 |
|
| 39 |
|
| 40 |
-
def submit_input(input_, num_chunks, max_new_tokens, repetition_penalty, top_k, top_p, do_sample, temperature):
|
| 41 |
if input_.strip() == "":
|
| 42 |
gr.Warning("Not possible to inference an empty input")
|
| 43 |
return None
|
|
@@ -53,7 +53,7 @@ def submit_input(input_, num_chunks, max_new_tokens, repetition_penalty, top_k,
|
|
| 53 |
"temperature": temperature
|
| 54 |
}
|
| 55 |
|
| 56 |
-
output, context, source = generate(input_, model_parameters)
|
| 57 |
sources_markup = ""
|
| 58 |
|
| 59 |
for url in source:
|
|
@@ -112,6 +112,12 @@ def gradio_app():
|
|
| 112 |
placeholder="Quina és la finalitat del Servei Meteorològic de Catalunya?",
|
| 113 |
# value = "Quina és la finalitat del Servei Meteorològic de Catalunya?"
|
| 114 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 115 |
with gr.Row(variant="panel"):
|
| 116 |
clear_btn = Button(
|
| 117 |
"Clear",
|
|
@@ -201,8 +207,8 @@ def gradio_app():
|
|
| 201 |
inputs=[input_],
|
| 202 |
api_name=False,
|
| 203 |
js="""(i, m) => {
|
| 204 |
-
document.getElementById('
|
| 205 |
-
document.getElementById('
|
| 206 |
}""",
|
| 207 |
)
|
| 208 |
|
|
@@ -216,7 +222,7 @@ def gradio_app():
|
|
| 216 |
|
| 217 |
submit_btn.click(
|
| 218 |
fn=submit_input,
|
| 219 |
-
inputs=[input_]+ parameters_compontents,
|
| 220 |
outputs=[output, source_context, context_evaluation],
|
| 221 |
api_name="get-results"
|
| 222 |
)
|
|
|
|
| 22 |
)
|
| 23 |
|
| 24 |
|
| 25 |
+
def generate(prompt, sys_prompt, model_parameters):
|
| 26 |
try:
|
| 27 |
+
output, context, source = rag.get_response(prompt, sys_prompt, model_parameters)
|
| 28 |
return output, context, source
|
| 29 |
except HTTPError as err:
|
| 30 |
if err.code == 400:
|
|
|
|
| 37 |
)
|
| 38 |
|
| 39 |
|
| 40 |
+
def submit_input(input_, sysprompt_, num_chunks, max_new_tokens, repetition_penalty, top_k, top_p, do_sample, temperature):
|
| 41 |
if input_.strip() == "":
|
| 42 |
gr.Warning("Not possible to inference an empty input")
|
| 43 |
return None
|
|
|
|
| 53 |
"temperature": temperature
|
| 54 |
}
|
| 55 |
|
| 56 |
+
output, context, source = generate(input_, sysprompt_, model_parameters)
|
| 57 |
sources_markup = ""
|
| 58 |
|
| 59 |
for url in source:
|
|
|
|
| 112 |
placeholder="Quina és la finalitat del Servei Meteorològic de Catalunya?",
|
| 113 |
# value = "Quina és la finalitat del Servei Meteorològic de Catalunya?"
|
| 114 |
)
|
| 115 |
+
sysprompt_ = Textbox(
|
| 116 |
+
lines=2,
|
| 117 |
+
label="System",
|
| 118 |
+
placeholder="Below is a question that you should answer based on the given context. Write a response that answers the question using only information provided in the context.",
|
| 119 |
+
value = "Below is a question that you should answer based on the given context. Write a response that answers the question using only information provided in the context."
|
| 120 |
+
)
|
| 121 |
with gr.Row(variant="panel"):
|
| 122 |
clear_btn = Button(
|
| 123 |
"Clear",
|
|
|
|
| 207 |
inputs=[input_],
|
| 208 |
api_name=False,
|
| 209 |
js="""(i, m) => {
|
| 210 |
+
document.getElementById('inputlength').textContent = i.length + ' '
|
| 211 |
+
document.getElementById('inputlength').style.color = (i.length > m) ? "#ef4444" : "";
|
| 212 |
}""",
|
| 213 |
)
|
| 214 |
|
|
|
|
| 222 |
|
| 223 |
submit_btn.click(
|
| 224 |
fn=submit_input,
|
| 225 |
+
inputs=[input_, sysprompt_]+ parameters_compontents,
|
| 226 |
outputs=[output, source_context, context_evaluation],
|
| 227 |
api_name="get-results"
|
| 228 |
)
|
rag.py
CHANGED
|
@@ -32,7 +32,7 @@ class RAG:
|
|
| 32 |
|
| 33 |
return documentos
|
| 34 |
|
| 35 |
-
def predict(self, instruction, context, model_parameters):
|
| 36 |
|
| 37 |
from openai import OpenAI
|
| 38 |
|
|
@@ -42,9 +42,10 @@ class RAG:
|
|
| 42 |
api_key=os.getenv("HF_TOKEN")
|
| 43 |
)
|
| 44 |
|
| 45 |
-
sys_prompt = "You are a helpful assistant. Answer the question using only the context you are provided with. If it is not possible to do it with the context, just say 'I can't answer'. <|endoftext|>"
|
| 46 |
#query = f"Context:\n{context}\n\nQuestion:\n{instruction}"
|
| 47 |
query = f"{sys_prompt}\n\nContext:\n{context}\n\nQuestion:\n{instruction}"
|
|
|
|
| 48 |
#query = f"{sys_prompt}\n\nQuestion:\n{instruction}\n\nContext:\n{context}"
|
| 49 |
chat_completion = client.chat.completions.create(
|
| 50 |
model="tgi",
|
|
@@ -77,14 +78,14 @@ class RAG:
|
|
| 77 |
|
| 78 |
return text_context, full_context, source_context
|
| 79 |
|
| 80 |
-
def get_response(self, prompt: str, model_parameters: dict) -> str:
|
| 81 |
try:
|
| 82 |
docs = self.get_context(prompt, model_parameters["NUM_CHUNKS"])
|
| 83 |
text_context, full_context, source = self.beautiful_context(docs)
|
| 84 |
|
| 85 |
del model_parameters["NUM_CHUNKS"]
|
| 86 |
|
| 87 |
-
response = self.predict(prompt, text_context, model_parameters)
|
| 88 |
|
| 89 |
if not response:
|
| 90 |
return self.NO_ANSWER_MESSAGE
|
|
|
|
| 32 |
|
| 33 |
return documentos
|
| 34 |
|
| 35 |
+
def predict(self, instruction, sys_prompt, context, model_parameters):
|
| 36 |
|
| 37 |
from openai import OpenAI
|
| 38 |
|
|
|
|
| 42 |
api_key=os.getenv("HF_TOKEN")
|
| 43 |
)
|
| 44 |
|
| 45 |
+
#sys_prompt = "You are a helpful assistant. Answer the question using only the context you are provided with. If it is not possible to do it with the context, just say 'I can't answer'. <|endoftext|>"
|
| 46 |
#query = f"Context:\n{context}\n\nQuestion:\n{instruction}"
|
| 47 |
query = f"{sys_prompt}\n\nContext:\n{context}\n\nQuestion:\n{instruction}"
|
| 48 |
+
print(query)
|
| 49 |
#query = f"{sys_prompt}\n\nQuestion:\n{instruction}\n\nContext:\n{context}"
|
| 50 |
chat_completion = client.chat.completions.create(
|
| 51 |
model="tgi",
|
|
|
|
| 78 |
|
| 79 |
return text_context, full_context, source_context
|
| 80 |
|
| 81 |
+
def get_response(self, prompt: str, sys_prompt: str, model_parameters: dict) -> str:
|
| 82 |
try:
|
| 83 |
docs = self.get_context(prompt, model_parameters["NUM_CHUNKS"])
|
| 84 |
text_context, full_context, source = self.beautiful_context(docs)
|
| 85 |
|
| 86 |
del model_parameters["NUM_CHUNKS"]
|
| 87 |
|
| 88 |
+
response = self.predict(prompt, sys_prompt, text_context, model_parameters)
|
| 89 |
|
| 90 |
if not response:
|
| 91 |
return self.NO_ANSWER_MESSAGE
|