Spaces:
Runtime error
Runtime error
File size: 4,900 Bytes
9a1d7f1 37419af 9a1d7f1 37419af 9a1d7f1 37419af 9a1d7f1 37419af 9a1d7f1 37419af 9a1d7f1 37419af 9a1d7f1 37419af 9a1d7f1 37419af 9a1d7f1 37419af 9a1d7f1 37419af 9a1d7f1 37419af 9a1d7f1 37419af 9a1d7f1 37419af 4ab9cb1 9a1d7f1 37419af 9a1d7f1 37419af 9a1d7f1 37419af 9a1d7f1 37419af 9a1d7f1 37419af 9a1d7f1 37419af 9a1d7f1 37419af 9a1d7f1 37419af 9a1d7f1 37419af 9a1d7f1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 |
from typing import Any, List, Mapping, Optional, Dict
from pydantic import Extra, Field # , root_validator, model_validator
import os, json
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.llms.base import LLM
import google.generativeai as genai
from google.generativeai import types
import ast
# from langchain.llms import GooglePalm
import requests, logging
logger = logging.getLogger("llm")
class GeminiLLM(LLM):
model_name: str = "gemini-1.5-flash" # "gemini-pro"
temperature: float = 0
max_tokens: int = 2048
stop: Optional[List] = []
prev_prompt: Optional[str] = ""
prev_stop: Optional[str] = ""
prev_run_manager: Optional[Any] = None
model: Optional[Any] = None
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.model = genai.GenerativeModel(self.model_name)
# self.model = palm.Text2Text(self.model_name)
@property
def _llm_type(self) -> str:
return "text2text-generation"
def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
) -> str:
self.prev_prompt = prompt
self.prev_stop = stop
self.prev_run_manager = run_manager
# print(types.SafetySettingDict)
if stop == None:
stop = self.stop
logger.debug("\nLLM in use is:" + self._llm_type)
logger.debug("Request to LLM is " + prompt)
response = self.model.generate_content(
prompt,
generation_config={
"stop_sequences": self.stop,
"temperature": self.temperature,
"max_output_tokens": self.max_tokens,
},
safety_settings=[
{
"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
"threshold": "BLOCK_NONE",
},
{"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_NONE"},
{"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_NONE"},
{
"category": "HARM_CATEGORY_DANGEROUS_CONTENT",
"threshold": "BLOCK_NONE",
},
],
stream=False,
)
try:
val = response.text
if val == None:
logger.debug("Response from LLM was None\n")
filterStr = ""
for item in response.filters:
for key, val in item.items():
filterStr += key + ":" + str(val)
logger.error(
"Will switch to fallback LLM as response from palm is None::"
+ filterStr
)
raise (Exception)
else:
logger.debug("Response from LLM " + val)
except Exception as ex:
logger.error("Will switch to fallback LLM as response from palm is None::")
raise (Exception)
if run_manager:
pass
# run_manager.on_llm_end(val)
return val
@property
def _identifying_params(self) -> Mapping[str, Any]:
"""Get the identifying parameters."""
return {"name": self.model_name, "type": "palm"}
def extractJson(self, val: str) -> Any:
"""Helper function to extract json from this LLMs output"""
# This is assuming the json is the first item within ````
# palm is responding always with ```json and ending with ```, however sometimes response is not complete
# in case trailing ``` is not seen, we will call generation again with prev_prompt and result appended to it
try:
count = 0
while val.startswith("```json") and not val.endswith("```") and count < 7:
val = self._call(
prompt=self.prev_prompt + " " + val,
stop=self.prev_stop,
run_manager=self.prev_run_manager,
)
count += 1
v2 = val.replace("```json", "```").split("```")[1]
try:
v4 = json.loads(v2)
except:
# v3=v2.replace("\n","").replace("\r","").replace("'","\"")
v3 = json.dumps(ast.literal_eval(v2))
v4 = json.loads(v3)
except:
v2 = val.replace("\n", "").replace("\r", "")
v3 = json.dumps(ast.literal_eval(val))
# v3=v2.replace("'","\"")
v4 = json.loads(v3)
# v4=json.loads(v2)
return v4
def extractPython(self, val: str) -> Any:
"""Helper function to extract python from this LLMs output"""
# This is assuming the python is the first item within ````
v2 = val.replace("```python", "```").split("```")[1]
return v2
|