Update TextGen/router.py
Browse files- TextGen/router.py +54 -13
TextGen/router.py
CHANGED
|
@@ -9,7 +9,7 @@ from fastapi.middleware.cors import CORSMiddleware
|
|
| 9 |
from langchain.chains import LLMChain
|
| 10 |
from langchain.prompts import PromptTemplate
|
| 11 |
from TextGen.suno import custom_generate_audio, get_audio_information,generate_lyrics
|
| 12 |
-
from TextGen.diffusion import generate_image
|
| 13 |
#from coqui import predict
|
| 14 |
from langchain_google_genai import (
|
| 15 |
ChatGoogleGenerativeAI,
|
|
@@ -73,6 +73,11 @@ main_npcs={
|
|
| 73 |
"Herbalist":"./voices/female.mp3",
|
| 74 |
"Bard":"./voices/Bard_voice.mp3"
|
| 75 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 76 |
main_npc_system_prompts={
|
| 77 |
"Blacksmith":"You are a blacksmith in a video game",
|
| 78 |
"Herbalist":"You are an herbalist in a video game",
|
|
@@ -82,6 +87,10 @@ main_npc_system_prompts={
|
|
| 82 |
class Generate(BaseModel):
|
| 83 |
text:str
|
| 84 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 85 |
def generate_text(messages: List[str], npc:str):
|
| 86 |
print(npc)
|
| 87 |
if npc in main_npcs:
|
|
@@ -123,6 +132,24 @@ app.add_middleware(
|
|
| 123 |
allow_headers=["*"],
|
| 124 |
)
|
| 125 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 126 |
@app.get("/", tags=["Home"])
|
| 127 |
def api_home():
|
| 128 |
return {'detail': 'Everchanging Quest backend, nothing to see here'}
|
|
@@ -131,6 +158,10 @@ def api_home():
|
|
| 131 |
def inference(message: Message):
|
| 132 |
return generate_text(messages=message.messages, npc=message.npc)
|
| 133 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 134 |
#Dummy function for now
|
| 135 |
def determine_vocie_from_npc(npc,genre):
|
| 136 |
if npc in main_npcs:
|
|
@@ -142,7 +173,17 @@ def determine_vocie_from_npc(npc,genre):
|
|
| 142 |
return"./voices/default_female.mp3"
|
| 143 |
else:
|
| 144 |
return "./voices/narator_out.wav"
|
| 145 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 146 |
|
| 147 |
@app.post("/generate_wav")
|
| 148 |
async def generate_wav(message: VoiceMessage):
|
|
@@ -234,15 +275,15 @@ async def generate_song():
|
|
| 234 |
infos=get_audio_information(f"{data[0]['id']},{data[1]['id']}")
|
| 235 |
return infos
|
| 236 |
|
| 237 |
-
|
| 238 |
-
def Imagen(image:ImageGen=None):
|
| 239 |
-
pil_image =generate_image(image.prompt)
|
| 240 |
-
|
| 241 |
-
|
| 242 |
-
# Convert the PIL Image to bytes
|
| 243 |
-
img_byte_arr = BytesIO()
|
| 244 |
-
pil_image.save(img_byte_arr, format='PNG')
|
| 245 |
-
img_byte_arr = img_byte_arr.getvalue()
|
| 246 |
-
|
| 247 |
# Return the image as a PNG response
|
| 248 |
-
return Response(content=img_byte_arr, media_type="image/png")
|
|
|
|
| 9 |
from langchain.chains import LLMChain
|
| 10 |
from langchain.prompts import PromptTemplate
|
| 11 |
from TextGen.suno import custom_generate_audio, get_audio_information,generate_lyrics
|
| 12 |
+
#from TextGen.diffusion import generate_image
|
| 13 |
#from coqui import predict
|
| 14 |
from langchain_google_genai import (
|
| 15 |
ChatGoogleGenerativeAI,
|
|
|
|
| 73 |
"Herbalist":"./voices/female.mp3",
|
| 74 |
"Bard":"./voices/Bard_voice.mp3"
|
| 75 |
}
|
| 76 |
+
main_npcs_elevenlabs={
|
| 77 |
+
"Blacksmith":"",
|
| 78 |
+
"Herbalist":"",
|
| 79 |
+
"Bard":""
|
| 80 |
+
}
|
| 81 |
main_npc_system_prompts={
|
| 82 |
"Blacksmith":"You are a blacksmith in a video game",
|
| 83 |
"Herbalist":"You are an herbalist in a video game",
|
|
|
|
| 87 |
class Generate(BaseModel):
|
| 88 |
text:str
|
| 89 |
|
| 90 |
+
class Invoke(BaseModel):
|
| 91 |
+
system_prompt:str
|
| 92 |
+
message:str
|
| 93 |
+
|
| 94 |
def generate_text(messages: List[str], npc:str):
|
| 95 |
print(npc)
|
| 96 |
if npc in main_npcs:
|
|
|
|
| 132 |
allow_headers=["*"],
|
| 133 |
)
|
| 134 |
|
| 135 |
+
def inference_model(system_messsage, prompt):
|
| 136 |
+
|
| 137 |
+
new_messages=[{"role": "user", "content": system_messsage},{"role": "user", "content": prompt}]
|
| 138 |
+
llm = ChatGoogleGenerativeAI(
|
| 139 |
+
model="gemini-1.5-pro-latest",
|
| 140 |
+
max_output_tokens=100,
|
| 141 |
+
temperature=1,
|
| 142 |
+
safety_settings={
|
| 143 |
+
HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE,
|
| 144 |
+
HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_NONE,
|
| 145 |
+
HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE,
|
| 146 |
+
HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE
|
| 147 |
+
},
|
| 148 |
+
)
|
| 149 |
+
llm_response = llm.invoke(new_messages)
|
| 150 |
+
print(llm_response)
|
| 151 |
+
return Generate(text=llm_response.content)
|
| 152 |
+
|
| 153 |
@app.get("/", tags=["Home"])
|
| 154 |
def api_home():
|
| 155 |
return {'detail': 'Everchanging Quest backend, nothing to see here'}
|
|
|
|
| 158 |
def inference(message: Message):
|
| 159 |
return generate_text(messages=message.messages, npc=message.npc)
|
| 160 |
|
| 161 |
+
@app.post("/invoke_model", response_model=Generate)
|
| 162 |
+
def story(prompt: Invoke):
|
| 163 |
+
return inference_model(system_messsage=prompt.system_prompt,prompt=prompt.message)
|
| 164 |
+
|
| 165 |
#Dummy function for now
|
| 166 |
def determine_vocie_from_npc(npc,genre):
|
| 167 |
if npc in main_npcs:
|
|
|
|
| 173 |
return"./voices/default_female.mp3"
|
| 174 |
else:
|
| 175 |
return "./voices/narator_out.wav"
|
| 176 |
+
#Dummy function for now
|
| 177 |
+
def determine_elevenLav_voice_from_npc(npc,genre):
|
| 178 |
+
if npc in main_npcs:
|
| 179 |
+
return main_npcs[npc]
|
| 180 |
+
else:
|
| 181 |
+
if genre =="Male":
|
| 182 |
+
"./voices/default_male.mp3"
|
| 183 |
+
if genre=="Female":
|
| 184 |
+
return"./voices/default_female.mp3"
|
| 185 |
+
else:
|
| 186 |
+
return "./voices/narator_out.wav"
|
| 187 |
|
| 188 |
@app.post("/generate_wav")
|
| 189 |
async def generate_wav(message: VoiceMessage):
|
|
|
|
| 275 |
infos=get_audio_information(f"{data[0]['id']},{data[1]['id']}")
|
| 276 |
return infos
|
| 277 |
|
| 278 |
+
#@app.post('/generate_image')
|
| 279 |
+
#def Imagen(image:ImageGen=None):
|
| 280 |
+
# pil_image =generate_image(image.prompt)
|
| 281 |
+
#
|
| 282 |
+
#
|
| 283 |
+
# # Convert the PIL Image to bytes
|
| 284 |
+
# img_byte_arr = BytesIO()
|
| 285 |
+
# pil_image.save(img_byte_arr, format='PNG')
|
| 286 |
+
# img_byte_arr = img_byte_arr.getvalue()
|
| 287 |
+
#
|
| 288 |
# Return the image as a PNG response
|
| 289 |
+
# return Response(content=img_byte_arr, media_type="image/png")
|