MATH-LLM-7B / Vision_Project.py
ALmonster's picture
Upload 17 files
9487267 verified
import math
import re
import torch
import torch.nn as nn
class IdentityMap(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x, *args, **kwargs):
return x
@property
def config(self):
return {'mm_projector_type': 'identity'}
def mlp2x_gelu(projector_type):
# mm_hidden_size = 1024
mm_hidden_size = 1280
hidden_size = 3584
mlp_gelu_match = re.match(r'^mlp(\d+)x_gelu$', projector_type)
if mlp_gelu_match:
mlp_depth = int(mlp_gelu_match.group(1))
modules = [nn.Linear(mm_hidden_size, hidden_size)]
for _ in range(1, mlp_depth):
modules.append(nn.GELU())
modules.append(nn.Linear(hidden_size, hidden_size))
return nn.Sequential(*modules)
if projector_type == 'identity':
return IdentityMap()
raise ValueError(f'Unknown projector type: {projector_type}')