File size: 6,140 Bytes
c9803a3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60c7a7f
 
 
c9803a3
 
 
 
 
 
 
 
 
 
 
 
 
60c7a7f
 
 
 
 
 
c9803a3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60c7a7f
 
 
 
 
 
 
 
 
 
 
c9803a3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
import base64
import re
import textwrap
from io import BytesIO

from PIL import Image, ImageDraw, ImageFont

from proxy_lite.environments.environment_base import Action, Observation
from proxy_lite.recorder import Run


def create_run_gif(
    run: Run, output_path: str, white_panel_width: int = 300, duration: int = 1500, resize_factor: int = 4
) -> None:
    """
    Generate a gif from the Run object's history.

    For each Observation record, the observation image is decoded from its base64
    encoded string. If the next record is an Action, its text is drawn onto a
    white panel. The observation image and the white panel are then concatenated
    horizontally to produce a frame.

    Parameters:
        run (Run): A Run object with its history containing Observation and Action records.
        output_path (str): The path where the GIF will be saved.
        white_panel_width (int): The width of the white panel for displaying text.
                                 Default increased to 400 for larger images.
        duration (int): Duration between frames in milliseconds.
                        Increased here to slow the FPS (default is 1000ms).
        resize_factor (int): The factor to resize the image down by.
    """
    frames = []
    history = run.history
    i = 0
    while i < len(history):
        current_record = history[i]
        if isinstance(current_record, Observation):
            observation = current_record
            image_data = observation.state.image
            if not image_data:
                i += 1
                continue
            # Decode the base64 image
            image_bytes = base64.b64decode(image_data)
            obs_img = Image.open(BytesIO(image_bytes)).convert("RGB")

            # scale the image down
            obs_img = obs_img.resize((obs_img.width // resize_factor, obs_img.height // resize_factor))

            # Check if the next record is an Action and extract its text if available
            action_text = ""
            if i + 1 < len(history):
                next_record = history[i + 1]
                if isinstance(next_record, Action):
                    action = next_record
                    if hasattr(action, 'text') and action.text:
                        action_text = action.text

            # extract observation and thinking from tags in the action text
            observation_match = re.search(r"<observation>(.*?)</observation>", action_text, re.DOTALL)
            observation_content = observation_match.group(1).strip() if observation_match else None

            # Extract text between thinking tags if present
            thinking_match = re.search(r"<thinking>(.*?)</thinking>", action_text, re.DOTALL)
            thinking_content = thinking_match.group(1).strip() if thinking_match else None

            if observation_content and thinking_content:
                action_text = f"**OBSERVATION**\n{observation_content}\n\n**THINKING**\n{thinking_content}"

            # Create a white panel (same height as the observation image)
            panel = Image.new("RGB", (white_panel_width, obs_img.height), "white")
            draw = ImageDraw.Draw(panel)
            font = ImageFont.load_default()

            # Wrap the action text if it is too long
            max_chars_per_line = 40  # Adjusted for larger font size
            wrapped_text = textwrap.fill(action_text, width=max_chars_per_line)

            # Calculate text block size and center it on the panel
            try:
                # Use multiline_textbbox if available (returns bounding box tuple)
                bbox = draw.multiline_textbbox((0, 0), wrapped_text, font=font)
                text_width, text_height = bbox[2] - bbox[0], bbox[3] - bbox[1]
            except AttributeError:
                # Fallback for older Pillow versions: compute size for each line
                lines = wrapped_text.splitlines() or [wrapped_text]
                try:
                    # Try textbbox first (newer method)
                    line_sizes = [draw.textbbox((0, 0), line, font=font) for line in lines]
                    text_width = max(bbox[2] - bbox[0] for bbox in line_sizes)
                    text_height = sum(bbox[3] - bbox[1] for bbox in line_sizes)
                except AttributeError:
                    # Final fallback: estimate based on font size
                    estimated_char_width = 8  # Rough estimate
                    estimated_line_height = 16  # Rough estimate
                    text_width = max(len(line) * estimated_char_width for line in lines)
                    text_height = len(lines) * estimated_line_height
            text_x = (white_panel_width - text_width) // 2
            text_y = (obs_img.height - text_height) // 2
            draw.multiline_text((text_x, text_y), wrapped_text, fill="black", font=font, align="center")

            # Create the combined frame by concatenating the observation image and the panel
            total_width = obs_img.width + white_panel_width
            combined_frame = Image.new("RGB", (total_width, obs_img.height))
            combined_frame.paste(obs_img, (0, 0))
            combined_frame.paste(panel, (obs_img.width, 0))
            frames.append(combined_frame)

            # Skip the Action record since it has been processed with this Observation
            if i + 1 < len(history) and isinstance(history[i + 1], Action):
                i += 2
            else:
                i += 1
        else:
            i += 1

    if frames:
        frames[0].save(output_path, save_all=True, append_images=frames[1:], duration=duration, loop=0)
    else:
        raise ValueError("No frames were generated from the Run object's history.")


# Example usage:
if __name__ == "__main__":
    from proxy_lite.recorder import Run

    dummy_run = Run.load("0abdb4cb-f289-48b0-ba13-35ed1210f7c1")

    num_steps = int(len(dummy_run.history) / 2)
    print(f"Number of steps: {num_steps}")
    output_gif_path = "trajectory.gif"
    create_run_gif(dummy_run, output_gif_path, duration=1000)
    print(f"Trajectory GIF saved to {output_gif_path}")