TuringsSolutions commited on
Commit
61c80fb
·
verified ·
1 Parent(s): dfb63e9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +32 -25
app.py CHANGED
@@ -5,8 +5,8 @@
5
  # This script has been updated to run as a Hugging Face Space.
6
  #
7
  # Key Upgrades from the original script:
8
- # 1. **Hugging Face Model Integration**: Uses the 'google/gemma-3b-it' model
9
- # from the Hugging Face Hub for argument extraction, instead of the Gemini API.
10
  # 2. **Environment Variable Management**: Securely accesses the
11
  # HUGGING_FACE_HUB_TOKEN using os.environ.get(), which is the standard
12
  # for Hugging Face Spaces.
@@ -45,16 +45,18 @@ try:
45
  raise ValueError("HUGGING_FACE_HUB_TOKEN secret not found.")
46
 
47
  print("⚙️ Loading Hugging Face model for argument extraction...")
48
- # Using a smaller, instruction-tuned model for efficient argument extraction
49
- hf_tokenizer = AutoTokenizer.from_pretrained("google/gemma-3b-it", token=HF_TOKEN)
 
 
50
  hf_model = AutoModelForCausalLM.from_pretrained(
51
- "google/gemma-3b-it",
52
  token=HF_TOKEN,
53
  torch_dtype=torch.bfloat16, # Use bfloat16 for efficiency
54
  device_map="auto" # Automatically use GPU if available
55
  )
56
  USE_HF_LLM = True
57
- print("✅ Successfully loaded 'google/gemma-3b-it' model.")
58
 
59
  except Exception as e:
60
  USE_HF_LLM = False
@@ -86,7 +88,7 @@ class Tool:
86
  """
87
  schema_str = json.dumps(self.args_schema, indent=2)
88
  examples_str = "\n".join([f" - Example: {ex['prompt']} -> Args: {json.dumps(ex['args'])}" for ex in self.examples])
89
-
90
  embedding_text = (
91
  f"Tool Name: {self.name}\n"
92
  f"Description: {self.description}\n"
@@ -106,10 +108,10 @@ def get_weather_forecast(location: str, days: int = 1):
106
  """Simulates fetching a weather forecast."""
107
  if not isinstance(location, str) or not isinstance(days, int):
108
  return {"error": "Invalid argument types. 'location' must be a string and 'days' an integer."}
109
-
110
  weather_conditions = ["Sunny", "Cloudy", "Rainy", "Windy", "Snowy"]
111
  response = {"location": location, "forecast": []}
112
-
113
  for i in range(days):
114
  date = (datetime.now() + timedelta(days=i)).strftime('%Y-%m-%d')
115
  condition = np.random.choice(weather_conditions)
@@ -272,23 +274,23 @@ def extract_arguments_hf(user_prompt: str, tool: Tool):
272
  chat = [
273
  {"role": "user", "content": f"{system_prompt}\n\nUser Prompt: \"{user_prompt}\""},
274
  ]
275
-
276
  prompt = hf_tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True)
277
 
278
  try:
279
  inputs = hf_tokenizer.encode(prompt, add_special_tokens=False, return_tensors="pt").to(hf_model.device)
280
-
281
  # Generate with the model
282
  outputs = hf_model.generate(input_ids=inputs, max_new_tokens=256, do_sample=False)
283
  decoded_output = hf_tokenizer.decode(outputs[0][len(inputs[0]):], skip_special_tokens=True)
284
-
285
  # Clean the response to find the JSON object
286
  json_str = decoded_output.strip()
287
-
288
  # Find the first '{' and the last '}' to get the JSON part
289
  json_start = json_str.find('{')
290
  json_end = json_str.rfind('}')
291
-
292
  if json_start != -1 and json_end != -1:
293
  json_str = json_str[json_start : json_end + 1]
294
  return json.loads(json_str)
@@ -302,7 +304,7 @@ def extract_arguments_hf(user_prompt: str, tool: Tool):
302
  def execute_tool(user_prompt: str):
303
  """The main pipeline: Find tool, extract args, execute."""
304
  selected_tool, score, _ = find_best_tool(user_prompt)
305
-
306
  if USE_HF_LLM:
307
  print(f"⚙️ Selected Tool: {selected_tool.name}. Extracting arguments with Gemma...")
308
  extracted_args = extract_arguments_hf(user_prompt, selected_tool)
@@ -312,12 +314,17 @@ def execute_tool(user_prompt: str):
312
 
313
  if 'error' in extracted_args:
314
  print(f"❌ Argument extraction failed: {extracted_args['error']}")
 
 
 
 
 
315
  return (
316
  user_prompt,
317
  selected_tool.name,
318
  f"{score:.3f}",
319
  json.dumps(extracted_args, indent=2),
320
- "Execution failed during argument extraction."
321
  )
322
 
323
  print(f"✅ Arguments extracted: {json.dumps(extracted_args, indent=2)}")
@@ -349,7 +356,7 @@ def plot_tool_world(user_intent=None):
349
  tool_vectors = [tool.embedding.cpu().numpy() for tool in tools]
350
  labels = [tool.name for tool in tools]
351
  all_vectors = tool_vectors
352
-
353
  if user_intent and user_intent.strip():
354
  intent_vector = embedder.encode(user_intent, convert_to_tensor=True).cpu().numpy()
355
  all_vectors.append(intent_vector)
@@ -361,7 +368,7 @@ def plot_tool_world(user_intent=None):
361
  n_neighbors = 1
362
 
363
  reducer = umap.UMAP(n_neighbors=n_neighbors, min_dist=0.3, metric='cosine', random_state=42)
364
-
365
  # UMAP fit_transform requires at least 2 samples
366
  if len(all_vectors) < 2:
367
  # Create a dummy plot if there's not enough data
@@ -387,7 +394,7 @@ def plot_tool_world(user_intent=None):
387
  ax.set_xlabel("UMAP Dimension 1", fontsize=12)
388
  ax.set_ylabel("UMAP Dimension 2", fontsize=12)
389
  ax.grid(True)
390
-
391
  handles, labels_legend = ax.get_legend_handles_labels()
392
  by_label = dict(zip(labels_legend, handles))
393
  ax.legend(by_label.values(), by_label.keys())
@@ -406,7 +413,7 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
406
  gr.Markdown("# 🛠️ Tool World: Advanced Prototype (Hugging Face Version)")
407
  gr.Markdown(
408
  "Enter a natural language command. The system will select the best tool, "
409
- "extract structured arguments with **google/gemma-3b-it**, and execute it."
410
  )
411
 
412
  with gr.Row():
@@ -417,7 +424,7 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
417
  lines=3
418
  )
419
  run_btn = gr.Button("Invoke Tool", variant="primary")
420
-
421
  gr.Markdown("---")
422
  gr.Markdown("### Examples")
423
  gr.Examples(
@@ -435,7 +442,7 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
435
  with gr.Row():
436
  out_tool = gr.Textbox(label="Selected Tool", interactive=False)
437
  out_score = gr.Textbox(label="Similarity Score", interactive=False)
438
-
439
  out_args = gr.JSON(label="Extracted Arguments")
440
  out_result = gr.JSON(label="Tool Execution Output")
441
 
@@ -448,10 +455,10 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
448
  if not user_prompt or not user_prompt.strip():
449
  # Return empty state and the default plot
450
  return "", "", {}, {}, plot_tool_world()
451
-
452
  prompt, tool_name, score, args_json, result_json = execute_tool(user_prompt)
453
  fig = plot_tool_world(user_prompt)
454
-
455
  # Safely load JSON strings into objects for the UI
456
  try:
457
  args_obj = json.loads(args_json)
@@ -470,7 +477,7 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
470
  inputs=inp,
471
  outputs=[out_tool, out_score, out_args, out_result, plot_output]
472
  )
473
-
474
  # Load the initial plot when the app starts
475
  demo.load(fn=lambda: plot_tool_world(None), inputs=None, outputs=plot_output)
476
 
 
5
  # This script has been updated to run as a Hugging Face Space.
6
  #
7
  # Key Upgrades from the original script:
8
+ # 1. **Hugging Face Model Integration**: Uses the 'google/gemma-3n-E4B' model
9
+ # from the Hugging Face Hub for argument extraction.
10
  # 2. **Environment Variable Management**: Securely accesses the
11
  # HUGGING_FACE_HUB_TOKEN using os.environ.get(), which is the standard
12
  # for Hugging Face Spaces.
 
45
  raise ValueError("HUGGING_FACE_HUB_TOKEN secret not found.")
46
 
47
  print("⚙️ Loading Hugging Face model for argument extraction...")
48
+ # Using the user-specified Gemma 3n model
49
+ model_id = "google/gemma-3n-E4B"
50
+
51
+ hf_tokenizer = AutoTokenizer.from_pretrained(model_id, token=HF_TOKEN)
52
  hf_model = AutoModelForCausalLM.from_pretrained(
53
+ model_id,
54
  token=HF_TOKEN,
55
  torch_dtype=torch.bfloat16, # Use bfloat16 for efficiency
56
  device_map="auto" # Automatically use GPU if available
57
  )
58
  USE_HF_LLM = True
59
+ print(f"✅ Successfully loaded '{model_id}' model.")
60
 
61
  except Exception as e:
62
  USE_HF_LLM = False
 
88
  """
89
  schema_str = json.dumps(self.args_schema, indent=2)
90
  examples_str = "\n".join([f" - Example: {ex['prompt']} -> Args: {json.dumps(ex['args'])}" for ex in self.examples])
91
+
92
  embedding_text = (
93
  f"Tool Name: {self.name}\n"
94
  f"Description: {self.description}\n"
 
108
  """Simulates fetching a weather forecast."""
109
  if not isinstance(location, str) or not isinstance(days, int):
110
  return {"error": "Invalid argument types. 'location' must be a string and 'days' an integer."}
111
+
112
  weather_conditions = ["Sunny", "Cloudy", "Rainy", "Windy", "Snowy"]
113
  response = {"location": location, "forecast": []}
114
+
115
  for i in range(days):
116
  date = (datetime.now() + timedelta(days=i)).strftime('%Y-%m-%d')
117
  condition = np.random.choice(weather_conditions)
 
274
  chat = [
275
  {"role": "user", "content": f"{system_prompt}\n\nUser Prompt: \"{user_prompt}\""},
276
  ]
277
+
278
  prompt = hf_tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True)
279
 
280
  try:
281
  inputs = hf_tokenizer.encode(prompt, add_special_tokens=False, return_tensors="pt").to(hf_model.device)
282
+
283
  # Generate with the model
284
  outputs = hf_model.generate(input_ids=inputs, max_new_tokens=256, do_sample=False)
285
  decoded_output = hf_tokenizer.decode(outputs[0][len(inputs[0]):], skip_special_tokens=True)
286
+
287
  # Clean the response to find the JSON object
288
  json_str = decoded_output.strip()
289
+
290
  # Find the first '{' and the last '}' to get the JSON part
291
  json_start = json_str.find('{')
292
  json_end = json_str.rfind('}')
293
+
294
  if json_start != -1 and json_end != -1:
295
  json_str = json_str[json_start : json_end + 1]
296
  return json.loads(json_str)
 
304
  def execute_tool(user_prompt: str):
305
  """The main pipeline: Find tool, extract args, execute."""
306
  selected_tool, score, _ = find_best_tool(user_prompt)
307
+
308
  if USE_HF_LLM:
309
  print(f"⚙️ Selected Tool: {selected_tool.name}. Extracting arguments with Gemma...")
310
  extracted_args = extract_arguments_hf(user_prompt, selected_tool)
 
314
 
315
  if 'error' in extracted_args:
316
  print(f"❌ Argument extraction failed: {extracted_args['error']}")
317
+ # Ensure the final output string is valid JSON
318
+ final_output_str = json.dumps({
319
+ "error": "Execution failed during argument extraction.",
320
+ "details": extracted_args['error']
321
+ })
322
  return (
323
  user_prompt,
324
  selected_tool.name,
325
  f"{score:.3f}",
326
  json.dumps(extracted_args, indent=2),
327
+ final_output_str
328
  )
329
 
330
  print(f"✅ Arguments extracted: {json.dumps(extracted_args, indent=2)}")
 
356
  tool_vectors = [tool.embedding.cpu().numpy() for tool in tools]
357
  labels = [tool.name for tool in tools]
358
  all_vectors = tool_vectors
359
+
360
  if user_intent and user_intent.strip():
361
  intent_vector = embedder.encode(user_intent, convert_to_tensor=True).cpu().numpy()
362
  all_vectors.append(intent_vector)
 
368
  n_neighbors = 1
369
 
370
  reducer = umap.UMAP(n_neighbors=n_neighbors, min_dist=0.3, metric='cosine', random_state=42)
371
+
372
  # UMAP fit_transform requires at least 2 samples
373
  if len(all_vectors) < 2:
374
  # Create a dummy plot if there's not enough data
 
394
  ax.set_xlabel("UMAP Dimension 1", fontsize=12)
395
  ax.set_ylabel("UMAP Dimension 2", fontsize=12)
396
  ax.grid(True)
397
+
398
  handles, labels_legend = ax.get_legend_handles_labels()
399
  by_label = dict(zip(labels_legend, handles))
400
  ax.legend(by_label.values(), by_label.keys())
 
413
  gr.Markdown("# 🛠️ Tool World: Advanced Prototype (Hugging Face Version)")
414
  gr.Markdown(
415
  "Enter a natural language command. The system will select the best tool, "
416
+ "extract structured arguments with **google/gemma-3n-E4B**, and execute it."
417
  )
418
 
419
  with gr.Row():
 
424
  lines=3
425
  )
426
  run_btn = gr.Button("Invoke Tool", variant="primary")
427
+
428
  gr.Markdown("---")
429
  gr.Markdown("### Examples")
430
  gr.Examples(
 
442
  with gr.Row():
443
  out_tool = gr.Textbox(label="Selected Tool", interactive=False)
444
  out_score = gr.Textbox(label="Similarity Score", interactive=False)
445
+
446
  out_args = gr.JSON(label="Extracted Arguments")
447
  out_result = gr.JSON(label="Tool Execution Output")
448
 
 
455
  if not user_prompt or not user_prompt.strip():
456
  # Return empty state and the default plot
457
  return "", "", {}, {}, plot_tool_world()
458
+
459
  prompt, tool_name, score, args_json, result_json = execute_tool(user_prompt)
460
  fig = plot_tool_world(user_prompt)
461
+
462
  # Safely load JSON strings into objects for the UI
463
  try:
464
  args_obj = json.loads(args_json)
 
477
  inputs=inp,
478
  outputs=[out_tool, out_score, out_args, out_result, plot_output]
479
  )
480
+
481
  # Load the initial plot when the app starts
482
  demo.load(fn=lambda: plot_tool_world(None), inputs=None, outputs=plot_output)
483