giulio98 commited on
Commit
2212763
·
1 Parent(s): 1a243e9

make global rag index

Browse files
Files changed (1) hide show
  1. app.py +12 -11
app.py CHANGED
@@ -68,6 +68,8 @@ question: Prior to playing for Michigan State, Keith Nichol played football for
68
  answer: Norman
69
  """
70
 
 
 
71
  class FinchCache(DynamicCache):
72
  def __init__(self) -> None:
73
  super().__init__()
@@ -218,9 +220,9 @@ def auto_convert(file_objs, url, do_ocr, do_table_structure):
218
  else:
219
  rag_text = combined_text
220
  print("Creating RAG index")
221
- rag_index = create_rag_index(rag_text)
222
  print("Done")
223
- state = {"rag_index": rag_index}
224
 
225
  return (
226
  combined_text,
@@ -438,13 +440,13 @@ def get_compressed_kv_cache(sink_tokens, step_size, target_token_size, context_i
438
  return cache
439
 
440
 
441
- def run_naive_rag_query(vectorstore, query, rag_token_size, prefix, task, few_shot_examples):
442
  """
443
  For naive RAG, retrieves top-k chunks (k based on target token size)
444
  and generates an answer using those chunks.
445
  """
446
  k = max(1, rag_token_size // 256)
447
- retriever = vectorstore.as_retriever(search_type="similarity", search_kwargs={"k": k})
448
  retrieved_docs = retriever.invoke(query)
449
  for doc in retrieved_docs:
450
  print("=================")
@@ -477,9 +479,11 @@ def prepare_compression_and_rag(combined_text, retrieval_slider_value, global_lo
477
  print("Target token size for compression: ", target_token_size)
478
  step_size = 2
479
  start_time_prefill = time.perf_counter()
 
480
  past_key_values = copy.deepcopy(get_compressed_kv_cache(sink_tokens, step_size, target_token_size,
481
  context_ids, context_attention_mask,
482
  question_ids, question_attention_mask))
 
483
  compressed_length = past_key_values.get_seq_length()
484
  print("Context size after compression: ", compressed_length)
485
  print("Compression rate: ", context_ids.size(1) / compressed_length)
@@ -490,19 +494,17 @@ def prepare_compression_and_rag(combined_text, retrieval_slider_value, global_lo
490
  compressed_length = past_key_values.get_seq_length()
491
 
492
 
493
- # Use the precomputed rag_index from state.
494
- rag_index = state.get("rag_index", None)
495
- if rag_index is None:
496
  if combined_text.startswith(prefix):
497
  rag_text = combined_text[len(prefix):]
498
  else:
499
  rag_text = combined_text
500
- rag_index = create_rag_index(rag_text, device)
501
 
502
  state.update({
503
  "compressed_cache": past_key_values,
504
  "compressed_length": compressed_length,
505
- "rag_index": rag_index,
506
  "target_token_size": target_token_size,
507
  "global_local": percentage,
508
  "combined_text": combined_text,
@@ -523,7 +525,6 @@ def chat_response_stream(message: str, history: list, state: dict):
523
  user_message = message
524
  past_key_values = state["compressed_cache"]
525
  compressed_length = past_key_values.get_seq_length()
526
- rag_index = state["rag_index"]
527
  retrieval_slider_value = state["retrieval_slider"]
528
  percentage = state["global_local"]
529
 
@@ -540,7 +541,7 @@ def chat_response_stream(message: str, history: list, state: dict):
540
  rag_few_shot = ""
541
  print("user message: ", user_message)
542
  if rag_retrieval_size != 0:
543
- rag_context = run_naive_rag_query(rag_index, user_message, rag_retrieval_size, rag_prefix, rag_task, rag_few_shot)
544
  new_input = rag_context + "\nquestion: " + user_message + suffix + "answer:"
545
  else:
546
  new_input = "\nquestion: " + user_message + suffix + "answer:"
 
68
  answer: Norman
69
  """
70
 
71
+ global_rag_index = None
72
+
73
  class FinchCache(DynamicCache):
74
  def __init__(self) -> None:
75
  super().__init__()
 
220
  else:
221
  rag_text = combined_text
222
  print("Creating RAG index")
223
+ global_rag_index = create_rag_index(rag_text)
224
  print("Done")
225
+ state = {}
226
 
227
  return (
228
  combined_text,
 
440
  return cache
441
 
442
 
443
+ def run_naive_rag_query(query, rag_token_size, prefix, task, few_shot_examples):
444
  """
445
  For naive RAG, retrieves top-k chunks (k based on target token size)
446
  and generates an answer using those chunks.
447
  """
448
  k = max(1, rag_token_size // 256)
449
+ retriever = global_rag_index.as_retriever(search_type="similarity", search_kwargs={"k": k})
450
  retrieved_docs = retriever.invoke(query)
451
  for doc in retrieved_docs:
452
  print("=================")
 
479
  print("Target token size for compression: ", target_token_size)
480
  step_size = 2
481
  start_time_prefill = time.perf_counter()
482
+ print("Compressing KV Cache")
483
  past_key_values = copy.deepcopy(get_compressed_kv_cache(sink_tokens, step_size, target_token_size,
484
  context_ids, context_attention_mask,
485
  question_ids, question_attention_mask))
486
+ print("Done")
487
  compressed_length = past_key_values.get_seq_length()
488
  print("Context size after compression: ", compressed_length)
489
  print("Compression rate: ", context_ids.size(1) / compressed_length)
 
494
  compressed_length = past_key_values.get_seq_length()
495
 
496
 
497
+
498
+ if global_rag_index is None:
 
499
  if combined_text.startswith(prefix):
500
  rag_text = combined_text[len(prefix):]
501
  else:
502
  rag_text = combined_text
503
+ global_rag_index = create_rag_index(rag_text, device)
504
 
505
  state.update({
506
  "compressed_cache": past_key_values,
507
  "compressed_length": compressed_length,
 
508
  "target_token_size": target_token_size,
509
  "global_local": percentage,
510
  "combined_text": combined_text,
 
525
  user_message = message
526
  past_key_values = state["compressed_cache"]
527
  compressed_length = past_key_values.get_seq_length()
 
528
  retrieval_slider_value = state["retrieval_slider"]
529
  percentage = state["global_local"]
530
 
 
541
  rag_few_shot = ""
542
  print("user message: ", user_message)
543
  if rag_retrieval_size != 0:
544
+ rag_context = run_naive_rag_query(user_message, rag_retrieval_size, rag_prefix, rag_task, rag_few_shot)
545
  new_input = rag_context + "\nquestion: " + user_message + suffix + "answer:"
546
  else:
547
  new_input = "\nquestion: " + user_message + suffix + "answer:"