File size: 1,240 Bytes
3283950
 
 
5bca4fe
 
 
3283950
 
5bca4fe
3283950
 
 
5bca4fe
 
3283950
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
from sonique import get_pretrained_model
from sonique.interface.gradio import create_ui
import json 

import torch

def main(args):
    torch.manual_seed(42)

    interface = create_ui(model_config_path = args.model_config, ckpt_path=args.ckpt_path, pretrained_name=args.pretrained_name, pretransform_ckpt_path=args.pretransform_ckpt_path)
    interface.queue()
    interface.launch(share=True, auth=(args.username, args.password) if args.username is not None else None)

if __name__ == "__main__":
    import argparse
    parser = argparse.ArgumentParser(description='Run gradio interface')
    parser.add_argument('--pretrained-name', type=str, help='Name of pretrained model', required=False)
    parser.add_argument('--model-config', type=str, help='Path to model config', required=False)
    parser.add_argument('--ckpt-path', type=str, help='Path to model checkpoint', required=False)
    parser.add_argument('--pretransform-ckpt-path', type=str, help='Optional to model pretransform checkpoint', required=False)
    parser.add_argument('--username', type=str, help='Gradio username', required=False)
    parser.add_argument('--password', type=str, help='Gradio password', required=False)
    args = parser.parse_args()
    main(args)