CryoFM2: A Generative Foundation Model for Cryo-EM Densities

GitHub License

CryoFM2 Overview

Overview

CryoFM2 is a flow-based generative foundation model for cryo-EM density maps. It is pretrained on curated EMDB half maps to learn general priors of high-quality cryo-EM densities and can be fine-tuned for downstream tasks.

The model learns a continuous mapping from a simple Gaussian distribution to the complex distribution of cryo-EM densities, enabling stable generation and flexible adaptation. CryoFM2 can also act as a Bayesian prior, integrating naturally with task-specific likelihoods to support applications such as anisotropy-aware refinement, non-uniform reconstruction, and controlled density modification.

Model Details

CryoFM2 is pretrained on curated EMDB half maps to learn general priors of high-quality cryo-EM densities. The model can be fine-tuned for various downstream tasks such as density map enhancement and post-processing.

Pre-training Architecture:

CryoFM2 architecture for pre-training.

Fine-tuning Architecture (for EMhancer/EMReady style post-processing):

CryoFM2 architecture for fine-tuning.

Architecture

  • Architecture Type: 3D UNet
  • Input Size: 64ร—64ร—64 voxels
  • Input Channels: 2 for pre-trained model, 3 for fine-tuned model
  • Output Channels: 1
  • Down Blocks: DownBlock3D, DownBlock3D, AttnDownBlock3D, AttnDownBlock3D
  • Up Blocks: AttnUpBlock3D, AttnUpBlock3D, UpBlock3D, UpBlock3D
  • Block Output Channels: (64, 128, 256, 512)
  • Layers per Block: 2
  • Attention Head Dimension: 8
  • Normalization: GroupNorm (32 groups)
  • Activation: SiLU
  • Time Embedding: Positional encoding

Model Variants

  1. cryofm2-pretrain: Unconditional pretrained model for general density map generation
  2. cryofm2-emhancer: Fine-tuned model for density map enhancement (EMhancer style)
  3. cryofm2-emready: Fine-tuned model for density map enhancement (EMReady style)

Play with CryoFM2

Unconditional Generation (Explore Training Data Distribution)

Generate samples from the pretrained model to explore the learned data distribution:

Pretrained Model:

import torch
from mmengine import Config

from cryofm.core.utils.mrc_io import save_mrc
from cryofm.core.utils.sampling_fm import sample_from_fm
from cryofm.projects.cryofm2.lit_modules import CryoFM2Uncond

# Update the path to your model directory
model_dir = "path/to/cryofm-v2/cryofm2-pretrain"
cfg = Config.fromfile(f"{model_dir}/config.yaml")
lit_model = CryoFM2Uncond.load_from_safetensors(f"{model_dir}/model.safetensors", cfg=cfg)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

lit_model = lit_model.to(device)
lit_model.eval()
def v_xt_t(_xt, _t):
    return lit_model(_xt, _t)

# Enable bfloat16 for faster inference if your GPU supports it
with torch.no_grad(), torch.autocast("cuda", dtype=torch.bfloat16):
    out = sample_from_fm(
        v_xt_t, 
        lit_model.noise_scheduler, 
        method="euler", 
        num_steps=200, 
        num_samples=3, 
        device=lit_model.device, 
        side_shape=64
    )
    # Apply normalization if configured
    if hasattr(lit_model.cfg, "z_scale") and lit_model.cfg.z_scale.mean is not None:
        out = out * lit_model.cfg.z_scale.std + lit_model.cfg.z_scale.mean

# Save generated samples
for i in range(3):
    save_mrc(out[i].float().cpu().numpy(), f"sample-{i}.mrc", voxel_size=1.5)

Fine-tuned Models (EMhancer/EMReady):

import torch
from mmengine import Config

from cryofm.core.utils.mrc_io import save_mrc
from cryofm.core.utils.sampling_fm import sample_from_fm
from cryofm.projects.cryofm2.lit_modules import CryoFM2Cond

# Choose style: "emhancer" or "emready"
style = "emhancer"
model_dir = f"path/to/cryofm-v2/cryofm2-{style}"
cfg = Config.fromfile(f"{model_dir}/config.yaml")
lit_model = CryoFM2Cond.load_from_safetensors(f"{model_dir}/model.safetensors", cfg=cfg)
output_tag = 1 if style == "emhancer" else 0

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

lit_model = lit_model.to(device)
lit_model.eval()
def v_xt_t(_xt, _t):
    bs = _xt.shape[0]
    unconditional_generation_conds = {
        "input_cond": None,
        "output_cond": torch.tensor([output_tag] * bs).to(device),
        "vol_cond": None,  # dimension should be [bs, d, h, w]
    }
    return lit_model(_xt, _t, generation_conds=unconditional_generation_conds)

# Enable bfloat16 for faster inference if your GPU supports it
with torch.no_grad(), torch.autocast("cuda", dtype=torch.bfloat16):
    out = sample_from_fm(
        v_xt_t, 
        lit_model.noise_scheduler, 
        method="euler", 
        num_steps=200, 
        num_samples=3, 
        device=lit_model.device, 
        side_shape=64
    )
    # Apply normalization if configured
    if hasattr(lit_model.cfg, "z_scale") and lit_model.cfg.z_scale.mean is not None:
        out = out * lit_model.cfg.z_scale.std + lit_model.cfg.z_scale.mean

# Save generated samples
for i in range(3):
    save_mrc(out[i].float().cpu().numpy(), f"{style}-sample-{i}.mrc", voxel_size=1.5)

Density Map Modification

CryoFM2 supports various density map modification operations using the pretrained model as a Bayesian prior. Supported operators include:

  • denoise: Remove noise from density maps
  • inpaint: Fill missing regions (e.g., missing wedge)
  • denoise inpaint: Combined denoising and inpainting
  • non-uniform weight: Apply non-uniform weighting during reconstruction

Basic Usage:

python -m cryofm.projects.cryofm2.uncond_sampling \
    -i1 half_map_1.mrc \
    -i2 half_map_2.mrc \
    -o ./output \
    --model-dir path/to/cryofm-v2/cryofm2-pretrain \
    --op denoise \
    --norm-grad \
    --use-lamb-w

For inpainting tasks, you need to provide a RELION starfile path:

python -m cryofm.projects.cryofm2.uncond_sampling \
    -i1 half_map_1.mrc \
    -i2 half_map_2.mrc \
    -o ./output \
    --model-dir path/to/cryofm-v2/cryofm2-pretrain \
    --op inpaint \
    --data-starfile-path path/to/relion_data.star \
    --norm-grad \
    --use-lamb-w

Density Map Post-Processing

CryoFM2 provides fine-tuned models for density map enhancement in different styles, similar to EMhancer and EMReady.

EMhancer Style Enhancement

python -m cryofm.projects.cryofm2.cond_sampling \
    -i input_map.mrc \
    -o ./output_emhancer \
    --model-dir path/to/cryofm-v2/cryofm2-emhancer \
    --output-tag 1

EMReady Style Enhancement

python -m cryofm.projects.cryofm2.cond_sampling \
    -i input_map.mrc \
    -o ./output_emready \
    --model-dir path/to/cryofm-v2/cryofm2-emready \
    --output-tag 0 \
    --cfg-weight 0.5

Parameters:

  • -i: Input density map file (MRC format)
  • -o: Output directory
  • --model-dir: Path to the model directory containing config.yaml and model.safetensors
  • --output-tag: Style tag (1 for EMhancer, 0 for EMReady)
  • --cfg-weight: Classifier-free guidance weight (optional, default varies by model)

Performance Tips

  • Multi-GPU Inference: Use accelerate launch for faster inference on multiple GPUs:
    NCCL_DEBUG=ERROR accelerate launch --num_processes=${NUM_GPUS} --main_process_port=8881 \
        python -m cryofm.projects.cryofm2.cond_sampling ...
    
  • Mixed Precision: Use --bf16 flag when available to reduce memory usage and speed up inference.
  • Batch Processing: Adjust batch size based on your GPU memory capacity.

Limitations

  • Input size is fixed at 64ร—64ร—64 voxels
  • Model performance may vary depending on the input density map quality
  • Fine-tuned models are optimized for specific enhancement styles

Ethical Considerations

This model is intended for scientific research and structural biology applications. Users should:

  • Ensure proper attribution when using generated structures
  • Validate generated structures through experimental verification
  • Be aware of potential biases in the training data
  • Use the model responsibly and in accordance with scientific best practices

Citation

TBA

License

This model is released under the Apache 2.0 License. See the LICENSE file for details.

Acknowledgments

This work is developed by the ByteDance Seed Team. For more information, visit:

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. ๐Ÿ™‹ Ask for provider support