File size: 5,701 Bytes
8adcf73
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104

import torch
from transformers import AutoModelForCausalLM, Mistral3ForConditionalGeneration, AutoTokenizer
from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
from tqdm import tqdm

def copy_devstral_weights_to_mistral(devstral_id, mistral_id, output_path):
    """
    Copy Devstral language model weights to Mistral-Small model,
    preserving Mistral's vision components.
    """
    
    print(f"Loading Devstral model from {devstral_id}...")
    devstral_model = AutoModelForCausalLM.from_pretrained(
        devstral_id,
        torch_dtype=torch.bfloat16,
        device_map="cpu"
    )
    
    print(f"Loading Mistral-Small model from {mistral_id}...")
    mistral_model = Mistral3ForConditionalGeneration.from_pretrained(
        mistral_id,
        torch_dtype=torch.bfloat16,
        device_map="cpu"
    )

    print("Fixing generation configuration...")
    if hasattr(mistral_model, 'generation_config') and mistral_model.generation_config is not None:
        gen_config = mistral_model.generation_config
        
        # Fix the conflicting settings
        if hasattr(gen_config, 'do_sample') and hasattr(gen_config, 'temperature'):
            if not gen_config.do_sample and gen_config.temperature is not None:
                # Option 1: Remove temperature (recommended for deterministic generation)
                gen_config.temperature = None
                print("  - Removed temperature setting (keeping do_sample=False)")
                
                # Option 2: Alternative - enable sampling
                # gen_config.do_sample = True
                # print("  - Enabled sampling (keeping temperature=0.15)")
        
        # Validate the config
        try:
            gen_config.validate()
            print("  - Generation config is now valid")
        except Exception as e:
            print(f"  - Warning: Generation config validation failed: {e}")

    devstral_state = devstral_model.state_dict()
    mistral_state = mistral_model.state_dict()

    print("Copying weights from Devstral to Mistral-Small...")

    weight_mappings = [
        ("model.embed_tokens.weight", "model.language_model.embed_tokens.weight"),
        ("model.norm.weight", "model.language_model.norm.weight")
    ]
    for devstral_key, mistral_key in weight_mappings:
        print(f"Copying {devstral_key} to {mistral_key}")
        if devstral_key not in devstral_state or mistral_key not in mistral_state:
            # abort if any key is missing
            raise KeyError(f"Missing key: {devstral_key} or {mistral_key}")
        mistral_state[mistral_key] = devstral_state[devstral_key].clone()

    # Copy all other weights from Devstral to Mistral
    for i in tqdm(range(40), desc="Copying layer weights"):
        layer_mappings = [
            # layers.[0-39].attention.wk.weight,[1024,5120],BF16
            # (f"model.layers.{i}.self_attn.wk.weight",             f"layers.{i}.self_attn.wk.weight"),
            # (f"model.layers.{i}.self_attn.wo.weight",             f"layers.{i}.self_attn.wo.weight"),
            # (f"model.layers.{i}.self_attn.wq.weight",             f"layers.{i}.self_attn.wq.weight"),
            # (f"model.layers.{i}.self_attn.wv.weight",             f"layers.{i}.self_attn.wv.weight"),
            # (f"model.layers.{i}.attention_norm.weight",           f"layers.{i}.attention_norm.weight"),
            # (f"model.layers.{i}.feed_forward.w1.weight",          f"layers.{i}.feed_forward.w1.weight"),
            # (f"model.layers.{i}.feed_forward.w2.weight",          f"layers.{i}.feed_forward.w2.weight"),
            # (f"model.layers.{i}.feed_forward.w3.weight",          f"layers.{i}.feed_forward.w3.weight"),
            # (f"model.layers.{i}.ffn_norm.weight",                 f"layers.{i}.ffn_norm.weight"),
            (f"model.layers.{i}.input_layernorm.weight",          f"model.language_model.layers.{i}.input_layernorm.weight"),
            (f"model.layers.{i}.mlp.down_proj.weight",            f"model.language_model.layers.{i}.mlp.down_proj.weight"),
            (f"model.layers.{i}.mlp.gate_proj.weight",            f"model.language_model.layers.{i}.mlp.gate_proj.weight"),
            (f"model.layers.{i}.mlp.up_proj.weight",              f"model.language_model.layers.{i}.mlp.up_proj.weight"),
            (f"model.layers.{i}.post_attention_layernorm.weight", f"model.language_model.layers.{i}.post_attention_layernorm.weight"),
            (f"model.layers.{i}.self_attn.k_proj.weight",         f"model.language_model.layers.{i}.self_attn.k_proj.weight"),
            (f"model.layers.{i}.self_attn.o_proj.weight",         f"model.language_model.layers.{i}.self_attn.o_proj.weight"),
            (f"model.layers.{i}.self_attn.q_proj.weight",         f"model.language_model.layers.{i}.self_attn.q_proj.weight"),
            (f"model.layers.{i}.self_attn.v_proj.weight",         f"model.language_model.layers.{i}.self_attn.v_proj.weight"),
        ]

        for devstral_key, mistral_key in layer_mappings:
            if devstral_key not in devstral_state or mistral_key not in mistral_state:
                raise KeyError(f"Missing key: {devstral_key} or {mistral_key}")
            mistral_state[mistral_key] = devstral_state[devstral_key].clone()

    print("Saving updated Mistral-Small model...")

    mistral_model.load_state_dict(mistral_state) 
    mistral_model.save_pretrained(output_path, safe_serialization=True)

if __name__ == "__main__":
    devstral_id = "mistralai/Devstral-Small-2507"
    mistral_id = "mistralai/Mistral-Small-3.2-24B-Instruct-2506"
    output_path = "./Devstral-Vision-Small-2507"
    
    model = copy_devstral_weights_to_mistral(devstral_id, mistral_id, output_path)