maya-persistence / src /llm /geminiLLM.py
anubhav77's picture
move to gemini-1.5-flash
9a1d7f1
from typing import Any, List, Mapping, Optional, Dict
from pydantic import Extra, Field # , root_validator, model_validator
import os, json
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.llms.base import LLM
import google.generativeai as genai
from google.generativeai import types
import ast
# from langchain.llms import GooglePalm
import requests, logging
logger = logging.getLogger("llm")
class GeminiLLM(LLM):
model_name: str = "gemini-1.5-flash" # "gemini-pro"
temperature: float = 0
max_tokens: int = 2048
stop: Optional[List] = []
prev_prompt: Optional[str] = ""
prev_stop: Optional[str] = ""
prev_run_manager: Optional[Any] = None
model: Optional[Any] = None
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.model = genai.GenerativeModel(self.model_name)
# self.model = palm.Text2Text(self.model_name)
@property
def _llm_type(self) -> str:
return "text2text-generation"
def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
) -> str:
self.prev_prompt = prompt
self.prev_stop = stop
self.prev_run_manager = run_manager
# print(types.SafetySettingDict)
if stop == None:
stop = self.stop
logger.debug("\nLLM in use is:" + self._llm_type)
logger.debug("Request to LLM is " + prompt)
response = self.model.generate_content(
prompt,
generation_config={
"stop_sequences": self.stop,
"temperature": self.temperature,
"max_output_tokens": self.max_tokens,
},
safety_settings=[
{
"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
"threshold": "BLOCK_NONE",
},
{"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_NONE"},
{"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_NONE"},
{
"category": "HARM_CATEGORY_DANGEROUS_CONTENT",
"threshold": "BLOCK_NONE",
},
],
stream=False,
)
try:
val = response.text
if val == None:
logger.debug("Response from LLM was None\n")
filterStr = ""
for item in response.filters:
for key, val in item.items():
filterStr += key + ":" + str(val)
logger.error(
"Will switch to fallback LLM as response from palm is None::"
+ filterStr
)
raise (Exception)
else:
logger.debug("Response from LLM " + val)
except Exception as ex:
logger.error("Will switch to fallback LLM as response from palm is None::")
raise (Exception)
if run_manager:
pass
# run_manager.on_llm_end(val)
return val
@property
def _identifying_params(self) -> Mapping[str, Any]:
"""Get the identifying parameters."""
return {"name": self.model_name, "type": "palm"}
def extractJson(self, val: str) -> Any:
"""Helper function to extract json from this LLMs output"""
# This is assuming the json is the first item within ````
# palm is responding always with ```json and ending with ```, however sometimes response is not complete
# in case trailing ``` is not seen, we will call generation again with prev_prompt and result appended to it
try:
count = 0
while val.startswith("```json") and not val.endswith("```") and count < 7:
val = self._call(
prompt=self.prev_prompt + " " + val,
stop=self.prev_stop,
run_manager=self.prev_run_manager,
)
count += 1
v2 = val.replace("```json", "```").split("```")[1]
try:
v4 = json.loads(v2)
except:
# v3=v2.replace("\n","").replace("\r","").replace("'","\"")
v3 = json.dumps(ast.literal_eval(v2))
v4 = json.loads(v3)
except:
v2 = val.replace("\n", "").replace("\r", "")
v3 = json.dumps(ast.literal_eval(val))
# v3=v2.replace("'","\"")
v4 = json.loads(v3)
# v4=json.loads(v2)
return v4
def extractPython(self, val: str) -> Any:
"""Helper function to extract python from this LLMs output"""
# This is assuming the python is the first item within ````
v2 = val.replace("```python", "```").split("```")[1]
return v2