broadfield-dev commited on
Commit
fc3cc26
·
verified ·
1 Parent(s): 29e5977

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -5
app.py CHANGED
@@ -1,5 +1,6 @@
1
  import numpy as np
2
  import gradio as gr
 
3
 
4
  # Define the grid world environment
5
  class GridWorld:
@@ -59,29 +60,48 @@ class QLearningAgent:
59
  env = GridWorld()
60
  agent = QLearningAgent(env)
61
 
62
-
63
- def visualize_grid(agent_pos, goal_pos, obstacles):
64
  grid = np.zeros((env.size, env.size), dtype=str)
65
  grid[agent_pos[0], agent_pos[1]] = 'A'
66
  grid[goal_pos[0], goal_pos[1]] = 'G'
67
  for obstacle in obstacles:
68
  grid[obstacle[0], obstacle[1]] = 'X'
69
- return '\n'.join([' '.join(row) for row in grid])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
 
71
  def train_agent(steps=100):
72
  state = env.reset()
 
73
  for _ in range(steps):
74
  action = agent.choose_action(state)
75
  next_state, reward, done, _ = env.step(action)
76
  agent.learn(state, action, reward, next_state)
77
  state = next_state
 
78
  if done:
79
  break
80
- return visualize_grid(env.agent_pos, env.goal_pos, env.obstacles)
 
81
 
82
  # Create the Gradio interface
83
  input_steps = gr.Slider(1, 1000, value=100, label="Number of Training Steps")
84
- output_grid = gr.Textbox(label="Grid World")
85
 
86
  # Define the Gradio interface function
87
  def update_grid(steps):
 
1
  import numpy as np
2
  import gradio as gr
3
+ import matplotlib.pyplot as plt
4
 
5
  # Define the grid world environment
6
  class GridWorld:
 
60
  env = GridWorld()
61
  agent = QLearningAgent(env)
62
 
63
+ def visualize_grid(agent_pos, goal_pos, obstacles, path=None):
 
64
  grid = np.zeros((env.size, env.size), dtype=str)
65
  grid[agent_pos[0], agent_pos[1]] = 'A'
66
  grid[goal_pos[0], goal_pos[1]] = 'G'
67
  for obstacle in obstacles:
68
  grid[obstacle[0], obstacle[1]] = 'X'
69
+
70
+ if path:
71
+ for step in path:
72
+ grid[step[0], step[1]] = 'P'
73
+
74
+ fig, ax = plt.subplots()
75
+ ax.imshow(grid, cmap='viridis', aspect='equal')
76
+ ax.set_xticks(np.arange(-0.5, env.size, 1))
77
+ ax.set_yticks(np.arange(-0.5, env.size, 1))
78
+ ax.grid(color='w', linestyle='-', linewidth=2)
79
+ ax.set_xticklabels([])
80
+ ax.set_yticklabels([])
81
+
82
+ for i in range(env.size):
83
+ for j in range(env.size):
84
+ ax.text(j, i, grid[i, j], ha='center', va='center', color='w')
85
+
86
+ return fig
87
 
88
  def train_agent(steps=100):
89
  state = env.reset()
90
+ path = [env.agent_pos]
91
  for _ in range(steps):
92
  action = agent.choose_action(state)
93
  next_state, reward, done, _ = env.step(action)
94
  agent.learn(state, action, reward, next_state)
95
  state = next_state
96
+ path.append(env.agent_pos)
97
  if done:
98
  break
99
+ fig = visualize_grid(env.agent_pos, env.goal_pos, env.obstacles, path)
100
+ return fig
101
 
102
  # Create the Gradio interface
103
  input_steps = gr.Slider(1, 1000, value=100, label="Number of Training Steps")
104
+ output_grid = gr.Plot(label="Grid World")
105
 
106
  # Define the Gradio interface function
107
  def update_grid(steps):