logasanjeev commited on
Commit
10965e6
·
verified ·
1 Parent(s): df0e564

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -10
app.py CHANGED
@@ -16,6 +16,7 @@ import chromadb
16
  import tempfile
17
  from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type
18
  import requests
 
19
 
20
  # Set up logging
21
  logging.basicConfig(level=logging.INFO)
@@ -28,9 +29,9 @@ if os.environ["HUGGINGFACEHUB_API_TOKEN"] == "default-token":
28
 
29
  # Model and embedding options
30
  LLM_MODELS = {
31
- "Balanced (Mixtral-8x7B)": "mistralai/Mixtral-8x7B-Instruct-v0.1",
32
- "Lightweight (Gemma-2B)": "google/gemma-2-2b-it",
33
- "High Accuracy (Llama-3-8B)": "meta-llama/Llama-3-8b-hf"
34
  }
35
 
36
  EMBEDDING_MODELS = {
@@ -160,13 +161,16 @@ def initialize_qa_chain(llm_model, temperature):
160
  return "Please process documents first.", None
161
 
162
  try:
 
 
163
  llm = HuggingFaceEndpoint(
164
  repo_id=LLM_MODELS[llm_model],
165
  task="text-generation",
166
  temperature=float(temperature),
167
  max_new_tokens=512,
168
  huggingfacehub_api_token=os.environ["HUGGINGFACEHUB_API_TOKEN"],
169
- timeout=30
 
170
  )
171
  # Dynamically set k based on vector store size
172
  collection = vector_store._collection
@@ -182,13 +186,13 @@ def initialize_qa_chain(llm_model, temperature):
182
  except requests.exceptions.HTTPError as e:
183
  logger.error(f"HTTP error initializing QA chain for {llm_model}: {str(e)}")
184
  if "503" in str(e):
185
- return f"Error: Hugging Face API temporarily unavailable for {llm_model}. Try 'Balanced (Mixtral-8x7B)' or wait and retry.", None
186
  elif "403" in str(e):
187
- return f"Error: Access denied for {llm_model}. Ensure your HF token has access.", None
188
  return f"Error initializing QA chain: {str(e)}.", None
189
  except Exception as e:
190
  logger.error(f"Error initializing QA chain for {llm_model}: {str(e)}")
191
- return f"Error initializing QA chain: {str(e)}. Ensure your HF token has access to {llm_model}.", None
192
 
193
  # Function to handle user query with retry logic
194
  @retry(
@@ -214,9 +218,9 @@ def answer_question(question, llm_model, embedding_model, temperature, chunk_siz
214
  except requests.exceptions.HTTPError as e:
215
  logger.error(f"HTTP error answering question: {str(e)}")
216
  if "503" in str(e):
217
- return f"Error: Hugging Face API temporarily unavailable for {llm_model}. Try 'Balanced (Mixtral-8x7B)' or wait and retry.", chat_history
218
  elif "403" in str(e):
219
- return f"Error: Access denied for {llm_model}. Ensure your HF token has access.", chat_history
220
  return f"Error answering question: {str(e)}", chat_history
221
  except Exception as e:
222
  logger.error(f"Error answering question: {str(e)}")
@@ -272,7 +276,7 @@ with gr.Blocks(theme=gr.themes.Soft(), title="DocTalk: Document Q&A Chatbot") as
272
  status = gr.Textbox(label="Status", interactive=False)
273
 
274
  with gr.Column(scale=1):
275
- llm_model = gr.Dropdown(choices=list(LLM_MODELS.keys()), label="Select LLM Model", value="Balanced (Mixtral-8x7B)")
276
  embedding_model = gr.Dropdown(choices=list(EMBEDDING_MODELS.keys()), label="Select Embedding Model", value="Lightweight (MiniLM-L6)")
277
  temperature = gr.Slider(minimum=0.1, maximum=1.0, step=0.1, value=0.7, label="Temperature")
278
  chunk_size = gr.Slider(minimum=500, maximum=2000, step=100, value=1000, label="Chunk Size")
 
16
  import tempfile
17
  from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type
18
  import requests
19
+ from transformers import BitsAndBytesConfig
20
 
21
  # Set up logging
22
  logging.basicConfig(level=logging.INFO)
 
29
 
30
  # Model and embedding options
31
  LLM_MODELS = {
32
+ "High Accuracy (Mixtral-8x7B)": "mistralai/Mixtral-8x7B-Instruct-v0.1",
33
+ "Balanced (Gemma-2-9B)": "google/gemma-2-9b-it",
34
+ "Lightweight (Mistral-7B)": "mistralai/Mistral-7B-Instruct-v0.2"
35
  }
36
 
37
  EMBEDDING_MODELS = {
 
161
  return "Please process documents first.", None
162
 
163
  try:
164
+ # Enable quantization for Mixtral-8x7B to reduce memory usage
165
+ quantization_config = BitsAndBytesConfig(load_in_4bit=True) if llm_model == "mistralai/Mixtral-8x7B-Instruct-v0.1" else None
166
  llm = HuggingFaceEndpoint(
167
  repo_id=LLM_MODELS[llm_model],
168
  task="text-generation",
169
  temperature=float(temperature),
170
  max_new_tokens=512,
171
  huggingfacehub_api_token=os.environ["HUGGINGFACEHUB_API_TOKEN"],
172
+ timeout=30,
173
+ quantization_config=quantization_config
174
  )
175
  # Dynamically set k based on vector store size
176
  collection = vector_store._collection
 
186
  except requests.exceptions.HTTPError as e:
187
  logger.error(f"HTTP error initializing QA chain for {llm_model}: {str(e)}")
188
  if "503" in str(e):
189
+ return f"Error: Hugging Face API temporarily unavailable for {llm_model}. Try 'Lightweight (Mistral-7B)' or wait and retry.", None
190
  elif "403" in str(e):
191
+ return f"Error: Access denied for {llm_model}. Ensure your HF token is valid.", None
192
  return f"Error initializing QA chain: {str(e)}.", None
193
  except Exception as e:
194
  logger.error(f"Error initializing QA chain for {llm_model}: {str(e)}")
195
+ return f"Error initializing QA chain: {str(e)}. Ensure your HF token is valid.", None
196
 
197
  # Function to handle user query with retry logic
198
  @retry(
 
218
  except requests.exceptions.HTTPError as e:
219
  logger.error(f"HTTP error answering question: {str(e)}")
220
  if "503" in str(e):
221
+ return f"Error: Hugging Face API temporarily unavailable for {llm_model}. Try 'Lightweight (Mistral-7B)' or wait and retry.", chat_history
222
  elif "403" in str(e):
223
+ return f"Error: Access denied for {llm_model}. Ensure your HF token is valid.", chat_history
224
  return f"Error answering question: {str(e)}", chat_history
225
  except Exception as e:
226
  logger.error(f"Error answering question: {str(e)}")
 
276
  status = gr.Textbox(label="Status", interactive=False)
277
 
278
  with gr.Column(scale=1):
279
+ llm_model = gr.Dropdown(choices=list(LLM_MODELS.keys()), label="Select LLM Model", value="High Accuracy (Mixtral-8x7B)")
280
  embedding_model = gr.Dropdown(choices=list(EMBEDDING_MODELS.keys()), label="Select Embedding Model", value="Lightweight (MiniLM-L6)")
281
  temperature = gr.Slider(minimum=0.1, maximum=1.0, step=0.1, value=0.7, label="Temperature")
282
  chunk_size = gr.Slider(minimum=500, maximum=2000, step=100, value=1000, label="Chunk Size")