Spaces:
				
			
			
	
			
			
		Sleeping
		
	
	
	
			
			
	
	
	
	
		
		
		Sleeping
		
	Update app.py
Browse files
    	
        app.py
    CHANGED
    
    | @@ -1,22 +1,31 @@ | |
| 1 | 
             
            import gradio as gr
         | 
| 2 | 
             
            import spaces
         | 
| 3 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
| 4 |  | 
| 5 | 
            -
            #  | 
| 6 | 
            -
            MODEL_NAME   = "facebook/rag-sequence-nq"
         | 
| 7 | 
             
            DATASET_NAME = "username/mealplan-chunks"
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 8 |  | 
| 9 | 
            -
            #  | 
| 10 | 
             
            tokenizer = RagTokenizer.from_pretrained(MODEL_NAME)
         | 
| 11 | 
            -
             | 
| 12 | 
            -
                MODEL_NAME,
         | 
| 13 | 
            -
                index_name="exact",
         | 
| 14 | 
            -
                use_dummy_dataset=False,
         | 
| 15 | 
            -
                dataset_name=DATASET_NAME
         | 
| 16 | 
            -
            )
         | 
| 17 | 
            -
            model = RagSequenceForGeneration.from_pretrained(MODEL_NAME, retriever=retriever)
         | 
| 18 |  | 
| 19 | 
            -
            # βββ Core chat callback βββ
         | 
| 20 | 
             
            @spaces.GPU
         | 
| 21 | 
             
            def respond(
         | 
| 22 | 
             
                message: str,
         | 
| @@ -27,38 +36,45 @@ def respond( | |
| 27 | 
             
                avoid: str,
         | 
| 28 | 
             
                weeks: str,
         | 
| 29 | 
             
            ):
         | 
| 30 | 
            -
                #  | 
| 31 | 
             
                avoid_list = [a.strip() for a in avoid.split(",") if a.strip()]
         | 
| 32 | 
            -
                # build prefs string
         | 
| 33 | 
             
                prefs = (
         | 
| 34 | 
            -
                    f"Goal={goal}; "
         | 
| 35 | 
            -
                    f" | 
| 36 | 
            -
                    f"Meals={meals}/day; "
         | 
| 37 | 
            -
                    f"Avoid={','.join(avoid_list)}; "
         | 
| 38 | 
            -
                    f"Duration={weeks}"
         | 
| 39 | 
             
                )
         | 
| 40 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 41 | 
             
                prompt = (
         | 
| 42 | 
             
                    "SYSTEM: Only answer using the provided CONTEXT. "
         | 
| 43 | 
            -
                    "If itβs not there, say  | 
| 44 | 
             
                    f"PREFS: {prefs}\n"
         | 
|  | |
| 45 | 
             
                    f"Q: {message}\n"
         | 
| 46 | 
             
                )
         | 
| 47 | 
            -
             | 
|  | |
| 48 | 
             
                inputs  = tokenizer([prompt], return_tensors="pt")
         | 
| 49 | 
             
                outputs = model.generate(**inputs, num_beams=2, max_new_tokens=200)
         | 
| 50 | 
             
                answer  = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
         | 
| 51 |  | 
|  | |
| 52 | 
             
                history = history or []
         | 
| 53 | 
             
                history.append((message, answer))
         | 
| 54 | 
             
                return history
         | 
| 55 |  | 
| 56 | 
            -
            #  | 
| 57 | 
            -
            # preference controls
         | 
| 58 | 
             
            goal  = gr.Dropdown(["Lose weight","Bulk","Maintain"], value="Lose weight", label="Goal")
         | 
| 59 | 
             
            diet  = gr.CheckboxGroup(["Omnivore","Vegetarian","Vegan","Keto","Paleo","Low-Carb"], label="Diet Style")
         | 
| 60 | 
            -
            meals = gr.Slider(1, | 
| 61 | 
            -
            avoid = gr.Textbox(placeholder="e.g. Gluten, Dairy, Nuts | 
| 62 | 
             
            weeks = gr.Dropdown(["1 week","2 weeks","3 weeks","4 weeks"], value="1 week", label="Plan Length")
         | 
| 63 |  | 
| 64 | 
             
            demo = gr.ChatInterface(
         | 
|  | |
| 1 | 
             
            import gradio as gr
         | 
| 2 | 
             
            import spaces
         | 
| 3 | 
            +
            import faiss
         | 
| 4 | 
            +
            import numpy as np
         | 
| 5 | 
            +
            from datasets import load_dataset
         | 
| 6 | 
            +
            from sentence_transformers import SentenceTransformer
         | 
| 7 | 
            +
            from transformers import RagTokenizer, RagSequenceForGeneration
         | 
| 8 |  | 
| 9 | 
            +
            # β Config β
         | 
|  | |
| 10 | 
             
            DATASET_NAME = "username/mealplan-chunks"
         | 
| 11 | 
            +
            INDEX_PATH   = "mealplan.index"
         | 
| 12 | 
            +
            MODEL_NAME   = "facebook/rag-sequence-nq"
         | 
| 13 | 
            +
             | 
| 14 | 
            +
            # β Load chunks & FAISS index β
         | 
| 15 | 
            +
            ds = load_dataset(DATASET_NAME, split="train")
         | 
| 16 | 
            +
            texts   = ds["text"]
         | 
| 17 | 
            +
            sources = ds["source"]
         | 
| 18 | 
            +
            pages   = ds["page"]
         | 
| 19 | 
            +
             | 
| 20 | 
            +
            # β Embeddings embedder & FAISS β
         | 
| 21 | 
            +
            embedder         = SentenceTransformer("all-MiniLM-L6-v2")
         | 
| 22 | 
            +
            chunk_embeddings = embedder.encode(texts, convert_to_numpy=True)
         | 
| 23 | 
            +
            index            = faiss.read_index(INDEX_PATH)
         | 
| 24 |  | 
| 25 | 
            +
            # β RAG generator β
         | 
| 26 | 
             
            tokenizer = RagTokenizer.from_pretrained(MODEL_NAME)
         | 
| 27 | 
            +
            model     = RagSequenceForGeneration.from_pretrained(MODEL_NAME)
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 28 |  | 
|  | |
| 29 | 
             
            @spaces.GPU
         | 
| 30 | 
             
            def respond(
         | 
| 31 | 
             
                message: str,
         | 
|  | |
| 36 | 
             
                avoid: str,
         | 
| 37 | 
             
                weeks: str,
         | 
| 38 | 
             
            ):
         | 
| 39 | 
            +
                # Parse preferences
         | 
| 40 | 
             
                avoid_list = [a.strip() for a in avoid.split(",") if a.strip()]
         | 
|  | |
| 41 | 
             
                prefs = (
         | 
| 42 | 
            +
                    f"Goal={goal}; Diet={','.join(diet)}; "
         | 
| 43 | 
            +
                    f"Meals={meals}/day; Avoid={','.join(avoid_list)}; Duration={weeks}"
         | 
|  | |
|  | |
|  | |
| 44 | 
             
                )
         | 
| 45 | 
            +
             | 
| 46 | 
            +
                # 1) Query embedding & FAISS search
         | 
| 47 | 
            +
                q_emb = embedder.encode([message], convert_to_numpy=True)
         | 
| 48 | 
            +
                D, I  = index.search(q_emb, 5)  # top-5
         | 
| 49 | 
            +
                ctx_chunks = [
         | 
| 50 | 
            +
                    f"[{sources[i]} p{pages[i]}] {texts[i]}" for i in I[0]
         | 
| 51 | 
            +
                ]
         | 
| 52 | 
            +
                context = "\n".join(ctx_chunks)
         | 
| 53 | 
            +
             | 
| 54 | 
            +
                # 2) Build prompt
         | 
| 55 | 
             
                prompt = (
         | 
| 56 | 
             
                    "SYSTEM: Only answer using the provided CONTEXT. "
         | 
| 57 | 
            +
                    "If itβs not there, say \"I'm sorry, I don't know.\"\n"
         | 
| 58 | 
             
                    f"PREFS: {prefs}\n"
         | 
| 59 | 
            +
                    f"CONTEXT:\n{context}\n"
         | 
| 60 | 
             
                    f"Q: {message}\n"
         | 
| 61 | 
             
                )
         | 
| 62 | 
            +
             | 
| 63 | 
            +
                # 3) Generate
         | 
| 64 | 
             
                inputs  = tokenizer([prompt], return_tensors="pt")
         | 
| 65 | 
             
                outputs = model.generate(**inputs, num_beams=2, max_new_tokens=200)
         | 
| 66 | 
             
                answer  = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
         | 
| 67 |  | 
| 68 | 
            +
                # 4) Update history
         | 
| 69 | 
             
                history = history or []
         | 
| 70 | 
             
                history.append((message, answer))
         | 
| 71 | 
             
                return history
         | 
| 72 |  | 
| 73 | 
            +
            # β Build Gradio chat interface β
         | 
|  | |
| 74 | 
             
            goal  = gr.Dropdown(["Lose weight","Bulk","Maintain"], value="Lose weight", label="Goal")
         | 
| 75 | 
             
            diet  = gr.CheckboxGroup(["Omnivore","Vegetarian","Vegan","Keto","Paleo","Low-Carb"], label="Diet Style")
         | 
| 76 | 
            +
            meals = gr.Slider(1,6,value=3,step=1,label="Meals per day")
         | 
| 77 | 
            +
            avoid = gr.Textbox(placeholder="e.g. Gluten, Dairy, Nuts...", label="Avoidances (comma-separated)")
         | 
| 78 | 
             
            weeks = gr.Dropdown(["1 week","2 weeks","3 weeks","4 weeks"], value="1 week", label="Plan Length")
         | 
| 79 |  | 
| 80 | 
             
            demo = gr.ChatInterface(
         | 
