|
|
|
|
|
|
|
import streamlit as st |
|
import torch |
|
import logging |
|
from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM |
|
|
|
|
|
from langchain_community.llms import HuggingFacePipeline |
|
from langchain.prompts import PromptTemplate |
|
from langchain.chains import LLMChain |
|
from langchain.memory import ConversationBufferMemory |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
|
class LangChainBot: |
|
def __init__(self): |
|
""" |
|
Loads the models and wraps them in LangChain components with fallback options. |
|
""" |
|
self.chain = None |
|
self.translator = None |
|
self.memory = None |
|
|
|
try: |
|
|
|
device = 0 if torch.cuda.is_available() else -1 |
|
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32 |
|
|
|
st.info(f"Using device: {'CUDA' if device == 0 else 'CPU'}") |
|
|
|
|
|
self._load_main_model(device, torch_dtype) |
|
|
|
|
|
self._load_translator(device) |
|
|
|
except Exception as e: |
|
logger.error(f"Fatal error during initialization: {e}") |
|
st.error(f"Fatal: Could not initialize the bot. Error: {e}") |
|
|
|
def _load_main_model(self, device, torch_dtype): |
|
"""Load the main generation model with fallback options.""" |
|
models_to_try = [ |
|
"ai4bharat/IndicBARTSS", |
|
"google/flan-t5-small", |
|
"t5-small" |
|
] |
|
|
|
for model_name in models_to_try: |
|
try: |
|
st.info(f"Attempting to load model: {model_name}") |
|
|
|
|
|
generator_pipeline = pipeline( |
|
"text2text-generation", |
|
model=model_name, |
|
device=device, |
|
torch_dtype=torch_dtype, |
|
max_new_tokens=100, |
|
repetition_penalty=1.5, |
|
do_sample=True, |
|
temperature=0.7, |
|
top_p=0.9, |
|
no_repeat_ngram_size=3, |
|
trust_remote_code=True |
|
) |
|
|
|
|
|
llm = HuggingFacePipeline(pipeline=generator_pipeline) |
|
|
|
|
|
template = """You are a helpful AI assistant. Please provide a clear and concise response to the user's question. |
|
|
|
Previous conversation: |
|
{history} |
|
|
|
User: {input} |
|
Assistant:""" |
|
prompt_template = PromptTemplate( |
|
input_variables=["history", "input"], |
|
template=template |
|
) |
|
|
|
|
|
self.memory = ConversationBufferMemory(memory_key="history") |
|
|
|
|
|
self.chain = LLMChain( |
|
llm=llm, |
|
prompt=prompt_template, |
|
verbose=True, |
|
memory=self.memory |
|
) |
|
|
|
st.success(f"Successfully loaded model: {model_name}") |
|
return |
|
|
|
except Exception as e: |
|
logger.warning(f"Failed to load {model_name}: {e}") |
|
st.warning(f"Failed to load {model_name}, trying next option...") |
|
continue |
|
|
|
raise Exception("All model loading attempts failed") |
|
|
|
def _load_translator(self, device): |
|
"""Load the translator with fallback options.""" |
|
translators_to_try = [ |
|
"Helsinki-NLP/opus-mt-en-hi", |
|
"ai4bharat/indictrans2-indic-indic-1B", |
|
] |
|
|
|
for translator_name in translators_to_try: |
|
try: |
|
st.info(f"Attempting to load translator: {translator_name}") |
|
|
|
self.translator = pipeline( |
|
"translation", |
|
model=translator_name, |
|
device=device, |
|
trust_remote_code=True |
|
) |
|
|
|
st.success(f"Successfully loaded translator: {translator_name}") |
|
return |
|
|
|
except Exception as e: |
|
logger.warning(f"Failed to load translator {translator_name}: {e}") |
|
st.warning(f"Failed to load translator {translator_name}, trying next option...") |
|
continue |
|
|
|
st.warning("No translator loaded - translation features will be limited") |
|
|
|
def _translate(self, text, source_lang, target_lang): |
|
"""Translation logic with improved error handling.""" |
|
if not self.translator or source_lang == target_lang: |
|
return text |
|
|
|
try: |
|
|
|
indictrans_codes = { |
|
'english': 'eng_Latn', |
|
'hindi': 'hin_Deva', |
|
'tamil': 'tam_Taml', |
|
'telugu': 'tel_Telu' |
|
} |
|
|
|
|
|
if source_lang in indictrans_codes and target_lang in indictrans_codes: |
|
try: |
|
result = self.translator( |
|
text, |
|
src_lang=indictrans_codes[source_lang], |
|
tgt_lang=indictrans_codes[target_lang] |
|
) |
|
if result and len(result) > 0 and 'translation_text' in result[0]: |
|
return result[0]['translation_text'] |
|
except Exception as e: |
|
logger.warning(f"Indictrans2 translation failed: {e}") |
|
|
|
|
|
try: |
|
result = self.translator(text) |
|
if result and len(result) > 0: |
|
if 'translation_text' in result[0]: |
|
return result[0]['translation_text'] |
|
elif 'generated_text' in result[0]: |
|
return result[0]['generated_text'] |
|
except Exception as e: |
|
logger.warning(f"Simple translation failed: {e}") |
|
|
|
except Exception as e: |
|
logger.warning(f"Translation failed: {e}") |
|
|
|
|
|
return text |
|
|
|
def get_response(self, user_message, input_lang, output_lang): |
|
"""Generate response with comprehensive error handling.""" |
|
if not self.chain: |
|
return "Error: The LangChain chain is not initialized. Please check the logs above." |
|
|
|
try: |
|
|
|
user_message = user_message.strip() |
|
|
|
|
|
|
|
if input_lang == 'english': |
|
processed_message = user_message |
|
else: |
|
|
|
translated = self._translate(user_message, input_lang, 'english') |
|
processed_message = translated if translated != user_message else user_message |
|
|
|
|
|
if len(processed_message.strip()) == 0: |
|
return "I didn't receive a valid message. Please try again." |
|
|
|
|
|
response = self.chain.run(input=processed_message) |
|
|
|
|
|
response = response.strip() |
|
|
|
|
|
words = response.split() |
|
if len(words) > 10: |
|
|
|
word_counts = {} |
|
for word in words: |
|
word_counts[word] = word_counts.get(word, 0) + 1 |
|
|
|
|
|
max_count = max(word_counts.values()) if word_counts else 0 |
|
if max_count > 5: |
|
|
|
response = f"I understand you said '{processed_message[:50]}...' Let me provide a helpful response to that." |
|
|
|
|
|
if output_lang != 'english' and output_lang != input_lang: |
|
final_response = self._translate(response, 'english', output_lang) |
|
|
|
return final_response if final_response != response else response |
|
else: |
|
return response |
|
|
|
except Exception as e: |
|
logger.error(f"Error generating response: {e}") |
|
return f"I apologize, but I encountered an error while processing your request. Please try rephrasing your message." |
|
|
|
|
|
|
|
|
|
|
|
st.set_page_config( |
|
page_title="LangChain Model Interface", |
|
page_icon="π€", |
|
layout="centered" |
|
) |
|
|
|
st.title("π€ LangChain Model Interface") |
|
st.markdown("*Multi-language conversational AI powered by LangChain*") |
|
|
|
|
|
@st.cache_resource |
|
def load_bot(): |
|
with st.spinner("Loading models... This may take a few minutes on first run."): |
|
return LangChainBot() |
|
|
|
|
|
bot = load_bot() |
|
|
|
|
|
if bot and bot.chain: |
|
st.success("β
Bot loaded successfully!") |
|
|
|
st.markdown("---") |
|
|
|
|
|
language_options = ["english", "hindi", "tamil", "telugu"] |
|
col1, col2 = st.columns(2) |
|
|
|
with col1: |
|
input_lang = st.selectbox( |
|
"π€ Input Language", |
|
options=language_options, |
|
index=0, |
|
help="Select the language you'll type in" |
|
) |
|
with col2: |
|
output_lang = st.selectbox( |
|
"π£οΈ Output Language", |
|
options=language_options, |
|
index=0, |
|
help="Select the language for the response" |
|
) |
|
|
|
|
|
if not bot.translator: |
|
st.info("βΉοΈ Translation is currently limited. For best results, use English input and output.") |
|
elif input_lang != 'english' or output_lang != 'english': |
|
st.warning("β οΈ Translation is experimental. If you encounter issues, try using English.") |
|
|
|
|
|
|
|
st.markdown("### π¬ Chat Interface") |
|
user_input = st.text_area( |
|
"Your Message:", |
|
height=100, |
|
placeholder=f"Type your message in {input_lang}..." |
|
) |
|
|
|
col1, col2 = st.columns([3, 1]) |
|
|
|
with col1: |
|
if st.button("π Get Response", type="primary"): |
|
if user_input.strip(): |
|
with st.spinner("π€ LangChain is processing your request..."): |
|
response = bot.get_response(user_input, input_lang, output_lang) |
|
|
|
st.markdown("### π€ Model Response:") |
|
st.info(response) |
|
|
|
|
|
if 'conversation_history' not in st.session_state: |
|
st.session_state.conversation_history = [] |
|
|
|
st.session_state.conversation_history.append({ |
|
'user': user_input, |
|
'bot': response, |
|
'input_lang': input_lang, |
|
'output_lang': output_lang |
|
}) |
|
|
|
else: |
|
st.warning("β οΈ Please enter a message.") |
|
|
|
with col2: |
|
if st.button("π§Ή Clear Memory"): |
|
if hasattr(bot, 'memory') and bot.memory: |
|
bot.memory.clear() |
|
if 'conversation_history' in st.session_state: |
|
del st.session_state.conversation_history |
|
st.success("β
Conversation memory cleared!") |
|
|
|
|
|
if 'conversation_history' in st.session_state and st.session_state.conversation_history: |
|
st.markdown("### π Conversation History") |
|
for i, conv in enumerate(reversed(st.session_state.conversation_history[-5:])): |
|
with st.expander(f"Exchange {len(st.session_state.conversation_history) - i}"): |
|
st.markdown(f"**You ({conv['input_lang']})**: {conv['user']}") |
|
st.markdown(f"**Bot ({conv['output_lang']})**: {conv['bot']}") |
|
|
|
else: |
|
st.error("β Application could not start. Please check the error messages above.") |
|
|
|
|
|
st.markdown("### π§ Troubleshooting Tips:") |
|
st.markdown(""" |
|
1. **Model Loading Issues**: The models might be too large for the available resources |
|
2. **Memory Issues**: Try restarting the application |
|
3. **Network Issues**: Ensure stable internet connection for model downloads |
|
4. **Compatibility Issues**: Some models might not be compatible with the current environment |
|
""") |
|
|
|
if st.button("π Retry Loading"): |
|
st.cache_resource.clear() |
|
st.rerun() |
|
|
|
|
|
with st.sidebar: |
|
st.markdown("### βΉοΈ Information") |
|
st.markdown(""" |
|
This application uses: |
|
- **LangChain** for conversation management |
|
- **Hugging Face Transformers** for AI models |
|
- **Multi-language support** via translation models |
|
|
|
**Supported Languages:** |
|
- English |
|
- Hindi |
|
- Tamil |
|
- Telugu |
|
""") |
|
|
|
if torch.cuda.is_available(): |
|
st.success("π CUDA GPU detected - faster processing!") |
|
else: |
|
st.info("π» Using CPU - processing may be slower") |
|
|
|
st.markdown("### π§ System Status") |
|
st.markdown(f"- PyTorch: {torch.__version__}") |
|
st.markdown(f"- Device: {'CUDA' if torch.cuda.is_available() else 'CPU'}") |
|
if bot and bot.chain: |
|
st.markdown("- Model: β
Loaded") |
|
st.markdown(f"- Translator: {'β
Loaded' if bot.translator else 'β Not loaded'}") |
|
else: |
|
st.markdown("- Model: β Failed to load") |