tianzhechu commited on
Commit
df2fe71
·
1 Parent(s): 76a7db5
Files changed (2) hide show
  1. requirements.txt +2 -1
  2. src/streamlit_app.py +7 -2
requirements.txt CHANGED
@@ -1,4 +1,5 @@
1
  altair
2
  pandas
3
  streamlit
4
- transformers
 
 
1
  altair
2
  pandas
3
  streamlit
4
+ transformers
5
+ torch
src/streamlit_app.py CHANGED
@@ -1,10 +1,15 @@
1
  import streamlit as st
2
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
 
3
 
4
  @st.cache_resource
5
  def load_model():
6
- tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-small")
7
- model = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-small")
 
 
 
 
8
  return pipeline("text2text-generation", model=model, tokenizer=tokenizer)
9
 
10
  st.set_page_config(page_title="LLM Demo", layout="centered")
 
1
  import streamlit as st
2
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
3
+ import os
4
 
5
  @st.cache_resource
6
  def load_model():
7
+ # Create a local cache directory
8
+ cache_dir = "./model_cache"
9
+ os.makedirs(cache_dir, exist_ok=True)
10
+
11
+ tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-small", cache_dir=cache_dir)
12
+ model = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-small", cache_dir=cache_dir)
13
  return pipeline("text2text-generation", model=model, tokenizer=tokenizer)
14
 
15
  st.set_page_config(page_title="LLM Demo", layout="centered")