TuringsSolutions commited on
Commit
dfb63e9
Β·
verified Β·
1 Parent(s): 8b613ad

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +477 -0
app.py ADDED
@@ -0,0 +1,477 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ==============================================================================
2
+ # Tool World: Advanced Prototype (Hugging Face Space Version)
3
+ # ==============================================================================
4
+ #
5
+ # This script has been updated to run as a Hugging Face Space.
6
+ #
7
+ # Key Upgrades from the original script:
8
+ # 1. **Hugging Face Model Integration**: Uses the 'google/gemma-3b-it' model
9
+ # from the Hugging Face Hub for argument extraction, instead of the Gemini API.
10
+ # 2. **Environment Variable Management**: Securely accesses the
11
+ # HUGGING_FACE_HUB_TOKEN using os.environ.get(), which is the standard
12
+ # for Hugging Face Spaces.
13
+ # 3. **Standard Dependencies**: All dependencies are managed via a
14
+ # `requirements.txt` file.
15
+ #
16
+ # ==============================================================================
17
+
18
+ # ------------------------------
19
+ # 1. INSTALL & IMPORT PACKAGES
20
+ # ------------------------------
21
+ import numpy as np
22
+ import umap
23
+ import gradio as gr
24
+ from sentence_transformers import SentenceTransformer, util
25
+ import matplotlib.pyplot as plt
26
+ import json
27
+ import os
28
+ from datetime import datetime, timedelta
29
+ import torch
30
+ from transformers import AutoTokenizer, AutoModelForCausalLM
31
+
32
+ # ------------------------------
33
+ # 2. CONFIGURE & LOAD MODELS
34
+ # ------------------------------
35
+
36
+ print("βš™οΈ Loading embedding model...")
37
+ # Using a powerful model for better semantic understanding
38
+ embedder = SentenceTransformer('all-mpnet-base-v2')
39
+ print("βœ… Embedding model loaded.")
40
+
41
+ # --- Configuration for Hugging Face Model-based Argument Extraction ---
42
+ try:
43
+ HF_TOKEN = os.environ.get('HUGGING_FACE_HUB_TOKEN')
44
+ if HF_TOKEN is None:
45
+ raise ValueError("HUGGING_FACE_HUB_TOKEN secret not found.")
46
+
47
+ print("βš™οΈ Loading Hugging Face model for argument extraction...")
48
+ # Using a smaller, instruction-tuned model for efficient argument extraction
49
+ hf_tokenizer = AutoTokenizer.from_pretrained("google/gemma-3b-it", token=HF_TOKEN)
50
+ hf_model = AutoModelForCausalLM.from_pretrained(
51
+ "google/gemma-3b-it",
52
+ token=HF_TOKEN,
53
+ torch_dtype=torch.bfloat16, # Use bfloat16 for efficiency
54
+ device_map="auto" # Automatically use GPU if available
55
+ )
56
+ USE_HF_LLM = True
57
+ print("βœ… Successfully loaded 'google/gemma-3b-it' model.")
58
+
59
+ except Exception as e:
60
+ USE_HF_LLM = False
61
+ print(f"⚠️ WARNING: Could not load the Hugging Face model. Reason: {e}")
62
+ print(" Argument extraction will be disabled.")
63
+
64
+
65
+ # ------------------------------
66
+ # 3. ADVANCED TOOL DEFINITION
67
+ # ------------------------------
68
+
69
+ class Tool:
70
+ """
71
+ Represents a tool with structured arguments and rich descriptive data
72
+ for high-quality embedding.
73
+ """
74
+ def __init__(self, name, description, args_schema, function, examples=None):
75
+ self.name = name
76
+ self.description = description
77
+ self.args_schema = args_schema
78
+ self.function = function
79
+ self.examples = examples or []
80
+ self.embedding = self._create_embedding()
81
+
82
+ def _create_embedding(self):
83
+ """
84
+ Creates a rich embedding by combining the tool's name, description,
85
+ argument structure, and examples.
86
+ """
87
+ schema_str = json.dumps(self.args_schema, indent=2)
88
+ examples_str = "\n".join([f" - Example: {ex['prompt']} -> Args: {json.dumps(ex['args'])}" for ex in self.examples])
89
+
90
+ embedding_text = (
91
+ f"Tool Name: {self.name}\n"
92
+ f"Description: {self.description}\n"
93
+ f"Argument Schema: {schema_str}\n"
94
+ f"Usage Examples:\n{examples_str}"
95
+ )
96
+ return embedder.encode(embedding_text, convert_to_tensor=True)
97
+
98
+ def __repr__(self):
99
+ return f"<Tool: {self.name}>"
100
+
101
+ # ------------------------------
102
+ # 4. TOOL IMPLEMENTATIONS
103
+ # ------------------------------
104
+
105
+ def get_weather_forecast(location: str, days: int = 1):
106
+ """Simulates fetching a weather forecast."""
107
+ if not isinstance(location, str) or not isinstance(days, int):
108
+ return {"error": "Invalid argument types. 'location' must be a string and 'days' an integer."}
109
+
110
+ weather_conditions = ["Sunny", "Cloudy", "Rainy", "Windy", "Snowy"]
111
+ response = {"location": location, "forecast": []}
112
+
113
+ for i in range(days):
114
+ date = (datetime.now() + timedelta(days=i)).strftime('%Y-%m-%d')
115
+ condition = np.random.choice(weather_conditions)
116
+ temp = np.random.randint(5, 25)
117
+ response["forecast"].append({
118
+ "date": date,
119
+ "condition": condition,
120
+ "temperature_celsius": temp
121
+ })
122
+ return response
123
+
124
+ def create_calendar_event(title: str, date: str, duration_minutes: int = 60, participants: list = None):
125
+ """Simulates creating a calendar event."""
126
+ try:
127
+ event_time = datetime.strptime(date, '%Y-%m-%d %H:%M')
128
+ return {
129
+ "status": "success",
130
+ "event_created": {
131
+ "title": title,
132
+ "start_time": event_time.isoformat(),
133
+ "end_time": (event_time + timedelta(minutes=duration_minutes)).isoformat(),
134
+ "participants": participants or ["organizer"]
135
+ }
136
+ }
137
+ except ValueError:
138
+ return {"error": "Invalid date format. Please use 'YYYY-MM-DD HH:MM'."}
139
+
140
+ def summarize_text(text: str, compression_level: str = 'medium'):
141
+ """Summarizes a given text based on a compression level."""
142
+ word_count = len(text.split())
143
+ ratios = {'high': 0.2, 'medium': 0.4, 'low': 0.7}
144
+ ratio = ratios.get(compression_level, 0.4)
145
+ summary_length = int(word_count * ratio)
146
+ summary = " ".join(text.split()[:summary_length])
147
+ return {"summary": summary + "...", "original_word_count": word_count, "summary_word_count": summary_length}
148
+
149
+ def search_web(query: str, domain: str = None):
150
+ """Simulates a web search, with an optional domain filter."""
151
+ results = [
152
+ f"Simulated result 1 for '{query}'",
153
+ f"Simulated result 2 for '{query}'",
154
+ f"Simulated result 3 for '{query}'"
155
+ ]
156
+ if domain:
157
+ return {"status": f"Searching for '{query}' within '{domain}'...", "results": results}
158
+ return {"status": f"Searching for '{query}'...", "results": results}
159
+
160
+
161
+ # ------------------------------
162
+ # 5. DEFINE THE TOOLSET
163
+ # ------------------------------
164
+
165
+ tools = [
166
+ Tool(
167
+ name="weather_reporter",
168
+ description="Provides the weather forecast for a specific location for a given number of days.",
169
+ args_schema={
170
+ "type": "object",
171
+ "properties": {
172
+ "location": {"type": "string", "description": "The city and state, e.g., 'San Francisco, CA'"},
173
+ "days": {"type": "integer", "description": "The number of days to forecast", "default": 1}
174
+ },
175
+ "required": ["location"]
176
+ },
177
+ function=get_weather_forecast,
178
+ examples=[
179
+ {"prompt": "what's the weather like in London for the next 3 days", "args": {"location": "London", "days": 3}},
180
+ {"prompt": "forecast for New York tomorrow", "args": {"location": "New York", "days": 1}}
181
+ ]
182
+ ),
183
+ Tool(
184
+ name="calendar_creator",
185
+ description="Creates a new event in the user's calendar.",
186
+ args_schema={
187
+ "type": "object",
188
+ "properties": {
189
+ "title": {"type": "string", "description": "The title of the calendar event"},
190
+ "date": {"type": "string", "description": "The start date and time in 'YYYY-MM-DD HH:MM' format"},
191
+ "duration_minutes": {"type": "integer", "description": "The duration of the event in minutes", "default": 60},
192
+ "participants": {"type": "array", "items": {"type": "string"}, "description": "List of email addresses of participants"}
193
+ },
194
+ "required": ["title", "date"]
195
+ },
196
+ function=create_calendar_event,
197
+ examples=[
198
+ {"prompt": "Schedule a 'Project Sync' for tomorrow at 3pm with [email protected]", "args": {"title": "Project Sync", "date": (datetime.now() + timedelta(days=1)).strftime('%Y-%m-%d 15:00'), "participants": ["[email protected]"]}},
199
+ {"prompt": "new event: Dentist appointment on 2025-12-20 at 10:00 for 45 mins", "args": {"title": "Dentist appointment", "date": "2025-12-20 10:00", "duration_minutes": 45}}
200
+ ]
201
+ ),
202
+ Tool(
203
+ name="text_summarizer",
204
+ description="Summarizes a long piece of text. Can be set to high, medium, or low compression.",
205
+ args_schema={
206
+ "type": "object",
207
+ "properties": {
208
+ "text": {"type": "string", "description": "The text to be summarized."},
209
+ "compression_level": {"type": "string", "enum": ["high", "medium", "low"], "description": "The level of summarization.", "default": "medium"}
210
+ },
211
+ "required": ["text"]
212
+ },
213
+ function=summarize_text,
214
+ examples=[
215
+ {"prompt": "summarize this article for me, make it very short: [long text...]", "args": {"text": "[long text...]", "compression_level": "high"}}
216
+ ]
217
+ ),
218
+ Tool(
219
+ name="web_search",
220
+ description="Performs a web search to find information on a topic.",
221
+ args_schema={
222
+ "type": "object",
223
+ "properties": {
224
+ "query": {"type": "string", "description": "The search query."},
225
+ "domain": {"type": "string", "description": "Optional: a specific website domain to search within (e.g., 'wikipedia.org')."}
226
+ },
227
+ "required": ["query"]
228
+ },
229
+ function=search_web,
230
+ examples=[
231
+ {"prompt": "who invented the light bulb", "args": {"query": "who invented the light bulb"}},
232
+ {"prompt": "search for 'transformer models' on arxiv.org", "args": {"query": "transformer models", "domain": "arxiv.org"}}
233
+ ]
234
+ )
235
+ ]
236
+
237
+ print(f"βœ… {len(tools)} tools defined and embedded.")
238
+
239
+ # ------------------------------
240
+ # 6. CORE LOGIC: TOOL SELECTION & ARGUMENT EXTRACTION
241
+ # ------------------------------
242
+
243
+ def find_best_tool(user_intent: str):
244
+ """Finds the most semantically similar tool for a user's intent."""
245
+ intent_embedding = embedder.encode(user_intent, convert_to_tensor=True)
246
+ # Move tool embeddings to the same device as the intent embedding
247
+ tool_embeddings = [tool.embedding.to(intent_embedding.device) for tool in tools]
248
+ similarities = [util.pytorch_cos_sim(intent_embedding, tool_emb).item() for tool_emb in tool_embeddings]
249
+ best_index = int(np.argmax(similarities))
250
+ best_tool = tools[best_index]
251
+ best_score = similarities[best_index]
252
+ return best_tool, best_score, similarities
253
+
254
+ def extract_arguments_hf(user_prompt: str, tool: Tool):
255
+ """
256
+ Uses a local Hugging Face model to extract structured arguments.
257
+ """
258
+ system_prompt = f"""
259
+ You are an expert at extracting structured data from natural language.
260
+ Your task is to analyze the user's prompt and extract the arguments required to call the tool: '{tool.name}'.
261
+
262
+ You must adhere to the following JSON schema for the arguments:
263
+ {json.dumps(tool.args_schema, indent=2)}
264
+
265
+ - If a value is not present in the prompt for a non-required field, omit it from the JSON.
266
+ - If a required value is missing, return a JSON object with an "error" key explaining what is missing.
267
+ - Today's date is {datetime.now().strftime('%Y-%m-%d')}.
268
+ - Respond ONLY with a valid JSON object. Do not include any other text, explanation, or markdown code blocks.
269
+ """
270
+
271
+ # Gemma instruction-following format
272
+ chat = [
273
+ {"role": "user", "content": f"{system_prompt}\n\nUser Prompt: \"{user_prompt}\""},
274
+ ]
275
+
276
+ prompt = hf_tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True)
277
+
278
+ try:
279
+ inputs = hf_tokenizer.encode(prompt, add_special_tokens=False, return_tensors="pt").to(hf_model.device)
280
+
281
+ # Generate with the model
282
+ outputs = hf_model.generate(input_ids=inputs, max_new_tokens=256, do_sample=False)
283
+ decoded_output = hf_tokenizer.decode(outputs[0][len(inputs[0]):], skip_special_tokens=True)
284
+
285
+ # Clean the response to find the JSON object
286
+ json_str = decoded_output.strip()
287
+
288
+ # Find the first '{' and the last '}' to get the JSON part
289
+ json_start = json_str.find('{')
290
+ json_end = json_str.rfind('}')
291
+
292
+ if json_start != -1 and json_end != -1:
293
+ json_str = json_str[json_start : json_end + 1]
294
+ return json.loads(json_str)
295
+ else:
296
+ raise json.JSONDecodeError("No JSON object found in the model output.", json_str, 0)
297
+
298
+ except Exception as e:
299
+ print(f"Error during HF model inference or JSON parsing: {e}")
300
+ return {"error": f"Failed to extract arguments with the local LLM. Details: {str(e)}"}
301
+
302
+ def execute_tool(user_prompt: str):
303
+ """The main pipeline: Find tool, extract args, execute."""
304
+ selected_tool, score, _ = find_best_tool(user_prompt)
305
+
306
+ if USE_HF_LLM:
307
+ print(f"βš™οΈ Selected Tool: {selected_tool.name}. Extracting arguments with Gemma...")
308
+ extracted_args = extract_arguments_hf(user_prompt, selected_tool)
309
+ else:
310
+ # Fallback if the model failed to load
311
+ extracted_args = {"error": "Argument extraction is disabled because the Hugging Face model could not be loaded."}
312
+
313
+ if 'error' in extracted_args:
314
+ print(f"❌ Argument extraction failed: {extracted_args['error']}")
315
+ return (
316
+ user_prompt,
317
+ selected_tool.name,
318
+ f"{score:.3f}",
319
+ json.dumps(extracted_args, indent=2),
320
+ "Execution failed during argument extraction."
321
+ )
322
+
323
+ print(f"βœ… Arguments extracted: {json.dumps(extracted_args, indent=2)}")
324
+
325
+ try:
326
+ print(f"πŸš€ Executing tool function: {selected_tool.name}...")
327
+ output = selected_tool.function(**extracted_args)
328
+ print(f"βœ… Execution complete.")
329
+ output_str = json.dumps(output, indent=2)
330
+ except Exception as e:
331
+ print(f"❌ Tool execution failed: {e}")
332
+ output_str = f'{{"error": "Tool execution failed", "details": "{str(e)}"}}'
333
+
334
+ return (
335
+ user_prompt,
336
+ selected_tool.name,
337
+ f"{score:.3f}",
338
+ json.dumps(extracted_args, indent=2),
339
+ output_str
340
+ )
341
+
342
+
343
+ # ------------------------------
344
+ # 7. VISUALIZATION
345
+ # ------------------------------
346
+
347
+ def plot_tool_world(user_intent=None):
348
+ """Generates a 2D UMAP plot of the tool latent space."""
349
+ tool_vectors = [tool.embedding.cpu().numpy() for tool in tools]
350
+ labels = [tool.name for tool in tools]
351
+ all_vectors = tool_vectors
352
+
353
+ if user_intent and user_intent.strip():
354
+ intent_vector = embedder.encode(user_intent, convert_to_tensor=True).cpu().numpy()
355
+ all_vectors.append(intent_vector)
356
+ labels.append("Your Intent")
357
+
358
+ # UMAP requires at least 2 neighbors
359
+ n_neighbors = min(len(all_vectors) - 1, 5)
360
+ if n_neighbors < 1:
361
+ n_neighbors = 1
362
+
363
+ reducer = umap.UMAP(n_neighbors=n_neighbors, min_dist=0.3, metric='cosine', random_state=42)
364
+
365
+ # UMAP fit_transform requires at least 2 samples
366
+ if len(all_vectors) < 2:
367
+ # Create a dummy plot if there's not enough data
368
+ fig, ax = plt.subplots(figsize=(10, 7))
369
+ ax.text(0.5, 0.5, "Not enough data to create a plot.", ha='center', va='center')
370
+ return fig
371
+
372
+ reduced_vectors = reducer.fit_transform(all_vectors)
373
+
374
+ plt.style.use('seaborn-v0_8-whitegrid')
375
+ fig, ax = plt.subplots(figsize=(10, 7))
376
+
377
+ for i, label in enumerate(labels):
378
+ x, y = reduced_vectors[i]
379
+ if label == "Your Intent":
380
+ ax.scatter(x, y, color='red', s=150, zorder=5, label=label, marker='*')
381
+ ax.text(x, y + 0.05, label, fontsize=12, ha='center', color='red', weight='bold')
382
+ else:
383
+ ax.scatter(x, y, s=100, alpha=0.8, label=label)
384
+ ax.text(x, y + 0.05, label, fontsize=10, ha='center')
385
+
386
+ ax.set_title("Tool World: Latent Space Map", fontsize=16)
387
+ ax.set_xlabel("UMAP Dimension 1", fontsize=12)
388
+ ax.set_ylabel("UMAP Dimension 2", fontsize=12)
389
+ ax.grid(True)
390
+
391
+ handles, labels_legend = ax.get_legend_handles_labels()
392
+ by_label = dict(zip(labels_legend, handles))
393
+ ax.legend(by_label.values(), by_label.keys())
394
+
395
+ plt.tight_layout()
396
+ return fig
397
+
398
+
399
+ # ------------------------------
400
+ # 8. GRADIO INTERFACE
401
+ # ------------------------------
402
+
403
+ print("πŸš€ Launching Gradio interface...")
404
+
405
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
406
+ gr.Markdown("# πŸ› οΈ Tool World: Advanced Prototype (Hugging Face Version)")
407
+ gr.Markdown(
408
+ "Enter a natural language command. The system will select the best tool, "
409
+ "extract structured arguments with **google/gemma-3b-it**, and execute it."
410
+ )
411
+
412
+ with gr.Row():
413
+ with gr.Column(scale=1):
414
+ inp = gr.Textbox(
415
+ label="Your Intent",
416
+ placeholder="e.g., What's the weather in Paris for 2 days?",
417
+ lines=3
418
+ )
419
+ run_btn = gr.Button("Invoke Tool", variant="primary")
420
+
421
+ gr.Markdown("---")
422
+ gr.Markdown("### Examples")
423
+ gr.Examples(
424
+ examples=[
425
+ "Schedule a 'Team Meeting' for tomorrow at 10:30 am",
426
+ "What is the weather forecast in Tokyo for the next 5 days?",
427
+ "search for the latest news on generative AI on reuters.com",
428
+ "Please give me a very short summary of this text: The Industrial Revolution was the transition to new manufacturing processes in Europe and the United States, in the period from about 1760 to sometime between 1820 and 1840."
429
+ ],
430
+ inputs=inp
431
+ )
432
+
433
+ with gr.Column(scale=2):
434
+ gr.Markdown("### Invocation Details")
435
+ with gr.Row():
436
+ out_tool = gr.Textbox(label="Selected Tool", interactive=False)
437
+ out_score = gr.Textbox(label="Similarity Score", interactive=False)
438
+
439
+ out_args = gr.JSON(label="Extracted Arguments")
440
+ out_result = gr.JSON(label="Tool Execution Output")
441
+
442
+ with gr.Row():
443
+ gr.Markdown("---")
444
+ gr.Markdown("### Latent Space Visualization")
445
+ plot_output = gr.Plot(label="Tool World Map")
446
+
447
+ def process_and_plot(user_prompt):
448
+ if not user_prompt or not user_prompt.strip():
449
+ # Return empty state and the default plot
450
+ return "", "", {}, {}, plot_tool_world()
451
+
452
+ prompt, tool_name, score, args_json, result_json = execute_tool(user_prompt)
453
+ fig = plot_tool_world(user_prompt)
454
+
455
+ # Safely load JSON strings into objects for the UI
456
+ try:
457
+ args_obj = json.loads(args_json)
458
+ except (json.JSONDecodeError, TypeError):
459
+ args_obj = {"error": "Invalid JSON in arguments", "raw": args_json}
460
+
461
+ try:
462
+ result_obj = json.loads(result_json)
463
+ except (json.JSONDecodeError, TypeError):
464
+ result_obj = {"error": "Invalid JSON in result", "raw": result_json}
465
+
466
+ return tool_name, score, args_obj, result_obj, fig
467
+
468
+ run_btn.click(
469
+ fn=process_and_plot,
470
+ inputs=inp,
471
+ outputs=[out_tool, out_score, out_args, out_result, plot_output]
472
+ )
473
+
474
+ # Load the initial plot when the app starts
475
+ demo.load(fn=lambda: plot_tool_world(None), inputs=None, outputs=plot_output)
476
+
477
+ demo.launch()