Spaces:
Configuration error
Configuration error
add magma arena
Browse files- app.py +51 -63
- app_1p.py +213 -0
- assets/images/magma_game_thin.png +0 -0
- requirements.txt +3 -1
- vlms/__pycache__/llavanext.cpython-310.pyc +0 -0
- vlms/__pycache__/llavaov.cpython-310.pyc +0 -0
- vlms/__pycache__/magma.cpython-310.pyc +0 -0
- vlms/__pycache__/qwen25vl.cpython-310.pyc +0 -0
- vlms/__pycache__/qwen2vl.cpython-310.pyc +0 -0
- vlms/llavanext.py +43 -0
- vlms/llavaov.py +44 -0
- vlms/magma.py +40 -0
- vlms/qwen2vl.py +59 -0
app.py
CHANGED
|
@@ -1,5 +1,4 @@
|
|
| 1 |
import os
|
| 2 |
-
# add a command for installing flash-attn
|
| 3 |
os.system('pip install flash-attn --no-build-isolation')
|
| 4 |
os.system("pip install gradio==4.44.1")
|
| 5 |
|
|
@@ -12,12 +11,15 @@ from PIL import Image
|
|
| 12 |
from transformers import AutoModelForCausalLM, AutoProcessor
|
| 13 |
import re
|
| 14 |
import random
|
|
|
|
|
|
|
|
|
|
| 15 |
|
| 16 |
pygame.mixer.quit() # Disable sound
|
| 17 |
|
| 18 |
# Constants
|
| 19 |
-
WIDTH, HEIGHT =
|
| 20 |
-
GRID_SIZE =
|
| 21 |
WHITE = (255, 255, 255)
|
| 22 |
GREEN = (34, 139, 34) # Forest green - more like an apple
|
| 23 |
RED = (200, 50, 50)
|
|
@@ -34,29 +36,24 @@ STATIC = (0, 0)
|
|
| 34 |
|
| 35 |
ACTIONS = ["up", "down", "left", "right", "static"]
|
| 36 |
|
| 37 |
-
# Load AI Model
|
| 38 |
dtype = torch.bfloat16
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
magma_processor = AutoProcessor.from_pretrained(magma_model_id, trust_remote_code=True)
|
| 42 |
-
magam_model.to("cuda")
|
| 43 |
|
| 44 |
-
# Load magma image
|
| 45 |
magma_img = pygame.image.load("./assets/images/magma_game_thin.png")
|
| 46 |
magma_img = pygame.transform.scale(magma_img, (GRID_SIZE, GRID_SIZE))
|
| 47 |
|
| 48 |
-
target_img = pygame.image.load("./assets/images/apple.png")
|
| 49 |
-
target_img = pygame.transform.scale(target_img, (GRID_SIZE, GRID_SIZE))
|
| 50 |
-
|
| 51 |
class MagmaFindGPU:
|
| 52 |
def __init__(self):
|
| 53 |
self.reset()
|
| 54 |
-
|
|
|
|
| 55 |
def reset(self):
|
| 56 |
self.snake = [(5, 5)]
|
| 57 |
self.direction = RIGHT
|
| 58 |
self.score = 0
|
| 59 |
self.game_over = False
|
|
|
|
| 60 |
self.place_target()
|
| 61 |
|
| 62 |
def place_target(self):
|
|
@@ -79,16 +76,18 @@ class MagmaFindGPU:
|
|
| 79 |
elif action == "static":
|
| 80 |
self.direction = STATIC
|
| 81 |
|
| 82 |
-
if self.game_over:
|
| 83 |
-
|
|
|
|
| 84 |
|
| 85 |
new_head = (self.snake[0][0] + self.direction[0], self.snake[0][1] + self.direction[1])
|
| 86 |
-
|
| 87 |
if new_head[0] < 0 or new_head[1] < 0 or new_head[0] >= WIDTH // GRID_SIZE or new_head[1] >= HEIGHT // GRID_SIZE:
|
| 88 |
-
self.game_over = True
|
| 89 |
return self.render(), self.score
|
| 90 |
|
| 91 |
self.snake = [new_head] # Keep only the head (single block snake)
|
|
|
|
| 92 |
|
| 93 |
# Check if the target is covered by four surrounding squares
|
| 94 |
head_x, head_y = self.snake[0]
|
|
@@ -99,7 +98,7 @@ class MagmaFindGPU:
|
|
| 99 |
self.place_target()
|
| 100 |
|
| 101 |
return self.render(), self.score
|
| 102 |
-
|
| 103 |
def render(self):
|
| 104 |
pygame.init()
|
| 105 |
surface = pygame.Surface((WIDTH, HEIGHT))
|
|
@@ -109,10 +108,8 @@ class MagmaFindGPU:
|
|
| 109 |
surface.blit(magma_img, (head_x * GRID_SIZE, head_y * GRID_SIZE))
|
| 110 |
|
| 111 |
# pygame.draw.rect(surface, RED, (self.snake[0][0] * GRID_SIZE, self.snake[0][1] * GRID_SIZE, GRID_SIZE, GRID_SIZE))
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
surface.blit(target_img, (self.target[0] * GRID_SIZE, self.target[1] * GRID_SIZE))
|
| 115 |
-
|
| 116 |
# Draw four surrounding squares with labels
|
| 117 |
head_x, head_y = self.snake[0]
|
| 118 |
neighbors = [(head_x, head_y - 1), (head_x, head_y + 1), (head_x - 1, head_y), (head_x + 1, head_y)]
|
|
@@ -135,41 +132,20 @@ class MagmaFindGPU:
|
|
| 135 |
def get_state(self):
|
| 136 |
return self.render()
|
| 137 |
|
| 138 |
-
|
|
|
|
| 139 |
|
| 140 |
-
def play_game():
|
| 141 |
state, state_som = game.get_state()
|
| 142 |
pil_img = Image.fromarray(state_som)
|
| 143 |
-
|
| 144 |
-
{"role": "system", "content": "You are an agent that can see, talk, and act."},
|
| 145 |
-
{"role": "user", "content": "<image_start><image><image_end>\nWhich mark is closer to green apple? Answer with a single number."},
|
| 146 |
-
]
|
| 147 |
-
prompt = magma_processor.tokenizer.apply_chat_template(convs, tokenize=False, add_generation_prompt=True)
|
| 148 |
-
inputs = magma_processor(images=[pil_img], texts=prompt, return_tensors="pt")
|
| 149 |
-
inputs['pixel_values'] = inputs['pixel_values'].unsqueeze(0)
|
| 150 |
-
inputs['image_sizes'] = inputs['image_sizes'].unsqueeze(0)
|
| 151 |
-
inputs = inputs.to("cuda").to(dtype)
|
| 152 |
-
generation_args = {
|
| 153 |
-
"max_new_tokens": 10,
|
| 154 |
-
"temperature": 0,
|
| 155 |
-
"do_sample": False,
|
| 156 |
-
"use_cache": True,
|
| 157 |
-
"num_beams": 1,
|
| 158 |
-
}
|
| 159 |
-
with torch.inference_mode():
|
| 160 |
-
generate_ids = magam_model.generate(**inputs, **generation_args)
|
| 161 |
-
generate_ids = generate_ids[:, inputs["input_ids"].shape[-1] :]
|
| 162 |
-
action = magma_processor.decode(generate_ids[0], skip_special_tokens=True).strip()
|
| 163 |
# extract mark id fro action use re
|
|
|
|
| 164 |
match = re.search(r'\d+', action)
|
| 165 |
if match:
|
| 166 |
action = match.group(0)
|
| 167 |
if action.isdigit() and 1 <= int(action) <= 4:
|
| 168 |
-
|
| 169 |
-
if random.random() < 0.1:
|
| 170 |
-
action = random.choice(ACTIONS[:-1])
|
| 171 |
-
else:
|
| 172 |
-
action = ACTIONS[int(action) - 1]
|
| 173 |
else:
|
| 174 |
# random choose one from the pool
|
| 175 |
action = random.choice(ACTIONS[:-1])
|
|
@@ -177,34 +153,46 @@ def play_game():
|
|
| 177 |
action = random.choice(ACTIONS[:-1])
|
| 178 |
|
| 179 |
img, score = game.step(action)
|
| 180 |
-
|
| 181 |
-
|
|
|
|
|
|
|
| 182 |
|
| 183 |
-
def
|
| 184 |
-
|
| 185 |
-
return game.render()[0], "Score: 0"
|
| 186 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 187 |
MARKDOWN = """
|
| 188 |
<div align="center">
|
| 189 |
<h2>Magma: A Foundation Model for Multimodal AI Agents</h2>
|
| 190 |
|
| 191 |
\[[arXiv Paper](https://www.arxiv.org/pdf/2502.13130)\] \[[Project Page](https://microsoft.github.io/Magma/)\] \[[Github Repo](https://github.com/microsoft/Magma)\] \[[Hugging Face Model](https://huggingface.co/microsoft/Magma-8B)\]
|
| 192 |
|
| 193 |
-
|
| 194 |
|
| 195 |
This demo is powered by [Gradio](https://gradio.app/).
|
|
|
|
| 196 |
</div>
|
| 197 |
"""
|
| 198 |
|
| 199 |
with gr.Blocks() as interface:
|
| 200 |
gr.Markdown(MARKDOWN)
|
| 201 |
with gr.Row():
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 209 |
|
| 210 |
-
interface.launch()
|
|
|
|
| 1 |
import os
|
|
|
|
| 2 |
os.system('pip install flash-attn --no-build-isolation')
|
| 3 |
os.system("pip install gradio==4.44.1")
|
| 4 |
|
|
|
|
| 11 |
from transformers import AutoModelForCausalLM, AutoProcessor
|
| 12 |
import re
|
| 13 |
import random
|
| 14 |
+
from vlms.magma import MagmaAgent
|
| 15 |
+
from vlms.llavaov import LLaVAOVAgent
|
| 16 |
+
from vlms.qwen2vl import Qwen2VLAgent
|
| 17 |
|
| 18 |
pygame.mixer.quit() # Disable sound
|
| 19 |
|
| 20 |
# Constants
|
| 21 |
+
WIDTH, HEIGHT = 800, 800
|
| 22 |
+
GRID_SIZE = 80
|
| 23 |
WHITE = (255, 255, 255)
|
| 24 |
GREEN = (34, 139, 34) # Forest green - more like an apple
|
| 25 |
RED = (200, 50, 50)
|
|
|
|
| 36 |
|
| 37 |
ACTIONS = ["up", "down", "left", "right", "static"]
|
| 38 |
|
|
|
|
| 39 |
dtype = torch.bfloat16
|
| 40 |
+
agent_1 = MagmaAgent("cuda:0", dtype)
|
| 41 |
+
agent_2 = Qwen2VLAgent("cuda:0", dtype)
|
|
|
|
|
|
|
| 42 |
|
|
|
|
| 43 |
magma_img = pygame.image.load("./assets/images/magma_game_thin.png")
|
| 44 |
magma_img = pygame.transform.scale(magma_img, (GRID_SIZE, GRID_SIZE))
|
| 45 |
|
|
|
|
|
|
|
|
|
|
| 46 |
class MagmaFindGPU:
|
| 47 |
def __init__(self):
|
| 48 |
self.reset()
|
| 49 |
+
self.step_count = 0
|
| 50 |
+
|
| 51 |
def reset(self):
|
| 52 |
self.snake = [(5, 5)]
|
| 53 |
self.direction = RIGHT
|
| 54 |
self.score = 0
|
| 55 |
self.game_over = False
|
| 56 |
+
self.step_count = 0
|
| 57 |
self.place_target()
|
| 58 |
|
| 59 |
def place_target(self):
|
|
|
|
| 76 |
elif action == "static":
|
| 77 |
self.direction = STATIC
|
| 78 |
|
| 79 |
+
# if self.game_over:
|
| 80 |
+
# self.reset()
|
| 81 |
+
# return self.render(), self.score
|
| 82 |
|
| 83 |
new_head = (self.snake[0][0] + self.direction[0], self.snake[0][1] + self.direction[1])
|
| 84 |
+
|
| 85 |
if new_head[0] < 0 or new_head[1] < 0 or new_head[0] >= WIDTH // GRID_SIZE or new_head[1] >= HEIGHT // GRID_SIZE:
|
| 86 |
+
# self.game_over = True
|
| 87 |
return self.render(), self.score
|
| 88 |
|
| 89 |
self.snake = [new_head] # Keep only the head (single block snake)
|
| 90 |
+
self.step_count += 1
|
| 91 |
|
| 92 |
# Check if the target is covered by four surrounding squares
|
| 93 |
head_x, head_y = self.snake[0]
|
|
|
|
| 98 |
self.place_target()
|
| 99 |
|
| 100 |
return self.render(), self.score
|
| 101 |
+
|
| 102 |
def render(self):
|
| 103 |
pygame.init()
|
| 104 |
surface = pygame.Surface((WIDTH, HEIGHT))
|
|
|
|
| 108 |
surface.blit(magma_img, (head_x * GRID_SIZE, head_y * GRID_SIZE))
|
| 109 |
|
| 110 |
# pygame.draw.rect(surface, RED, (self.snake[0][0] * GRID_SIZE, self.snake[0][1] * GRID_SIZE, GRID_SIZE, GRID_SIZE))
|
| 111 |
+
pygame.draw.rect(surface, GREEN, (self.target[0] * GRID_SIZE, self.target[1] * GRID_SIZE, GRID_SIZE, GRID_SIZE))
|
| 112 |
+
|
|
|
|
|
|
|
| 113 |
# Draw four surrounding squares with labels
|
| 114 |
head_x, head_y = self.snake[0]
|
| 115 |
neighbors = [(head_x, head_y - 1), (head_x, head_y + 1), (head_x - 1, head_y), (head_x + 1, head_y)]
|
|
|
|
| 132 |
def get_state(self):
|
| 133 |
return self.render()
|
| 134 |
|
| 135 |
+
game_1 = MagmaFindGPU()
|
| 136 |
+
game_2 = MagmaFindGPU()
|
| 137 |
|
| 138 |
+
def play_game(game, agent):
|
| 139 |
state, state_som = game.get_state()
|
| 140 |
pil_img = Image.fromarray(state_som)
|
| 141 |
+
action = agent.generate_response(pil_img, "Which mark is closer to green block? Answer with a single number.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 142 |
# extract mark id fro action use re
|
| 143 |
+
# print(agent.__class__.__name__, action)
|
| 144 |
match = re.search(r'\d+', action)
|
| 145 |
if match:
|
| 146 |
action = match.group(0)
|
| 147 |
if action.isdigit() and 1 <= int(action) <= 4:
|
| 148 |
+
action = ACTIONS[int(action) - 1]
|
|
|
|
|
|
|
|
|
|
|
|
|
| 149 |
else:
|
| 150 |
# random choose one from the pool
|
| 151 |
action = random.choice(ACTIONS[:-1])
|
|
|
|
| 153 |
action = random.choice(ACTIONS[:-1])
|
| 154 |
|
| 155 |
img, score = game.step(action)
|
| 156 |
+
return img[0], f"Score: {score}"
|
| 157 |
+
|
| 158 |
+
def play_game_1():
|
| 159 |
+
return play_game(game_1, agent_1)
|
| 160 |
|
| 161 |
+
def play_game_2():
|
| 162 |
+
return play_game(game_2, agent_2)
|
|
|
|
| 163 |
|
| 164 |
+
def reset_games():
|
| 165 |
+
game_1.reset()
|
| 166 |
+
game_2.reset()
|
| 167 |
+
return game_1.render()[0], "Score: 0", game_2.render()[0], "Score: 0"
|
| 168 |
MARKDOWN = """
|
| 169 |
<div align="center">
|
| 170 |
<h2>Magma: A Foundation Model for Multimodal AI Agents</h2>
|
| 171 |
|
| 172 |
\[[arXiv Paper](https://www.arxiv.org/pdf/2502.13130)\] \[[Project Page](https://microsoft.github.io/Magma/)\] \[[Github Repo](https://github.com/microsoft/Magma)\] \[[Hugging Face Model](https://huggingface.co/microsoft/Magma-8B)\]
|
| 173 |
|
| 174 |
+
<h3>Magma Arena: A battle between two agents to collect the green blocks by automatically moving up, down, left and right.</h3>
|
| 175 |
|
| 176 |
This demo is powered by [Gradio](https://gradio.app/).
|
| 177 |
+
|
| 178 |
</div>
|
| 179 |
"""
|
| 180 |
|
| 181 |
with gr.Blocks() as interface:
|
| 182 |
gr.Markdown(MARKDOWN)
|
| 183 |
with gr.Row():
|
| 184 |
+
with gr.Column():
|
| 185 |
+
img_output_1 = gr.Image(label="{}".format(agent_1.__class__.__name__))
|
| 186 |
+
score_output_1 = gr.Text(label="Score 1")
|
| 187 |
+
with gr.Column():
|
| 188 |
+
img_output_2 = gr.Image(label="{}".format(agent_2.__class__.__name__))
|
| 189 |
+
score_output_2 = gr.Text(label="Score 2")
|
| 190 |
+
|
| 191 |
+
start_btn = gr.Button("Start/Reset Game")
|
| 192 |
+
|
| 193 |
+
interface.load(fn=play_game_1, every=1, inputs=[], outputs=[img_output_1, score_output_1])
|
| 194 |
+
interface.load(fn=play_game_2, every=1, inputs=[], outputs=[img_output_2, score_output_2])
|
| 195 |
+
|
| 196 |
+
start_btn.click(fn=reset_games, inputs=[], outputs=[img_output_1, score_output_1, img_output_2, score_output_2])
|
| 197 |
|
| 198 |
+
interface.launch(server_port=7861)
|
app_1p.py
ADDED
|
@@ -0,0 +1,213 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
# add a command for installing flash-attn
|
| 3 |
+
os.system('pip install flash-attn --no-build-isolation')
|
| 4 |
+
os.system("pip install gradio==4.44.1")
|
| 5 |
+
|
| 6 |
+
import pygame
|
| 7 |
+
import numpy as np
|
| 8 |
+
import gradio as gr
|
| 9 |
+
import time
|
| 10 |
+
import torch
|
| 11 |
+
from PIL import Image
|
| 12 |
+
from transformers import AutoModelForCausalLM, AutoProcessor
|
| 13 |
+
import re
|
| 14 |
+
import random
|
| 15 |
+
|
| 16 |
+
pygame.mixer.quit() # Disable sound
|
| 17 |
+
|
| 18 |
+
# Constants
|
| 19 |
+
WIDTH, HEIGHT = 800, 800
|
| 20 |
+
GRID_SIZE = 80
|
| 21 |
+
WHITE = (255, 255, 255)
|
| 22 |
+
GREEN = (34, 139, 34) # Forest green - more like an apple
|
| 23 |
+
RED = (200, 50, 50)
|
| 24 |
+
BLACK = (0, 0, 0)
|
| 25 |
+
GRAY = (128, 128, 128)
|
| 26 |
+
YELLOW = (218, 165, 32) # Golden yellow color
|
| 27 |
+
|
| 28 |
+
# Directions
|
| 29 |
+
UP = (0, -1)
|
| 30 |
+
DOWN = (0, 1)
|
| 31 |
+
LEFT = (-1, 0)
|
| 32 |
+
RIGHT = (1, 0)
|
| 33 |
+
STATIC = (0, 0)
|
| 34 |
+
|
| 35 |
+
ACTIONS = ["up", "down", "left", "right", "static"]
|
| 36 |
+
|
| 37 |
+
# Load AI Model
|
| 38 |
+
dtype = torch.bfloat16
|
| 39 |
+
magma_model_id = "microsoft/Magma-8B"
|
| 40 |
+
magam_model = AutoModelForCausalLM.from_pretrained(magma_model_id, trust_remote_code=True, torch_dtype=dtype)
|
| 41 |
+
magma_processor = AutoProcessor.from_pretrained(magma_model_id, trust_remote_code=True)
|
| 42 |
+
magam_model.to("cuda")
|
| 43 |
+
|
| 44 |
+
magma_img = pygame.image.load("./assets/images/magma_game_thin.png")
|
| 45 |
+
magma_img = pygame.transform.scale(magma_img, (GRID_SIZE, GRID_SIZE))
|
| 46 |
+
|
| 47 |
+
class MagmaFindGPU:
|
| 48 |
+
def __init__(self):
|
| 49 |
+
self.reset()
|
| 50 |
+
self.step_count = 0
|
| 51 |
+
|
| 52 |
+
def reset(self):
|
| 53 |
+
self.snake = [(5, 5)]
|
| 54 |
+
self.direction = RIGHT
|
| 55 |
+
self.score = 0
|
| 56 |
+
self.game_over = False
|
| 57 |
+
self.step_count = 0
|
| 58 |
+
self.place_target()
|
| 59 |
+
|
| 60 |
+
def place_target(self):
|
| 61 |
+
while True:
|
| 62 |
+
target_x = np.random.randint(1, WIDTH // GRID_SIZE - 1)
|
| 63 |
+
target_y = np.random.randint(1, HEIGHT // GRID_SIZE - 1)
|
| 64 |
+
if (target_x, target_y) not in self.snake:
|
| 65 |
+
self.target = (target_x, target_y)
|
| 66 |
+
break
|
| 67 |
+
|
| 68 |
+
def step(self, action):
|
| 69 |
+
if action == "up":
|
| 70 |
+
self.direction = UP
|
| 71 |
+
elif action == "down":
|
| 72 |
+
self.direction = DOWN
|
| 73 |
+
elif action == "left":
|
| 74 |
+
self.direction = LEFT
|
| 75 |
+
elif action == "right":
|
| 76 |
+
self.direction = RIGHT
|
| 77 |
+
elif action == "static":
|
| 78 |
+
self.direction = STATIC
|
| 79 |
+
|
| 80 |
+
if self.game_over:
|
| 81 |
+
self.reset()
|
| 82 |
+
return self.render(), self.score
|
| 83 |
+
|
| 84 |
+
new_head = (self.snake[0][0] + self.direction[0], self.snake[0][1] + self.direction[1])
|
| 85 |
+
|
| 86 |
+
if new_head[0] < 0 or new_head[1] < 0 or new_head[0] >= WIDTH // GRID_SIZE or new_head[1] >= HEIGHT // GRID_SIZE:
|
| 87 |
+
self.game_over = True
|
| 88 |
+
return self.render(), self.score
|
| 89 |
+
|
| 90 |
+
self.snake = [new_head] # Keep only the head (single block snake)
|
| 91 |
+
self.step_count += 1
|
| 92 |
+
|
| 93 |
+
# Check if the target is covered by four surrounding squares
|
| 94 |
+
head_x, head_y = self.snake[0]
|
| 95 |
+
neighbors = set([(head_x, head_y - 1), (head_x, head_y + 1), (head_x - 1, head_y), (head_x + 1, head_y)])
|
| 96 |
+
|
| 97 |
+
if neighbors.issuperset(set([self.target])):
|
| 98 |
+
self.score += 1
|
| 99 |
+
self.place_target()
|
| 100 |
+
|
| 101 |
+
return self.render(), self.score
|
| 102 |
+
|
| 103 |
+
def render(self):
|
| 104 |
+
pygame.init()
|
| 105 |
+
surface = pygame.Surface((WIDTH, HEIGHT))
|
| 106 |
+
surface.fill(BLACK)
|
| 107 |
+
|
| 108 |
+
head_x, head_y = self.snake[0]
|
| 109 |
+
surface.blit(magma_img, (head_x * GRID_SIZE, head_y * GRID_SIZE))
|
| 110 |
+
|
| 111 |
+
# pygame.draw.rect(surface, RED, (self.snake[0][0] * GRID_SIZE, self.snake[0][1] * GRID_SIZE, GRID_SIZE, GRID_SIZE))
|
| 112 |
+
pygame.draw.rect(surface, GREEN, (self.target[0] * GRID_SIZE, self.target[1] * GRID_SIZE, GRID_SIZE, GRID_SIZE))
|
| 113 |
+
|
| 114 |
+
# Draw four surrounding squares with labels
|
| 115 |
+
head_x, head_y = self.snake[0]
|
| 116 |
+
neighbors = [(head_x, head_y - 1), (head_x, head_y + 1), (head_x - 1, head_y), (head_x + 1, head_y)]
|
| 117 |
+
labels = ["1", "2", "3", "4"]
|
| 118 |
+
font = pygame.font.Font(None, 48)
|
| 119 |
+
|
| 120 |
+
# clone surface
|
| 121 |
+
surface_nomark = surface.copy()
|
| 122 |
+
for i, (nx, ny) in enumerate(neighbors):
|
| 123 |
+
if 0 <= nx < WIDTH // GRID_SIZE and 0 <= ny < HEIGHT // GRID_SIZE:
|
| 124 |
+
pygame.draw.rect(surface, RED, (nx * GRID_SIZE, ny * GRID_SIZE, GRID_SIZE, GRID_SIZE), GRID_SIZE)
|
| 125 |
+
# pygame.draw.rect(surface_nomark, RED, (nx * GRID_SIZE, ny * GRID_SIZE, GRID_SIZE, GRID_SIZE), GRID_SIZE)
|
| 126 |
+
|
| 127 |
+
text = font.render(labels[i], True, WHITE)
|
| 128 |
+
text_rect = text.get_rect(center=(nx * GRID_SIZE + GRID_SIZE // 2, ny * GRID_SIZE + GRID_SIZE // 2))
|
| 129 |
+
surface.blit(text, text_rect)
|
| 130 |
+
|
| 131 |
+
return np.array(pygame.surfarray.array3d(surface_nomark)).swapaxes(0, 1), np.array(pygame.surfarray.array3d(surface)).swapaxes(0, 1)
|
| 132 |
+
|
| 133 |
+
def get_state(self):
|
| 134 |
+
return self.render()
|
| 135 |
+
|
| 136 |
+
game = MagmaFindGPU()
|
| 137 |
+
|
| 138 |
+
def play_game():
|
| 139 |
+
state, state_som = game.get_state()
|
| 140 |
+
pil_img = Image.fromarray(state_som)
|
| 141 |
+
convs = [
|
| 142 |
+
{"role": "system", "content": "You are an agent that can see, talk, and act. Avoid hitting the wall."},
|
| 143 |
+
{"role": "user", "content": "<image_start><image><image_end>\nWhich mark is closer to green block? Answer with a single number."},
|
| 144 |
+
]
|
| 145 |
+
prompt = magma_processor.tokenizer.apply_chat_template(convs, tokenize=False, add_generation_prompt=True)
|
| 146 |
+
inputs = magma_processor(images=[pil_img], texts=prompt, return_tensors="pt")
|
| 147 |
+
inputs['pixel_values'] = inputs['pixel_values'].unsqueeze(0)
|
| 148 |
+
inputs['image_sizes'] = inputs['image_sizes'].unsqueeze(0)
|
| 149 |
+
inputs = inputs.to("cuda").to(dtype)
|
| 150 |
+
generation_args = {
|
| 151 |
+
"max_new_tokens": 10,
|
| 152 |
+
"temperature": 0.3,
|
| 153 |
+
"do_sample": True,
|
| 154 |
+
"use_cache": True,
|
| 155 |
+
"num_beams": 1,
|
| 156 |
+
}
|
| 157 |
+
with torch.inference_mode():
|
| 158 |
+
generate_ids = magam_model.generate(**inputs, **generation_args)
|
| 159 |
+
generate_ids = generate_ids[:, inputs["input_ids"].shape[-1] :]
|
| 160 |
+
action = magma_processor.decode(generate_ids[0], skip_special_tokens=True).strip()
|
| 161 |
+
# extract mark id fro action use re
|
| 162 |
+
match = re.search(r'\d+', action)
|
| 163 |
+
if match:
|
| 164 |
+
action = match.group(0)
|
| 165 |
+
if action.isdigit() and 1 <= int(action) <= 4:
|
| 166 |
+
action = ACTIONS[int(action) - 1]
|
| 167 |
+
else:
|
| 168 |
+
# random choose one from the pool
|
| 169 |
+
action = random.choice(ACTIONS[:-1])
|
| 170 |
+
else:
|
| 171 |
+
action = random.choice(ACTIONS[:-1])
|
| 172 |
+
|
| 173 |
+
img, score = game.step(action)
|
| 174 |
+
img = img[0]
|
| 175 |
+
return img, f"Score: {score}"
|
| 176 |
+
|
| 177 |
+
def reset_game():
|
| 178 |
+
game.reset()
|
| 179 |
+
return game.render()[0], "Score: 0"
|
| 180 |
+
|
| 181 |
+
MARKDOWN = """
|
| 182 |
+
<div align="center">
|
| 183 |
+
<img src="./assets/images/logo.png" alt="Magma Logo" style="margin-right: 5px; height: 80px;margin-top: -10px;">
|
| 184 |
+
<h2>Magma: A Foundation Model for Multimodal AI Agents</h2>
|
| 185 |
+
|
| 186 |
+
\[[arXiv Paper](https://www.arxiv.org/pdf/2502.13130)\] \[[Project Page](https://microsoft.github.io/Magma/)\] \[[Github Repo](https://github.com/microsoft/Magma)\] \[[Hugging Face Model](https://huggingface.co/microsoft/Magma-8B)\]
|
| 187 |
+
|
| 188 |
+
This demo is powered by [Gradio](https://gradio.app/).
|
| 189 |
+
|
| 190 |
+
<b>Goal: Collects the green blocks by automatically moving up, down, left and right.</b>
|
| 191 |
+
|
| 192 |
+
</div>
|
| 193 |
+
"""
|
| 194 |
+
|
| 195 |
+
with gr.Blocks() as interface:
|
| 196 |
+
gr.Markdown(MARKDOWN)
|
| 197 |
+
with gr.Row():
|
| 198 |
+
image_output = gr.Image(label="Game Screen")
|
| 199 |
+
with gr.Column():
|
| 200 |
+
score_output = gr.Text(label="Score", elem_classes="large-text")
|
| 201 |
+
gr.HTML("""
|
| 202 |
+
<style>
|
| 203 |
+
.large-text textarea {
|
| 204 |
+
font-size: 24px !important;
|
| 205 |
+
}
|
| 206 |
+
</style>
|
| 207 |
+
""")
|
| 208 |
+
start_btn = gr.Button("Start/Reset Game")
|
| 209 |
+
|
| 210 |
+
interface.load(fn=play_game, every=1, inputs=[], outputs=[image_output, score_output])
|
| 211 |
+
start_btn.click(fn=reset_game, inputs=[], outputs=[image_output, score_output])
|
| 212 |
+
|
| 213 |
+
interface.launch()
|
assets/images/magma_game_thin.png
CHANGED
|
|
requirements.txt
CHANGED
|
@@ -1,7 +1,7 @@
|
|
| 1 |
torch==2.3.1
|
| 2 |
torchvision==0.18.1
|
| 3 |
pytorch-lightning>=1.0.8
|
| 4 |
-
transformers @ git+https://github.com/jwyang/transformers.git@dev/jwyang-v4.
|
| 5 |
tokenizers>=0.15.0
|
| 6 |
sentencepiece==0.1.99
|
| 7 |
shortuuid
|
|
@@ -35,3 +35,5 @@ open_clip_torch
|
|
| 35 |
supervision==0.18.0
|
| 36 |
ultralytics==8.3.78
|
| 37 |
pygame
|
|
|
|
|
|
|
|
|
| 1 |
torch==2.3.1
|
| 2 |
torchvision==0.18.1
|
| 3 |
pytorch-lightning>=1.0.8
|
| 4 |
+
transformers @ git+https://github.com/jwyang/transformers.git@dev/jwyang-v4.48.2
|
| 5 |
tokenizers>=0.15.0
|
| 6 |
sentencepiece==0.1.99
|
| 7 |
shortuuid
|
|
|
|
| 35 |
supervision==0.18.0
|
| 36 |
ultralytics==8.3.78
|
| 37 |
pygame
|
| 38 |
+
pyautogui
|
| 39 |
+
qwen-vl-utils
|
vlms/__pycache__/llavanext.cpython-310.pyc
ADDED
|
Binary file (1.79 kB). View file
|
|
|
vlms/__pycache__/llavaov.cpython-310.pyc
ADDED
|
Binary file (1.81 kB). View file
|
|
|
vlms/__pycache__/magma.cpython-310.pyc
ADDED
|
Binary file (1.93 kB). View file
|
|
|
vlms/__pycache__/qwen25vl.cpython-310.pyc
ADDED
|
Binary file (2.05 kB). View file
|
|
|
vlms/__pycache__/qwen2vl.cpython-310.pyc
ADDED
|
Binary file (2.01 kB). View file
|
|
|
vlms/llavanext.py
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from transformers import LlavaNextProcessor, LlavaNextForConditionalGeneration
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
from PIL import Image
|
| 5 |
+
import requests
|
| 6 |
+
|
| 7 |
+
class LLaVANextAgent(nn.Module):
|
| 8 |
+
def __init__(self, device="cuda", dtype=torch.float16):
|
| 9 |
+
super().__init__()
|
| 10 |
+
|
| 11 |
+
self.processor = LlavaNextProcessor.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf")
|
| 12 |
+
self.model = LlavaNextForConditionalGeneration.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf", torch_dtype=dtype, low_cpu_mem_usage=True)
|
| 13 |
+
self.dtype = dtype
|
| 14 |
+
self.device = device
|
| 15 |
+
|
| 16 |
+
self.model.to(device)
|
| 17 |
+
|
| 18 |
+
self.generation_args = {
|
| 19 |
+
"max_new_tokens": 10,
|
| 20 |
+
"temperature": 0.3,
|
| 21 |
+
"do_sample": True,
|
| 22 |
+
"use_cache": True,
|
| 23 |
+
"num_beams": 1,
|
| 24 |
+
}
|
| 25 |
+
|
| 26 |
+
def generate_response(self, image, question):
|
| 27 |
+
conversation = [
|
| 28 |
+
{
|
| 29 |
+
"role": "user",
|
| 30 |
+
"content": [
|
| 31 |
+
{"type": "text", "text": question},
|
| 32 |
+
{"type": "image"},
|
| 33 |
+
],
|
| 34 |
+
},
|
| 35 |
+
]
|
| 36 |
+
prompt = self.processor.apply_chat_template(conversation, add_generation_prompt=True)
|
| 37 |
+
inputs = self.processor(images=image, text=prompt, return_tensors="pt").to(self.device)
|
| 38 |
+
# autoregressively complete prompt
|
| 39 |
+
self.model.generation_config.pad_token_id = self.processor.tokenizer.pad_token_id
|
| 40 |
+
with torch.inference_mode():
|
| 41 |
+
output = self.model.generate(**inputs, **self.generation_args)
|
| 42 |
+
output = output[:, inputs["input_ids"].shape[-1] :]
|
| 43 |
+
return self.processor.decode(output[0], skip_special_tokens=True).strip()
|
vlms/llavaov.py
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from transformers import AutoProcessor, LlavaOnevisionForConditionalGeneration
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
from PIL import Image
|
| 5 |
+
import requests
|
| 6 |
+
|
| 7 |
+
class LLaVAOVAgent(nn.Module):
|
| 8 |
+
model_id = "llava-hf/llava-onevision-qwen2-7b-ov-hf"
|
| 9 |
+
def __init__(self, device="cuda", dtype=torch.float16):
|
| 10 |
+
super().__init__()
|
| 11 |
+
|
| 12 |
+
self.processor = AutoProcessor.from_pretrained(self.model_id)
|
| 13 |
+
self.model = LlavaOnevisionForConditionalGeneration.from_pretrained(self.model_id, torch_dtype=dtype, low_cpu_mem_usage=True)
|
| 14 |
+
self.dtype = dtype
|
| 15 |
+
self.device = device
|
| 16 |
+
|
| 17 |
+
self.model.to(device)
|
| 18 |
+
|
| 19 |
+
self.generation_args = {
|
| 20 |
+
"max_new_tokens": 10,
|
| 21 |
+
"temperature": 0.3,
|
| 22 |
+
"do_sample": True,
|
| 23 |
+
"use_cache": True,
|
| 24 |
+
"num_beams": 1,
|
| 25 |
+
}
|
| 26 |
+
|
| 27 |
+
def generate_response(self, image, question):
|
| 28 |
+
conversation = [
|
| 29 |
+
{
|
| 30 |
+
"role": "user",
|
| 31 |
+
"content": [
|
| 32 |
+
{"type": "text", "text": question},
|
| 33 |
+
{"type": "image"},
|
| 34 |
+
],
|
| 35 |
+
},
|
| 36 |
+
]
|
| 37 |
+
prompt = self.processor.apply_chat_template(conversation, add_generation_prompt=True)
|
| 38 |
+
inputs = self.processor(images=image, text=prompt, return_tensors="pt").to(self.device)
|
| 39 |
+
# autoregressively complete prompt
|
| 40 |
+
self.model.generation_config.pad_token_id = self.processor.tokenizer.pad_token_id
|
| 41 |
+
with torch.inference_mode():
|
| 42 |
+
output = self.model.generate(**inputs, **self.generation_args)
|
| 43 |
+
output = output[:, inputs["input_ids"].shape[-1] :]
|
| 44 |
+
return self.processor.decode(output[0], skip_special_tokens=True).strip()
|
vlms/magma.py
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from transformers import AutoModelForCausalLM, AutoProcessor
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
from PIL import Image
|
| 5 |
+
import requests
|
| 6 |
+
|
| 7 |
+
model_id = "microsoft/Magma-8B"
|
| 8 |
+
class MagmaAgent(nn.Module):
|
| 9 |
+
def __init__(self, device="cuda", dtype=torch.float16):
|
| 10 |
+
super().__init__()
|
| 11 |
+
|
| 12 |
+
self.model = AutoModelForCausalLM.from_pretrained(model_id, trust_remote_code=True, torch_dtype=dtype, low_cpu_mem_usage=True)
|
| 13 |
+
self.processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
|
| 14 |
+
self.dtype = dtype
|
| 15 |
+
self.device = device
|
| 16 |
+
self.model.to(device)
|
| 17 |
+
|
| 18 |
+
self.generation_args = {
|
| 19 |
+
"max_new_tokens": 10,
|
| 20 |
+
"temperature": 0.3,
|
| 21 |
+
"do_sample": True,
|
| 22 |
+
"use_cache": True,
|
| 23 |
+
"num_beams": 1,
|
| 24 |
+
}
|
| 25 |
+
|
| 26 |
+
def generate_response(self, image, question):
|
| 27 |
+
convs = [
|
| 28 |
+
{"role": "system", "content": "You are an agent that can see, talk, and act."},
|
| 29 |
+
{"role": "user", "content": "<image_start><image><image_end>\n{}".format(question)},
|
| 30 |
+
]
|
| 31 |
+
prompt = self.processor.tokenizer.apply_chat_template(convs, tokenize=False, add_generation_prompt=True)
|
| 32 |
+
inputs = self.processor(images=[image], texts=prompt, return_tensors="pt").to(self.dtype).to(self.device)
|
| 33 |
+
inputs['pixel_values'] = inputs['pixel_values'].unsqueeze(0)
|
| 34 |
+
inputs['image_sizes'] = inputs['image_sizes'].unsqueeze(0)
|
| 35 |
+
|
| 36 |
+
with torch.inference_mode():
|
| 37 |
+
generate_ids = self.model.generate(**inputs, **self.generation_args)
|
| 38 |
+
generate_ids = generate_ids[:, inputs["input_ids"].shape[-1] :]
|
| 39 |
+
action = self.processor.decode(generate_ids[0], skip_special_tokens=True).strip()
|
| 40 |
+
return action
|
vlms/qwen2vl.py
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from transformers import Qwen2VLForConditionalGeneration, AutoTokenizer, AutoProcessor
|
| 2 |
+
from qwen_vl_utils import process_vision_info
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
from PIL import Image
|
| 7 |
+
import requests
|
| 8 |
+
|
| 9 |
+
class Qwen2VLAgent(nn.Module):
|
| 10 |
+
model_id = "Qwen/Qwen2-VL-7B-Instruct"
|
| 11 |
+
def __init__(self, device="cuda", dtype=torch.float16):
|
| 12 |
+
super().__init__()
|
| 13 |
+
|
| 14 |
+
self.processor = AutoProcessor.from_pretrained(self.model_id)
|
| 15 |
+
self.model = Qwen2VLForConditionalGeneration.from_pretrained(self.model_id, torch_dtype=dtype, low_cpu_mem_usage=True)
|
| 16 |
+
self.dtype = dtype
|
| 17 |
+
self.device = device
|
| 18 |
+
|
| 19 |
+
self.model.to(device)
|
| 20 |
+
|
| 21 |
+
self.generation_args = {
|
| 22 |
+
"max_new_tokens": 10,
|
| 23 |
+
"temperature": 0.3,
|
| 24 |
+
"do_sample": True,
|
| 25 |
+
"use_cache": True,
|
| 26 |
+
"num_beams": 1,
|
| 27 |
+
}
|
| 28 |
+
|
| 29 |
+
def generate_response(self, image, question):
|
| 30 |
+
image.save('qwen25vl.png')
|
| 31 |
+
conversation = [
|
| 32 |
+
{
|
| 33 |
+
"role": "user",
|
| 34 |
+
"content": [
|
| 35 |
+
{"type": "text", "text": question},
|
| 36 |
+
{"type": "image", "image": "qwen25vl.png"},
|
| 37 |
+
],
|
| 38 |
+
},
|
| 39 |
+
]
|
| 40 |
+
|
| 41 |
+
# Preparation for inference
|
| 42 |
+
text = self.processor.apply_chat_template(
|
| 43 |
+
conversation, tokenize=False, add_generation_prompt=True
|
| 44 |
+
)
|
| 45 |
+
image_inputs, video_inputs = process_vision_info(conversation)
|
| 46 |
+
inputs = self.processor(
|
| 47 |
+
text=[text],
|
| 48 |
+
images=image_inputs,
|
| 49 |
+
videos=video_inputs,
|
| 50 |
+
padding=True,
|
| 51 |
+
return_tensors="pt",
|
| 52 |
+
).to(self.device)
|
| 53 |
+
|
| 54 |
+
# autoregressively complete prompt
|
| 55 |
+
self.model.generation_config.pad_token_id = self.processor.tokenizer.pad_token_id
|
| 56 |
+
with torch.inference_mode():
|
| 57 |
+
output = self.model.generate(**inputs, **self.generation_args)
|
| 58 |
+
output = output[:, inputs["input_ids"].shape[-1] :]
|
| 59 |
+
return self.processor.decode(output[0], skip_special_tokens=True).strip()
|