Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -1,5 +1,6 @@
|
|
1 |
import numpy as np
|
2 |
import gradio as gr
|
|
|
3 |
|
4 |
# Define the grid world environment
|
5 |
class GridWorld:
|
@@ -59,29 +60,48 @@ class QLearningAgent:
|
|
59 |
env = GridWorld()
|
60 |
agent = QLearningAgent(env)
|
61 |
|
62 |
-
|
63 |
-
def visualize_grid(agent_pos, goal_pos, obstacles):
|
64 |
grid = np.zeros((env.size, env.size), dtype=str)
|
65 |
grid[agent_pos[0], agent_pos[1]] = 'A'
|
66 |
grid[goal_pos[0], goal_pos[1]] = 'G'
|
67 |
for obstacle in obstacles:
|
68 |
grid[obstacle[0], obstacle[1]] = 'X'
|
69 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
70 |
|
71 |
def train_agent(steps=100):
|
72 |
state = env.reset()
|
|
|
73 |
for _ in range(steps):
|
74 |
action = agent.choose_action(state)
|
75 |
next_state, reward, done, _ = env.step(action)
|
76 |
agent.learn(state, action, reward, next_state)
|
77 |
state = next_state
|
|
|
78 |
if done:
|
79 |
break
|
80 |
-
|
|
|
81 |
|
82 |
# Create the Gradio interface
|
83 |
input_steps = gr.Slider(1, 1000, value=100, label="Number of Training Steps")
|
84 |
-
output_grid = gr.
|
85 |
|
86 |
# Define the Gradio interface function
|
87 |
def update_grid(steps):
|
|
|
1 |
import numpy as np
|
2 |
import gradio as gr
|
3 |
+
import matplotlib.pyplot as plt
|
4 |
|
5 |
# Define the grid world environment
|
6 |
class GridWorld:
|
|
|
60 |
env = GridWorld()
|
61 |
agent = QLearningAgent(env)
|
62 |
|
63 |
+
def visualize_grid(agent_pos, goal_pos, obstacles, path=None):
|
|
|
64 |
grid = np.zeros((env.size, env.size), dtype=str)
|
65 |
grid[agent_pos[0], agent_pos[1]] = 'A'
|
66 |
grid[goal_pos[0], goal_pos[1]] = 'G'
|
67 |
for obstacle in obstacles:
|
68 |
grid[obstacle[0], obstacle[1]] = 'X'
|
69 |
+
|
70 |
+
if path:
|
71 |
+
for step in path:
|
72 |
+
grid[step[0], step[1]] = 'P'
|
73 |
+
|
74 |
+
fig, ax = plt.subplots()
|
75 |
+
ax.imshow(grid, cmap='viridis', aspect='equal')
|
76 |
+
ax.set_xticks(np.arange(-0.5, env.size, 1))
|
77 |
+
ax.set_yticks(np.arange(-0.5, env.size, 1))
|
78 |
+
ax.grid(color='w', linestyle='-', linewidth=2)
|
79 |
+
ax.set_xticklabels([])
|
80 |
+
ax.set_yticklabels([])
|
81 |
+
|
82 |
+
for i in range(env.size):
|
83 |
+
for j in range(env.size):
|
84 |
+
ax.text(j, i, grid[i, j], ha='center', va='center', color='w')
|
85 |
+
|
86 |
+
return fig
|
87 |
|
88 |
def train_agent(steps=100):
|
89 |
state = env.reset()
|
90 |
+
path = [env.agent_pos]
|
91 |
for _ in range(steps):
|
92 |
action = agent.choose_action(state)
|
93 |
next_state, reward, done, _ = env.step(action)
|
94 |
agent.learn(state, action, reward, next_state)
|
95 |
state = next_state
|
96 |
+
path.append(env.agent_pos)
|
97 |
if done:
|
98 |
break
|
99 |
+
fig = visualize_grid(env.agent_pos, env.goal_pos, env.obstacles, path)
|
100 |
+
return fig
|
101 |
|
102 |
# Create the Gradio interface
|
103 |
input_steps = gr.Slider(1, 1000, value=100, label="Number of Training Steps")
|
104 |
+
output_grid = gr.Plot(label="Grid World")
|
105 |
|
106 |
# Define the Gradio interface function
|
107 |
def update_grid(steps):
|