Spaces:
Sleeping
Sleeping
user
commited on
Commit
·
becd78e
1
Parent(s):
fe293b8
Implement data persistence for improved performance and reusability
Browse files
app.py
CHANGED
@@ -43,12 +43,16 @@ MODEL_COMBINATIONS = {
|
|
43 |
}
|
44 |
|
45 |
@st.cache_resource
|
46 |
-
def load_models(
|
47 |
try:
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
|
|
|
|
|
|
|
|
52 |
return embedding_tokenizer, embedding_model, generation_tokenizer, generation_model
|
53 |
except Exception as e:
|
54 |
st.error(f"Error loading models: {str(e)}")
|
@@ -81,8 +85,8 @@ def create_faiss_index(embeddings):
|
|
81 |
index.add(embeddings)
|
82 |
return index
|
83 |
|
84 |
-
def generate_response(query,
|
85 |
-
inputs =
|
86 |
with torch.no_grad():
|
87 |
outputs = embedding_model(**inputs)
|
88 |
query_embedding = outputs.last_hidden_state.mean(dim=1).squeeze().numpy()
|
@@ -94,25 +98,28 @@ def generate_response(query, tokenizer, generation_model, embedding_model, index
|
|
94 |
|
95 |
prompt = f"As the Muse of A.R. Ammons, respond to this query: {query}\nContext: {context}\nMuse:"
|
96 |
|
97 |
-
input_ids =
|
98 |
output = generation_model.generate(input_ids, max_new_tokens=100, num_return_sequences=1, temperature=0.7)
|
99 |
-
response =
|
100 |
|
101 |
muse_response = response.split("Muse:")[-1].strip()
|
102 |
return muse_response
|
103 |
|
104 |
-
def save_data(chunks, embeddings, index):
|
105 |
-
|
|
|
106 |
pickle.dump(chunks, f)
|
107 |
-
np.save('
|
108 |
-
faiss.write_index(index, '
|
109 |
-
|
110 |
-
def load_data():
|
111 |
-
if os.path.exists('
|
112 |
-
|
|
|
|
|
113 |
chunks = pickle.load(f)
|
114 |
-
embeddings = np.load('
|
115 |
-
index = faiss.read_index('
|
116 |
return chunks, embeddings, index
|
117 |
return None, None, None
|
118 |
|
@@ -167,12 +174,22 @@ st.info(f"Potential time saved compared to slowest option: {MODEL_COMBINATIONS[s
|
|
167 |
if st.button("Load Selected Models"):
|
168 |
with st.spinner("Loading models and data..."):
|
169 |
embedding_tokenizer, embedding_model, generation_tokenizer, generation_model = load_models(st.session_state.model_combination)
|
170 |
-
|
171 |
-
|
172 |
-
index =
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
173 |
|
174 |
st.session_state.models_loaded = True
|
175 |
-
st.
|
|
|
|
|
|
|
176 |
|
177 |
if 'models_loaded' not in st.session_state or not st.session_state.models_loaded:
|
178 |
st.warning("Please load the models before chatting.")
|
@@ -194,7 +211,7 @@ if prompt := st.chat_input("What would you like to ask the Muse?"):
|
|
194 |
|
195 |
with st.spinner("The Muse is contemplating..."):
|
196 |
try:
|
197 |
-
response = generate_response(prompt,
|
198 |
except Exception as e:
|
199 |
response = f"I apologize, but I encountered an error: {str(e)}"
|
200 |
|
|
|
43 |
}
|
44 |
|
45 |
@st.cache_resource
|
46 |
+
def load_models(combination):
|
47 |
try:
|
48 |
+
embedding_model_name = MODEL_COMBINATIONS[combination]["embedding"]
|
49 |
+
generation_model_name = MODEL_COMBINATIONS[combination]["generation"]
|
50 |
+
|
51 |
+
embedding_tokenizer = AutoTokenizer.from_pretrained(embedding_model_name)
|
52 |
+
embedding_model = AutoModel.from_pretrained(embedding_model_name)
|
53 |
+
generation_tokenizer = AutoTokenizer.from_pretrained(generation_model_name)
|
54 |
+
generation_model = AutoModelForCausalLM.from_pretrained(generation_model_name)
|
55 |
+
|
56 |
return embedding_tokenizer, embedding_model, generation_tokenizer, generation_model
|
57 |
except Exception as e:
|
58 |
st.error(f"Error loading models: {str(e)}")
|
|
|
85 |
index.add(embeddings)
|
86 |
return index
|
87 |
|
88 |
+
def generate_response(query, embedding_tokenizer, generation_tokenizer, generation_model, embedding_model, index, chunks):
|
89 |
+
inputs = embedding_tokenizer(query, return_tensors="pt", padding=True, truncation=True, max_length=512)
|
90 |
with torch.no_grad():
|
91 |
outputs = embedding_model(**inputs)
|
92 |
query_embedding = outputs.last_hidden_state.mean(dim=1).squeeze().numpy()
|
|
|
98 |
|
99 |
prompt = f"As the Muse of A.R. Ammons, respond to this query: {query}\nContext: {context}\nMuse:"
|
100 |
|
101 |
+
input_ids = generation_tokenizer.encode(prompt, return_tensors="pt")
|
102 |
output = generation_model.generate(input_ids, max_new_tokens=100, num_return_sequences=1, temperature=0.7)
|
103 |
+
response = generation_tokenizer.decode(output[0], skip_special_tokens=True, clean_up_tokenization_spaces=True)
|
104 |
|
105 |
muse_response = response.split("Muse:")[-1].strip()
|
106 |
return muse_response
|
107 |
|
108 |
+
def save_data(chunks, embeddings, index, model_combination):
|
109 |
+
os.makedirs('data', exist_ok=True)
|
110 |
+
with open(f'data/chunks_{model_combination}.pkl', 'wb') as f:
|
111 |
pickle.dump(chunks, f)
|
112 |
+
np.save(f'data/embeddings_{model_combination}.npy', embeddings)
|
113 |
+
faiss.write_index(index, f'data/faiss_index_{model_combination}.bin')
|
114 |
+
|
115 |
+
def load_data(model_combination):
|
116 |
+
if os.path.exists(f'data/chunks_{model_combination}.pkl') and \
|
117 |
+
os.path.exists(f'data/embeddings_{model_combination}.npy') and \
|
118 |
+
os.path.exists(f'data/faiss_index_{model_combination}.bin'):
|
119 |
+
with open(f'data/chunks_{model_combination}.pkl', 'rb') as f:
|
120 |
chunks = pickle.load(f)
|
121 |
+
embeddings = np.load(f'data/embeddings_{model_combination}.npy')
|
122 |
+
index = faiss.read_index(f'data/faiss_index_{model_combination}.bin')
|
123 |
return chunks, embeddings, index
|
124 |
return None, None, None
|
125 |
|
|
|
174 |
if st.button("Load Selected Models"):
|
175 |
with st.spinner("Loading models and data..."):
|
176 |
embedding_tokenizer, embedding_model, generation_tokenizer, generation_model = load_models(st.session_state.model_combination)
|
177 |
+
|
178 |
+
# Try to load existing data
|
179 |
+
chunks, embeddings, index = load_data(st.session_state.model_combination)
|
180 |
+
|
181 |
+
# If data doesn't exist, process it and save
|
182 |
+
if chunks is None or embeddings is None or index is None:
|
183 |
+
chunks = load_and_process_text('ammons_muse.txt')
|
184 |
+
embeddings = create_embeddings(chunks, embedding_model)
|
185 |
+
index = create_faiss_index(embeddings)
|
186 |
+
save_data(chunks, embeddings, index, st.session_state.model_combination)
|
187 |
|
188 |
st.session_state.models_loaded = True
|
189 |
+
st.session_state.chunks = chunks
|
190 |
+
st.session_state.embeddings = embeddings
|
191 |
+
st.session_state.index = index
|
192 |
+
st.success("Models and data loaded successfully!")
|
193 |
|
194 |
if 'models_loaded' not in st.session_state or not st.session_state.models_loaded:
|
195 |
st.warning("Please load the models before chatting.")
|
|
|
211 |
|
212 |
with st.spinner("The Muse is contemplating..."):
|
213 |
try:
|
214 |
+
response = generate_response(prompt, embedding_tokenizer, generation_tokenizer, generation_model, embedding_model, st.session_state.index, st.session_state.chunks)
|
215 |
except Exception as e:
|
216 |
response = f"I apologize, but I encountered an error: {str(e)}"
|
217 |
|