jrahn commited on
Commit
d28b192
·
verified ·
1 Parent(s): 0a2b7af

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -2
app.py CHANGED
@@ -9,7 +9,7 @@ DEBUG = False
9
  device = "cuda" if torch.cuda.is_available() else "cpu"
10
  pipe = pipeline(
11
  "text-generation",
12
- model="jrahn/rookworld_7m_3e_gpt2_124M_hf",
13
  torch_dtype=torch.bfloat16,
14
  device=device
15
  )
@@ -34,6 +34,7 @@ def generate_action(state):
34
  try:
35
  action = generation[0]['generated_text'].split("B: ")[-1].strip()
36
  gr.Info(f"Policy generated move: {action}", duration=3)
 
37
  except:
38
  gr.Info(f"Policy generation invalid: {generation}", duration=None)
39
  action = "0000"
@@ -48,7 +49,8 @@ def generate_state(state, action, history):
48
  generation = pipe(prompt, **sampling_args)
49
  if DEBUG: print(generation)
50
  try:
51
- new_state, reward, terminated, truncated = generation[0]['generated_text'].split("+")
 
52
  #gr.Info(f"Environment generated state: {new_state}", duration=3)
53
  except:
54
  new_state, reward, terminated, truncated = START_POSITION, "0", "0", "1"
 
9
  device = "cuda" if torch.cuda.is_available() else "cpu"
10
  pipe = pipeline(
11
  "text-generation",
12
+ model="jrahn/RookWorld-LM-124M",
13
  torch_dtype=torch.bfloat16,
14
  device=device
15
  )
 
34
  try:
35
  action = generation[0]['generated_text'].split("B: ")[-1].strip()
36
  gr.Info(f"Policy generated move: {action}", duration=3)
37
+ # TODO: display generated CoT
38
  except:
39
  gr.Info(f"Policy generation invalid: {generation}", duration=None)
40
  action = "0000"
 
49
  generation = pipe(prompt, **sampling_args)
50
  if DEBUG: print(generation)
51
  try:
52
+ parts = generation[0]['generated_text'].split("+")
53
+ new_state, reward, terminated, truncated = parts[-4], parts[-3], parts[-2], parts[-1]
54
  #gr.Info(f"Environment generated state: {new_state}", duration=3)
55
  except:
56
  new_state, reward, terminated, truncated = START_POSITION, "0", "0", "1"