File size: 748 Bytes
779abe8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import argparse

from open_flamingo.eval.models.mistral_model import EvalModel
from open_flamingo.train.distributed import init_distributed_device, world_info_from_env

parser = argparse.ArgumentParser()

parser.add_argument(
    "--model",
    type=str,
    help="Model name. Currently only `OpenFlamingo` is supported.",
    default="open_flamingo",
)


def main():
    model_args = {
        "config_yaml": "configs/mlm_multi_source_v1_zephyr_ift_zero2.yaml",
        "checkpoint_path": "cruise_logs/zephyr_freeze_ift/mp_rank_00_model_states.pt",
        "precision": "bf16",
    }
    eval_model = EvalModel(model_args)

    tokenizer = eval_model.tokenizer
    # tokenizer.save_pretrained('hf_weights')


if __name__ == "__main__":
    main()