Spaces:
Runtime error
Runtime error
from typing import Any, List, Mapping, Optional, Dict | |
from pydantic import Extra, Field # , root_validator, model_validator | |
import os, json | |
from langchain.callbacks.manager import CallbackManagerForLLMRun | |
from langchain.llms.base import LLM | |
import google.generativeai as genai | |
from google.generativeai import types | |
import ast | |
# from langchain.llms import GooglePalm | |
import requests, logging | |
logger = logging.getLogger("llm") | |
class GeminiLLM(LLM): | |
model_name: str = "gemini-1.5-flash" # "gemini-pro" | |
temperature: float = 0 | |
max_tokens: int = 2048 | |
stop: Optional[List] = [] | |
prev_prompt: Optional[str] = "" | |
prev_stop: Optional[str] = "" | |
prev_run_manager: Optional[Any] = None | |
model: Optional[Any] = None | |
def __init__(self, **kwargs): | |
super().__init__(**kwargs) | |
self.model = genai.GenerativeModel(self.model_name) | |
# self.model = palm.Text2Text(self.model_name) | |
def _llm_type(self) -> str: | |
return "text2text-generation" | |
def _call( | |
self, | |
prompt: str, | |
stop: Optional[List[str]] = None, | |
run_manager: Optional[CallbackManagerForLLMRun] = None, | |
) -> str: | |
self.prev_prompt = prompt | |
self.prev_stop = stop | |
self.prev_run_manager = run_manager | |
# print(types.SafetySettingDict) | |
if stop == None: | |
stop = self.stop | |
logger.debug("\nLLM in use is:" + self._llm_type) | |
logger.debug("Request to LLM is " + prompt) | |
response = self.model.generate_content( | |
prompt, | |
generation_config={ | |
"stop_sequences": self.stop, | |
"temperature": self.temperature, | |
"max_output_tokens": self.max_tokens, | |
}, | |
safety_settings=[ | |
{ | |
"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", | |
"threshold": "BLOCK_NONE", | |
}, | |
{"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_NONE"}, | |
{"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_NONE"}, | |
{ | |
"category": "HARM_CATEGORY_DANGEROUS_CONTENT", | |
"threshold": "BLOCK_NONE", | |
}, | |
], | |
stream=False, | |
) | |
try: | |
val = response.text | |
if val == None: | |
logger.debug("Response from LLM was None\n") | |
filterStr = "" | |
for item in response.filters: | |
for key, val in item.items(): | |
filterStr += key + ":" + str(val) | |
logger.error( | |
"Will switch to fallback LLM as response from palm is None::" | |
+ filterStr | |
) | |
raise (Exception) | |
else: | |
logger.debug("Response from LLM " + val) | |
except Exception as ex: | |
logger.error("Will switch to fallback LLM as response from palm is None::") | |
raise (Exception) | |
if run_manager: | |
pass | |
# run_manager.on_llm_end(val) | |
return val | |
def _identifying_params(self) -> Mapping[str, Any]: | |
"""Get the identifying parameters.""" | |
return {"name": self.model_name, "type": "palm"} | |
def extractJson(self, val: str) -> Any: | |
"""Helper function to extract json from this LLMs output""" | |
# This is assuming the json is the first item within ```` | |
# palm is responding always with ```json and ending with ```, however sometimes response is not complete | |
# in case trailing ``` is not seen, we will call generation again with prev_prompt and result appended to it | |
try: | |
count = 0 | |
while val.startswith("```json") and not val.endswith("```") and count < 7: | |
val = self._call( | |
prompt=self.prev_prompt + " " + val, | |
stop=self.prev_stop, | |
run_manager=self.prev_run_manager, | |
) | |
count += 1 | |
v2 = val.replace("```json", "```").split("```")[1] | |
try: | |
v4 = json.loads(v2) | |
except: | |
# v3=v2.replace("\n","").replace("\r","").replace("'","\"") | |
v3 = json.dumps(ast.literal_eval(v2)) | |
v4 = json.loads(v3) | |
except: | |
v2 = val.replace("\n", "").replace("\r", "") | |
v3 = json.dumps(ast.literal_eval(val)) | |
# v3=v2.replace("'","\"") | |
v4 = json.loads(v3) | |
# v4=json.loads(v2) | |
return v4 | |
def extractPython(self, val: str) -> Any: | |
"""Helper function to extract python from this LLMs output""" | |
# This is assuming the python is the first item within ```` | |
v2 = val.replace("```python", "```").split("```")[1] | |
return v2 | |