jhansss commited on
Commit
7a41e86
·
1 Parent(s): 749c8c3

Add more iterations to prompt gemini

Browse files
Files changed (1) hide show
  1. modules/llm/gemini.py +14 -11
modules/llm/gemini.py CHANGED
@@ -28,6 +28,7 @@ class GeminiLLM(AbstractLLMModel):
28
  prompt: str,
29
  system_prompt: Optional[str] = None,
30
  max_output_tokens: int = 1024,
 
31
  **kwargs,
32
  ) -> str:
33
  generation_config_dict = {
@@ -36,15 +37,17 @@ class GeminiLLM(AbstractLLMModel):
36
  }
37
  if system_prompt:
38
  generation_config_dict["system_instruction"] = system_prompt
39
- response = self.client.models.generate_content(
40
- model=self.model_id,
41
- contents=prompt,
42
- config=types.GenerateContentConfig(**generation_config_dict),
43
- )
44
- if response.text:
45
- return response.text
46
- else:
47
- print(
48
- f"No response from Gemini. May need to increase max_new_tokens. Current max_new_tokens: {max_new_tokens}"
49
  )
50
- return ""
 
 
 
 
 
 
 
 
28
  prompt: str,
29
  system_prompt: Optional[str] = None,
30
  max_output_tokens: int = 1024,
31
+ max_iterations: int = 3,
32
  **kwargs,
33
  ) -> str:
34
  generation_config_dict = {
 
37
  }
38
  if system_prompt:
39
  generation_config_dict["system_instruction"] = system_prompt
40
+ for _ in range(max_iterations):
41
+ response = self.client.models.generate_content(
42
+ model=self.model_id,
43
+ contents=prompt,
44
+ config=types.GenerateContentConfig(**generation_config_dict),
 
 
 
 
 
45
  )
46
+ if response.text:
47
+ return response.text
48
+ else:
49
+ print(
50
+ f"No response from Gemini. May need to increase max_output_tokens. Current {max_output_tokens=}. Try again."
51
+ )
52
+ print(f"Failed to generate response from Gemini after {max_iterations} attempts.")
53
+ return ""