dlaima commited on
Commit
4fcc621
·
verified ·
1 Parent(s): 86df5d9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -16
app.py CHANGED
@@ -19,11 +19,11 @@ DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
19
 
20
  # Generation result wrapper to match smolagents expectations
21
  class GenerationResult:
22
- def __init__(self, content, token_usage=None, input_tokens=0, output_tokens=0):
23
  self.content = content
24
- self.token_usage = token_usage or {}
25
  self.input_tokens = input_tokens
26
  self.output_tokens = output_tokens
 
27
 
28
  # Gemini model wrapper
29
  class GeminiFlashModel:
@@ -32,14 +32,15 @@ class GeminiFlashModel:
32
  self.model = genai.GenerativeModel(model_id)
33
  self.system_prompt = SYSTEM_PROMPT
34
 
35
- # Accept stop_sequences explicitly to avoid unexpected kwarg errors
36
- def generate(self, messages, stop_sequences=None, **kwargs):
37
  if not isinstance(messages, list) or not all(isinstance(m, dict) for m in messages):
38
  raise TypeError("Expected 'messages' to be a list of dicts")
39
 
 
40
  if not any(m.get("role") == "system" for m in messages):
41
  messages = [{"role": "system", "content": self.system_prompt}] + messages
42
 
 
43
  prompt = ""
44
  for m in messages:
45
  role = m["role"].capitalize()
@@ -47,20 +48,19 @@ class GeminiFlashModel:
47
  prompt += f"{role}: {content}\n"
48
 
49
  try:
50
- # Note: genai.GenerativeModel.generate_content may not support stop_sequences
51
  response = self.model.generate_content(prompt)
 
52
  return GenerationResult(
53
  content=response.text.strip(),
54
- token_usage={}, # you can extend if API provides token info
55
- input_tokens=0,
56
- output_tokens=0
57
  )
58
  except Exception as e:
 
59
  return GenerationResult(
60
  content=f"GENERATION ERROR: {e}",
61
- token_usage={},
62
  input_tokens=0,
63
- output_tokens=0
64
  )
65
 
66
  # Agent wrapper
@@ -70,15 +70,13 @@ class MyAgent:
70
  self.agent = CodeAgent(tools=[DuckDuckGoSearchTool()], model=self.model)
71
 
72
  def __call__(self, question: str) -> str:
73
- # The agent.run expects a string answer
74
  result = self.agent.run(question)
75
- # If result is GenerationResult or dict-like, convert to string
76
  if hasattr(result, "content"):
77
  return result.content
78
- elif isinstance(result, dict):
79
  return result.get("content", str(result))
80
- else:
81
- return str(result)
82
 
83
  # Main evaluation function
84
  def run_and_submit_all(profile: gr.OAuthProfile | None):
@@ -172,4 +170,3 @@ if __name__ == "__main__":
172
 
173
 
174
 
175
-
 
19
 
20
  # Generation result wrapper to match smolagents expectations
21
  class GenerationResult:
22
+ def __init__(self, content, input_tokens=0, output_tokens=0, token_usage=None):
23
  self.content = content
 
24
  self.input_tokens = input_tokens
25
  self.output_tokens = output_tokens
26
+ self.token_usage = token_usage or {}
27
 
28
  # Gemini model wrapper
29
  class GeminiFlashModel:
 
32
  self.model = genai.GenerativeModel(model_id)
33
  self.system_prompt = SYSTEM_PROMPT
34
 
35
+ def generate(self, messages, **kwargs):
 
36
  if not isinstance(messages, list) or not all(isinstance(m, dict) for m in messages):
37
  raise TypeError("Expected 'messages' to be a list of dicts")
38
 
39
+ # Ensure system prompt is first message
40
  if not any(m.get("role") == "system" for m in messages):
41
  messages = [{"role": "system", "content": self.system_prompt}] + messages
42
 
43
+ # Build prompt text by concatenating messages with roles
44
  prompt = ""
45
  for m in messages:
46
  role = m["role"].capitalize()
 
48
  prompt += f"{role}: {content}\n"
49
 
50
  try:
 
51
  response = self.model.generate_content(prompt)
52
+ # Always wrap the result in GenerationResult
53
  return GenerationResult(
54
  content=response.text.strip(),
55
+ input_tokens=0, # Could add token counts here if available
56
+ output_tokens=0,
 
57
  )
58
  except Exception as e:
59
+ # Wrap errors too, so agent doesn't fail
60
  return GenerationResult(
61
  content=f"GENERATION ERROR: {e}",
 
62
  input_tokens=0,
63
+ output_tokens=0,
64
  )
65
 
66
  # Agent wrapper
 
70
  self.agent = CodeAgent(tools=[DuckDuckGoSearchTool()], model=self.model)
71
 
72
  def __call__(self, question: str) -> str:
 
73
  result = self.agent.run(question)
74
+ # result can be GenerationResult or maybe dict or str - normalize:
75
  if hasattr(result, "content"):
76
  return result.content
77
+ if isinstance(result, dict):
78
  return result.get("content", str(result))
79
+ return str(result)
 
80
 
81
  # Main evaluation function
82
  def run_and_submit_all(profile: gr.OAuthProfile | None):
 
170
 
171
 
172