indic_bot / streamlit_app.py
akashraut's picture
Update streamlit_app.py
afb7b3f verified
# streamlit_app.py
# A robust Streamlit app with proper error handling and fallback options
import streamlit as st
import torch
import logging
from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM
# Updated LangChain imports for modern versions
from langchain_community.llms import HuggingFacePipeline
from langchain.prompts import PromptTemplate
from langchain.chains import LLMChain
from langchain.memory import ConversationBufferMemory
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
# CORE MODEL LOGIC (Rebuilt with LangChain and Error Handling)
# -----------------------------------------------------------------------------
class LangChainBot:
def __init__(self):
"""
Loads the models and wraps them in LangChain components with fallback options.
"""
self.chain = None
self.translator = None
self.memory = None
try:
# Check CUDA availability
device = 0 if torch.cuda.is_available() else -1
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
st.info(f"Using device: {'CUDA' if device == 0 else 'CPU'}")
# Try to load the main model with error handling
self._load_main_model(device, torch_dtype)
# Try to load the translator with error handling
self._load_translator(device)
except Exception as e:
logger.error(f"Fatal error during initialization: {e}")
st.error(f"Fatal: Could not initialize the bot. Error: {e}")
def _load_main_model(self, device, torch_dtype):
"""Load the main generation model with fallback options."""
models_to_try = [
"ai4bharat/IndicBARTSS",
"google/flan-t5-small", # Fallback option
"t5-small" # Another fallback
]
for model_name in models_to_try:
try:
st.info(f"Attempting to load model: {model_name}")
# Try loading with pipeline first
generator_pipeline = pipeline(
"text2text-generation",
model=model_name,
device=device,
torch_dtype=torch_dtype,
max_new_tokens=100, # Reduced from 150
repetition_penalty=1.5, # Increased from 1.2
do_sample=True,
temperature=0.7,
top_p=0.9,
no_repeat_ngram_size=3, # Prevent repetition
trust_remote_code=True
)
# Wrap in LangChain LLM
llm = HuggingFacePipeline(pipeline=generator_pipeline)
# Create prompt template
template = """You are a helpful AI assistant. Please provide a clear and concise response to the user's question.
Previous conversation:
{history}
User: {input}
Assistant:"""
prompt_template = PromptTemplate(
input_variables=["history", "input"],
template=template
)
# Set up memory
self.memory = ConversationBufferMemory(memory_key="history")
# Create the chain
self.chain = LLMChain(
llm=llm,
prompt=prompt_template,
verbose=True,
memory=self.memory
)
st.success(f"Successfully loaded model: {model_name}")
return # Success, exit the loop
except Exception as e:
logger.warning(f"Failed to load {model_name}: {e}")
st.warning(f"Failed to load {model_name}, trying next option...")
continue
raise Exception("All model loading attempts failed")
def _load_translator(self, device):
"""Load the translator with fallback options."""
translators_to_try = [
"Helsinki-NLP/opus-mt-en-hi", # More reliable fallback for English-Hindi
"ai4bharat/indictrans2-indic-indic-1B",
]
for translator_name in translators_to_try:
try:
st.info(f"Attempting to load translator: {translator_name}")
self.translator = pipeline(
"translation",
model=translator_name,
device=device,
trust_remote_code=True
)
st.success(f"Successfully loaded translator: {translator_name}")
return # Success
except Exception as e:
logger.warning(f"Failed to load translator {translator_name}: {e}")
st.warning(f"Failed to load translator {translator_name}, trying next option...")
continue
st.warning("No translator loaded - translation features will be limited")
def _translate(self, text, source_lang, target_lang):
"""Translation logic with improved error handling."""
if not self.translator or source_lang == target_lang:
return text
try:
# Define language codes for indictrans2
indictrans_codes = {
'english': 'eng_Latn',
'hindi': 'hin_Deva',
'tamil': 'tam_Taml',
'telugu': 'tel_Telu'
}
# Try indictrans2 format first
if source_lang in indictrans_codes and target_lang in indictrans_codes:
try:
result = self.translator(
text,
src_lang=indictrans_codes[source_lang],
tgt_lang=indictrans_codes[target_lang]
)
if result and len(result) > 0 and 'translation_text' in result[0]:
return result[0]['translation_text']
except Exception as e:
logger.warning(f"Indictrans2 translation failed: {e}")
# Fallback: Try simple pipeline format
try:
result = self.translator(text)
if result and len(result) > 0:
if 'translation_text' in result[0]:
return result[0]['translation_text']
elif 'generated_text' in result[0]:
return result[0]['generated_text']
except Exception as e:
logger.warning(f"Simple translation failed: {e}")
except Exception as e:
logger.warning(f"Translation failed: {e}")
# Don't show warning to user for every translation failure
return text
def get_response(self, user_message, input_lang, output_lang):
"""Generate response with comprehensive error handling."""
if not self.chain:
return "Error: The LangChain chain is not initialized. Please check the logs above."
try:
# Clean the input message
user_message = user_message.strip()
# For now, let's work primarily in English to avoid translation issues
# Only translate if specifically needed and working
if input_lang == 'english':
processed_message = user_message
else:
# Try translation, but fallback to original if it fails
translated = self._translate(user_message, input_lang, 'english')
processed_message = translated if translated != user_message else user_message
# Generate response with input validation
if len(processed_message.strip()) == 0:
return "I didn't receive a valid message. Please try again."
# Generate response
response = self.chain.run(input=processed_message)
# Clean up the response
response = response.strip()
# Remove any repetitive patterns
words = response.split()
if len(words) > 10:
# Check for excessive repetition
word_counts = {}
for word in words:
word_counts[word] = word_counts.get(word, 0) + 1
# If any word appears more than 5 times, it's likely repetitive
max_count = max(word_counts.values()) if word_counts else 0
if max_count > 5:
# Generate a simple fallback response
response = f"I understand you said '{processed_message[:50]}...' Let me provide a helpful response to that."
# Translate output if needed and different from English
if output_lang != 'english' and output_lang != input_lang:
final_response = self._translate(response, 'english', output_lang)
# If translation fails, return English response
return final_response if final_response != response else response
else:
return response
except Exception as e:
logger.error(f"Error generating response: {e}")
return f"I apologize, but I encountered an error while processing your request. Please try rephrasing your message."
# -----------------------------------------------------------------------------
# STREAMLIT UI WITH BETTER ERROR HANDLING
# -----------------------------------------------------------------------------
st.set_page_config(
page_title="LangChain Model Interface",
page_icon="πŸ€–",
layout="centered"
)
st.title("πŸ€– LangChain Model Interface")
st.markdown("*Multi-language conversational AI powered by LangChain*")
# Initialize the bot with progress tracking
@st.cache_resource
def load_bot():
with st.spinner("Loading models... This may take a few minutes on first run."):
return LangChainBot()
# Load the bot
bot = load_bot()
# Check if bot loaded successfully
if bot and bot.chain:
st.success("βœ… Bot loaded successfully!")
st.markdown("---")
# Language selection with helpful notes
language_options = ["english", "hindi", "tamil", "telugu"]
col1, col2 = st.columns(2)
with col1:
input_lang = st.selectbox(
"πŸ”€ Input Language",
options=language_options,
index=0,
help="Select the language you'll type in"
)
with col2:
output_lang = st.selectbox(
"πŸ—£οΈ Output Language",
options=language_options,
index=0, # Default to English for now
help="Select the language for the response"
)
# Show translation status
if not bot.translator:
st.info("ℹ️ Translation is currently limited. For best results, use English input and output.")
elif input_lang != 'english' or output_lang != 'english':
st.warning("⚠️ Translation is experimental. If you encounter issues, try using English.")
# Chat interface
st.markdown("### πŸ’¬ Chat Interface")
user_input = st.text_area(
"Your Message:",
height=100,
placeholder=f"Type your message in {input_lang}..."
)
col1, col2 = st.columns([3, 1])
with col1:
if st.button("πŸš€ Get Response", type="primary"):
if user_input.strip():
with st.spinner("πŸ€” LangChain is processing your request..."):
response = bot.get_response(user_input, input_lang, output_lang)
st.markdown("### πŸ€– Model Response:")
st.info(response)
# Add to conversation history display
if 'conversation_history' not in st.session_state:
st.session_state.conversation_history = []
st.session_state.conversation_history.append({
'user': user_input,
'bot': response,
'input_lang': input_lang,
'output_lang': output_lang
})
else:
st.warning("⚠️ Please enter a message.")
with col2:
if st.button("🧹 Clear Memory"):
if hasattr(bot, 'memory') and bot.memory:
bot.memory.clear()
if 'conversation_history' in st.session_state:
del st.session_state.conversation_history
st.success("βœ… Conversation memory cleared!")
# Display conversation history
if 'conversation_history' in st.session_state and st.session_state.conversation_history:
st.markdown("### πŸ“ Conversation History")
for i, conv in enumerate(reversed(st.session_state.conversation_history[-5:])): # Show last 5
with st.expander(f"Exchange {len(st.session_state.conversation_history) - i}"):
st.markdown(f"**You ({conv['input_lang']})**: {conv['user']}")
st.markdown(f"**Bot ({conv['output_lang']})**: {conv['bot']}")
else:
st.error("❌ Application could not start. Please check the error messages above.")
# Show some troubleshooting tips
st.markdown("### πŸ”§ Troubleshooting Tips:")
st.markdown("""
1. **Model Loading Issues**: The models might be too large for the available resources
2. **Memory Issues**: Try restarting the application
3. **Network Issues**: Ensure stable internet connection for model downloads
4. **Compatibility Issues**: Some models might not be compatible with the current environment
""")
if st.button("πŸ”„ Retry Loading"):
st.cache_resource.clear()
st.rerun()
# Add sidebar with information
with st.sidebar:
st.markdown("### ℹ️ Information")
st.markdown("""
This application uses:
- **LangChain** for conversation management
- **Hugging Face Transformers** for AI models
- **Multi-language support** via translation models
**Supported Languages:**
- English
- Hindi
- Tamil
- Telugu
""")
if torch.cuda.is_available():
st.success("πŸš€ CUDA GPU detected - faster processing!")
else:
st.info("πŸ’» Using CPU - processing may be slower")
st.markdown("### πŸ”§ System Status")
st.markdown(f"- PyTorch: {torch.__version__}")
st.markdown(f"- Device: {'CUDA' if torch.cuda.is_available() else 'CPU'}")
if bot and bot.chain:
st.markdown("- Model: βœ… Loaded")
st.markdown(f"- Translator: {'βœ… Loaded' if bot.translator else '❌ Not loaded'}")
else:
st.markdown("- Model: ❌ Failed to load")