Magma-Gaming / app.py
jw2yang's picture
add magma arena
4f00e93
raw
history blame
7.03 kB
import os
os.system('pip install flash-attn --no-build-isolation')
os.system("pip install gradio==4.44.1")
import pygame
import numpy as np
import gradio as gr
import time
import torch
from PIL import Image
from transformers import AutoModelForCausalLM, AutoProcessor
import re
import random
from vlms.magma import MagmaAgent
from vlms.llavaov import LLaVAOVAgent
from vlms.qwen2vl import Qwen2VLAgent
pygame.mixer.quit() # Disable sound
# Constants
WIDTH, HEIGHT = 800, 800
GRID_SIZE = 80
WHITE = (255, 255, 255)
GREEN = (34, 139, 34) # Forest green - more like an apple
RED = (200, 50, 50)
BLACK = (0, 0, 0)
GRAY = (128, 128, 128)
YELLOW = (218, 165, 32) # Golden yellow color
# Directions
UP = (0, -1)
DOWN = (0, 1)
LEFT = (-1, 0)
RIGHT = (1, 0)
STATIC = (0, 0)
ACTIONS = ["up", "down", "left", "right", "static"]
dtype = torch.bfloat16
agent_1 = MagmaAgent("cuda:0", dtype)
agent_2 = Qwen2VLAgent("cuda:0", dtype)
magma_img = pygame.image.load("./assets/images/magma_game_thin.png")
magma_img = pygame.transform.scale(magma_img, (GRID_SIZE, GRID_SIZE))
class MagmaFindGPU:
def __init__(self):
self.reset()
self.step_count = 0
def reset(self):
self.snake = [(5, 5)]
self.direction = RIGHT
self.score = 0
self.game_over = False
self.step_count = 0
self.place_target()
def place_target(self):
while True:
target_x = np.random.randint(1, WIDTH // GRID_SIZE - 1)
target_y = np.random.randint(1, HEIGHT // GRID_SIZE - 1)
if (target_x, target_y) not in self.snake:
self.target = (target_x, target_y)
break
def step(self, action):
if action == "up":
self.direction = UP
elif action == "down":
self.direction = DOWN
elif action == "left":
self.direction = LEFT
elif action == "right":
self.direction = RIGHT
elif action == "static":
self.direction = STATIC
# if self.game_over:
# self.reset()
# return self.render(), self.score
new_head = (self.snake[0][0] + self.direction[0], self.snake[0][1] + self.direction[1])
if new_head[0] < 0 or new_head[1] < 0 or new_head[0] >= WIDTH // GRID_SIZE or new_head[1] >= HEIGHT // GRID_SIZE:
# self.game_over = True
return self.render(), self.score
self.snake = [new_head] # Keep only the head (single block snake)
self.step_count += 1
# Check if the target is covered by four surrounding squares
head_x, head_y = self.snake[0]
neighbors = set([(head_x, head_y - 1), (head_x, head_y + 1), (head_x - 1, head_y), (head_x + 1, head_y)])
if neighbors.issuperset(set([self.target])):
self.score += 1
self.place_target()
return self.render(), self.score
def render(self):
pygame.init()
surface = pygame.Surface((WIDTH, HEIGHT))
surface.fill(BLACK)
head_x, head_y = self.snake[0]
surface.blit(magma_img, (head_x * GRID_SIZE, head_y * GRID_SIZE))
# pygame.draw.rect(surface, RED, (self.snake[0][0] * GRID_SIZE, self.snake[0][1] * GRID_SIZE, GRID_SIZE, GRID_SIZE))
pygame.draw.rect(surface, GREEN, (self.target[0] * GRID_SIZE, self.target[1] * GRID_SIZE, GRID_SIZE, GRID_SIZE))
# Draw four surrounding squares with labels
head_x, head_y = self.snake[0]
neighbors = [(head_x, head_y - 1), (head_x, head_y + 1), (head_x - 1, head_y), (head_x + 1, head_y)]
labels = ["1", "2", "3", "4"]
font = pygame.font.Font(None, 48)
# clone surface
surface_nomark = surface.copy()
for i, (nx, ny) in enumerate(neighbors):
if 0 <= nx < WIDTH // GRID_SIZE and 0 <= ny < HEIGHT // GRID_SIZE:
pygame.draw.rect(surface, RED, (nx * GRID_SIZE, ny * GRID_SIZE, GRID_SIZE, GRID_SIZE), GRID_SIZE)
# pygame.draw.rect(surface_nomark, RED, (nx * GRID_SIZE, ny * GRID_SIZE, GRID_SIZE, GRID_SIZE), GRID_SIZE)
text = font.render(labels[i], True, WHITE)
text_rect = text.get_rect(center=(nx * GRID_SIZE + GRID_SIZE // 2, ny * GRID_SIZE + GRID_SIZE // 2))
surface.blit(text, text_rect)
return np.array(pygame.surfarray.array3d(surface_nomark)).swapaxes(0, 1), np.array(pygame.surfarray.array3d(surface)).swapaxes(0, 1)
def get_state(self):
return self.render()
game_1 = MagmaFindGPU()
game_2 = MagmaFindGPU()
def play_game(game, agent):
state, state_som = game.get_state()
pil_img = Image.fromarray(state_som)
action = agent.generate_response(pil_img, "Which mark is closer to green block? Answer with a single number.")
# extract mark id fro action use re
# print(agent.__class__.__name__, action)
match = re.search(r'\d+', action)
if match:
action = match.group(0)
if action.isdigit() and 1 <= int(action) <= 4:
action = ACTIONS[int(action) - 1]
else:
# random choose one from the pool
action = random.choice(ACTIONS[:-1])
else:
action = random.choice(ACTIONS[:-1])
img, score = game.step(action)
return img[0], f"Score: {score}"
def play_game_1():
return play_game(game_1, agent_1)
def play_game_2():
return play_game(game_2, agent_2)
def reset_games():
game_1.reset()
game_2.reset()
return game_1.render()[0], "Score: 0", game_2.render()[0], "Score: 0"
MARKDOWN = """
<div align="center">
<h2>Magma: A Foundation Model for Multimodal AI Agents</h2>
\[[arXiv Paper](https://www.arxiv.org/pdf/2502.13130)\] &nbsp; \[[Project Page](https://microsoft.github.io/Magma/)\] &nbsp; \[[Github Repo](https://github.com/microsoft/Magma)\] &nbsp; \[[Hugging Face Model](https://huggingface.co/microsoft/Magma-8B)\] &nbsp;
<h3>Magma Arena: A battle between two agents to collect the green blocks by automatically moving up, down, left and right.</h3>
This demo is powered by [Gradio](https://gradio.app/).
</div>
"""
with gr.Blocks() as interface:
gr.Markdown(MARKDOWN)
with gr.Row():
with gr.Column():
img_output_1 = gr.Image(label="{}".format(agent_1.__class__.__name__))
score_output_1 = gr.Text(label="Score 1")
with gr.Column():
img_output_2 = gr.Image(label="{}".format(agent_2.__class__.__name__))
score_output_2 = gr.Text(label="Score 2")
start_btn = gr.Button("Start/Reset Game")
interface.load(fn=play_game_1, every=1, inputs=[], outputs=[img_output_1, score_output_1])
interface.load(fn=play_game_2, every=1, inputs=[], outputs=[img_output_2, score_output_2])
start_btn.click(fn=reset_games, inputs=[], outputs=[img_output_1, score_output_1, img_output_2, score_output_2])
interface.launch(server_port=7861)