File size: 2,744 Bytes
5cee86f
 
 
 
 
 
 
 
643d279
 
 
 
5cee86f
 
4d381d5
 
5cee86f
 
74e53ba
5cee86f
 
 
 
ea73078
5cee86f
 
a15843c
1656b5a
0f5b6fb
 
 
 
5cee86f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dc96994
5cee86f
 
 
 
 
63acdcd
5cee86f
dc96994
 
 
5cee86f
b0a3ad2
 
5cee86f
 
 
 
 
 
 
e44be4d
dc96994
 
e44be4d
 
 
 
 
 
 
 
5cee86f
a15843c
5cee86f
 
 
 
 
74e53ba
5cee86f
74e53ba
5cee86f
 
 
99260a9
1656b5a
5cee86f
 
1656b5a
5cee86f
1656b5a
 
 
 
 
5cee86f
99260a9
5cee86f
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
from asyncio import constants
import gradio as gr
import requests
import os 
import re
import random
from words import *


# GPT-J-6B API
API_URL = "https://api-inference.huggingface.co/models/EleutherAI/gpt-j-6B"

basePrompt="""

The following session was recorded from a text adventure game.

computer: you are an adventurer exploring the darkest dungeon

player: enter dungeon

"""

story="""Computer: You are standing in the enterance of the dungeon.

"""


def fallbackResponse():
    return "You are attacked by a {monster}!".format(monster=random.choice(monsters))

def continue_story(prompt):
  global story
  print("about to die",basePrompt,story,prompt)
  print("huh?",story)
  p=basePrompt+story+"player:"+str(prompt)

  print("got prompt:\n\n",p)

  print(f"*****Inside desc_generate - Prompt is :{p}")
  json_ = {"inputs": p,
            "parameters":
            {
            "top_p": 0.9,
          "temperature": 1.1,
          "max_new_tokens": 50,
          "return_full_text": False,
          }}
  #response = requests.post(API_URL, headers=headers, json=json_)
  response = requests.post(API_URL, json=json_)
  output = response.json()
  print(f"If there was an error? Reason is : {output}")

  

  #error handling
  if "error" in output:
    print("using fallback description method!")
    #fallback method
    output_tmp=fallbackResponse()
  else:

    print("generated text was",output[0]['generated_text'])

    output_tmp = output[0]['generated_text']
    #strip whitespace
    output_tmp = output_tmp.strip()
    #truncate response at first newline
    if "\n" in output_tmp:
        idx = output_tmp.find('\n')
        output_tmp = output_tmp[:idx]
    #check if response starts with "computer:", if not add it
    if ":" not in output_tmp:
        output_tmp = "computer:"+output_tmp

    print("which was trimmed to",output_tmp)



  #truncate story to last 6 lines
  story_tmp = story.split("\n")
  if len(story_tmp)>6:
    story_tmp = story_tmp[-6:]
  story = "\n".join(story_tmp)
  #return story
  story=story+"player:"+prompt+"\n"+output_tmp+"\n"
  return story


demo = gr.Blocks()

with demo:
  gr.Markdown("<h1><center>LiteDungeon</center></h1>")
  gr.Markdown(
        "<div>Create a text adventure, using GPT-J</div>"
    )
  
  with gr.Row():
    output_story = gr.Textbox(value=story,label="story",lines=7)

  with gr.Row():  
    input_command = gr.Textbox(label="input",placeholder="look around")

  with gr.Row():
    b0 = gr.Button("Submit")
  
  
    
  
  
  b0.click(continue_story,inputs=[input_command],outputs=[output_story])
  #examples=examples

demo.launch(enable_queue=True, debug=True)