alessandro trinca tornidor commited on
Commit
8facf64
·
1 Parent(s): ca06190

[refactor] try to wrap model generation code in a dedicated function

Browse files
Files changed (1) hide show
  1. app.py +78 -81
app.py CHANGED
@@ -116,91 +116,88 @@ def preprocess(
116
  return x
117
 
118
 
119
- args = parse_args(sys.argv[1:])
120
- os.makedirs(args.vis_save_path, exist_ok=True)
121
-
122
- # Create model
123
- tokenizer = AutoTokenizer.from_pretrained(
124
- args.version,
125
- cache_dir=None,
126
- model_max_length=args.model_max_length,
127
- padding_side="right",
128
- use_fast=False,
129
- )
130
- tokenizer.pad_token = tokenizer.unk_token
131
- args.seg_token_idx = tokenizer("[SEG]", add_special_tokens=False).input_ids[0]
132
-
133
- torch_dtype = torch.float32
134
- if args.precision == "bf16":
135
- torch_dtype = torch.bfloat16
136
- elif args.precision == "fp16":
137
- torch_dtype = torch.half
138
-
139
- kwargs = {"torch_dtype": torch_dtype}
140
- if args.load_in_4bit:
141
- kwargs.update(
142
- {
143
- "torch_dtype": torch.half,
144
- "load_in_4bit": True,
145
- "quantization_config": BitsAndBytesConfig(
146
- load_in_4bit=True,
147
- bnb_4bit_compute_dtype=torch.float16,
148
- bnb_4bit_use_double_quant=True,
149
- bnb_4bit_quant_type="nf4",
150
- llm_int8_skip_modules=["visual_model"],
151
- ),
152
- }
153
- )
154
- elif args.load_in_8bit:
155
- kwargs.update(
156
- {
157
- "torch_dtype": torch.half,
158
- "quantization_config": BitsAndBytesConfig(
159
- llm_int8_skip_modules=["visual_model"],
160
- load_in_8bit=True,
161
- ),
162
- }
163
  )
164
-
165
- model = LISAForCausalLM.from_pretrained(
166
- args.version, low_cpu_mem_usage=True, vision_tower=args.vision_tower, seg_token_idx=args.seg_token_idx, **kwargs
167
- )
168
-
169
- model.config.eos_token_id = tokenizer.eos_token_id
170
- model.config.bos_token_id = tokenizer.bos_token_id
171
- model.config.pad_token_id = tokenizer.pad_token_id
172
-
173
- model.get_model().initialize_vision_modules(model.get_model().config)
174
- vision_tower = model.get_model().get_vision_tower()
175
- vision_tower.to(dtype=torch_dtype)
176
-
177
- if args.precision == "bf16":
178
- model = model.bfloat16().cuda()
179
- elif (
180
- args.precision == "fp16" and (not args.load_in_4bit) and (not args.load_in_8bit)
181
- ):
182
- vision_tower = model.get_model().get_vision_tower()
183
- model.model.vision_tower = None
184
- import deepspeed
185
-
186
- model_engine = deepspeed.init_inference(
187
- model=model,
188
- dtype=torch.half,
189
- replace_with_kernel_inject=True,
190
- replace_method="auto",
 
 
 
 
 
 
 
191
  )
192
- model = model_engine.module
193
- model.model.vision_tower = vision_tower.half().cuda()
194
- elif args.precision == "fp32":
195
- model = model.float().cuda()
196
-
197
- vision_tower = model.get_model().get_vision_tower()
198
- vision_tower.to(device=args.local_rank)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
199
 
200
- clip_image_processor = CLIPImageProcessor.from_pretrained(model.config.vision_tower)
201
- transform = ResizeLongestSide(args.image_size)
202
 
203
- model.eval()
 
204
 
205
 
206
  ## to be implemented
 
116
  return x
117
 
118
 
119
+ def get_model(args_to_parse):
120
+ os.makedirs(args_to_parse.vis_save_path, exist_ok=True)
121
+
122
+ # global tokenizer, tokenizer
123
+ # Create model
124
+ _tokenizer = AutoTokenizer.from_pretrained(
125
+ args_to_parse.version,
126
+ cache_dir=None,
127
+ model_max_length=args_to_parse.model_max_length,
128
+ padding_side="right",
129
+ use_fast=False,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
130
  )
131
+ _tokenizer.pad_token = _tokenizer.unk_token
132
+ args_to_parse.seg_token_idx = _tokenizer("[SEG]", add_special_tokens=False).input_ids[0]
133
+ torch_dtype = torch.float32
134
+ if args_to_parse.precision == "bf16":
135
+ torch_dtype = torch.bfloat16
136
+ elif args_to_parse.precision == "fp16":
137
+ torch_dtype = torch.half
138
+ kwargs = {"torch_dtype": torch_dtype}
139
+ if args_to_parse.load_in_4bit:
140
+ kwargs.update(
141
+ {
142
+ "torch_dtype": torch.half,
143
+ "load_in_4bit": True,
144
+ "quantization_config": BitsAndBytesConfig(
145
+ load_in_4bit=True,
146
+ bnb_4bit_compute_dtype=torch.float16,
147
+ bnb_4bit_use_double_quant=True,
148
+ bnb_4bit_quant_type="nf4",
149
+ llm_int8_skip_modules=["visual_model"],
150
+ ),
151
+ }
152
+ )
153
+ elif args_to_parse.load_in_8bit:
154
+ kwargs.update(
155
+ {
156
+ "torch_dtype": torch.half,
157
+ "quantization_config": BitsAndBytesConfig(
158
+ llm_int8_skip_modules=["visual_model"],
159
+ load_in_8bit=True,
160
+ ),
161
+ }
162
+ )
163
+ _model = LISAForCausalLM.from_pretrained(
164
+ args_to_parse.version, low_cpu_mem_usage=True, vision_tower=args_to_parse.vision_tower, seg_token_idx=args_to_parse.seg_token_idx, **kwargs
165
  )
166
+ _model.config.eos_token_id = _tokenizer.eos_token_id
167
+ _model.config.bos_token_id = _tokenizer.bos_token_id
168
+ _model.config.pad_token_id = _tokenizer.pad_token_id
169
+ _model.get_model().initialize_vision_modules(_model.get_model().config)
170
+ vision_tower = _model.get_model().get_vision_tower()
171
+ vision_tower.to(dtype=torch_dtype)
172
+ if args_to_parse.precision == "bf16":
173
+ _model = _model.bfloat16().cuda()
174
+ elif (
175
+ args_to_parse.precision == "fp16" and (not args_to_parse.load_in_4bit) and (not args_to_parse.load_in_8bit)
176
+ ):
177
+ vision_tower = _model.get_model().get_vision_tower()
178
+ _model.model.vision_tower = None
179
+ import deepspeed
180
+
181
+ model_engine = deepspeed.init_inference(
182
+ model=_model,
183
+ dtype=torch.half,
184
+ replace_with_kernel_inject=True,
185
+ replace_method="auto",
186
+ )
187
+ _model = model_engine.module
188
+ _model.model.vision_tower = vision_tower.half().cuda()
189
+ elif args_to_parse.precision == "fp32":
190
+ _model = _model.float().cuda()
191
+ vision_tower = _model.get_model().get_vision_tower()
192
+ vision_tower.to(device=args_to_parse.local_rank)
193
+ _clip_image_processor = CLIPImageProcessor.from_pretrained(_model.config.vision_tower)
194
+ _transform = ResizeLongestSide(args_to_parse.image_size)
195
+ _model.eval()
196
+ return _model, _clip_image_processor, _tokenizer, _transform
197
 
 
 
198
 
199
+ args = parse_args(sys.argv[1:])
200
+ model, clip_image_processor, tokenizer, transform = get_model(args)
201
 
202
 
203
  ## to be implemented