bryanmildort commited on
Commit
8a98078
·
1 Parent(s): 0bbebb9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -5
app.py CHANGED
@@ -13,7 +13,7 @@ def summarize_function(notes):
13
  st.markdown("<h1 style='text-align: center; color: #489DDB;'>GPT Clinical Notes Summarizer</h1>", unsafe_allow_html=True)
14
  st.markdown("<h6 style='text-align: center; color: #489DDB;'>by Bryan Mildort</h1>", unsafe_allow_html=True)
15
 
16
- from transformers import AutoTokenizer, GPTJForCausalLM
17
  # from accelerate import infer_auto_device_map
18
  import torch
19
  device = "cuda:0" if torch.cuda.is_available() else "cpu"
@@ -21,11 +21,15 @@ device_str = f"""Device being used: {device}"""
21
  st.write(device_str)
22
 
23
  # model = AutoModelForCausalLM.from_pretrained("bryanmildort/gpt-clinical-notes-summarizer", load_in_8bit=True, device_map="auto")
24
- # model = model.to(device)
25
 
26
- tokenizer = AutoTokenizer.from_pretrained("bryanmildort/gpt-clinical-notes-summarizer")
27
- checkpoint = "bryanmildort/gpt-clinical-notes-summarizer"
28
- model = GPTJForCausalLM.from_pretrained(checkpoint)
 
 
 
 
29
 
30
  # device_map = infer_auto_device_map(model, dtype="float16")
31
  # st.write(device_map)
 
13
  st.markdown("<h1 style='text-align: center; color: #489DDB;'>GPT Clinical Notes Summarizer</h1>", unsafe_allow_html=True)
14
  st.markdown("<h6 style='text-align: center; color: #489DDB;'>by Bryan Mildort</h1>", unsafe_allow_html=True)
15
 
16
+ from transformers import AutoTokenizer, AutoModelForCausalLM
17
  # from accelerate import infer_auto_device_map
18
  import torch
19
  device = "cuda:0" if torch.cuda.is_available() else "cpu"
 
21
  st.write(device_str)
22
 
23
  # model = AutoModelForCausalLM.from_pretrained("bryanmildort/gpt-clinical-notes-summarizer", load_in_8bit=True, device_map="auto")
24
+ #
25
 
26
+ # tokenizer = AutoTokenizer.from_pretrained("bryanmildort/gpt-clinical-notes-summarizer")
27
+ # checkpoint = "bryanmildort/gpt-clinical-notes-summarizer"
28
+ # model = GPTJForCausalLM.from_pretrained(checkpoint)
29
+
30
+ model = AutoModelForCausalLM.from_pretrained("bryanmildort/gpt-neo-small-notes-summarizer")
31
+ tokenizer = AutoTokenizer.from_pretrained("bryanmildort/gpt-neo-small-notes-summarizer")
32
+ model = model.to(device)
33
 
34
  # device_map = infer_auto_device_map(model, dtype="float16")
35
  # st.write(device_map)