import streamlit as st import torch from transformers import AutoTokenizer, AutoModel, AutoModelForCausalLM import faiss import numpy as np import os import pickle import warnings warnings.filterwarnings("ignore", category=FutureWarning, module="transformers") @st.cache_resource def load_models(): try: tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased") embedding_model = AutoModel.from_pretrained("distilbert-base-uncased") generation_model = AutoModelForCausalLM.from_pretrained("gpt2") return tokenizer, embedding_model, generation_model except Exception as e: st.error(f"Error loading models: {str(e)}") return None, None, None @st.cache_data def load_and_process_text(file_path): try: with open(file_path, 'r', encoding='utf-8') as file: text = file.read() chunks = [text[i:i+512] for i in range(0, len(text), 512)] return chunks except Exception as e: st.error(f"Error loading text file: {str(e)}") return [] @st.cache_data def create_embeddings(chunks, _embedding_model): embeddings = [] for chunk in chunks: inputs = tokenizer(chunk, return_tensors="pt", padding=True, truncation=True, max_length=512) with torch.no_grad(): outputs = _embedding_model(**inputs) embeddings.append(outputs.last_hidden_state.mean(dim=1).squeeze().numpy()) return np.array(embeddings) @st.cache_resource def create_faiss_index(embeddings): index = faiss.IndexFlatL2(embeddings.shape[1]) index.add(embeddings) return index def generate_response(query, tokenizer, generation_model, embedding_model, index, chunks): inputs = tokenizer(query, return_tensors="pt", padding=True, truncation=True, max_length=512) with torch.no_grad(): outputs = embedding_model(**inputs) query_embedding = outputs.last_hidden_state.mean(dim=1).squeeze().numpy() k = 3 _, I = index.search(query_embedding.reshape(1, -1), k) context = " ".join([chunks[i] for i in I[0]]) prompt = f"As the Muse of A.R. Ammons, respond to this query: {query}\nContext: {context}\nMuse:" input_ids = tokenizer.encode(prompt, return_tensors="pt") output = generation_model.generate(input_ids, max_new_tokens=100, num_return_sequences=1, temperature=0.7) response = tokenizer.decode(output[0], skip_special_tokens=True) muse_response = response.split("Muse:")[-1].strip() return muse_response def save_data(chunks, embeddings, index): with open('chunks.pkl', 'wb') as f: pickle.dump(chunks, f) np.save('embeddings.npy', embeddings) faiss.write_index(index, 'faiss_index.bin') def load_data(): if os.path.exists('chunks.pkl') and os.path.exists('embeddings.npy') and os.path.exists('faiss_index.bin'): with open('chunks.pkl', 'rb') as f: chunks = pickle.load(f) embeddings = np.load('embeddings.npy') index = faiss.read_index('faiss_index.bin') return chunks, embeddings, index return None, None, None # Streamlit UI st.set_page_config(page_title="A.R. Ammons' Muse Chatbot", page_icon="🎭") st.title("A.R. Ammons' Muse Chatbot 🎭") st.markdown(""" """, unsafe_allow_html=True) st.markdown('

Chat with the Muse of A.R. Ammons. Ask questions or discuss poetry!

', unsafe_allow_html=True) # Load models and data with st.spinner("Loading models and data..."): tokenizer, embedding_model, generation_model = load_models() chunks, embeddings, index = load_data() if chunks is None or embeddings is None or index is None: chunks = load_and_process_text('ammons_muse.txt') embeddings = create_embeddings(chunks, embedding_model) index = create_faiss_index(embeddings) save_data(chunks, embeddings, index) if tokenizer is None or embedding_model is None or generation_model is None or not chunks: st.error("Failed to load necessary components. Please try again later.") st.stop() # Initialize chat history if 'messages' not in st.session_state: st.session_state.messages = [] # Display chat messages from history on app rerun for message in st.session_state.messages: with st.chat_message(message["role"]): st.markdown(message["content"]) # React to user input if prompt := st.chat_input("What would you like to ask the Muse?"): st.chat_message("user").markdown(prompt) st.session_state.messages.append({"role": "user", "content": prompt}) with st.spinner("The Muse is contemplating..."): try: response = generate_response(prompt, tokenizer, generation_model, embedding_model, index, chunks) except Exception as e: response = f"I apologize, but I encountered an error: {str(e)}" with st.chat_message("assistant"): st.markdown(response) st.session_state.messages.append({"role": "assistant", "content": response}) # Add a button to clear chat history if st.button("Clear Chat History"): st.session_state.messages = [] st.experimental_rerun() # Add a footer st.markdown("---") st.markdown("*Powered by the spirit of A.R. Ammons and the magic of AI*")