dlaima commited on
Commit
badada6
·
verified ·
1 Parent(s): 7de7b09

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -6
app.py CHANGED
@@ -17,7 +17,12 @@ If you're asked for a string, don’t use articles or abbreviations (e.g. for ci
17
 
18
  DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
19
 
20
- # Gemini model wrapper (lightweight, no smolagents.model.base)
 
 
 
 
 
21
  class GeminiFlashModel:
22
  def __init__(self, model_id="gemini-1.5-flash", api_key=None):
23
  import google.generativeai as genai
@@ -26,14 +31,12 @@ class GeminiFlashModel:
26
  self.system_prompt = SYSTEM_PROMPT
27
 
28
  def generate(self, messages, stop_sequences=None, **kwargs):
29
- # ✅ Validate and prepend system prompt
30
  if not isinstance(messages, list) or not all(isinstance(m, dict) for m in messages):
31
  raise TypeError("Expected 'messages' to be a list of dicts")
32
 
33
  if not any(m.get("role") == "system" for m in messages):
34
  messages = [{"role": "system", "content": self.system_prompt}] + messages
35
 
36
- # ✅ Construct prompt string (you may improve formatting here)
37
  prompt = ""
38
  for m in messages:
39
  role = m["role"].capitalize()
@@ -42,10 +45,9 @@ class GeminiFlashModel:
42
 
43
  try:
44
  response = self.model.generate_content(prompt)
45
- return {"content": response.text.strip()} # ✅ SmolAgent expects this format
46
  except Exception as e:
47
- return {"content": f"GENERATION ERROR: {e}"}
48
-
49
 
50
  # Agent using Gemini
51
  class MyAgent:
 
17
 
18
  DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
19
 
20
+ # Wrapper class for model output
21
+ class GenerationResult:
22
+ def __init__(self, content):
23
+ self.content = content
24
+
25
+ # Gemini model wrapper
26
  class GeminiFlashModel:
27
  def __init__(self, model_id="gemini-1.5-flash", api_key=None):
28
  import google.generativeai as genai
 
31
  self.system_prompt = SYSTEM_PROMPT
32
 
33
  def generate(self, messages, stop_sequences=None, **kwargs):
 
34
  if not isinstance(messages, list) or not all(isinstance(m, dict) for m in messages):
35
  raise TypeError("Expected 'messages' to be a list of dicts")
36
 
37
  if not any(m.get("role") == "system" for m in messages):
38
  messages = [{"role": "system", "content": self.system_prompt}] + messages
39
 
 
40
  prompt = ""
41
  for m in messages:
42
  role = m["role"].capitalize()
 
45
 
46
  try:
47
  response = self.model.generate_content(prompt)
48
+ return GenerationResult(response.text.strip())
49
  except Exception as e:
50
+ return GenerationResult(f"GENERATION ERROR: {e}")
 
51
 
52
  # Agent using Gemini
53
  class MyAgent: