Commit
·
8a98078
1
Parent(s):
0bbebb9
Update app.py
Browse files
app.py
CHANGED
@@ -13,7 +13,7 @@ def summarize_function(notes):
|
|
13 |
st.markdown("<h1 style='text-align: center; color: #489DDB;'>GPT Clinical Notes Summarizer</h1>", unsafe_allow_html=True)
|
14 |
st.markdown("<h6 style='text-align: center; color: #489DDB;'>by Bryan Mildort</h1>", unsafe_allow_html=True)
|
15 |
|
16 |
-
from transformers import AutoTokenizer,
|
17 |
# from accelerate import infer_auto_device_map
|
18 |
import torch
|
19 |
device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
@@ -21,11 +21,15 @@ device_str = f"""Device being used: {device}"""
|
|
21 |
st.write(device_str)
|
22 |
|
23 |
# model = AutoModelForCausalLM.from_pretrained("bryanmildort/gpt-clinical-notes-summarizer", load_in_8bit=True, device_map="auto")
|
24 |
-
#
|
25 |
|
26 |
-
tokenizer = AutoTokenizer.from_pretrained("bryanmildort/gpt-clinical-notes-summarizer")
|
27 |
-
checkpoint = "bryanmildort/gpt-clinical-notes-summarizer"
|
28 |
-
model = GPTJForCausalLM.from_pretrained(checkpoint)
|
|
|
|
|
|
|
|
|
29 |
|
30 |
# device_map = infer_auto_device_map(model, dtype="float16")
|
31 |
# st.write(device_map)
|
|
|
13 |
st.markdown("<h1 style='text-align: center; color: #489DDB;'>GPT Clinical Notes Summarizer</h1>", unsafe_allow_html=True)
|
14 |
st.markdown("<h6 style='text-align: center; color: #489DDB;'>by Bryan Mildort</h1>", unsafe_allow_html=True)
|
15 |
|
16 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
17 |
# from accelerate import infer_auto_device_map
|
18 |
import torch
|
19 |
device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
|
|
21 |
st.write(device_str)
|
22 |
|
23 |
# model = AutoModelForCausalLM.from_pretrained("bryanmildort/gpt-clinical-notes-summarizer", load_in_8bit=True, device_map="auto")
|
24 |
+
#
|
25 |
|
26 |
+
# tokenizer = AutoTokenizer.from_pretrained("bryanmildort/gpt-clinical-notes-summarizer")
|
27 |
+
# checkpoint = "bryanmildort/gpt-clinical-notes-summarizer"
|
28 |
+
# model = GPTJForCausalLM.from_pretrained(checkpoint)
|
29 |
+
|
30 |
+
model = AutoModelForCausalLM.from_pretrained("bryanmildort/gpt-neo-small-notes-summarizer")
|
31 |
+
tokenizer = AutoTokenizer.from_pretrained("bryanmildort/gpt-neo-small-notes-summarizer")
|
32 |
+
model = model.to(device)
|
33 |
|
34 |
# device_map = infer_auto_device_map(model, dtype="float16")
|
35 |
# st.write(device_map)
|