import streamlit as st from transformers import ( T5ForConditionalGeneration, T5Tokenizer, pipeline, AutoTokenizer, AutoModelForCausalLM ) import torch # ----- Streamlit page config ----- st.set_page_config(page_title="Chat", layout="wide") # ----- Sidebar: Model controls ----- st.sidebar.title("Model Controls") model_options = { "1": "karthikeyan-r/calculation_model_11k", "2": "karthikeyan-r/slm-custom-model_6k" } model_choice = st.sidebar.selectbox( "Select Model", options=list(model_options.values()) ) load_model_button = st.sidebar.button("Load Model") clear_conversation_button = st.sidebar.button("Clear Conversation") clear_model_button = st.sidebar.button("Clear Model") # ----- Session States ----- if "model" not in st.session_state: st.session_state["model"] = None if "tokenizer" not in st.session_state: st.session_state["tokenizer"] = None if "qa_pipeline" not in st.session_state: st.session_state["qa_pipeline"] = None if "conversation" not in st.session_state: # We'll store conversation as a list of dicts, # e.g. [{"role": "assistant", "content": "Hello..."}, {"role": "user", "content": "..."}] st.session_state["conversation"] = [] # ----- Load Model ----- if load_model_button: with st.spinner("Loading model..."): try: if model_choice == model_options["1"]: # Load the calculation model tokenizer = AutoTokenizer.from_pretrained(model_choice, cache_dir="./model_cache") model = AutoModelForCausalLM.from_pretrained(model_choice, cache_dir="./model_cache") # Add special tokens if needed if tokenizer.pad_token is None: tokenizer.add_special_tokens({'pad_token': '[PAD]'}) model.resize_token_embeddings(len(tokenizer)) if tokenizer.eos_token is None: tokenizer.add_special_tokens({'eos_token': '[EOS]'}) model.resize_token_embeddings(len(tokenizer)) model.config.pad_token_id = tokenizer.pad_token_id model.config.eos_token_id = tokenizer.eos_token_id st.session_state["model"] = model st.session_state["tokenizer"] = tokenizer st.session_state["qa_pipeline"] = None # Not needed for calculation model elif model_choice == model_options["2"]: # Load the T5 model for general QA device = 0 if torch.cuda.is_available() else -1 model = T5ForConditionalGeneration.from_pretrained(model_choice, cache_dir="./model_cache") tokenizer = T5Tokenizer.from_pretrained(model_choice, cache_dir="./model_cache") qa_pipe = pipeline( "text2text-generation", model=model, tokenizer=tokenizer, device=device ) st.session_state["model"] = model st.session_state["tokenizer"] = tokenizer st.session_state["qa_pipeline"] = qa_pipe # If conversation is empty, insert a welcome message if len(st.session_state["conversation"]) == 0: st.session_state["conversation"].append({ "role": "assistant", "content": "Hello! I’m your assistant. How can I help you today?" }) st.success("Model loaded successfully and ready!") except Exception as e: st.error(f"Error loading model: {e}") # ----- Clear Model ----- if clear_model_button: st.session_state["model"] = None st.session_state["tokenizer"] = None st.session_state["qa_pipeline"] = None st.success("Model cleared.") # ----- Clear Conversation ----- if clear_conversation_button: st.session_state["conversation"] = [] st.success("Conversation cleared.") # ----- Title ----- st.title("Chat Conversation UI") user_input = None if st.session_state["qa_pipeline"]: # T5 pipeline user_input = st.chat_input("Enter your query:") if user_input: # 1) Save user message st.session_state["conversation"].append({ "role": "user", "content": user_input }) # 2) Generate assistant response try: response = st.session_state["qa_pipeline"]( f"Q: {user_input}", max_length=250 ) answer = response[0]["generated_text"] except Exception as e: answer = f"Error: {str(e)}" # 3) Append assistant message to conversation st.session_state["conversation"].append({ "role": "assistant", "content": answer }) elif st.session_state["model"] and (model_choice == model_options["1"]): # Calculation model user_input = st.chat_input("Enter your query for calculation:") if user_input: # 1) Save user message st.session_state["conversation"].append({ "role": "user", "content": user_input }) # 2) Generate assistant response tokenizer = st.session_state["tokenizer"] model = st.session_state["model"] try: inputs = tokenizer( f"Input: {user_input}\nOutput:", return_tensors="pt", padding=True, truncation=True ) input_ids = inputs.input_ids attention_mask = inputs.attention_mask output = model.generate( input_ids=input_ids, attention_mask=attention_mask, max_length=250, pad_token_id=tokenizer.pad_token_id, eos_token_id=tokenizer.eos_token_id, do_sample=False ) decoded_output = tokenizer.decode( output[0], skip_special_tokens=True ) # Extract answer after 'Output:' if present if "Output:" in decoded_output: answer = decoded_output.split("Output:")[-1].strip() else: answer = decoded_output.strip() except Exception as e: answer = f"Error: {str(e)}" # 3) Append assistant message to conversation st.session_state["conversation"].append({ "role": "assistant", "content": answer }) else: # If no model is loaded: st.info("No model is loaded. Please select a model and click 'Load Model' from the sidebar.") for message in st.session_state["conversation"]: if message["role"] == "user": with st.chat_message("user"): st.write(message["content"]) else: with st.chat_message("assistant"): st.write(message["content"])