Magma-Gaming / vlms /magma.py
jw2yang's picture
add magma arena
4f00e93
raw
history blame
1.75 kB
from transformers import AutoModelForCausalLM, AutoProcessor
import torch
import torch.nn as nn
from PIL import Image
import requests
model_id = "microsoft/Magma-8B"
class MagmaAgent(nn.Module):
def __init__(self, device="cuda", dtype=torch.float16):
super().__init__()
self.model = AutoModelForCausalLM.from_pretrained(model_id, trust_remote_code=True, torch_dtype=dtype, low_cpu_mem_usage=True)
self.processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
self.dtype = dtype
self.device = device
self.model.to(device)
self.generation_args = {
"max_new_tokens": 10,
"temperature": 0.3,
"do_sample": True,
"use_cache": True,
"num_beams": 1,
}
def generate_response(self, image, question):
convs = [
{"role": "system", "content": "You are an agent that can see, talk, and act."},
{"role": "user", "content": "<image_start><image><image_end>\n{}".format(question)},
]
prompt = self.processor.tokenizer.apply_chat_template(convs, tokenize=False, add_generation_prompt=True)
inputs = self.processor(images=[image], texts=prompt, return_tensors="pt").to(self.dtype).to(self.device)
inputs['pixel_values'] = inputs['pixel_values'].unsqueeze(0)
inputs['image_sizes'] = inputs['image_sizes'].unsqueeze(0)
with torch.inference_mode():
generate_ids = self.model.generate(**inputs, **self.generation_args)
generate_ids = generate_ids[:, inputs["input_ids"].shape[-1] :]
action = self.processor.decode(generate_ids[0], skip_special_tokens=True).strip()
return action