Spaces:
Running
on
Zero
Running
on
Zero
make global rag index
Browse files
app.py
CHANGED
@@ -68,6 +68,8 @@ question: Prior to playing for Michigan State, Keith Nichol played football for
|
|
68 |
answer: Norman
|
69 |
"""
|
70 |
|
|
|
|
|
71 |
class FinchCache(DynamicCache):
|
72 |
def __init__(self) -> None:
|
73 |
super().__init__()
|
@@ -218,9 +220,9 @@ def auto_convert(file_objs, url, do_ocr, do_table_structure):
|
|
218 |
else:
|
219 |
rag_text = combined_text
|
220 |
print("Creating RAG index")
|
221 |
-
|
222 |
print("Done")
|
223 |
-
state = {
|
224 |
|
225 |
return (
|
226 |
combined_text,
|
@@ -438,13 +440,13 @@ def get_compressed_kv_cache(sink_tokens, step_size, target_token_size, context_i
|
|
438 |
return cache
|
439 |
|
440 |
|
441 |
-
def run_naive_rag_query(
|
442 |
"""
|
443 |
For naive RAG, retrieves top-k chunks (k based on target token size)
|
444 |
and generates an answer using those chunks.
|
445 |
"""
|
446 |
k = max(1, rag_token_size // 256)
|
447 |
-
retriever =
|
448 |
retrieved_docs = retriever.invoke(query)
|
449 |
for doc in retrieved_docs:
|
450 |
print("=================")
|
@@ -477,9 +479,11 @@ def prepare_compression_and_rag(combined_text, retrieval_slider_value, global_lo
|
|
477 |
print("Target token size for compression: ", target_token_size)
|
478 |
step_size = 2
|
479 |
start_time_prefill = time.perf_counter()
|
|
|
480 |
past_key_values = copy.deepcopy(get_compressed_kv_cache(sink_tokens, step_size, target_token_size,
|
481 |
context_ids, context_attention_mask,
|
482 |
question_ids, question_attention_mask))
|
|
|
483 |
compressed_length = past_key_values.get_seq_length()
|
484 |
print("Context size after compression: ", compressed_length)
|
485 |
print("Compression rate: ", context_ids.size(1) / compressed_length)
|
@@ -490,19 +494,17 @@ def prepare_compression_and_rag(combined_text, retrieval_slider_value, global_lo
|
|
490 |
compressed_length = past_key_values.get_seq_length()
|
491 |
|
492 |
|
493 |
-
|
494 |
-
|
495 |
-
if rag_index is None:
|
496 |
if combined_text.startswith(prefix):
|
497 |
rag_text = combined_text[len(prefix):]
|
498 |
else:
|
499 |
rag_text = combined_text
|
500 |
-
|
501 |
|
502 |
state.update({
|
503 |
"compressed_cache": past_key_values,
|
504 |
"compressed_length": compressed_length,
|
505 |
-
"rag_index": rag_index,
|
506 |
"target_token_size": target_token_size,
|
507 |
"global_local": percentage,
|
508 |
"combined_text": combined_text,
|
@@ -523,7 +525,6 @@ def chat_response_stream(message: str, history: list, state: dict):
|
|
523 |
user_message = message
|
524 |
past_key_values = state["compressed_cache"]
|
525 |
compressed_length = past_key_values.get_seq_length()
|
526 |
-
rag_index = state["rag_index"]
|
527 |
retrieval_slider_value = state["retrieval_slider"]
|
528 |
percentage = state["global_local"]
|
529 |
|
@@ -540,7 +541,7 @@ def chat_response_stream(message: str, history: list, state: dict):
|
|
540 |
rag_few_shot = ""
|
541 |
print("user message: ", user_message)
|
542 |
if rag_retrieval_size != 0:
|
543 |
-
rag_context = run_naive_rag_query(
|
544 |
new_input = rag_context + "\nquestion: " + user_message + suffix + "answer:"
|
545 |
else:
|
546 |
new_input = "\nquestion: " + user_message + suffix + "answer:"
|
|
|
68 |
answer: Norman
|
69 |
"""
|
70 |
|
71 |
+
global_rag_index = None
|
72 |
+
|
73 |
class FinchCache(DynamicCache):
|
74 |
def __init__(self) -> None:
|
75 |
super().__init__()
|
|
|
220 |
else:
|
221 |
rag_text = combined_text
|
222 |
print("Creating RAG index")
|
223 |
+
global_rag_index = create_rag_index(rag_text)
|
224 |
print("Done")
|
225 |
+
state = {}
|
226 |
|
227 |
return (
|
228 |
combined_text,
|
|
|
440 |
return cache
|
441 |
|
442 |
|
443 |
+
def run_naive_rag_query(query, rag_token_size, prefix, task, few_shot_examples):
|
444 |
"""
|
445 |
For naive RAG, retrieves top-k chunks (k based on target token size)
|
446 |
and generates an answer using those chunks.
|
447 |
"""
|
448 |
k = max(1, rag_token_size // 256)
|
449 |
+
retriever = global_rag_index.as_retriever(search_type="similarity", search_kwargs={"k": k})
|
450 |
retrieved_docs = retriever.invoke(query)
|
451 |
for doc in retrieved_docs:
|
452 |
print("=================")
|
|
|
479 |
print("Target token size for compression: ", target_token_size)
|
480 |
step_size = 2
|
481 |
start_time_prefill = time.perf_counter()
|
482 |
+
print("Compressing KV Cache")
|
483 |
past_key_values = copy.deepcopy(get_compressed_kv_cache(sink_tokens, step_size, target_token_size,
|
484 |
context_ids, context_attention_mask,
|
485 |
question_ids, question_attention_mask))
|
486 |
+
print("Done")
|
487 |
compressed_length = past_key_values.get_seq_length()
|
488 |
print("Context size after compression: ", compressed_length)
|
489 |
print("Compression rate: ", context_ids.size(1) / compressed_length)
|
|
|
494 |
compressed_length = past_key_values.get_seq_length()
|
495 |
|
496 |
|
497 |
+
|
498 |
+
if global_rag_index is None:
|
|
|
499 |
if combined_text.startswith(prefix):
|
500 |
rag_text = combined_text[len(prefix):]
|
501 |
else:
|
502 |
rag_text = combined_text
|
503 |
+
global_rag_index = create_rag_index(rag_text, device)
|
504 |
|
505 |
state.update({
|
506 |
"compressed_cache": past_key_values,
|
507 |
"compressed_length": compressed_length,
|
|
|
508 |
"target_token_size": target_token_size,
|
509 |
"global_local": percentage,
|
510 |
"combined_text": combined_text,
|
|
|
525 |
user_message = message
|
526 |
past_key_values = state["compressed_cache"]
|
527 |
compressed_length = past_key_values.get_seq_length()
|
|
|
528 |
retrieval_slider_value = state["retrieval_slider"]
|
529 |
percentage = state["global_local"]
|
530 |
|
|
|
541 |
rag_few_shot = ""
|
542 |
print("user message: ", user_message)
|
543 |
if rag_retrieval_size != 0:
|
544 |
+
rag_context = run_naive_rag_query(user_message, rag_retrieval_size, rag_prefix, rag_task, rag_few_shot)
|
545 |
new_input = rag_context + "\nquestion: " + user_message + suffix + "answer:"
|
546 |
else:
|
547 |
new_input = "\nquestion: " + user_message + suffix + "answer:"
|