File size: 18,962 Bytes
deafbd7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0cc5955
deafbd7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0cc5955
deafbd7
 
 
 
 
 
 
 
 
 
 
 
0cc5955
deafbd7
0cc5955
 
 
deafbd7
 
 
 
 
 
 
 
 
 
 
 
 
 
0cc5955
deafbd7
0cc5955
deafbd7
 
 
 
 
 
 
 
 
 
 
 
 
0cc5955
 
 
deafbd7
0cc5955
deafbd7
 
0cc5955
 
 
deafbd7
 
 
0cc5955
 
 
 
 
 
 
 
 
deafbd7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0cc5955
 
 
 
 
deafbd7
 
0cc5955
 
deafbd7
 
 
0cc5955
deafbd7
 
0cc5955
 
 
deafbd7
0cc5955
 
 
 
 
 
deafbd7
0cc5955
deafbd7
 
 
 
 
 
 
 
 
 
 
 
 
 
0cc5955
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
deafbd7
0cc5955
 
 
deafbd7
0cc5955
deafbd7
0cc5955
 
 
 
 
 
 
 
 
 
deafbd7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0cc5955
 
deafbd7
 
0cc5955
deafbd7
0cc5955
deafbd7
 
0cc5955
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
deafbd7
 
 
 
0cc5955
 
 
 
 
deafbd7
0cc5955
 
 
 
 
 
 
 
 
 
 
 
 
 
deafbd7
 
 
 
0cc5955
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
deafbd7
 
 
0cc5955
 
 
 
 
 
 
 
 
 
 
 
 
deafbd7
0cc5955
 
 
deafbd7
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
#!/usr/bin/env python
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import mimetypes
import os
import re
import shutil
from typing import Optional

from smolagents.agent_types import AgentAudio, AgentImage, AgentText, handle_agent_output_types
from smolagents.agents import ActionStep, MultiStepAgent # Ensure MultiStepAgent is correctly referenced
from smolagents.memory import MemoryStep
from smolagents.utils import _is_package_available


def pull_messages_from_step(
    step_log: MemoryStep,
):
    """Extract ChatMessage objects from agent steps with proper nesting"""
    import gradio as gr

    if isinstance(step_log, ActionStep):
        # Output the step number
        step_number = f"Step {step_log.step_number}" if step_log.step_number is not None else ""
        yield gr.ChatMessage(role="assistant", content=f"**{step_number}**")

        # First yield the thought/reasoning from the LLM
        if hasattr(step_log, "model_output") and step_log.model_output is not None:
            # Clean up the LLM output
            model_output = step_log.model_output.strip()
            # Remove any trailing <end_code> and extra backticks, handling multiple possible formats
            model_output = re.sub(r"```\s*<end_code>", "```", model_output)  # handles ```<end_code>
            model_output = re.sub(r"<end_code>\s*```", "```", model_output)  # handles <end_code>```
            model_output = re.sub(r"```\s*\n\s*<end_code>", "```", model_output)  # handles ```\n<end_code>
            model_output = model_output.strip()
            yield gr.ChatMessage(role="assistant", content=model_output)

        # For tool calls, create a parent message
        if hasattr(step_log, "tool_calls") and step_log.tool_calls is not None:
            first_tool_call = step_log.tool_calls[0]
            used_code = first_tool_call.name == "python_interpreter"
            parent_id = f"call_{len(step_log.tool_calls)}_{step_log.step_number or 'x'}" # Make parent_id more unique

            # Tool call becomes the parent message with timing info
            args = first_tool_call.arguments
            if isinstance(args, dict):
                content = str(args.get("answer", str(args)))
            else:
                content = str(args).strip()

            if used_code:
                content = re.sub(r"```.*?\n", "", content)  # Remove existing code blocks
                content = re.sub(r"\s*<end_code>\s*", "", content)  # Remove end_code tags
                content = content.strip()
                if not content.startswith("```python"): # Ensure it's a python block
                    content = f"```python\n{content}\n```"
                else: # If it is, ensure newlines are correct
                    content = content.replace("```python", "```python\n").replace("\n```", "\n```")


            parent_message_tool = gr.ChatMessage(
                role="assistant",
                content=content,
                metadata={
                    "title": f"🛠️ Used tool {first_tool_call.name}",
                    "id": parent_id,
                    "status": "pending",
                },
            )
            yield parent_message_tool

            if hasattr(step_log, "observations") and (
                step_log.observations is not None and step_log.observations.strip()
            ):
                log_content = step_log.observations.strip()
                if log_content: # Only yield if there's actual content
                    log_content = re.sub(r"^Execution logs:\s*", "", log_content)
                    yield gr.ChatMessage(
                        role="assistant",
                        content=f"{log_content}",
                        metadata={"title": "📝 Execution Logs", "parent_id": parent_id, "status": "done"},
                    )

            if hasattr(step_log, "error") and step_log.error is not None:
                yield gr.ChatMessage(
                    role="assistant",
                    content=str(step_log.error),
                    metadata={"title": "💥 Error", "parent_id": parent_id, "status": "done"},
                )
            # This direct update might not work as expected as yield creates new objects.
            # Status update is visual; actual logic might be more complex.
            parent_message_tool.metadata["status"] = "done" 

        elif hasattr(step_log, "error") and step_log.error is not None: # Standalone errors
            yield gr.ChatMessage(role="assistant", content=str(step_log.error), metadata={"title": "💥 Error"})

        step_footnote_parts = [step_number]
        if hasattr(step_log, "input_token_count") and step_log.input_token_count is not None and \
           hasattr(step_log, "output_token_count") and step_log.output_token_count is not None:
            token_str = (
                f" | Input-tokens:{step_log.input_token_count:,} | Output-tokens:{step_log.output_token_count:,}"
            )
            step_footnote_parts.append(token_str)
        if hasattr(step_log, "duration") and step_log.duration is not None:
            step_duration = f" | Duration: {round(float(step_log.duration), 2)}s"
            step_footnote_parts.append(step_duration)
        
        step_footnote_text = "".join(filter(None, step_footnote_parts))
        if step_footnote_text:
             step_footnote = f"""<span style="color: #bbbbc2; font-size: 12px;">{step_footnote_text}</span> """
             yield gr.ChatMessage(role="assistant", content=f"{step_footnote}")
        yield gr.ChatMessage(role="assistant", content="-----")


def stream_to_gradio(
    agent,
    task: str,
    reset_agent_memory: bool = False,
    additional_args: Optional[dict] = None,
):
    """Runs an agent with the given task and streams the messages from the agent as gradio ChatMessages."""
    if not _is_package_available("gradio"):
        raise ModuleNotFoundError(
            "Please install 'gradio' extra to use the GradioUI: `pip install 'smolagents[gradio]'`"
        )
    import gradio as gr

    # Reset interaction logs for the new run if the agent has this attribute
    if hasattr(agent, 'interaction_logs'):
        agent.interaction_logs.clear()
        print("DEBUG: Cleared agent interaction_logs for new run.")


    for step_log in agent.run(task, stream=True, reset=reset_agent_memory, additional_args=additional_args):
        if hasattr(agent.model, "last_input_token_count") and agent.model.last_input_token_count is not None: # Check for None
            if isinstance(step_log, ActionStep): # Only add token counts to ActionSteps
                step_log.input_token_count = agent.model.last_input_token_count
                step_log.output_token_count = agent.model.last_output_token_count

        for message in pull_messages_from_step(step_log):
            yield message

    # After the loop, step_log holds the final answer or the last step's log
    final_answer_content = step_log 
    final_answer_processed = handle_agent_output_types(final_answer_content)

    if isinstance(final_answer_processed, AgentText):
        yield gr.ChatMessage(role="assistant", content=f"**Final answer:**\n{final_answer_processed.to_string()}\n")
    elif isinstance(final_answer_processed, AgentImage):
        yield gr.ChatMessage(role="assistant", content={"path": final_answer_processed.to_string(), "mime_type": "image/png"})
    elif isinstance(final_answer_processed, AgentAudio):
        yield gr.ChatMessage(role="assistant", content={"path": final_answer_processed.to_string(), "mime_type": "audio/wav"})
    else:
        yield gr.ChatMessage(role="assistant", content=f"**Final answer:** {str(final_answer_processed)}")


class GradioUI:
    """A one-line interface to launch your agent in Gradio"""

    def __init__(self, agent: MultiStepAgent, file_upload_folder: str | None = None):
        if not _is_package_available("gradio"):
            raise ModuleNotFoundError(
                "Please install 'gradio' extra to use the GradioUI: `pip install 'smolagents[gradio]'`"
            )
        self.agent = agent
        self.file_upload_folder = file_upload_folder
        if self.file_upload_folder is not None:
            if not os.path.exists(file_upload_folder):
                os.makedirs(self.file_upload_folder, exist_ok=True) # Use makedirs
        
        self._latest_file_path_for_download = None # For download button state

    def _check_for_created_file(self):
        """Helper function to check interaction logs for a created file path."""
        self._latest_file_path_for_download = None # Reset
        if hasattr(self.agent, 'interaction_logs') and self.agent.interaction_logs:
            print(f"DEBUG UI: Checking {len(self.agent.interaction_logs)} interaction log entries.")
            for log_entry in self.agent.interaction_logs:
                if log_entry.get("tool_name") == "create_document":
                    tool_output_value = log_entry.get("tool_output")
                    print(f"DEBUG UI: Log for 'create_document', output: {tool_output_value}")
                    if tool_output_value and isinstance(tool_output_value, str):
                        if not tool_output_value.strip().startswith("ERROR:"):
                            normalized_path = os.path.normpath(tool_output_value)
                            if os.path.exists(normalized_path):
                                self._latest_file_path_for_download = normalized_path
                                print(f"DEBUG UI: File path for download set: {self._latest_file_path_for_download}")
                                return True # Found a valid file
                            else:
                                print(f"DEBUG UI: Path from log ('{normalized_path}') does not exist.")
                        else:
                            print(f"DEBUG UI: 'create_document' tool reported error: {tool_output_value}")
        return False


    def interact_with_agent(self, prompt, messages_history, download_btn_state, file_output_state):
        import gradio as gr
        
        messages_history.append(gr.ChatMessage(role="user", content=prompt))
        yield messages_history, gr.update(visible=False), gr.update(value=None, visible=False) # Hide download items initially

        # Stream agent messages to chatbot
        for msg in stream_to_gradio(self.agent, task=prompt, reset_agent_memory=False):
            messages_history.append(msg)
            yield messages_history, gr.update(visible=False), gr.update(value=None, visible=False) # Keep hidden during streaming

        # After streaming all agent messages, check for created file
        file_found = self._check_for_created_file()
        
        # Update UI based on whether a file was found
        # Yielding final state for chatbot, download button, and file component
        yield messages_history, gr.update(visible=file_found), gr.update(value=None, visible=False)


    def upload_file(
        self,
        file,
        file_uploads_log,
        allowed_file_types=[
            "application/pdf",
            "application/vnd.openxmlformats-officedocument.wordprocessingml.document",
            "text/plain",
        ],
    ):
        import gradio as gr

        if file is None:
            return gr.Textbox("No file uploaded", visible=True), file_uploads_log

        try:
            mime_type, _ = mimetypes.guess_type(file.name)
            if mime_type is None: # Fallback if guess_type returns None
                mime_type = file.type # Gradio File object has a 'type' attribute
        except Exception as e:
            return gr.Textbox(f"Error: {e}", visible=True), file_uploads_log
        
        if mime_type not in allowed_file_types:
            return gr.Textbox(f"File type '{mime_type}' disallowed", visible=True), file_uploads_log

        original_name = os.path.basename(file.name)
        sanitized_name = re.sub(r"[^\w\-.]", "_", original_name)
        
        # Ensure correct extension based on mime type, if possible
        base_name, current_ext = os.path.splitext(sanitized_name)
        
        type_to_ext_map = {
            "application/pdf": ".pdf",
            "application/vnd.openxmlformats-officedocument.wordprocessingml.document": ".docx",
            "text/plain": ".txt",
        }
        expected_ext = type_to_ext_map.get(mime_type)
        if expected_ext and current_ext.lower() != expected_ext:
            sanitized_name = base_name + expected_ext
        
        file_path = os.path.join(self.file_upload_folder, sanitized_name)
        shutil.copy(file.name, file_path) # file.name is the temp path of the uploaded file

        return gr.Textbox(f"File uploaded: {file_path}", visible=True), file_uploads_log + [file_path]

    def log_user_message(self, text_input, file_uploads_log):
        # This function prepares the prompt that goes to the agent.
        # It also clears the text_input box.
        full_prompt = text_input
        if file_uploads_log: # Check if list is not empty
            full_prompt += (
                f"\nYou have been provided with these files, which might be helpful or not: {file_uploads_log}"
            )
        return full_prompt, "" # Return the full prompt and an empty string to clear input

    def prepare_and_show_download_file(self):
        import gradio as gr
        if self._latest_file_path_for_download and os.path.exists(self._latest_file_path_for_download):
            print(f"DEBUG UI: Preparing download for UI: {self._latest_file_path_for_download}")
            return gr.File.update(value=self._latest_file_path_for_download, 
                                  label=os.path.basename(self._latest_file_path_for_download), 
                                  visible=True)
        else:
            print("DEBUG UI: No valid file path to prepare for download component.")
            gr.Warning("No file available for download or path is invalid.")
            return gr.File.update(visible=False)

    def launch(self, **kwargs):
        import gradio as gr

        with gr.Blocks(fill_height=True, theme=gr.themes.Soft()) as demo: # Added a theme
            # --- State Variables ---
            # stored_messages is used to build the prompt for the agent, not directly for chatbot display here.
            # The chatbot takes messages directly from interact_with_agent.
            # We'll use chat_history_state for the chatbot's message list.
            chat_history_state = gr.State([]) 
            file_uploads_log = gr.State([]) # Tracks paths of uploaded files

            # --- UI Layout ---
            gr.Markdown("# Smol Talk with your Agent") # Title

            with gr.Row():
                with gr.Column(scale=3): # Main chat area
                    chatbot = gr.Chatbot(
                        label="Agent Interaction",
                        # Bubble full width can make text hard to read, try default
                        # bubble_full_width=False, 
                        avatar_images=(
                            None, # User avatar
                            "https://huggingface.co/datasets/huggingface/brand-assets/resolve/main/hf-logo-round.png" # Agent avatar
                        ),
                        height=600 
                    )
                    text_input = gr.Textbox(
                        lines=1, 
                        label="Your Message to the Agent", 
                        placeholder="Type your message and press Enter..."
                    )
                
                with gr.Column(scale=1): # Sidebar for uploads and downloads
                    if self.file_upload_folder is not None:
                        gr.Markdown("### File Upload")
                        upload_file_component = gr.File(label="Upload a supporting file")
                        upload_status_display = gr.Textbox(label="Upload Status", interactive=False, visible=True, lines=2) # Make visible by default
                        upload_file_component.upload( # Use 'upload' event for gr.File
                            self.upload_file,
                            [upload_file_component, file_uploads_log],
                            [upload_status_display, file_uploads_log],
                        )
                    
                    gr.Markdown("### Generated File")
                    # This button becomes visible if a file is created by the agent
                    download_btn = gr.Button("Download Generated File", visible=False) 
                    # This gr.File component becomes visible and populated when the button above is clicked
                    file_output_display = gr.File(label="Downloadable Document", visible=False, interactive=False) 

            # --- Event Handling ---
            
            # When user submits text_input:
            # 1. log_user_message: prepares the prompt (text + file info), clears text_input.
            #    The output 'prepared_prompt' is then passed to interact_with_agent.
            # 2. interact_with_agent: streams agent's responses to chatbot, updates download button.
            
            # We need a state to hold the prepared prompt temporarily if log_user_message is separate
            prepared_prompt_state = gr.State("")

            text_input.submit(
                self.log_user_message,
                [text_input, file_uploads_log],
                [prepared_prompt_state, text_input] # prepared_prompt_state gets the full prompt, text_input is cleared
            ).then(
                self.interact_with_agent,
                [prepared_prompt_state, chat_history_state, download_btn, file_output_display], # Pass current UI states
                [chat_history_state, download_btn, file_output_display] # Update these UI states
            )

            # When download_btn is clicked:
            download_btn.click(
                self.prepare_and_show_download_file, 
                [], # No inputs needed from UI for this action
                [file_output_display] # Update the file_output_display component
            )

        # Launch the Gradio app
        # Set share=False if running locally or on Spaces where share=True might be an issue
        demo.launch(debug=True, share=kwargs.get("share", False), **kwargs)


__all__ = ["stream_to_gradio", "GradioUI"]