from typing import Any, List, Mapping, Optional, Dict from pydantic import Extra, Field # , root_validator, model_validator import os, json from langchain.callbacks.manager import CallbackManagerForLLMRun from langchain.llms.base import LLM import google.generativeai as genai from google.generativeai import types import ast # from langchain.llms import GooglePalm import requests, logging logger = logging.getLogger("llm") class GeminiLLM(LLM): model_name: str = "gemini-1.5-flash" # "gemini-pro" temperature: float = 0 max_tokens: int = 2048 stop: Optional[List] = [] prev_prompt: Optional[str] = "" prev_stop: Optional[str] = "" prev_run_manager: Optional[Any] = None model: Optional[Any] = None def __init__(self, **kwargs): super().__init__(**kwargs) self.model = genai.GenerativeModel(self.model_name) # self.model = palm.Text2Text(self.model_name) @property def _llm_type(self) -> str: return "text2text-generation" def _call( self, prompt: str, stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, ) -> str: self.prev_prompt = prompt self.prev_stop = stop self.prev_run_manager = run_manager # print(types.SafetySettingDict) if stop == None: stop = self.stop logger.debug("\nLLM in use is:" + self._llm_type) logger.debug("Request to LLM is " + prompt) response = self.model.generate_content( prompt, generation_config={ "stop_sequences": self.stop, "temperature": self.temperature, "max_output_tokens": self.max_tokens, }, safety_settings=[ { "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "BLOCK_NONE", }, {"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_NONE"}, {"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_NONE"}, { "category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "BLOCK_NONE", }, ], stream=False, ) try: val = response.text if val == None: logger.debug("Response from LLM was None\n") filterStr = "" for item in response.filters: for key, val in item.items(): filterStr += key + ":" + str(val) logger.error( "Will switch to fallback LLM as response from palm is None::" + filterStr ) raise (Exception) else: logger.debug("Response from LLM " + val) except Exception as ex: logger.error("Will switch to fallback LLM as response from palm is None::") raise (Exception) if run_manager: pass # run_manager.on_llm_end(val) return val @property def _identifying_params(self) -> Mapping[str, Any]: """Get the identifying parameters.""" return {"name": self.model_name, "type": "palm"} def extractJson(self, val: str) -> Any: """Helper function to extract json from this LLMs output""" # This is assuming the json is the first item within ```` # palm is responding always with ```json and ending with ```, however sometimes response is not complete # in case trailing ``` is not seen, we will call generation again with prev_prompt and result appended to it try: count = 0 while val.startswith("```json") and not val.endswith("```") and count < 7: val = self._call( prompt=self.prev_prompt + " " + val, stop=self.prev_stop, run_manager=self.prev_run_manager, ) count += 1 v2 = val.replace("```json", "```").split("```")[1] try: v4 = json.loads(v2) except: # v3=v2.replace("\n","").replace("\r","").replace("'","\"") v3 = json.dumps(ast.literal_eval(v2)) v4 = json.loads(v3) except: v2 = val.replace("\n", "").replace("\r", "") v3 = json.dumps(ast.literal_eval(val)) # v3=v2.replace("'","\"") v4 = json.loads(v3) # v4=json.loads(v2) return v4 def extractPython(self, val: str) -> Any: """Helper function to extract python from this LLMs output""" # This is assuming the python is the first item within ```` v2 = val.replace("```python", "```").split("```")[1] return v2