Spaces:
				
			
			
	
			
			
		Paused
		
	
	
	
			
			
	
	
	
	
		
		
		Paused
		
	
		X-Lai
		
	commited on
		
		
					Commit 
							
							·
						
						c899f8b
	
1
								Parent(s):
							
							3d9efe2
								
fix bug in inference
Browse filesFormer-commit-id: 4b1776203d3410cd71b2d4720fc7d9cc61d1db3c
- app.py +1 -3
- chat.py +2 -3
- merge_lora_weights_and_save_hf_model.py +0 -0
- model/LISA.py +3 -1
- train_ds.py +0 -1
    	
        app.py
    CHANGED
    
    | @@ -92,7 +92,6 @@ if args.load_in_4bit: | |
| 92 | 
             
                kwargs.update(
         | 
| 93 | 
             
                    {
         | 
| 94 | 
             
                        "torch_dtype": torch.half,
         | 
| 95 | 
            -
                        "device_map": "auto",
         | 
| 96 | 
             
                        "load_in_4bit": True,
         | 
| 97 | 
             
                        "quantization_config": BitsAndBytesConfig(
         | 
| 98 | 
             
                            load_in_4bit=True,
         | 
| @@ -107,7 +106,6 @@ elif args.load_in_8bit: | |
| 107 | 
             
                kwargs.update(
         | 
| 108 | 
             
                    {
         | 
| 109 | 
             
                        "torch_dtype": torch.half,
         | 
| 110 | 
            -
                        "device_map": "auto",
         | 
| 111 | 
             
                        "quantization_config": BitsAndBytesConfig(
         | 
| 112 | 
             
                            llm_int8_skip_modules=["visual_model"],
         | 
| 113 | 
             
                            load_in_8bit=True,
         | 
| @@ -116,7 +114,7 @@ elif args.load_in_8bit: | |
| 116 | 
             
                )
         | 
| 117 |  | 
| 118 | 
             
            model = LISAForCausalLM.from_pretrained(
         | 
| 119 | 
            -
                args.version, low_cpu_mem_usage=True, seg_token_idx=args.seg_token_idx, **kwargs
         | 
| 120 | 
             
            )
         | 
| 121 |  | 
| 122 | 
             
            model.config.eos_token_id = tokenizer.eos_token_id
         | 
|  | |
| 92 | 
             
                kwargs.update(
         | 
| 93 | 
             
                    {
         | 
| 94 | 
             
                        "torch_dtype": torch.half,
         | 
|  | |
| 95 | 
             
                        "load_in_4bit": True,
         | 
| 96 | 
             
                        "quantization_config": BitsAndBytesConfig(
         | 
| 97 | 
             
                            load_in_4bit=True,
         | 
|  | |
| 106 | 
             
                kwargs.update(
         | 
| 107 | 
             
                    {
         | 
| 108 | 
             
                        "torch_dtype": torch.half,
         | 
|  | |
| 109 | 
             
                        "quantization_config": BitsAndBytesConfig(
         | 
| 110 | 
             
                            llm_int8_skip_modules=["visual_model"],
         | 
| 111 | 
             
                            load_in_8bit=True,
         | 
|  | |
| 114 | 
             
                )
         | 
| 115 |  | 
| 116 | 
             
            model = LISAForCausalLM.from_pretrained(
         | 
| 117 | 
            +
                args.version, low_cpu_mem_usage=True, vision_tower=args.vision_tower, seg_token_idx=args.seg_token_idx, **kwargs
         | 
| 118 | 
             
            )
         | 
| 119 |  | 
| 120 | 
             
            model.config.eos_token_id = tokenizer.eos_token_id
         | 
    	
        chat.py
    CHANGED
    
    | @@ -90,7 +90,6 @@ def main(args): | |
| 90 | 
             
                    kwargs.update(
         | 
| 91 | 
             
                        {
         | 
| 92 | 
             
                            "torch_dtype": torch.half,
         | 
| 93 | 
            -
                            "device_map": "auto",
         | 
| 94 | 
             
                            "load_in_4bit": True,
         | 
| 95 | 
             
                            "quantization_config": BitsAndBytesConfig(
         | 
| 96 | 
             
                                load_in_4bit=True,
         | 
| @@ -105,7 +104,6 @@ def main(args): | |
| 105 | 
             
                    kwargs.update(
         | 
| 106 | 
             
                        {
         | 
| 107 | 
             
                            "torch_dtype": torch.half,
         | 
| 108 | 
            -
                            "device_map": "auto",
         | 
| 109 | 
             
                            "quantization_config": BitsAndBytesConfig(
         | 
| 110 | 
             
                                llm_int8_skip_modules=["visual_model"],
         | 
| 111 | 
             
                                load_in_8bit=True,
         | 
| @@ -114,7 +112,7 @@ def main(args): | |
| 114 | 
             
                    )
         | 
| 115 |  | 
| 116 | 
             
                model = LISAForCausalLM.from_pretrained(
         | 
| 117 | 
            -
                    args.version, low_cpu_mem_usage=True, seg_token_idx=args.seg_token_idx, **kwargs
         | 
| 118 | 
             
                )
         | 
| 119 |  | 
| 120 | 
             
                model.config.eos_token_id = tokenizer.eos_token_id
         | 
| @@ -223,6 +221,7 @@ def main(args): | |
| 223 |  | 
| 224 | 
             
                    text_output = tokenizer.decode(output_ids, skip_special_tokens=False)
         | 
| 225 | 
             
                    text_output = text_output.replace("\n", "").replace("  ", " ")
         | 
|  | |
| 226 |  | 
| 227 | 
             
                    for i, pred_mask in enumerate(pred_masks):
         | 
| 228 | 
             
                        if pred_mask.shape[0] == 0:
         | 
|  | |
| 90 | 
             
                    kwargs.update(
         | 
| 91 | 
             
                        {
         | 
| 92 | 
             
                            "torch_dtype": torch.half,
         | 
|  | |
| 93 | 
             
                            "load_in_4bit": True,
         | 
| 94 | 
             
                            "quantization_config": BitsAndBytesConfig(
         | 
| 95 | 
             
                                load_in_4bit=True,
         | 
|  | |
| 104 | 
             
                    kwargs.update(
         | 
| 105 | 
             
                        {
         | 
| 106 | 
             
                            "torch_dtype": torch.half,
         | 
|  | |
| 107 | 
             
                            "quantization_config": BitsAndBytesConfig(
         | 
| 108 | 
             
                                llm_int8_skip_modules=["visual_model"],
         | 
| 109 | 
             
                                load_in_8bit=True,
         | 
|  | |
| 112 | 
             
                    )
         | 
| 113 |  | 
| 114 | 
             
                model = LISAForCausalLM.from_pretrained(
         | 
| 115 | 
            +
                    args.version, low_cpu_mem_usage=True, vision_tower=args.vision_tower, seg_token_idx=args.seg_token_idx, **kwargs
         | 
| 116 | 
             
                )
         | 
| 117 |  | 
| 118 | 
             
                model.config.eos_token_id = tokenizer.eos_token_id
         | 
|  | |
| 221 |  | 
| 222 | 
             
                    text_output = tokenizer.decode(output_ids, skip_special_tokens=False)
         | 
| 223 | 
             
                    text_output = text_output.replace("\n", "").replace("  ", " ")
         | 
| 224 | 
            +
                    print("text_output: ", text_output)
         | 
| 225 |  | 
| 226 | 
             
                    for i, pred_mask in enumerate(pred_masks):
         | 
| 227 | 
             
                        if pred_mask.shape[0] == 0:
         | 
    	
        merge_lora_weights_and_save_hf_model.py
    CHANGED
    
    | 
            File without changes
         | 
    	
        model/LISA.py
    CHANGED
    
    | @@ -134,7 +134,9 @@ class LISAForCausalLM(LlavaLlamaForCausalLM): | |
| 134 | 
             
                        self.ce_loss_weight = kwargs.pop("ce_loss_weight", None)
         | 
| 135 | 
             
                        self.dice_loss_weight = kwargs.pop("dice_loss_weight", None)
         | 
| 136 | 
             
                        self.bce_loss_weight = kwargs.pop("bce_loss_weight", None)
         | 
| 137 | 
            -
             | 
|  | |
|  | |
| 138 | 
             
                    self.seg_token_idx = kwargs.pop("seg_token_idx")
         | 
| 139 |  | 
| 140 | 
             
                    super().__init__(config)
         | 
|  | |
| 134 | 
             
                        self.ce_loss_weight = kwargs.pop("ce_loss_weight", None)
         | 
| 135 | 
             
                        self.dice_loss_weight = kwargs.pop("dice_loss_weight", None)
         | 
| 136 | 
             
                        self.bce_loss_weight = kwargs.pop("bce_loss_weight", None)
         | 
| 137 | 
            +
                    else:
         | 
| 138 | 
            +
                        config.mm_vision_tower = config.vision_tower
         | 
| 139 | 
            +
                        
         | 
| 140 | 
             
                    self.seg_token_idx = kwargs.pop("seg_token_idx")
         | 
| 141 |  | 
| 142 | 
             
                    super().__init__(config)
         | 
    	
        train_ds.py
    CHANGED
    
    | @@ -90,7 +90,6 @@ def parse_args(args): | |
| 90 | 
             
                parser.add_argument("--eval_only", action="store_true", default=False)
         | 
| 91 | 
             
                parser.add_argument("--vision_pretrained", default="PATH_TO_SAM_ViT-H", type=str)
         | 
| 92 | 
             
                parser.add_argument("--out_dim", default=256, type=int)
         | 
| 93 | 
            -
                parser.add_argument("--weight", default="", type=str)
         | 
| 94 | 
             
                parser.add_argument("--resume", default="", type=str)
         | 
| 95 | 
             
                parser.add_argument("--print_freq", default=1, type=int)
         | 
| 96 | 
             
                parser.add_argument("--start_epoch", default=0, type=int)
         | 
|  | |
| 90 | 
             
                parser.add_argument("--eval_only", action="store_true", default=False)
         | 
| 91 | 
             
                parser.add_argument("--vision_pretrained", default="PATH_TO_SAM_ViT-H", type=str)
         | 
| 92 | 
             
                parser.add_argument("--out_dim", default=256, type=int)
         | 
|  | |
| 93 | 
             
                parser.add_argument("--resume", default="", type=str)
         | 
| 94 | 
             
                parser.add_argument("--print_freq", default=1, type=int)
         | 
| 95 | 
             
                parser.add_argument("--start_epoch", default=0, type=int)
         |