File size: 748 Bytes
779abe8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 |
import argparse
from open_flamingo.eval.models.mistral_model import EvalModel
from open_flamingo.train.distributed import init_distributed_device, world_info_from_env
parser = argparse.ArgumentParser()
parser.add_argument(
"--model",
type=str,
help="Model name. Currently only `OpenFlamingo` is supported.",
default="open_flamingo",
)
def main():
model_args = {
"config_yaml": "configs/mlm_multi_source_v1_zephyr_ift_zero2.yaml",
"checkpoint_path": "cruise_logs/zephyr_freeze_ift/mp_rank_00_model_states.pt",
"precision": "bf16",
}
eval_model = EvalModel(model_args)
tokenizer = eval_model.tokenizer
# tokenizer.save_pretrained('hf_weights')
if __name__ == "__main__":
main()
|