Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -5,8 +5,8 @@
|
|
5 |
# This script has been updated to run as a Hugging Face Space.
|
6 |
#
|
7 |
# Key Upgrades from the original script:
|
8 |
-
# 1. **Hugging Face Model Integration**: Uses the 'google/gemma-
|
9 |
-
# from the Hugging Face Hub for argument extraction
|
10 |
# 2. **Environment Variable Management**: Securely accesses the
|
11 |
# HUGGING_FACE_HUB_TOKEN using os.environ.get(), which is the standard
|
12 |
# for Hugging Face Spaces.
|
@@ -45,16 +45,18 @@ try:
|
|
45 |
raise ValueError("HUGGING_FACE_HUB_TOKEN secret not found.")
|
46 |
|
47 |
print("⚙️ Loading Hugging Face model for argument extraction...")
|
48 |
-
# Using
|
49 |
-
|
|
|
|
|
50 |
hf_model = AutoModelForCausalLM.from_pretrained(
|
51 |
-
|
52 |
token=HF_TOKEN,
|
53 |
torch_dtype=torch.bfloat16, # Use bfloat16 for efficiency
|
54 |
device_map="auto" # Automatically use GPU if available
|
55 |
)
|
56 |
USE_HF_LLM = True
|
57 |
-
print("✅ Successfully loaded '
|
58 |
|
59 |
except Exception as e:
|
60 |
USE_HF_LLM = False
|
@@ -86,7 +88,7 @@ class Tool:
|
|
86 |
"""
|
87 |
schema_str = json.dumps(self.args_schema, indent=2)
|
88 |
examples_str = "\n".join([f" - Example: {ex['prompt']} -> Args: {json.dumps(ex['args'])}" for ex in self.examples])
|
89 |
-
|
90 |
embedding_text = (
|
91 |
f"Tool Name: {self.name}\n"
|
92 |
f"Description: {self.description}\n"
|
@@ -106,10 +108,10 @@ def get_weather_forecast(location: str, days: int = 1):
|
|
106 |
"""Simulates fetching a weather forecast."""
|
107 |
if not isinstance(location, str) or not isinstance(days, int):
|
108 |
return {"error": "Invalid argument types. 'location' must be a string and 'days' an integer."}
|
109 |
-
|
110 |
weather_conditions = ["Sunny", "Cloudy", "Rainy", "Windy", "Snowy"]
|
111 |
response = {"location": location, "forecast": []}
|
112 |
-
|
113 |
for i in range(days):
|
114 |
date = (datetime.now() + timedelta(days=i)).strftime('%Y-%m-%d')
|
115 |
condition = np.random.choice(weather_conditions)
|
@@ -272,23 +274,23 @@ def extract_arguments_hf(user_prompt: str, tool: Tool):
|
|
272 |
chat = [
|
273 |
{"role": "user", "content": f"{system_prompt}\n\nUser Prompt: \"{user_prompt}\""},
|
274 |
]
|
275 |
-
|
276 |
prompt = hf_tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True)
|
277 |
|
278 |
try:
|
279 |
inputs = hf_tokenizer.encode(prompt, add_special_tokens=False, return_tensors="pt").to(hf_model.device)
|
280 |
-
|
281 |
# Generate with the model
|
282 |
outputs = hf_model.generate(input_ids=inputs, max_new_tokens=256, do_sample=False)
|
283 |
decoded_output = hf_tokenizer.decode(outputs[0][len(inputs[0]):], skip_special_tokens=True)
|
284 |
-
|
285 |
# Clean the response to find the JSON object
|
286 |
json_str = decoded_output.strip()
|
287 |
-
|
288 |
# Find the first '{' and the last '}' to get the JSON part
|
289 |
json_start = json_str.find('{')
|
290 |
json_end = json_str.rfind('}')
|
291 |
-
|
292 |
if json_start != -1 and json_end != -1:
|
293 |
json_str = json_str[json_start : json_end + 1]
|
294 |
return json.loads(json_str)
|
@@ -302,7 +304,7 @@ def extract_arguments_hf(user_prompt: str, tool: Tool):
|
|
302 |
def execute_tool(user_prompt: str):
|
303 |
"""The main pipeline: Find tool, extract args, execute."""
|
304 |
selected_tool, score, _ = find_best_tool(user_prompt)
|
305 |
-
|
306 |
if USE_HF_LLM:
|
307 |
print(f"⚙️ Selected Tool: {selected_tool.name}. Extracting arguments with Gemma...")
|
308 |
extracted_args = extract_arguments_hf(user_prompt, selected_tool)
|
@@ -312,12 +314,17 @@ def execute_tool(user_prompt: str):
|
|
312 |
|
313 |
if 'error' in extracted_args:
|
314 |
print(f"❌ Argument extraction failed: {extracted_args['error']}")
|
|
|
|
|
|
|
|
|
|
|
315 |
return (
|
316 |
user_prompt,
|
317 |
selected_tool.name,
|
318 |
f"{score:.3f}",
|
319 |
json.dumps(extracted_args, indent=2),
|
320 |
-
|
321 |
)
|
322 |
|
323 |
print(f"✅ Arguments extracted: {json.dumps(extracted_args, indent=2)}")
|
@@ -349,7 +356,7 @@ def plot_tool_world(user_intent=None):
|
|
349 |
tool_vectors = [tool.embedding.cpu().numpy() for tool in tools]
|
350 |
labels = [tool.name for tool in tools]
|
351 |
all_vectors = tool_vectors
|
352 |
-
|
353 |
if user_intent and user_intent.strip():
|
354 |
intent_vector = embedder.encode(user_intent, convert_to_tensor=True).cpu().numpy()
|
355 |
all_vectors.append(intent_vector)
|
@@ -361,7 +368,7 @@ def plot_tool_world(user_intent=None):
|
|
361 |
n_neighbors = 1
|
362 |
|
363 |
reducer = umap.UMAP(n_neighbors=n_neighbors, min_dist=0.3, metric='cosine', random_state=42)
|
364 |
-
|
365 |
# UMAP fit_transform requires at least 2 samples
|
366 |
if len(all_vectors) < 2:
|
367 |
# Create a dummy plot if there's not enough data
|
@@ -387,7 +394,7 @@ def plot_tool_world(user_intent=None):
|
|
387 |
ax.set_xlabel("UMAP Dimension 1", fontsize=12)
|
388 |
ax.set_ylabel("UMAP Dimension 2", fontsize=12)
|
389 |
ax.grid(True)
|
390 |
-
|
391 |
handles, labels_legend = ax.get_legend_handles_labels()
|
392 |
by_label = dict(zip(labels_legend, handles))
|
393 |
ax.legend(by_label.values(), by_label.keys())
|
@@ -406,7 +413,7 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
|
406 |
gr.Markdown("# 🛠️ Tool World: Advanced Prototype (Hugging Face Version)")
|
407 |
gr.Markdown(
|
408 |
"Enter a natural language command. The system will select the best tool, "
|
409 |
-
"extract structured arguments with **google/gemma-
|
410 |
)
|
411 |
|
412 |
with gr.Row():
|
@@ -417,7 +424,7 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
|
417 |
lines=3
|
418 |
)
|
419 |
run_btn = gr.Button("Invoke Tool", variant="primary")
|
420 |
-
|
421 |
gr.Markdown("---")
|
422 |
gr.Markdown("### Examples")
|
423 |
gr.Examples(
|
@@ -435,7 +442,7 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
|
435 |
with gr.Row():
|
436 |
out_tool = gr.Textbox(label="Selected Tool", interactive=False)
|
437 |
out_score = gr.Textbox(label="Similarity Score", interactive=False)
|
438 |
-
|
439 |
out_args = gr.JSON(label="Extracted Arguments")
|
440 |
out_result = gr.JSON(label="Tool Execution Output")
|
441 |
|
@@ -448,10 +455,10 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
|
448 |
if not user_prompt or not user_prompt.strip():
|
449 |
# Return empty state and the default plot
|
450 |
return "", "", {}, {}, plot_tool_world()
|
451 |
-
|
452 |
prompt, tool_name, score, args_json, result_json = execute_tool(user_prompt)
|
453 |
fig = plot_tool_world(user_prompt)
|
454 |
-
|
455 |
# Safely load JSON strings into objects for the UI
|
456 |
try:
|
457 |
args_obj = json.loads(args_json)
|
@@ -470,7 +477,7 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
|
470 |
inputs=inp,
|
471 |
outputs=[out_tool, out_score, out_args, out_result, plot_output]
|
472 |
)
|
473 |
-
|
474 |
# Load the initial plot when the app starts
|
475 |
demo.load(fn=lambda: plot_tool_world(None), inputs=None, outputs=plot_output)
|
476 |
|
|
|
5 |
# This script has been updated to run as a Hugging Face Space.
|
6 |
#
|
7 |
# Key Upgrades from the original script:
|
8 |
+
# 1. **Hugging Face Model Integration**: Uses the 'google/gemma-3n-E4B' model
|
9 |
+
# from the Hugging Face Hub for argument extraction.
|
10 |
# 2. **Environment Variable Management**: Securely accesses the
|
11 |
# HUGGING_FACE_HUB_TOKEN using os.environ.get(), which is the standard
|
12 |
# for Hugging Face Spaces.
|
|
|
45 |
raise ValueError("HUGGING_FACE_HUB_TOKEN secret not found.")
|
46 |
|
47 |
print("⚙️ Loading Hugging Face model for argument extraction...")
|
48 |
+
# Using the user-specified Gemma 3n model
|
49 |
+
model_id = "google/gemma-3n-E4B"
|
50 |
+
|
51 |
+
hf_tokenizer = AutoTokenizer.from_pretrained(model_id, token=HF_TOKEN)
|
52 |
hf_model = AutoModelForCausalLM.from_pretrained(
|
53 |
+
model_id,
|
54 |
token=HF_TOKEN,
|
55 |
torch_dtype=torch.bfloat16, # Use bfloat16 for efficiency
|
56 |
device_map="auto" # Automatically use GPU if available
|
57 |
)
|
58 |
USE_HF_LLM = True
|
59 |
+
print(f"✅ Successfully loaded '{model_id}' model.")
|
60 |
|
61 |
except Exception as e:
|
62 |
USE_HF_LLM = False
|
|
|
88 |
"""
|
89 |
schema_str = json.dumps(self.args_schema, indent=2)
|
90 |
examples_str = "\n".join([f" - Example: {ex['prompt']} -> Args: {json.dumps(ex['args'])}" for ex in self.examples])
|
91 |
+
|
92 |
embedding_text = (
|
93 |
f"Tool Name: {self.name}\n"
|
94 |
f"Description: {self.description}\n"
|
|
|
108 |
"""Simulates fetching a weather forecast."""
|
109 |
if not isinstance(location, str) or not isinstance(days, int):
|
110 |
return {"error": "Invalid argument types. 'location' must be a string and 'days' an integer."}
|
111 |
+
|
112 |
weather_conditions = ["Sunny", "Cloudy", "Rainy", "Windy", "Snowy"]
|
113 |
response = {"location": location, "forecast": []}
|
114 |
+
|
115 |
for i in range(days):
|
116 |
date = (datetime.now() + timedelta(days=i)).strftime('%Y-%m-%d')
|
117 |
condition = np.random.choice(weather_conditions)
|
|
|
274 |
chat = [
|
275 |
{"role": "user", "content": f"{system_prompt}\n\nUser Prompt: \"{user_prompt}\""},
|
276 |
]
|
277 |
+
|
278 |
prompt = hf_tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True)
|
279 |
|
280 |
try:
|
281 |
inputs = hf_tokenizer.encode(prompt, add_special_tokens=False, return_tensors="pt").to(hf_model.device)
|
282 |
+
|
283 |
# Generate with the model
|
284 |
outputs = hf_model.generate(input_ids=inputs, max_new_tokens=256, do_sample=False)
|
285 |
decoded_output = hf_tokenizer.decode(outputs[0][len(inputs[0]):], skip_special_tokens=True)
|
286 |
+
|
287 |
# Clean the response to find the JSON object
|
288 |
json_str = decoded_output.strip()
|
289 |
+
|
290 |
# Find the first '{' and the last '}' to get the JSON part
|
291 |
json_start = json_str.find('{')
|
292 |
json_end = json_str.rfind('}')
|
293 |
+
|
294 |
if json_start != -1 and json_end != -1:
|
295 |
json_str = json_str[json_start : json_end + 1]
|
296 |
return json.loads(json_str)
|
|
|
304 |
def execute_tool(user_prompt: str):
|
305 |
"""The main pipeline: Find tool, extract args, execute."""
|
306 |
selected_tool, score, _ = find_best_tool(user_prompt)
|
307 |
+
|
308 |
if USE_HF_LLM:
|
309 |
print(f"⚙️ Selected Tool: {selected_tool.name}. Extracting arguments with Gemma...")
|
310 |
extracted_args = extract_arguments_hf(user_prompt, selected_tool)
|
|
|
314 |
|
315 |
if 'error' in extracted_args:
|
316 |
print(f"❌ Argument extraction failed: {extracted_args['error']}")
|
317 |
+
# Ensure the final output string is valid JSON
|
318 |
+
final_output_str = json.dumps({
|
319 |
+
"error": "Execution failed during argument extraction.",
|
320 |
+
"details": extracted_args['error']
|
321 |
+
})
|
322 |
return (
|
323 |
user_prompt,
|
324 |
selected_tool.name,
|
325 |
f"{score:.3f}",
|
326 |
json.dumps(extracted_args, indent=2),
|
327 |
+
final_output_str
|
328 |
)
|
329 |
|
330 |
print(f"✅ Arguments extracted: {json.dumps(extracted_args, indent=2)}")
|
|
|
356 |
tool_vectors = [tool.embedding.cpu().numpy() for tool in tools]
|
357 |
labels = [tool.name for tool in tools]
|
358 |
all_vectors = tool_vectors
|
359 |
+
|
360 |
if user_intent and user_intent.strip():
|
361 |
intent_vector = embedder.encode(user_intent, convert_to_tensor=True).cpu().numpy()
|
362 |
all_vectors.append(intent_vector)
|
|
|
368 |
n_neighbors = 1
|
369 |
|
370 |
reducer = umap.UMAP(n_neighbors=n_neighbors, min_dist=0.3, metric='cosine', random_state=42)
|
371 |
+
|
372 |
# UMAP fit_transform requires at least 2 samples
|
373 |
if len(all_vectors) < 2:
|
374 |
# Create a dummy plot if there's not enough data
|
|
|
394 |
ax.set_xlabel("UMAP Dimension 1", fontsize=12)
|
395 |
ax.set_ylabel("UMAP Dimension 2", fontsize=12)
|
396 |
ax.grid(True)
|
397 |
+
|
398 |
handles, labels_legend = ax.get_legend_handles_labels()
|
399 |
by_label = dict(zip(labels_legend, handles))
|
400 |
ax.legend(by_label.values(), by_label.keys())
|
|
|
413 |
gr.Markdown("# 🛠️ Tool World: Advanced Prototype (Hugging Face Version)")
|
414 |
gr.Markdown(
|
415 |
"Enter a natural language command. The system will select the best tool, "
|
416 |
+
"extract structured arguments with **google/gemma-3n-E4B**, and execute it."
|
417 |
)
|
418 |
|
419 |
with gr.Row():
|
|
|
424 |
lines=3
|
425 |
)
|
426 |
run_btn = gr.Button("Invoke Tool", variant="primary")
|
427 |
+
|
428 |
gr.Markdown("---")
|
429 |
gr.Markdown("### Examples")
|
430 |
gr.Examples(
|
|
|
442 |
with gr.Row():
|
443 |
out_tool = gr.Textbox(label="Selected Tool", interactive=False)
|
444 |
out_score = gr.Textbox(label="Similarity Score", interactive=False)
|
445 |
+
|
446 |
out_args = gr.JSON(label="Extracted Arguments")
|
447 |
out_result = gr.JSON(label="Tool Execution Output")
|
448 |
|
|
|
455 |
if not user_prompt or not user_prompt.strip():
|
456 |
# Return empty state and the default plot
|
457 |
return "", "", {}, {}, plot_tool_world()
|
458 |
+
|
459 |
prompt, tool_name, score, args_json, result_json = execute_tool(user_prompt)
|
460 |
fig = plot_tool_world(user_prompt)
|
461 |
+
|
462 |
# Safely load JSON strings into objects for the UI
|
463 |
try:
|
464 |
args_obj = json.loads(args_json)
|
|
|
477 |
inputs=inp,
|
478 |
outputs=[out_tool, out_score, out_args, out_result, plot_output]
|
479 |
)
|
480 |
+
|
481 |
# Load the initial plot when the app starts
|
482 |
demo.load(fn=lambda: plot_tool_world(None), inputs=None, outputs=plot_output)
|
483 |
|