ejschwartz commited on
Commit
33e5ee4
·
1 Parent(s): 0ec1c46

Try device map

Browse files
Files changed (1) hide show
  1. app.py +8 -4
app.py CHANGED
@@ -16,11 +16,15 @@ huggingface_hub.login(token=hf_key)
16
 
17
  tokenizer = AutoTokenizer.from_pretrained("bigcode/starcoderbase-3b")
18
  vardecoder_model = AutoModelForCausalLM.from_pretrained(
19
- "ejschwartz/resym-vardecoder", torch_dtype=torch.bfloat16
20
- ).cuda()
 
 
21
  fielddecoder_model = AutoModelForCausalLM.from_pretrained(
22
- "ejschwartz/resym-fielddecoder", torch_dtype=torch.bfloat16
23
- ).cuda()
 
 
24
 
25
  make_gradio_client = lambda: Client("https://ejschwartz-resym-field-helper.hf.space/")
26
 
 
16
 
17
  tokenizer = AutoTokenizer.from_pretrained("bigcode/starcoderbase-3b")
18
  vardecoder_model = AutoModelForCausalLM.from_pretrained(
19
+ "ejschwartz/resym-vardecoder",
20
+ torch_dtype=torch.bfloat16,
21
+ device_map="auto"
22
+ )
23
  fielddecoder_model = AutoModelForCausalLM.from_pretrained(
24
+ "ejschwartz/resym-fielddecoder",
25
+ torch_dtype=torch.bfloat16,
26
+ device_map="auto"
27
+ )
28
 
29
  make_gradio_client = lambda: Client("https://ejschwartz-resym-field-helper.hf.space/")
30