user commited on
Commit
becd78e
·
1 Parent(s): fe293b8

Implement data persistence for improved performance and reusability

Browse files
Files changed (1) hide show
  1. app.py +41 -24
app.py CHANGED
@@ -43,12 +43,16 @@ MODEL_COMBINATIONS = {
43
  }
44
 
45
  @st.cache_resource
46
- def load_models(model_combination):
47
  try:
48
- embedding_tokenizer = AutoTokenizer.from_pretrained(MODEL_COMBINATIONS[model_combination]['embedding'])
49
- embedding_model = AutoModel.from_pretrained(MODEL_COMBINATIONS[model_combination]['embedding'])
50
- generation_tokenizer = AutoTokenizer.from_pretrained(MODEL_COMBINATIONS[model_combination]['generation'])
51
- generation_model = AutoModelForCausalLM.from_pretrained(MODEL_COMBINATIONS[model_combination]['generation'])
 
 
 
 
52
  return embedding_tokenizer, embedding_model, generation_tokenizer, generation_model
53
  except Exception as e:
54
  st.error(f"Error loading models: {str(e)}")
@@ -81,8 +85,8 @@ def create_faiss_index(embeddings):
81
  index.add(embeddings)
82
  return index
83
 
84
- def generate_response(query, tokenizer, generation_model, embedding_model, index, chunks):
85
- inputs = tokenizer(query, return_tensors="pt", padding=True, truncation=True, max_length=512)
86
  with torch.no_grad():
87
  outputs = embedding_model(**inputs)
88
  query_embedding = outputs.last_hidden_state.mean(dim=1).squeeze().numpy()
@@ -94,25 +98,28 @@ def generate_response(query, tokenizer, generation_model, embedding_model, index
94
 
95
  prompt = f"As the Muse of A.R. Ammons, respond to this query: {query}\nContext: {context}\nMuse:"
96
 
97
- input_ids = tokenizer.encode(prompt, return_tensors="pt")
98
  output = generation_model.generate(input_ids, max_new_tokens=100, num_return_sequences=1, temperature=0.7)
99
- response = tokenizer.decode(output[0], skip_special_tokens=True)
100
 
101
  muse_response = response.split("Muse:")[-1].strip()
102
  return muse_response
103
 
104
- def save_data(chunks, embeddings, index):
105
- with open('chunks.pkl', 'wb') as f:
 
106
  pickle.dump(chunks, f)
107
- np.save('embeddings.npy', embeddings)
108
- faiss.write_index(index, 'faiss_index.bin')
109
-
110
- def load_data():
111
- if os.path.exists('chunks.pkl') and os.path.exists('embeddings.npy') and os.path.exists('faiss_index.bin'):
112
- with open('chunks.pkl', 'rb') as f:
 
 
113
  chunks = pickle.load(f)
114
- embeddings = np.load('embeddings.npy')
115
- index = faiss.read_index('faiss_index.bin')
116
  return chunks, embeddings, index
117
  return None, None, None
118
 
@@ -167,12 +174,22 @@ st.info(f"Potential time saved compared to slowest option: {MODEL_COMBINATIONS[s
167
  if st.button("Load Selected Models"):
168
  with st.spinner("Loading models and data..."):
169
  embedding_tokenizer, embedding_model, generation_tokenizer, generation_model = load_models(st.session_state.model_combination)
170
- chunks = load_and_process_text('ammons_muse.txt')
171
- embeddings = create_embeddings(chunks, embedding_model)
172
- index = create_faiss_index(embeddings)
 
 
 
 
 
 
 
173
 
174
  st.session_state.models_loaded = True
175
- st.success("Models loaded successfully!")
 
 
 
176
 
177
  if 'models_loaded' not in st.session_state or not st.session_state.models_loaded:
178
  st.warning("Please load the models before chatting.")
@@ -194,7 +211,7 @@ if prompt := st.chat_input("What would you like to ask the Muse?"):
194
 
195
  with st.spinner("The Muse is contemplating..."):
196
  try:
197
- response = generate_response(prompt, tokenizer, generation_model, embedding_model, index, chunks)
198
  except Exception as e:
199
  response = f"I apologize, but I encountered an error: {str(e)}"
200
 
 
43
  }
44
 
45
  @st.cache_resource
46
+ def load_models(combination):
47
  try:
48
+ embedding_model_name = MODEL_COMBINATIONS[combination]["embedding"]
49
+ generation_model_name = MODEL_COMBINATIONS[combination]["generation"]
50
+
51
+ embedding_tokenizer = AutoTokenizer.from_pretrained(embedding_model_name)
52
+ embedding_model = AutoModel.from_pretrained(embedding_model_name)
53
+ generation_tokenizer = AutoTokenizer.from_pretrained(generation_model_name)
54
+ generation_model = AutoModelForCausalLM.from_pretrained(generation_model_name)
55
+
56
  return embedding_tokenizer, embedding_model, generation_tokenizer, generation_model
57
  except Exception as e:
58
  st.error(f"Error loading models: {str(e)}")
 
85
  index.add(embeddings)
86
  return index
87
 
88
+ def generate_response(query, embedding_tokenizer, generation_tokenizer, generation_model, embedding_model, index, chunks):
89
+ inputs = embedding_tokenizer(query, return_tensors="pt", padding=True, truncation=True, max_length=512)
90
  with torch.no_grad():
91
  outputs = embedding_model(**inputs)
92
  query_embedding = outputs.last_hidden_state.mean(dim=1).squeeze().numpy()
 
98
 
99
  prompt = f"As the Muse of A.R. Ammons, respond to this query: {query}\nContext: {context}\nMuse:"
100
 
101
+ input_ids = generation_tokenizer.encode(prompt, return_tensors="pt")
102
  output = generation_model.generate(input_ids, max_new_tokens=100, num_return_sequences=1, temperature=0.7)
103
+ response = generation_tokenizer.decode(output[0], skip_special_tokens=True, clean_up_tokenization_spaces=True)
104
 
105
  muse_response = response.split("Muse:")[-1].strip()
106
  return muse_response
107
 
108
+ def save_data(chunks, embeddings, index, model_combination):
109
+ os.makedirs('data', exist_ok=True)
110
+ with open(f'data/chunks_{model_combination}.pkl', 'wb') as f:
111
  pickle.dump(chunks, f)
112
+ np.save(f'data/embeddings_{model_combination}.npy', embeddings)
113
+ faiss.write_index(index, f'data/faiss_index_{model_combination}.bin')
114
+
115
+ def load_data(model_combination):
116
+ if os.path.exists(f'data/chunks_{model_combination}.pkl') and \
117
+ os.path.exists(f'data/embeddings_{model_combination}.npy') and \
118
+ os.path.exists(f'data/faiss_index_{model_combination}.bin'):
119
+ with open(f'data/chunks_{model_combination}.pkl', 'rb') as f:
120
  chunks = pickle.load(f)
121
+ embeddings = np.load(f'data/embeddings_{model_combination}.npy')
122
+ index = faiss.read_index(f'data/faiss_index_{model_combination}.bin')
123
  return chunks, embeddings, index
124
  return None, None, None
125
 
 
174
  if st.button("Load Selected Models"):
175
  with st.spinner("Loading models and data..."):
176
  embedding_tokenizer, embedding_model, generation_tokenizer, generation_model = load_models(st.session_state.model_combination)
177
+
178
+ # Try to load existing data
179
+ chunks, embeddings, index = load_data(st.session_state.model_combination)
180
+
181
+ # If data doesn't exist, process it and save
182
+ if chunks is None or embeddings is None or index is None:
183
+ chunks = load_and_process_text('ammons_muse.txt')
184
+ embeddings = create_embeddings(chunks, embedding_model)
185
+ index = create_faiss_index(embeddings)
186
+ save_data(chunks, embeddings, index, st.session_state.model_combination)
187
 
188
  st.session_state.models_loaded = True
189
+ st.session_state.chunks = chunks
190
+ st.session_state.embeddings = embeddings
191
+ st.session_state.index = index
192
+ st.success("Models and data loaded successfully!")
193
 
194
  if 'models_loaded' not in st.session_state or not st.session_state.models_loaded:
195
  st.warning("Please load the models before chatting.")
 
211
 
212
  with st.spinner("The Muse is contemplating..."):
213
  try:
214
+ response = generate_response(prompt, embedding_tokenizer, generation_tokenizer, generation_model, embedding_model, st.session_state.index, st.session_state.chunks)
215
  except Exception as e:
216
  response = f"I apologize, but I encountered an error: {str(e)}"
217