wt002 commited on
Commit
8ec51fb
·
verified ·
1 Parent(s): 1db7a1e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +60 -10
app.py CHANGED
@@ -3,8 +3,7 @@ import gradio as gr
3
  import requests
4
  import inspect
5
  import pandas as pd
6
- from smolagents import tool, Tool, CodeAgent, DuckDuckGoSearchTool, VisitWebpageTool, SpeechToTextTool, FinalAnswerTool
7
- #from smolagents import tool, Tool, CodeAgent, DuckDuckGoSearchTool, HfApiModel, VisitWebpageTool, SpeechToTextTool, FinalAnswerTool
8
  from dotenv import load_dotenv
9
  import heapq
10
  from collections import Counter
@@ -15,8 +14,7 @@ from langchain_community.tools.tavily_search import TavilySearchResults
15
  from langchain_community.document_loaders import WikipediaLoader
16
  from langchain_community.utilities import WikipediaAPIWrapper
17
  from langchain_community.document_loaders import ArxivLoader
18
- from langchain_community.llms import HfApiModel
19
-
20
 
21
  # (Keep Constants as is)
22
  # --- Constants ---
@@ -184,12 +182,22 @@ class VideoTranscriptionTool(Tool):
184
 
185
  class BasicAgent:
186
  def __init__(self):
187
- token = os.environ.get("HF_API_TOKEN")
188
- self.model = HfApiModel(
189
- "google/gemini-2.5-flash",
190
- temperature=0.1,
191
- token=token
 
 
 
 
192
  )
 
 
 
 
 
 
193
 
194
  search_tool = DuckDuckGoSearchTool()
195
  wiki_search_tool = WikiSearchTool()
@@ -210,8 +218,49 @@ If the answer is a string, do not use articles or abbreviations (e.g., for citie
210
  If the answer is a comma-separated list, apply the above rules for each element based on whether it is a number or a string.
211
  """
212
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
213
  self.agent = CodeAgent(
214
- model=self.model,
215
  tools=[search_tool, wiki_search_tool, str_reverse_tool, keywords_extract_tool, speech_to_text_tool, visit_webpage_tool, final_answer_tool, parse_excel_to_json, video_transcription_tool],
216
  add_base_tools=True
217
  )
@@ -223,6 +272,7 @@ If the answer is a comma-separated list, apply the above rules for each element
223
  print(f"Agent returning answer: {answer}")
224
  return answer
225
 
 
226
  def run_and_submit_all( profile: gr.OAuthProfile | None):
227
  """
228
  Fetches all questions, runs the BasicAgent on them, submits all answers,
 
3
  import requests
4
  import inspect
5
  import pandas as pd
6
+ from smolagents import tool, Tool, CodeAgent, DuckDuckGoSearchTool, HfApiModel, VisitWebpageTool, SpeechToTextTool, FinalAnswerTool
 
7
  from dotenv import load_dotenv
8
  import heapq
9
  from collections import Counter
 
14
  from langchain_community.document_loaders import WikipediaLoader
15
  from langchain_community.utilities import WikipediaAPIWrapper
16
  from langchain_community.document_loaders import ArxivLoader
17
+ from transformers import AutoModelForCausalLM, AutoTokenizer
 
18
 
19
  # (Keep Constants as is)
20
  # --- Constants ---
 
182
 
183
  class BasicAgent:
184
  def __init__(self):
185
+ # Configuration for Qwen2.5-Coder-32B-Instruct
186
+ model_name = "Qwen/Qwen2.5-Coder-32B-Instruct"
187
+
188
+ # Load the model and tokenizer directly using Hugging Face Transformers
189
+ # This will download the model weights and load them onto your device (GPU if available)
190
+ self.model = AutoModelForCausalLM.from_pretrained(
191
+ model_name,
192
+ torch_dtype="auto", # Uses bfloat16 or float16 if supported, otherwise float32
193
+ device_map="auto" # Automatically maps model layers to available devices (e.g., GPU(s), CPU)
194
  )
195
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name)
196
+
197
+ # Note: You won't use 'token' for direct Hugging Face model loading unless
198
+ # the model is private and requires authentication. For public models like Qwen,
199
+ # it's usually not needed for loading.
200
+ token = os.environ.get("HF_API_TOKEN") # This line might not be needed now
201
 
202
  search_tool = DuckDuckGoSearchTool()
203
  wiki_search_tool = WikiSearchTool()
 
218
  If the answer is a comma-separated list, apply the above rules for each element based on whether it is a number or a string.
219
  """
220
 
221
+ # Here, you might need to adapt how CodeAgent expects the model.
222
+ # If CodeAgent is built to work with LangChain's LLM instances,
223
+ # you'll need to wrap your Qwen model with a custom LangChain LLM
224
+ # or adjust CodeAgent to accept direct Hugging Face model/tokenizer.
225
+ # For simplicity, if CodeAgent can take a callable for 'model',
226
+ # you could define a simple wrapper.
227
+ # Otherwise, you might need to write a custom LangChain LLM class.
228
+
229
+ # For demonstration, let's assume CodeAgent can handle a custom callable
230
+ # that performs inference using your loaded model and tokenizer.
231
+ # This is a simplification and might require adjustment to CodeAgent.
232
+ class CustomQwenLLM:
233
+ def __init__(self, model, tokenizer):
234
+ self.model = model
235
+ self.tokenizer = tokenizer
236
+
237
+ def __call__(self, prompt: str) -> str:
238
+ messages = [
239
+ {"role": "system", "content": "You are Qwen, created by Alibaba Cloud. You are a helpful assistant."},
240
+ {"role": "user", "content": prompt}
241
+ ]
242
+ text = self.tokenizer.apply_chat_template(
243
+ messages,
244
+ tokenize=False,
245
+ add_generation_prompt=True
246
+ )
247
+ model_inputs = self.tokenizer([text], return_tensors="pt").to(self.model.device)
248
+
249
+ generated_ids = self.model.generate(
250
+ **model_inputs,
251
+ max_new_tokens=512,
252
+ do_sample=True, # Added for better response quality
253
+ temperature=0.7 # Added for better response quality
254
+ )
255
+ input_length = model_inputs.input_ids.shape[1]
256
+ generated_text = self.tokenizer.batch_decode(generated_ids[:, input_length:], skip_special_tokens=True)[0]
257
+ return generated_text
258
+
259
+ self.llm_for_agent = CustomQwenLLM(self.model, self.tokenizer)
260
+
261
+
262
  self.agent = CodeAgent(
263
+ model=self.llm_for_agent, # Pass the custom wrapper
264
  tools=[search_tool, wiki_search_tool, str_reverse_tool, keywords_extract_tool, speech_to_text_tool, visit_webpage_tool, final_answer_tool, parse_excel_to_json, video_transcription_tool],
265
  add_base_tools=True
266
  )
 
272
  print(f"Agent returning answer: {answer}")
273
  return answer
274
 
275
+
276
  def run_and_submit_all( profile: gr.OAuthProfile | None):
277
  """
278
  Fetches all questions, runs the BasicAgent on them, submits all answers,