File size: 2,880 Bytes
5cee86f
 
 
 
 
 
 
 
643d279
 
 
 
5cee86f
 
 
 
 
 
 
 
 
 
 
 
 
 
a15843c
1656b5a
0f5b6fb
 
 
 
5cee86f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63acdcd
5cee86f
 
 
 
 
 
 
 
 
e44be4d
 
 
 
 
 
 
 
 
5cee86f
a15843c
5cee86f
 
 
 
 
 
 
 
 
 
 
 
 
 
99260a9
1656b5a
5cee86f
 
1656b5a
5cee86f
1656b5a
 
 
 
 
5cee86f
99260a9
5cee86f
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
from asyncio import constants
import gradio as gr
import requests
import os 
import re
import random
from words import *


# GPT-J-6B API
API_URL = "https://api-inference.huggingface.co/models/EleutherAI/gpt-j-6B"

basePrompt="""

The following session was recorded from a text adventure game.

----

"""

story="""

Computer: You approach the enterance of the dungeon.

"""


def fallbackResponse():
    "You are attacked by a {monster}!".format(monster=random.choice(monsters))

def continue_story(prompt):
  global story
  print("about to die",basePrompt,story,prompt)
  print("huh?",story)
  p=basePrompt+story+"player:"+str(prompt)

  print("got prompt:\n\n",p)

  print(f"*****Inside desc_generate - Prompt is :{p}")
  json_ = {"inputs": p,
            "parameters":
            {
            "top_p": 0.9,
          "temperature": 1.1,
          "max_new_tokens": 50,
          "return_full_text": False,
          }}
  #response = requests.post(API_URL, headers=headers, json=json_)
  response = requests.post(API_URL, json=json_)
  output = response.json()
  print(f"If there was an error? Reason is : {output}")


  #error handling
  if "error" in output:
    print("using fallback description method!")
    #fallback method
    output_tmp=fallbackResponse()
  else:
    output_tmp = output[0]['generated_text']
    #truncate response at first newline
    if "\n" in output_tmp:
        idx = output_tmp.find('\n')
        output_tmp = output_tmp[:idx]
    #check if response starts with "computer:", if not add it
    if ":" not in output_tmp:
        output_tmp = "computer:"+output_tmp



  #truncate story to last 6 lines
  story_tmp = story.split("\n")
  if len(story_tmp)>6:
    story_tmp = story_tmp[-6:]
  story = "\n".join(story_tmp)
  #return story
  story=story+"player:"+prompt+"\n"+output_tmp+"\n"
  return story


demo = gr.Blocks()

with demo:
  gr.Markdown("<h1><center>NPC Generator</center></h1>")
  gr.Markdown(
        "based on <a href=https://huggingface.co/spaces/Gradio-Blocks/GPTJ6B_Poetry_LatentDiff_Illustration> Gradio poetry generator</a>."
        "<div>first input name, race and class (or generate them randomly)</div>"
        "<div>Next, use GPT-J to generate a short description</div>"
        "<div>Finally, Generate an illustration 🎨 provided by <a href=https://huggingface.co/spaces/multimodalart/latentdiffusion>Latent Diffusion model</a>.</div>"
    )
  
  with gr.Row():
    output_story = gr.Textbox(value=story,label="story",lines=7)

  with gr.Row():  
    input_command = gr.Textbox(label="input",placeholder="look around")

  with gr.Row():
    b0 = gr.Button("Submit")
  
  
    
  
  
  b0.click(continue_story,inputs=[input_command],outputs=[output_story])
  #examples=examples

demo.launch(enable_queue=True, debug=True)