DurgaDeepak commited on
Commit
9b1fba6
Β·
verified Β·
1 Parent(s): 4eb79ca

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +42 -26
app.py CHANGED
@@ -1,22 +1,31 @@
1
  import gradio as gr
2
  import spaces
3
- from transformers import RagTokenizer, RagRetriever, RagSequenceForGeneration
 
 
 
 
4
 
5
- # β€”β€”β€” Configuration β€”β€”β€”
6
- MODEL_NAME = "facebook/rag-sequence-nq"
7
  DATASET_NAME = "username/mealplan-chunks"
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
- # β€”β€”β€” Initialize RAG β€”β€”β€”
10
  tokenizer = RagTokenizer.from_pretrained(MODEL_NAME)
11
- retriever = RagRetriever.from_pretrained(
12
- MODEL_NAME,
13
- index_name="exact",
14
- use_dummy_dataset=False,
15
- dataset_name=DATASET_NAME
16
- )
17
- model = RagSequenceForGeneration.from_pretrained(MODEL_NAME, retriever=retriever)
18
 
19
- # β€”β€”β€” Core chat callback β€”β€”β€”
20
  @spaces.GPU
21
  def respond(
22
  message: str,
@@ -27,38 +36,45 @@ def respond(
27
  avoid: str,
28
  weeks: str,
29
  ):
30
- # parse avoidances
31
  avoid_list = [a.strip() for a in avoid.split(",") if a.strip()]
32
- # build prefs string
33
  prefs = (
34
- f"Goal={goal}; "
35
- f"Diet={','.join(diet)}; "
36
- f"Meals={meals}/day; "
37
- f"Avoid={','.join(avoid_list)}; "
38
- f"Duration={weeks}"
39
  )
40
- # system guardrail + prefs + question
 
 
 
 
 
 
 
 
 
41
  prompt = (
42
  "SYSTEM: Only answer using the provided CONTEXT. "
43
- "If it’s not there, say β€œI’m sorry, I don’t know.”\n"
44
  f"PREFS: {prefs}\n"
 
45
  f"Q: {message}\n"
46
  )
47
- # generate
 
48
  inputs = tokenizer([prompt], return_tensors="pt")
49
  outputs = model.generate(**inputs, num_beams=2, max_new_tokens=200)
50
  answer = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
51
 
 
52
  history = history or []
53
  history.append((message, answer))
54
  return history
55
 
56
- # β€”β€”β€” Build the Chat Interface β€”β€”β€”
57
- # preference controls
58
  goal = gr.Dropdown(["Lose weight","Bulk","Maintain"], value="Lose weight", label="Goal")
59
  diet = gr.CheckboxGroup(["Omnivore","Vegetarian","Vegan","Keto","Paleo","Low-Carb"], label="Diet Style")
60
- meals = gr.Slider(1, 6, value=3, step=1, label="Meals per day")
61
- avoid = gr.Textbox(placeholder="e.g. Gluten, Dairy, Nuts, Eggs, Soy…", label="Avoidances (comma-separated)")
62
  weeks = gr.Dropdown(["1 week","2 weeks","3 weeks","4 weeks"], value="1 week", label="Plan Length")
63
 
64
  demo = gr.ChatInterface(
 
1
  import gradio as gr
2
  import spaces
3
+ import faiss
4
+ import numpy as np
5
+ from datasets import load_dataset
6
+ from sentence_transformers import SentenceTransformer
7
+ from transformers import RagTokenizer, RagSequenceForGeneration
8
 
9
+ # β€” Config β€”
 
10
  DATASET_NAME = "username/mealplan-chunks"
11
+ INDEX_PATH = "mealplan.index"
12
+ MODEL_NAME = "facebook/rag-sequence-nq"
13
+
14
+ # β€” Load chunks & FAISS index β€”
15
+ ds = load_dataset(DATASET_NAME, split="train")
16
+ texts = ds["text"]
17
+ sources = ds["source"]
18
+ pages = ds["page"]
19
+
20
+ # β€” Embeddings embedder & FAISS β€”
21
+ embedder = SentenceTransformer("all-MiniLM-L6-v2")
22
+ chunk_embeddings = embedder.encode(texts, convert_to_numpy=True)
23
+ index = faiss.read_index(INDEX_PATH)
24
 
25
+ # β€” RAG generator β€”
26
  tokenizer = RagTokenizer.from_pretrained(MODEL_NAME)
27
+ model = RagSequenceForGeneration.from_pretrained(MODEL_NAME)
 
 
 
 
 
 
28
 
 
29
  @spaces.GPU
30
  def respond(
31
  message: str,
 
36
  avoid: str,
37
  weeks: str,
38
  ):
39
+ # Parse preferences
40
  avoid_list = [a.strip() for a in avoid.split(",") if a.strip()]
 
41
  prefs = (
42
+ f"Goal={goal}; Diet={','.join(diet)}; "
43
+ f"Meals={meals}/day; Avoid={','.join(avoid_list)}; Duration={weeks}"
 
 
 
44
  )
45
+
46
+ # 1) Query embedding & FAISS search
47
+ q_emb = embedder.encode([message], convert_to_numpy=True)
48
+ D, I = index.search(q_emb, 5) # top-5
49
+ ctx_chunks = [
50
+ f"[{sources[i]} p{pages[i]}] {texts[i]}" for i in I[0]
51
+ ]
52
+ context = "\n".join(ctx_chunks)
53
+
54
+ # 2) Build prompt
55
  prompt = (
56
  "SYSTEM: Only answer using the provided CONTEXT. "
57
+ "If it’s not there, say \"I'm sorry, I don't know.\"\n"
58
  f"PREFS: {prefs}\n"
59
+ f"CONTEXT:\n{context}\n"
60
  f"Q: {message}\n"
61
  )
62
+
63
+ # 3) Generate
64
  inputs = tokenizer([prompt], return_tensors="pt")
65
  outputs = model.generate(**inputs, num_beams=2, max_new_tokens=200)
66
  answer = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
67
 
68
+ # 4) Update history
69
  history = history or []
70
  history.append((message, answer))
71
  return history
72
 
73
+ # β€” Build Gradio chat interface β€”
 
74
  goal = gr.Dropdown(["Lose weight","Bulk","Maintain"], value="Lose weight", label="Goal")
75
  diet = gr.CheckboxGroup(["Omnivore","Vegetarian","Vegan","Keto","Paleo","Low-Carb"], label="Diet Style")
76
+ meals = gr.Slider(1,6,value=3,step=1,label="Meals per day")
77
+ avoid = gr.Textbox(placeholder="e.g. Gluten, Dairy, Nuts...", label="Avoidances (comma-separated)")
78
  weeks = gr.Dropdown(["1 week","2 weeks","3 weeks","4 weeks"], value="1 week", label="Plan Length")
79
 
80
  demo = gr.ChatInterface(