gaonkarrs commited on
Commit
b0693c6
Β·
1 Parent(s): 200d2b6

New changes

Browse files
Files changed (3) hide show
  1. .gitignore +1 -0
  2. app.py +14 -9
  3. requirements.txt +2 -1
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ .env
app.py CHANGED
@@ -22,6 +22,11 @@ import traceback
22
  import shutil
23
  from langchain.text_splitter import RecursiveCharacterTextSplitter
24
  from tqdm import tqdm
 
 
 
 
 
25
 
26
  embedder = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
27
 
@@ -29,13 +34,11 @@ def build_index_and_dataset(domain, subsets, embedder_type="sentence-transformer
29
  dataset_path = f"{domain}_dataset"
30
  index_path = f"{domain}_index/faiss.index"
31
 
32
- # ❌ Always remove previous
33
- if os.path.exists(dataset_path):
34
- shutil.rmtree(dataset_path)
35
- if os.path.exists(index_path):
36
- os.remove(index_path)
37
 
38
- print(f"πŸš€ Rebuilding dataset and index for domain: {domain}")
39
 
40
  all_docs = []
41
  for subset in subsets:
@@ -107,6 +110,8 @@ gk_dataset = load_dataset("rungalileo/ragbench", "hotpotqa", split="test")
107
  cs_dataset = load_dataset("rungalileo/ragbench", "emanual", split="test")
108
  fin_dataset = load_dataset("rungalileo/ragbench", "finqa", split="test")
109
 
 
 
110
  # Load BGE reranker
111
  reranker = CrossEncoder("BAAI/bge-reranker-base", max_length=512)
112
 
@@ -173,7 +178,7 @@ def retrieve_top_c(query, domain, embedder, k=5):
173
 
174
 
175
  client = Groq(
176
- api_key= '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',#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',
177
  )
178
 
179
 
@@ -584,7 +589,7 @@ def evaluate_rag_pipeline(domain, q_indices):
584
  result["AUC-ROC (Adherence)"] = round(roc_auc_score(gt_adherence, pred_adherence), 4)
585
  else:
586
  result["Adherence"] = compute_rmse(gt_adherence, pred_adherence)
587
- result["AUC-ROC (Adherence)"] = "N/A - one class only"
588
 
589
  return result
590
 
@@ -627,4 +632,4 @@ iface = gr.Interface(
627
  )
628
 
629
  # Launch app
630
- iface.launch(server_name="0.0.0.0", server_port=7860, debug=True)
 
22
  import shutil
23
  from langchain.text_splitter import RecursiveCharacterTextSplitter
24
  from tqdm import tqdm
25
+ from dotenv import load_dotenv
26
+ import os
27
+
28
+ load_dotenv()
29
+ GROQ_API_KEY = os.getenv("GROQ_API_KEY")
30
 
31
  embedder = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
32
 
 
34
  dataset_path = f"{domain}_dataset"
35
  index_path = f"{domain}_index/faiss.index"
36
 
37
+ if os.path.exists(dataset_path) and os.path.exists(index_path):
38
+ print(f"βœ… Using cached dataset and index for domain: {domain}")
39
+ return Dataset.load_from_disk(dataset_path), faiss.read_index(index_path)
 
 
40
 
41
+ print(f"πŸš€ Building dataset and index for domain: {domain}")
42
 
43
  all_docs = []
44
  for subset in subsets:
 
110
  cs_dataset = load_dataset("rungalileo/ragbench", "emanual", split="test")
111
  fin_dataset = load_dataset("rungalileo/ragbench", "finqa", split="test")
112
 
113
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
114
+
115
  # Load BGE reranker
116
  reranker = CrossEncoder("BAAI/bge-reranker-base", max_length=512)
117
 
 
178
 
179
 
180
  client = Groq(
181
+ api_key= 'GROQ_API_KEY',
182
  )
183
 
184
 
 
589
  result["AUC-ROC (Adherence)"] = round(roc_auc_score(gt_adherence, pred_adherence), 4)
590
  else:
591
  result["Adherence"] = compute_rmse(gt_adherence, pred_adherence)
592
+ #result["AUC-ROC (Adherence)"] = "N/A - one class only"
593
 
594
  return result
595
 
 
632
  )
633
 
634
  # Launch app
635
+ iface.launch(server_name="0.0.0.0", server_port=7860, debug=True)
requirements.txt CHANGED
@@ -7,4 +7,5 @@ datasets
7
  scikit-learn
8
  groq
9
  langchain
10
- tqdm
 
 
7
  scikit-learn
8
  groq
9
  langchain
10
+ tqdm
11
+ python-dotenv