Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -21,8 +21,8 @@ from dynamic_cheatsheet.language_model import LanguageModel
|
|
21 |
# --- Configuration ---
|
22 |
SAMBANOVA_API_KEY = os.environ.get("SAMBANOVA_API_KEY")
|
23 |
# SAMBANOVA_BASE_URL is now set from SAMBANOVA_DEFINED_BASE_URL to env var if not present
|
24 |
-
|
25 |
-
SAMBANOVA_MODEL_NAME = "sambanova/DeepSeek-R1-Distill-Llama-70B"
|
26 |
|
27 |
GENERATOR_PROMPT_PATH = "prompts/generator_prompt.txt"
|
28 |
CURATOR_PROMPT_PATH = "prompts/curator_prompt_for_dc_cumulative.txt"
|
@@ -43,22 +43,22 @@ except FileNotFoundError:
|
|
43 |
# --- Global variable for cheatsheet ---
|
44 |
current_cheatsheet_cache = "(empty)"
|
45 |
|
46 |
-
def initialize_model():
|
47 |
if not SAMBANOVA_API_KEY:
|
48 |
raise gr.Error("SAMBANOVA_API_KEY environment variable not set. Please set it in your Hugging Face Space secrets or local environment.")
|
49 |
# LanguageModel will be modified to handle samba/ prefix using env vars for API key/base URL via litellm
|
50 |
model = LanguageModel(
|
51 |
-
model_name=
|
52 |
)
|
53 |
return model
|
54 |
|
55 |
-
def generate_cheatsheet_func(training_data_text, progress=gr.Progress(track_tqdm=True)):
|
56 |
global current_cheatsheet_cache
|
57 |
if not training_data_text.strip():
|
58 |
current_cheatsheet_cache = "(empty)"
|
59 |
return "Training data is empty. Cheatsheet reset to (empty)."
|
60 |
|
61 |
-
model = initialize_model()
|
62 |
|
63 |
training_examples = [ex.strip() for ex in training_data_text.split("\n") if ex.strip()]
|
64 |
|
@@ -86,12 +86,12 @@ def generate_cheatsheet_func(training_data_text, progress=gr.Progress(track_tqdm
|
|
86 |
current_cheatsheet_cache = cheatsheet_content
|
87 |
return current_cheatsheet_cache
|
88 |
|
89 |
-
def get_answers_func(user_query):
|
90 |
global current_cheatsheet_cache
|
91 |
if not user_query.strip():
|
92 |
return "Query is empty.", "Query is empty."
|
93 |
|
94 |
-
model = initialize_model()
|
95 |
answer_with_cheatsheet = "Error retrieving answer."
|
96 |
answer_without_cheatsheet = "Error retrieving answer."
|
97 |
|
@@ -136,6 +136,13 @@ with gr.Blocks(title="Task Caching Demo", theme=gr.themes.Soft()) as demo:
|
|
136 |
gr.Markdown("# Task Caching Demo")
|
137 |
gr.Markdown("Demonstrates the effect of using a dynamically generated cheatsheet (Task Caching) on model inference. Uses SambaNova API via `litellm`.")
|
138 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
139 |
with gr.Tabs():
|
140 |
with gr.TabItem("1. Generate Cheatsheet (Task Caching)"):
|
141 |
gr.Markdown("Paste your training data below, one example per line. This data will be used to build a cumulative cheatsheet. The process may take some time depending on the number of examples.")
|
@@ -144,7 +151,7 @@ with gr.Blocks(title="Task Caching Demo", theme=gr.themes.Soft()) as demo:
|
|
144 |
cheatsheet_output = gr.Textbox(label="Generated Cheatsheet", lines=15, interactive=False, show_label=True)
|
145 |
generate_cheatsheet_button.click(
|
146 |
fn=generate_cheatsheet_func,
|
147 |
-
inputs=training_data_input,
|
148 |
outputs=cheatsheet_output,
|
149 |
show_progress="full"
|
150 |
)
|
@@ -160,7 +167,7 @@ with gr.Blocks(title="Task Caching Demo", theme=gr.themes.Soft()) as demo:
|
|
160 |
|
161 |
get_answers_button.click(
|
162 |
fn=get_answers_func,
|
163 |
-
inputs=query_input,
|
164 |
outputs=[answer_with_cheatsheet_output, answer_without_cheatsheet_output]
|
165 |
)
|
166 |
|
|
|
21 |
# --- Configuration ---
|
22 |
SAMBANOVA_API_KEY = os.environ.get("SAMBANOVA_API_KEY")
|
23 |
# SAMBANOVA_BASE_URL is now set from SAMBANOVA_DEFINED_BASE_URL to env var if not present
|
24 |
+
|
25 |
+
# SAMBANOVA_MODEL_NAME = "sambanova/DeepSeek-R1-Distill-Llama-70B"
|
26 |
|
27 |
GENERATOR_PROMPT_PATH = "prompts/generator_prompt.txt"
|
28 |
CURATOR_PROMPT_PATH = "prompts/curator_prompt_for_dc_cumulative.txt"
|
|
|
43 |
# --- Global variable for cheatsheet ---
|
44 |
current_cheatsheet_cache = "(empty)"
|
45 |
|
46 |
+
def initialize_model(model_name_input):
|
47 |
if not SAMBANOVA_API_KEY:
|
48 |
raise gr.Error("SAMBANOVA_API_KEY environment variable not set. Please set it in your Hugging Face Space secrets or local environment.")
|
49 |
# LanguageModel will be modified to handle samba/ prefix using env vars for API key/base URL via litellm
|
50 |
model = LanguageModel(
|
51 |
+
model_name=model_name_input
|
52 |
)
|
53 |
return model
|
54 |
|
55 |
+
def generate_cheatsheet_func(training_data_text, model_name_input, progress=gr.Progress(track_tqdm=True)):
|
56 |
global current_cheatsheet_cache
|
57 |
if not training_data_text.strip():
|
58 |
current_cheatsheet_cache = "(empty)"
|
59 |
return "Training data is empty. Cheatsheet reset to (empty)."
|
60 |
|
61 |
+
model = initialize_model(model_name_input)
|
62 |
|
63 |
training_examples = [ex.strip() for ex in training_data_text.split("\n") if ex.strip()]
|
64 |
|
|
|
86 |
current_cheatsheet_cache = cheatsheet_content
|
87 |
return current_cheatsheet_cache
|
88 |
|
89 |
+
def get_answers_func(user_query, model_name_input):
|
90 |
global current_cheatsheet_cache
|
91 |
if not user_query.strip():
|
92 |
return "Query is empty.", "Query is empty."
|
93 |
|
94 |
+
model = initialize_model(model_name_input)
|
95 |
answer_with_cheatsheet = "Error retrieving answer."
|
96 |
answer_without_cheatsheet = "Error retrieving answer."
|
97 |
|
|
|
136 |
gr.Markdown("# Task Caching Demo")
|
137 |
gr.Markdown("Demonstrates the effect of using a dynamically generated cheatsheet (Task Caching) on model inference. Uses SambaNova API via `litellm`.")
|
138 |
|
139 |
+
model_name_input = gr.Textbox(
|
140 |
+
label="SambaNova Model Name",
|
141 |
+
value="sambanova/DeepSeek-R1-Distill-Llama-70B", # Default value
|
142 |
+
info="Enter the SambaNova model name (e.g., samba/DeepSeek-R1-Distill-Llama-70B). Ensure the 'samba/' prefix if required by litellm configuration."
|
143 |
+
)
|
144 |
+
# END OF ADDED PART
|
145 |
+
|
146 |
with gr.Tabs():
|
147 |
with gr.TabItem("1. Generate Cheatsheet (Task Caching)"):
|
148 |
gr.Markdown("Paste your training data below, one example per line. This data will be used to build a cumulative cheatsheet. The process may take some time depending on the number of examples.")
|
|
|
151 |
cheatsheet_output = gr.Textbox(label="Generated Cheatsheet", lines=15, interactive=False, show_label=True)
|
152 |
generate_cheatsheet_button.click(
|
153 |
fn=generate_cheatsheet_func,
|
154 |
+
inputs=[training_data_input, model_name_input],
|
155 |
outputs=cheatsheet_output,
|
156 |
show_progress="full"
|
157 |
)
|
|
|
167 |
|
168 |
get_answers_button.click(
|
169 |
fn=get_answers_func,
|
170 |
+
inputs=[query_input, model_name_input]
|
171 |
outputs=[answer_with_cheatsheet_output, answer_without_cheatsheet_output]
|
172 |
)
|
173 |
|