akashraut commited on
Commit
2223d6d
·
verified ·
1 Parent(s): 14da30b

Update streamlit_app.py

Browse files
Files changed (1) hide show
  1. streamlit_app.py +287 -82
streamlit_app.py CHANGED
@@ -1,9 +1,10 @@
1
  # streamlit_app.py
2
- # A minimal Streamlit app rebuilt with the LangChain framework.
3
 
4
  import streamlit as st
5
  import torch
6
- from transformers import pipeline
 
7
 
8
  # Updated LangChain imports for modern versions
9
  from langchain_community.llms import HuggingFacePipeline
@@ -11,122 +12,326 @@ from langchain.prompts import PromptTemplate
11
  from langchain.chains import LLMChain
12
  from langchain.memory import ConversationBufferMemory
13
 
 
 
 
 
14
  # -----------------------------------------------------------------------------
15
- # CORE MODEL LOGIC (Rebuilt with LangChain)
16
  # -----------------------------------------------------------------------------
17
  class LangChainBot:
18
  def __init__(self):
19
  """
20
- Loads the models and wraps them in LangChain components.
21
  """
 
 
 
 
22
  try:
23
- # 1. Load the base Hugging Face pipelines
24
- generator_pipeline = pipeline(
25
- "text2text-generation",
26
- model="ai4bharat/IndicBARTSS",
27
- device=0 if torch.cuda.is_available() else -1,
28
- torch_dtype=(torch.float16 if torch.cuda.is_available() else torch.float32),
29
- max_new_tokens=150,
30
- repetition_penalty=1.2
31
- )
32
 
33
- # Added `trust_remote_code=True` to allow the special translator model to load.
34
- self.translator = pipeline(
35
- "translation",
36
- model="ai4bharat/indictrans2-indic-indic-1B",
37
- device=0 if torch.cuda.is_available() else -1,
38
- trust_remote_code=True
39
- )
40
-
41
- # 2. Wrap the generator in a LangChain LLM object
42
- llm = HuggingFacePipeline(pipeline=generator_pipeline)
43
-
44
- # 3. Create a Prompt Template
45
- template = """
46
- You are a helpful conversational AI. Respond to the user's message.
47
-
48
- {history}
49
- मनुष्य: {input}
50
- सहायक:
51
- """
52
- prompt_template = PromptTemplate(input_variables=["history", "input"], template=template)
53
-
54
- # 4. Set up conversational memory
55
- self.memory = ConversationBufferMemory(memory_key="history")
56
-
57
- # 5. Create the final LLMChain
58
- self.chain = LLMChain(
59
- llm=llm,
60
- prompt=prompt_template,
61
- verbose=True,
62
- memory=self.memory
63
- )
64
 
65
  except Exception as e:
66
- st.error(f"Fatal: Could not load models. Error: {e}")
67
- self.chain = None
68
- self.translator = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
 
70
  def _translate(self, text, source_lang, target_lang):
71
- """Translation logic remains the same."""
72
  if not self.translator or source_lang == target_lang:
73
  return text
 
74
  try:
75
- codes = {'english': 'eng_Latn', 'hindi': 'hin_Deva', 'tamil': 'tam_Taml', 'telugu': 'tel_Telu'}
76
- result = self.translator(text, src_lang=codes[source_lang], tgt_lang=codes[target_lang])
77
- return result[0]['translation_text']
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
  except Exception as e:
79
- st.warning(f"Translation failed. Error: {e}")
80
- return text
 
 
81
 
82
  def get_response(self, user_message, input_lang, output_lang):
83
- """The main function to get a response."""
84
  if not self.chain:
85
- return "Error: The LangChain chain is not initialized."
86
 
87
- hindi_message = self._translate(user_message, input_lang, 'hindi')
88
- hindi_response = self.chain.run(hindi_message)
89
- final_response = self._translate(hindi_response, 'hindi', output_lang)
90
-
91
- return final_response
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
 
93
  # -----------------------------------------------------------------------------
94
- # MINIMAL STREAMLIT UI (This part remains mostly the same)
95
  # -----------------------------------------------------------------------------
96
 
97
- st.set_page_config(layout="centered")
98
- st.title("LangChain Model Interface")
 
 
 
 
 
 
99
 
 
100
  @st.cache_resource
101
  def load_bot():
102
- return LangChainBot()
 
103
 
 
104
  bot = load_bot()
105
 
106
- if bot and bot.chain: # Only show the UI if the bot loaded successfully
 
 
 
107
  st.markdown("---")
 
 
108
  language_options = ["english", "hindi", "tamil", "telugu"]
109
  col1, col2 = st.columns(2)
 
110
  with col1:
111
- input_lang = st.selectbox("Input Language", options=language_options, index=0)
 
 
 
 
 
112
  with col2:
113
- output_lang = st.selectbox("Output Language", options=language_options, index=1)
 
 
 
 
 
114
 
115
- user_input = st.text_area("Your Message:", height=100)
 
 
 
 
 
 
116
 
117
- if st.button("Get Response"):
118
- if user_input:
119
- with st.spinner("LangChain is processing your request..."):
120
- response = bot.get_response(user_input, input_lang, output_lang)
121
- st.markdown("### Model Response:")
 
 
 
 
122
  st.info(response)
123
- else:
124
- st.warning("Please enter a message.")
125
-
126
- # Add a button to clear LangChain's memory
127
- if st.button("Clear Conversation Memory"):
128
- if hasattr(bot, 'memory'):
129
- bot.memory.clear()
130
- st.success("Conversation memory has been cleared.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
131
  else:
132
- st.error("Application could not start. Please check the logs.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  # streamlit_app.py
2
+ # A robust Streamlit app with proper error handling and fallback options
3
 
4
  import streamlit as st
5
  import torch
6
+ import logging
7
+ from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM
8
 
9
  # Updated LangChain imports for modern versions
10
  from langchain_community.llms import HuggingFacePipeline
 
12
  from langchain.chains import LLMChain
13
  from langchain.memory import ConversationBufferMemory
14
 
15
+ # Set up logging
16
+ logging.basicConfig(level=logging.INFO)
17
+ logger = logging.getLogger(__name__)
18
+
19
  # -----------------------------------------------------------------------------
20
+ # CORE MODEL LOGIC (Rebuilt with LangChain and Error Handling)
21
  # -----------------------------------------------------------------------------
22
  class LangChainBot:
23
  def __init__(self):
24
  """
25
+ Loads the models and wraps them in LangChain components with fallback options.
26
  """
27
+ self.chain = None
28
+ self.translator = None
29
+ self.memory = None
30
+
31
  try:
32
+ # Check CUDA availability
33
+ device = 0 if torch.cuda.is_available() else -1
34
+ torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
35
+
36
+ st.info(f"Using device: {'CUDA' if device == 0 else 'CPU'}")
37
+
38
+ # Try to load the main model with error handling
39
+ self._load_main_model(device, torch_dtype)
 
40
 
41
+ # Try to load the translator with error handling
42
+ self._load_translator(device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
 
44
  except Exception as e:
45
+ logger.error(f"Fatal error during initialization: {e}")
46
+ st.error(f"Fatal: Could not initialize the bot. Error: {e}")
47
+
48
+ def _load_main_model(self, device, torch_dtype):
49
+ """Load the main generation model with fallback options."""
50
+ models_to_try = [
51
+ "ai4bharat/IndicBARTSS",
52
+ "google/flan-t5-small", # Fallback option
53
+ "t5-small" # Another fallback
54
+ ]
55
+
56
+ for model_name in models_to_try:
57
+ try:
58
+ st.info(f"Attempting to load model: {model_name}")
59
+
60
+ # Try loading with pipeline first
61
+ generator_pipeline = pipeline(
62
+ "text2text-generation",
63
+ model=model_name,
64
+ device=device,
65
+ torch_dtype=torch_dtype,
66
+ max_new_tokens=150,
67
+ repetition_penalty=1.2,
68
+ trust_remote_code=True # Added this for compatibility
69
+ )
70
+
71
+ # Wrap in LangChain LLM
72
+ llm = HuggingFacePipeline(pipeline=generator_pipeline)
73
+
74
+ # Create prompt template
75
+ template = """
76
+ You are a helpful conversational AI. Respond to the user's message appropriately.
77
+ Previous conversation:
78
+ {history}
79
+
80
+ Human: {input}
81
+ Assistant:
82
+ """
83
+ prompt_template = PromptTemplate(
84
+ input_variables=["history", "input"],
85
+ template=template
86
+ )
87
+
88
+ # Set up memory
89
+ self.memory = ConversationBufferMemory(memory_key="history")
90
+
91
+ # Create the chain
92
+ self.chain = LLMChain(
93
+ llm=llm,
94
+ prompt=prompt_template,
95
+ verbose=True,
96
+ memory=self.memory
97
+ )
98
+
99
+ st.success(f"Successfully loaded model: {model_name}")
100
+ return # Success, exit the loop
101
+
102
+ except Exception as e:
103
+ logger.warning(f"Failed to load {model_name}: {e}")
104
+ st.warning(f"Failed to load {model_name}, trying next option...")
105
+ continue
106
+
107
+ raise Exception("All model loading attempts failed")
108
+
109
+ def _load_translator(self, device):
110
+ """Load the translator with fallback options."""
111
+ translators_to_try = [
112
+ "ai4bharat/indictrans2-indic-indic-1B",
113
+ "Helsinki-NLP/opus-mt-en-hi", # Fallback for English-Hindi
114
+ ]
115
+
116
+ for translator_name in translators_to_try:
117
+ try:
118
+ st.info(f"Attempting to load translator: {translator_name}")
119
+
120
+ self.translator = pipeline(
121
+ "translation",
122
+ model=translator_name,
123
+ device=device,
124
+ trust_remote_code=True
125
+ )
126
+
127
+ st.success(f"Successfully loaded translator: {translator_name}")
128
+ return # Success
129
+
130
+ except Exception as e:
131
+ logger.warning(f"Failed to load translator {translator_name}: {e}")
132
+ st.warning(f"Failed to load translator {translator_name}, trying next option...")
133
+ continue
134
+
135
+ st.warning("No translator loaded - translation features will be limited")
136
 
137
  def _translate(self, text, source_lang, target_lang):
138
+ """Translation logic with improved error handling."""
139
  if not self.translator or source_lang == target_lang:
140
  return text
141
+
142
  try:
143
+ # Define language codes
144
+ codes = {
145
+ 'english': 'eng_Latn',
146
+ 'hindi': 'hin_Deva',
147
+ 'tamil': 'tam_Taml',
148
+ 'telugu': 'tel_Telu'
149
+ }
150
+
151
+ if source_lang in codes and target_lang in codes:
152
+ result = self.translator(
153
+ text,
154
+ src_lang=codes[source_lang],
155
+ tgt_lang=codes[target_lang]
156
+ )
157
+ return result[0]['translation_text']
158
+ else:
159
+ # Fallback for simple English-Hindi translation
160
+ if source_lang == 'english' and target_lang == 'hindi':
161
+ result = self.translator(text)
162
+ return result[0]['translation_text'] if result else text
163
+
164
  except Exception as e:
165
+ logger.warning(f"Translation failed: {e}")
166
+ st.warning(f"Translation failed, using original text. Error: {e}")
167
+
168
+ return text
169
 
170
  def get_response(self, user_message, input_lang, output_lang):
171
+ """Generate response with comprehensive error handling."""
172
  if not self.chain:
173
+ return "Error: The LangChain chain is not initialized. Please check the logs above."
174
 
175
+ try:
176
+ # Translate input to a common language if needed
177
+ if input_lang != 'english':
178
+ processed_message = self._translate(user_message, input_lang, 'english')
179
+ else:
180
+ processed_message = user_message
181
+
182
+ # Generate response
183
+ response = self.chain.run(processed_message)
184
+
185
+ # Translate output if needed
186
+ if output_lang != 'english':
187
+ final_response = self._translate(response, 'english', output_lang)
188
+ else:
189
+ final_response = response
190
+
191
+ return final_response
192
+
193
+ except Exception as e:
194
+ logger.error(f"Error generating response: {e}")
195
+ return f"I apologize, but I encountered an error while processing your request: {str(e)}"
196
 
197
  # -----------------------------------------------------------------------------
198
+ # STREAMLIT UI WITH BETTER ERROR HANDLING
199
  # -----------------------------------------------------------------------------
200
 
201
+ st.set_page_config(
202
+ page_title="LangChain Model Interface",
203
+ page_icon="🤖",
204
+ layout="centered"
205
+ )
206
+
207
+ st.title("🤖 LangChain Model Interface")
208
+ st.markdown("*Multi-language conversational AI powered by LangChain*")
209
 
210
+ # Initialize the bot with progress tracking
211
  @st.cache_resource
212
  def load_bot():
213
+ with st.spinner("Loading models... This may take a few minutes on first run."):
214
+ return LangChainBot()
215
 
216
+ # Load the bot
217
  bot = load_bot()
218
 
219
+ # Check if bot loaded successfully
220
+ if bot and bot.chain:
221
+ st.success("✅ Bot loaded successfully!")
222
+
223
  st.markdown("---")
224
+
225
+ # Language selection
226
  language_options = ["english", "hindi", "tamil", "telugu"]
227
  col1, col2 = st.columns(2)
228
+
229
  with col1:
230
+ input_lang = st.selectbox(
231
+ "🔤 Input Language",
232
+ options=language_options,
233
+ index=0,
234
+ help="Select the language you'll type in"
235
+ )
236
  with col2:
237
+ output_lang = st.selectbox(
238
+ "🗣️ Output Language",
239
+ options=language_options,
240
+ index=1,
241
+ help="Select the language for the response"
242
+ )
243
 
244
+ # Chat interface
245
+ st.markdown("### 💬 Chat Interface")
246
+ user_input = st.text_area(
247
+ "Your Message:",
248
+ height=100,
249
+ placeholder=f"Type your message in {input_lang}..."
250
+ )
251
 
252
+ col1, col2 = st.columns([3, 1])
253
+
254
+ with col1:
255
+ if st.button("🚀 Get Response", type="primary"):
256
+ if user_input.strip():
257
+ with st.spinner("🤔 LangChain is processing your request..."):
258
+ response = bot.get_response(user_input, input_lang, output_lang)
259
+
260
+ st.markdown("### 🤖 Model Response:")
261
  st.info(response)
262
+
263
+ # Add to conversation history display
264
+ if 'conversation_history' not in st.session_state:
265
+ st.session_state.conversation_history = []
266
+
267
+ st.session_state.conversation_history.append({
268
+ 'user': user_input,
269
+ 'bot': response,
270
+ 'input_lang': input_lang,
271
+ 'output_lang': output_lang
272
+ })
273
+
274
+ else:
275
+ st.warning("⚠️ Please enter a message.")
276
+
277
+ with col2:
278
+ if st.button("🧹 Clear Memory"):
279
+ if hasattr(bot, 'memory') and bot.memory:
280
+ bot.memory.clear()
281
+ if 'conversation_history' in st.session_state:
282
+ del st.session_state.conversation_history
283
+ st.success("✅ Conversation memory cleared!")
284
+
285
+ # Display conversation history
286
+ if 'conversation_history' in st.session_state and st.session_state.conversation_history:
287
+ st.markdown("### 📝 Conversation History")
288
+ for i, conv in enumerate(reversed(st.session_state.conversation_history[-5:])): # Show last 5
289
+ with st.expander(f"Exchange {len(st.session_state.conversation_history) - i}"):
290
+ st.markdown(f"**You ({conv['input_lang']})**: {conv['user']}")
291
+ st.markdown(f"**Bot ({conv['output_lang']})**: {conv['bot']}")
292
+
293
  else:
294
+ st.error("Application could not start. Please check the error messages above.")
295
+
296
+ # Show some troubleshooting tips
297
+ st.markdown("### 🔧 Troubleshooting Tips:")
298
+ st.markdown("""
299
+ 1. **Model Loading Issues**: The models might be too large for the available resources
300
+ 2. **Memory Issues**: Try restarting the application
301
+ 3. **Network Issues**: Ensure stable internet connection for model downloads
302
+ 4. **Compatibility Issues**: Some models might not be compatible with the current environment
303
+ """)
304
+
305
+ if st.button("🔄 Retry Loading"):
306
+ st.cache_resource.clear()
307
+ st.rerun()
308
+
309
+ # Add sidebar with information
310
+ with st.sidebar:
311
+ st.markdown("### ℹ️ Information")
312
+ st.markdown("""
313
+ This application uses:
314
+ - **LangChain** for conversation management
315
+ - **Hugging Face Transformers** for AI models
316
+ - **Multi-language support** via translation models
317
+
318
+ **Supported Languages:**
319
+ - English
320
+ - Hindi
321
+ - Tamil
322
+ - Telugu
323
+ """)
324
+
325
+ if torch.cuda.is_available():
326
+ st.success("🚀 CUDA GPU detected - faster processing!")
327
+ else:
328
+ st.info("💻 Using CPU - processing may be slower")
329
+
330
+ st.markdown("### 🔧 System Status")
331
+ st.markdown(f"- PyTorch: {torch.__version__}")
332
+ st.markdown(f"- Device: {'CUDA' if torch.cuda.is_available() else 'CPU'}")
333
+ if bot and bot.chain:
334
+ st.markdown("- Model: ✅ Loaded")
335
+ st.markdown(f"- Translator: {'✅ Loaded' if bot.translator else '❌ Not loaded'}")
336
+ else:
337
+ st.markdown("- Model: ❌ Failed to load")