File size: 9,911 Bytes
b43abc8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
05ea985
 
b43abc8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
05ea985
b43abc8
 
 
 
05ea985
b43abc8
 
 
05ea985
b43abc8
 
 
 
2bace82
 
05ea985
b43abc8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ad30f0a
b43abc8
 
 
 
 
 
 
 
 
 
05ea985
b43abc8
 
 
 
2bace82
05ea985
b43abc8
 
 
 
 
 
 
 
 
 
 
 
 
ad30f0a
b43abc8
 
 
 
 
 
 
 
 
 
 
 
2bace82
b43abc8
 
 
ad30f0a
b43abc8
 
 
 
 
 
 
 
 
87b381f
 
b43abc8
 
 
be89223
 
 
 
 
04958a9
 
 
 
 
 
 
 
 
 
 
05ea985
b43abc8
04958a9
b43abc8
b112410
b43abc8
 
 
 
05ea985
b43abc8
 
 
 
 
 
b112410
b43abc8
 
 
b112410
 
b43abc8
 
 
ef72a9c
b43abc8
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
import os
# Ensure SAMBANOVA_BASE_URL is in the environment for litellm
# This should be set before dynamic_cheatsheet.language_model is imported if it relies on it at import time,
# but it's generally used at runtime when making the API call.
# Setting it here early in app.py is a safeguard.
SAMBANOVA_DEFINED_BASE_URL = "https://api.sambanova.ai/v1"
if "SAMBANOVA_BASE_URL" not in os.environ:
    os.environ["SAMBANOVA_BASE_URL"] = SAMBANOVA_DEFINED_BASE_URL
    print(f"SAMBANOVA_BASE_URL environment variable set to: {SAMBANOVA_DEFINED_BASE_URL}")
elif os.environ["SAMBANOVA_BASE_URL"] != SAMBANOVA_DEFINED_BASE_URL:
    print(f"Warning: SAMBANOVA_BASE_URL environment variable is already set to {os.environ['SAMBANOVA_BASE_URL']}, but app expects {SAMBANOVA_DEFINED_BASE_URL}. Using the existing one.")

import gradio as gr
import sys

# Add the project root to the Python path to allow importing dynamic_cheatsheet
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), ".")))

from dynamic_cheatsheet.language_model import LanguageModel

# --- Configuration ---
SAMBANOVA_API_KEY = os.environ.get("SAMBANOVA_API_KEY")
# SAMBANOVA_BASE_URL is now set from SAMBANOVA_DEFINED_BASE_URL to env var if not present

# SAMBANOVA_MODEL_NAME = "sambanova/DeepSeek-R1-Distill-Llama-70B" 

GENERATOR_PROMPT_PATH = "prompts/generator_prompt.txt"
CURATOR_PROMPT_PATH = "prompts/curator_prompt_for_dc_cumulative.txt"

GENERATOR_PROMPT = ""
CURATOR_PROMPT = ""

try:
    with open(GENERATOR_PROMPT_PATH, "r") as f:
        GENERATOR_PROMPT = f.read()
    with open(CURATOR_PROMPT_PATH, "r") as f:
        CURATOR_PROMPT = f.read()
except FileNotFoundError:
    print(f"Error: Prompt files not found at {GENERATOR_PROMPT_PATH} or {CURATOR_PROMPT_PATH}. Please ensure they exist.")
    GENERATOR_PROMPT = "You are a helpful assistant. Given a question and a cheatsheet, provide an answer. Cheatsheet: [[CHEATSHEET]] Question: [[QUESTION]] FINAL ANSWER: <answer></answer>"
    CURATOR_PROMPT = "You are a helpful assistant. Given a question, a model answer, and a previous cheatsheet, update the cheatsheet. Previous Cheatsheet: [[PREVIOUS_CHEATSHEET]] Question: [[QUESTION]] Model Answer: [[MODEL_ANSWER]] NEW CHEATSHEET: <cheatsheet></cheatsheet>"

# --- Global variable for cheatsheet ---
current_cheatsheet_cache = "(empty)"

def initialize_model(model_name_input):
    if not SAMBANOVA_API_KEY:
        raise gr.Error("SAMBANOVA_API_KEY environment variable not set. Please set it in your Hugging Face Space secrets or local environment.")
    # LanguageModel will be modified to handle samba/ prefix using env vars for API key/base URL via litellm
    model = LanguageModel(
        model_name=model_name_input
    )
    return model

def generate_cheatsheet_func(training_data_text, model_name_input, progress=gr.Progress(track_tqdm=True)):
    global current_cheatsheet_cache
    if not training_data_text.strip():
        current_cheatsheet_cache = "(empty)"
        return "Training data is empty. Cheatsheet reset to (empty)."

    print('generate_cheatsheet_func model_name_input', model_name_input)
    model = initialize_model(model_name_input)
    
    training_examples = [ex.strip() for ex in training_data_text.split("\n") if ex.strip()]
    
    cheatsheet_content = "(empty)"
    
    progress(0, desc="Initializing Cheatsheet Generation")
    for i, example_input in enumerate(progress.tqdm(training_examples, desc="Generating Cheatsheet")):
        print(f"Processing training example {i+1}/{len(training_examples)}: {example_input[:50]}...")
        try:
            results_dict = model.advanced_generate(
                approach_name="DynamicCheatsheet_Cumulative",
                input_txt=example_input,
                cheatsheet=cheatsheet_content, 
                generator_template=GENERATOR_PROMPT,
                cheatsheet_template=CURATOR_PROMPT,
                temperature=0.1,
                max_tokens=2048
            )
            cheatsheet_content = results_dict.get("final_cheatsheet", cheatsheet_content)
        except Exception as e:
            print(f"Error processing example '{example_input[:50]}...': {e}")
            # Continue with the current cheatsheet, and show error in UI
            gr.Warning(f"Error on example '{example_input[:30]}...': {e}. Skipping this example.")
            pass       
    current_cheatsheet_cache = cheatsheet_content
    return current_cheatsheet_cache

def get_answers_func(user_query, model_name_input):
    global current_cheatsheet_cache
    if not user_query.strip():
        return "Query is empty.", "Query is empty."

    print('get_answers_func model_name_input', model_name_input)
    model = initialize_model(model_name_input)
    answer_with_cheatsheet = "Error retrieving answer."
    answer_without_cheatsheet = "Error retrieving answer."

    # Inference WITH cheatsheet
    try:
        print(f"Querying WITH cheatsheet ({current_cheatsheet_cache[:50]}...)")
        results_with_cheatsheet = model.advanced_generate(
            approach_name="DynamicCheatsheet_Cumulative",
            input_txt=user_query,
            cheatsheet=current_cheatsheet_cache,
            generator_template=GENERATOR_PROMPT,
            cheatsheet_template=CURATOR_PROMPT, 
            temperature=0.1,
            max_tokens=2048
        )
        answer_with_cheatsheet = results_with_cheatsheet.get("final_answer", "Error: Could not extract answer.")
    except Exception as e:
        print(f"Error (with cheatsheet): {e}")
        answer_with_cheatsheet = f"Error during inference with cheatsheet: {e}"

    # Inference WITHOUT cheatsheet
    try:
        print(f"Querying WITHOUT cheatsheet...")
        results_without_cheatsheet = model.advanced_generate(
            approach_name="DynamicCheatsheet_Cumulative", 
            input_txt=user_query,
            cheatsheet="(empty)", 
            generator_template=GENERATOR_PROMPT,
            cheatsheet_template=CURATOR_PROMPT,
            temperature=0.1,
            max_tokens=2048
        )
        answer_without_cheatsheet = results_without_cheatsheet.get("final_answer", "Error: Could not extract answer.")
    except Exception as e:
        print(f"Error (without cheatsheet): {e}")
        answer_without_cheatsheet = f"Error during inference without cheatsheet: {e}"
        
    return answer_with_cheatsheet, answer_without_cheatsheet

# --- Gradio Interface ---

with gr.Blocks(title="Task Caching Demo", theme=gr.themes.Default(font=[gr.themes.GoogleFont("Inconsolata"), "Arial", "sans-serif"])) as demo:
    gr.Markdown("# Task Caching Demo")
    gr.Markdown("Demonstrates the effect of using a dynamically generated cheatsheet (Task Caching) on model inference. Uses SambaNova API via `litellm`.")

    training_data_example = '''
Solve for 24: 1 2 3 4
Solve for 24: 3 4 5 6
Solve for 24: 4 5 6 7
'''
    with gr.Tabs():
        model_name_input = gr.Textbox(
            label="SambaNova Model Name",
            value="sambanova/Meta-Llama-3.1-8B-Instruct", # Default value
            info="Enter the SambaNova model name (e.g., sambanova/DeepSeek-R1-Distill-Llama-70B). Ensure the 'sambanova/' prefix if required by litellm configuration."
        )
        SAMBANOVA_API_KEY = gr.Textbox(
            label="SambaNova API Key",
            value="", # Default value
            info="Please Enter your SambaNova API Key, otherwise by default will use Changran's key, but RPM is low"
        )
    
    with gr.Tabs():
        with gr.TabItem("1. Task Caching (Generate Task-Specific Cheatsheet from Training Data)"):
            gr.Markdown("Paste your training data below, one example per line. This data will be used to build a cumulative cheatsheet. The process may take some time depending on the number of examples.")
            training_data_input = gr.Textbox(lines=10, label="Training Data", value=training_data_example)
            generate_cheatsheet_button = gr.Button("Generate Cheatsheet (Task Caching)", variant="primary")
            cheatsheet_output = gr.Textbox(label="Generated Cheatsheet", lines=15, interactive=False, show_label=True)
            generate_cheatsheet_button.click(
                fn=generate_cheatsheet_func, 
                inputs=[training_data_input, model_name_input],
                outputs=cheatsheet_output,
                show_progress="full"
            )

        with gr.TabItem("2. Test Inference"):
            gr.Markdown("Enter your query below. The model will attempt to answer it twice: once using the generated cheatsheet (if any), and once without it.")
            query_input = gr.Textbox(lines=3, label="Your Query", value="e.g., What is the solution to 5 6 6 8 in the Game of 24?")
            get_answers_button = gr.Button("Get Answers", variant="primary")
            
            with gr.Row():
                answer_with_cheatsheet_output = gr.Textbox(label="Answer WITH Task Caching", lines=10, interactive=False, show_label=True)
                answer_without_cheatsheet_output = gr.Textbox(label="Answer WITHOUT Task Caching", lines=10, interactive=False, show_label=True)
            
            get_answers_button.click(
                fn=get_answers_func, 
                inputs=[query_input, model_name_input],
                outputs=[answer_with_cheatsheet_output, answer_without_cheatsheet_output]
            )
    
    gr.Markdown("**Important:** Ensure `SAMBANOVA_API_KEY` is set as a secret in your Hugging Face Space or as an environment variable if running locally. `SAMBANOVA_BASE_URL` is set to `https://api.sambanova.ai/v1` by default if not found in environment.")

if __name__ == "__main__":
    if not SAMBANOVA_API_KEY:
        print("Warning: SAMBANOVA_API_KEY is not set. The application will likely fail to contact the SambaNova API.")
        print("Please set the SAMBANOVA_API_KEY environment variable.")
    demo.launch()