Spaces:
Running
Running
File size: 9,911 Bytes
b43abc8 05ea985 b43abc8 05ea985 b43abc8 05ea985 b43abc8 05ea985 b43abc8 2bace82 05ea985 b43abc8 ad30f0a b43abc8 05ea985 b43abc8 2bace82 05ea985 b43abc8 ad30f0a b43abc8 2bace82 b43abc8 ad30f0a b43abc8 87b381f b43abc8 be89223 04958a9 05ea985 b43abc8 04958a9 b43abc8 b112410 b43abc8 05ea985 b43abc8 b112410 b43abc8 b112410 b43abc8 ef72a9c b43abc8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 |
import os
# Ensure SAMBANOVA_BASE_URL is in the environment for litellm
# This should be set before dynamic_cheatsheet.language_model is imported if it relies on it at import time,
# but it's generally used at runtime when making the API call.
# Setting it here early in app.py is a safeguard.
SAMBANOVA_DEFINED_BASE_URL = "https://api.sambanova.ai/v1"
if "SAMBANOVA_BASE_URL" not in os.environ:
os.environ["SAMBANOVA_BASE_URL"] = SAMBANOVA_DEFINED_BASE_URL
print(f"SAMBANOVA_BASE_URL environment variable set to: {SAMBANOVA_DEFINED_BASE_URL}")
elif os.environ["SAMBANOVA_BASE_URL"] != SAMBANOVA_DEFINED_BASE_URL:
print(f"Warning: SAMBANOVA_BASE_URL environment variable is already set to {os.environ['SAMBANOVA_BASE_URL']}, but app expects {SAMBANOVA_DEFINED_BASE_URL}. Using the existing one.")
import gradio as gr
import sys
# Add the project root to the Python path to allow importing dynamic_cheatsheet
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), ".")))
from dynamic_cheatsheet.language_model import LanguageModel
# --- Configuration ---
SAMBANOVA_API_KEY = os.environ.get("SAMBANOVA_API_KEY")
# SAMBANOVA_BASE_URL is now set from SAMBANOVA_DEFINED_BASE_URL to env var if not present
# SAMBANOVA_MODEL_NAME = "sambanova/DeepSeek-R1-Distill-Llama-70B"
GENERATOR_PROMPT_PATH = "prompts/generator_prompt.txt"
CURATOR_PROMPT_PATH = "prompts/curator_prompt_for_dc_cumulative.txt"
GENERATOR_PROMPT = ""
CURATOR_PROMPT = ""
try:
with open(GENERATOR_PROMPT_PATH, "r") as f:
GENERATOR_PROMPT = f.read()
with open(CURATOR_PROMPT_PATH, "r") as f:
CURATOR_PROMPT = f.read()
except FileNotFoundError:
print(f"Error: Prompt files not found at {GENERATOR_PROMPT_PATH} or {CURATOR_PROMPT_PATH}. Please ensure they exist.")
GENERATOR_PROMPT = "You are a helpful assistant. Given a question and a cheatsheet, provide an answer. Cheatsheet: [[CHEATSHEET]] Question: [[QUESTION]] FINAL ANSWER: <answer></answer>"
CURATOR_PROMPT = "You are a helpful assistant. Given a question, a model answer, and a previous cheatsheet, update the cheatsheet. Previous Cheatsheet: [[PREVIOUS_CHEATSHEET]] Question: [[QUESTION]] Model Answer: [[MODEL_ANSWER]] NEW CHEATSHEET: <cheatsheet></cheatsheet>"
# --- Global variable for cheatsheet ---
current_cheatsheet_cache = "(empty)"
def initialize_model(model_name_input):
if not SAMBANOVA_API_KEY:
raise gr.Error("SAMBANOVA_API_KEY environment variable not set. Please set it in your Hugging Face Space secrets or local environment.")
# LanguageModel will be modified to handle samba/ prefix using env vars for API key/base URL via litellm
model = LanguageModel(
model_name=model_name_input
)
return model
def generate_cheatsheet_func(training_data_text, model_name_input, progress=gr.Progress(track_tqdm=True)):
global current_cheatsheet_cache
if not training_data_text.strip():
current_cheatsheet_cache = "(empty)"
return "Training data is empty. Cheatsheet reset to (empty)."
print('generate_cheatsheet_func model_name_input', model_name_input)
model = initialize_model(model_name_input)
training_examples = [ex.strip() for ex in training_data_text.split("\n") if ex.strip()]
cheatsheet_content = "(empty)"
progress(0, desc="Initializing Cheatsheet Generation")
for i, example_input in enumerate(progress.tqdm(training_examples, desc="Generating Cheatsheet")):
print(f"Processing training example {i+1}/{len(training_examples)}: {example_input[:50]}...")
try:
results_dict = model.advanced_generate(
approach_name="DynamicCheatsheet_Cumulative",
input_txt=example_input,
cheatsheet=cheatsheet_content,
generator_template=GENERATOR_PROMPT,
cheatsheet_template=CURATOR_PROMPT,
temperature=0.1,
max_tokens=2048
)
cheatsheet_content = results_dict.get("final_cheatsheet", cheatsheet_content)
except Exception as e:
print(f"Error processing example '{example_input[:50]}...': {e}")
# Continue with the current cheatsheet, and show error in UI
gr.Warning(f"Error on example '{example_input[:30]}...': {e}. Skipping this example.")
pass
current_cheatsheet_cache = cheatsheet_content
return current_cheatsheet_cache
def get_answers_func(user_query, model_name_input):
global current_cheatsheet_cache
if not user_query.strip():
return "Query is empty.", "Query is empty."
print('get_answers_func model_name_input', model_name_input)
model = initialize_model(model_name_input)
answer_with_cheatsheet = "Error retrieving answer."
answer_without_cheatsheet = "Error retrieving answer."
# Inference WITH cheatsheet
try:
print(f"Querying WITH cheatsheet ({current_cheatsheet_cache[:50]}...)")
results_with_cheatsheet = model.advanced_generate(
approach_name="DynamicCheatsheet_Cumulative",
input_txt=user_query,
cheatsheet=current_cheatsheet_cache,
generator_template=GENERATOR_PROMPT,
cheatsheet_template=CURATOR_PROMPT,
temperature=0.1,
max_tokens=2048
)
answer_with_cheatsheet = results_with_cheatsheet.get("final_answer", "Error: Could not extract answer.")
except Exception as e:
print(f"Error (with cheatsheet): {e}")
answer_with_cheatsheet = f"Error during inference with cheatsheet: {e}"
# Inference WITHOUT cheatsheet
try:
print(f"Querying WITHOUT cheatsheet...")
results_without_cheatsheet = model.advanced_generate(
approach_name="DynamicCheatsheet_Cumulative",
input_txt=user_query,
cheatsheet="(empty)",
generator_template=GENERATOR_PROMPT,
cheatsheet_template=CURATOR_PROMPT,
temperature=0.1,
max_tokens=2048
)
answer_without_cheatsheet = results_without_cheatsheet.get("final_answer", "Error: Could not extract answer.")
except Exception as e:
print(f"Error (without cheatsheet): {e}")
answer_without_cheatsheet = f"Error during inference without cheatsheet: {e}"
return answer_with_cheatsheet, answer_without_cheatsheet
# --- Gradio Interface ---
with gr.Blocks(title="Task Caching Demo", theme=gr.themes.Default(font=[gr.themes.GoogleFont("Inconsolata"), "Arial", "sans-serif"])) as demo:
gr.Markdown("# Task Caching Demo")
gr.Markdown("Demonstrates the effect of using a dynamically generated cheatsheet (Task Caching) on model inference. Uses SambaNova API via `litellm`.")
training_data_example = '''
Solve for 24: 1 2 3 4
Solve for 24: 3 4 5 6
Solve for 24: 4 5 6 7
'''
with gr.Tabs():
model_name_input = gr.Textbox(
label="SambaNova Model Name",
value="sambanova/Meta-Llama-3.1-8B-Instruct", # Default value
info="Enter the SambaNova model name (e.g., sambanova/DeepSeek-R1-Distill-Llama-70B). Ensure the 'sambanova/' prefix if required by litellm configuration."
)
SAMBANOVA_API_KEY = gr.Textbox(
label="SambaNova API Key",
value="", # Default value
info="Please Enter your SambaNova API Key, otherwise by default will use Changran's key, but RPM is low"
)
with gr.Tabs():
with gr.TabItem("1. Task Caching (Generate Task-Specific Cheatsheet from Training Data)"):
gr.Markdown("Paste your training data below, one example per line. This data will be used to build a cumulative cheatsheet. The process may take some time depending on the number of examples.")
training_data_input = gr.Textbox(lines=10, label="Training Data", value=training_data_example)
generate_cheatsheet_button = gr.Button("Generate Cheatsheet (Task Caching)", variant="primary")
cheatsheet_output = gr.Textbox(label="Generated Cheatsheet", lines=15, interactive=False, show_label=True)
generate_cheatsheet_button.click(
fn=generate_cheatsheet_func,
inputs=[training_data_input, model_name_input],
outputs=cheatsheet_output,
show_progress="full"
)
with gr.TabItem("2. Test Inference"):
gr.Markdown("Enter your query below. The model will attempt to answer it twice: once using the generated cheatsheet (if any), and once without it.")
query_input = gr.Textbox(lines=3, label="Your Query", value="e.g., What is the solution to 5 6 6 8 in the Game of 24?")
get_answers_button = gr.Button("Get Answers", variant="primary")
with gr.Row():
answer_with_cheatsheet_output = gr.Textbox(label="Answer WITH Task Caching", lines=10, interactive=False, show_label=True)
answer_without_cheatsheet_output = gr.Textbox(label="Answer WITHOUT Task Caching", lines=10, interactive=False, show_label=True)
get_answers_button.click(
fn=get_answers_func,
inputs=[query_input, model_name_input],
outputs=[answer_with_cheatsheet_output, answer_without_cheatsheet_output]
)
gr.Markdown("**Important:** Ensure `SAMBANOVA_API_KEY` is set as a secret in your Hugging Face Space or as an environment variable if running locally. `SAMBANOVA_BASE_URL` is set to `https://api.sambanova.ai/v1` by default if not found in environment.")
if __name__ == "__main__":
if not SAMBANOVA_API_KEY:
print("Warning: SAMBANOVA_API_KEY is not set. The application will likely fail to contact the SambaNova API.")
print("Please set the SAMBANOVA_API_KEY environment variable.")
demo.launch()
|