import torch from transformers import AutoModelForCausalLM, Mistral3ForConditionalGeneration, AutoTokenizer from mistral_common.tokens.tokenizers.mistral import MistralTokenizer from tqdm import tqdm def copy_devstral_weights_to_mistral(devstral_id, mistral_id, output_path): """ Copy Devstral language model weights to Mistral-Small model, preserving Mistral's vision components. """ print(f"Loading Devstral model from {devstral_id}...") devstral_model = AutoModelForCausalLM.from_pretrained( devstral_id, torch_dtype=torch.bfloat16, device_map="cpu" ) print(f"Loading Mistral-Small model from {mistral_id}...") mistral_model = Mistral3ForConditionalGeneration.from_pretrained( mistral_id, torch_dtype=torch.bfloat16, device_map="cpu" ) print("Fixing generation configuration...") if hasattr(mistral_model, 'generation_config') and mistral_model.generation_config is not None: gen_config = mistral_model.generation_config # Fix the conflicting settings if hasattr(gen_config, 'do_sample') and hasattr(gen_config, 'temperature'): if not gen_config.do_sample and gen_config.temperature is not None: # Option 1: Remove temperature (recommended for deterministic generation) gen_config.temperature = None print(" - Removed temperature setting (keeping do_sample=False)") # Option 2: Alternative - enable sampling # gen_config.do_sample = True # print(" - Enabled sampling (keeping temperature=0.15)") # Validate the config try: gen_config.validate() print(" - Generation config is now valid") except Exception as e: print(f" - Warning: Generation config validation failed: {e}") devstral_state = devstral_model.state_dict() mistral_state = mistral_model.state_dict() print("Copying weights from Devstral to Mistral-Small...") weight_mappings = [ ("model.embed_tokens.weight", "model.language_model.embed_tokens.weight"), ("model.norm.weight", "model.language_model.norm.weight") ] for devstral_key, mistral_key in weight_mappings: print(f"Copying {devstral_key} to {mistral_key}") if devstral_key not in devstral_state or mistral_key not in mistral_state: # abort if any key is missing raise KeyError(f"Missing key: {devstral_key} or {mistral_key}") mistral_state[mistral_key] = devstral_state[devstral_key].clone() # Copy all other weights from Devstral to Mistral for i in tqdm(range(40), desc="Copying layer weights"): layer_mappings = [ # layers.[0-39].attention.wk.weight,[1024,5120],BF16 # (f"model.layers.{i}.self_attn.wk.weight", f"layers.{i}.self_attn.wk.weight"), # (f"model.layers.{i}.self_attn.wo.weight", f"layers.{i}.self_attn.wo.weight"), # (f"model.layers.{i}.self_attn.wq.weight", f"layers.{i}.self_attn.wq.weight"), # (f"model.layers.{i}.self_attn.wv.weight", f"layers.{i}.self_attn.wv.weight"), # (f"model.layers.{i}.attention_norm.weight", f"layers.{i}.attention_norm.weight"), # (f"model.layers.{i}.feed_forward.w1.weight", f"layers.{i}.feed_forward.w1.weight"), # (f"model.layers.{i}.feed_forward.w2.weight", f"layers.{i}.feed_forward.w2.weight"), # (f"model.layers.{i}.feed_forward.w3.weight", f"layers.{i}.feed_forward.w3.weight"), # (f"model.layers.{i}.ffn_norm.weight", f"layers.{i}.ffn_norm.weight"), (f"model.layers.{i}.input_layernorm.weight", f"model.language_model.layers.{i}.input_layernorm.weight"), (f"model.layers.{i}.mlp.down_proj.weight", f"model.language_model.layers.{i}.mlp.down_proj.weight"), (f"model.layers.{i}.mlp.gate_proj.weight", f"model.language_model.layers.{i}.mlp.gate_proj.weight"), (f"model.layers.{i}.mlp.up_proj.weight", f"model.language_model.layers.{i}.mlp.up_proj.weight"), (f"model.layers.{i}.post_attention_layernorm.weight", f"model.language_model.layers.{i}.post_attention_layernorm.weight"), (f"model.layers.{i}.self_attn.k_proj.weight", f"model.language_model.layers.{i}.self_attn.k_proj.weight"), (f"model.layers.{i}.self_attn.o_proj.weight", f"model.language_model.layers.{i}.self_attn.o_proj.weight"), (f"model.layers.{i}.self_attn.q_proj.weight", f"model.language_model.layers.{i}.self_attn.q_proj.weight"), (f"model.layers.{i}.self_attn.v_proj.weight", f"model.language_model.layers.{i}.self_attn.v_proj.weight"), ] for devstral_key, mistral_key in layer_mappings: if devstral_key not in devstral_state or mistral_key not in mistral_state: raise KeyError(f"Missing key: {devstral_key} or {mistral_key}") mistral_state[mistral_key] = devstral_state[devstral_key].clone() print("Saving updated Mistral-Small model...") mistral_model.load_state_dict(mistral_state) mistral_model.save_pretrained(output_path, safe_serialization=True) if __name__ == "__main__": devstral_id = "mistralai/Devstral-Small-2507" mistral_id = "mistralai/Mistral-Small-3.2-24B-Instruct-2506" output_path = "./Devstral-Vision-Small-2507" model = copy_devstral_weights_to_mistral(devstral_id, mistral_id, output_path)