Update app.py
Browse files
app.py
CHANGED
@@ -9,7 +9,7 @@ DEBUG = False
|
|
9 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
10 |
pipe = pipeline(
|
11 |
"text-generation",
|
12 |
-
model="jrahn/
|
13 |
torch_dtype=torch.bfloat16,
|
14 |
device=device
|
15 |
)
|
@@ -34,6 +34,7 @@ def generate_action(state):
|
|
34 |
try:
|
35 |
action = generation[0]['generated_text'].split("B: ")[-1].strip()
|
36 |
gr.Info(f"Policy generated move: {action}", duration=3)
|
|
|
37 |
except:
|
38 |
gr.Info(f"Policy generation invalid: {generation}", duration=None)
|
39 |
action = "0000"
|
@@ -48,7 +49,8 @@ def generate_state(state, action, history):
|
|
48 |
generation = pipe(prompt, **sampling_args)
|
49 |
if DEBUG: print(generation)
|
50 |
try:
|
51 |
-
|
|
|
52 |
#gr.Info(f"Environment generated state: {new_state}", duration=3)
|
53 |
except:
|
54 |
new_state, reward, terminated, truncated = START_POSITION, "0", "0", "1"
|
|
|
9 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
10 |
pipe = pipeline(
|
11 |
"text-generation",
|
12 |
+
model="jrahn/RookWorld-LM-124M",
|
13 |
torch_dtype=torch.bfloat16,
|
14 |
device=device
|
15 |
)
|
|
|
34 |
try:
|
35 |
action = generation[0]['generated_text'].split("B: ")[-1].strip()
|
36 |
gr.Info(f"Policy generated move: {action}", duration=3)
|
37 |
+
# TODO: display generated CoT
|
38 |
except:
|
39 |
gr.Info(f"Policy generation invalid: {generation}", duration=None)
|
40 |
action = "0000"
|
|
|
49 |
generation = pipe(prompt, **sampling_args)
|
50 |
if DEBUG: print(generation)
|
51 |
try:
|
52 |
+
parts = generation[0]['generated_text'].split("+")
|
53 |
+
new_state, reward, terminated, truncated = parts[-4], parts[-3], parts[-2], parts[-1]
|
54 |
#gr.Info(f"Environment generated state: {new_state}", duration=3)
|
55 |
except:
|
56 |
new_state, reward, terminated, truncated = START_POSITION, "0", "0", "1"
|