Spaces:
Runtime error
Runtime error
File size: 2,787 Bytes
5cee86f 643d279 5cee86f 4d381d5 5cee86f 7a59307 5cee86f ea73078 5cee86f b18e03b 1656b5a 0f5b6fb 5cee86f dc96994 5cee86f 63acdcd 5cee86f dc96994 5cee86f b0a3ad2 5cee86f e4d28ce dc96994 e44be4d 5cee86f a15843c 5cee86f 74e53ba 5cee86f 74e53ba 5cee86f 7a59307 1656b5a 5cee86f 1656b5a 5cee86f 1656b5a 5cee86f b18e03b 5cee86f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 |
from asyncio import constants
import gradio as gr
import requests
import os
import re
import random
from words import *
# GPT-J-6B API
API_URL = "https://api-inference.huggingface.co/models/EleutherAI/gpt-j-6B"
basePrompt="""
The following session was recorded from a text adventure game.
computer: you are an adventurer exploring the darkest dungeon
player: enter dungeon
"""
default_story="computer: you are standing in front of a dark dungeon.\n"
def fallbackResponse():
return "You are attacked by a {monster}!".format(monster=random.choice(monsters))
def continue_story(prompt,story):
print("about to die",basePrompt,story,prompt)
print("huh?",story)
p=basePrompt+story+"player:"+str(prompt)
print("got prompt:\n\n",p)
print(f"*****Inside desc_generate - Prompt is :{p}")
json_ = {"inputs": p,
"parameters":
{
"top_p": 0.9,
"temperature": 1.1,
"max_new_tokens": 50,
"return_full_text": False,
}}
#response = requests.post(API_URL, headers=headers, json=json_)
response = requests.post(API_URL, json=json_)
output = response.json()
print(f"If there was an error? Reason is : {output}")
#error handling
if "error" in output:
print("using fallback description method!")
#fallback method
output_tmp=fallbackResponse()
else:
print("generated text was",output[0]['generated_text'])
output_tmp = output[0]['generated_text']
#strip whitespace
output_tmp = output_tmp.strip()
#truncate response at first newline
if "\n" in output_tmp:
idx = output_tmp.find('\n')
output_tmp = output_tmp[:idx]
#check if response starts with "computer:", if not use fallback
if not output_tmp.startswith("computer:"):
output_tmp = "computer:"+fallbackResponse()
print("which was trimmed to",output_tmp)
#truncate story to last 6 lines
story_tmp = story.split("\n")
if len(story_tmp)>6:
story_tmp = story_tmp[-6:]
story = "\n".join(story_tmp)
#return story
story=story+"player:"+prompt+"\n"+output_tmp+"\n"
return story
demo = gr.Blocks()
with demo:
gr.Markdown("<h1><center>LiteDungeon</center></h1>")
gr.Markdown(
"<div>Create a text adventure, using GPT-J</div>"
)
with gr.Row():
output_story = gr.Textbox(value=default_story,label="story",lines=7)
with gr.Row():
input_command = gr.Textbox(label="input",placeholder="look around")
with gr.Row():
b0 = gr.Button("Submit")
b0.click(continue_story,inputs=[input_command,output_story],outputs=[output_story])
#examples=examples
demo.launch(enable_queue=True, debug=True) |