broadfield-dev commited on
Commit
29e5977
·
verified ·
1 Parent(s): 816eda2

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +100 -0
app.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import gradio as gr
3
+
4
+ # Define the grid world environment
5
+ class GridWorld:
6
+ def __init__(self, size=4):
7
+ self.size = size
8
+ self.agent_pos = [0, 0]
9
+ self.goal_pos = [size-1, size-1]
10
+ self.obstacles = [(1, 1), (2, 2)]
11
+
12
+ def reset(self):
13
+ self.agent_pos = [0, 0]
14
+ return self.agent_pos
15
+
16
+ def step(self, action):
17
+ x, y = self.agent_pos
18
+
19
+ if action == 0: # Up
20
+ x = max(0, x - 1)
21
+ elif action == 1: # Down
22
+ x = min(self.size - 1, x + 1)
23
+ elif action == 2: # Left
24
+ y = max(0, y - 1)
25
+ elif action == 3: # Right
26
+ y = min(self.size - 1, y + 1)
27
+
28
+ self.agent_pos = [x, y]
29
+
30
+ if tuple(self.agent_pos) in self.obstacles:
31
+ return self.agent_pos, -10, False, {}
32
+ elif self.agent_pos == self.goal_pos:
33
+ return self.agent_pos, 10, True, {}
34
+ else:
35
+ return self.agent_pos, -1, False, {}
36
+
37
+ # Define the RL agent
38
+ class QLearningAgent:
39
+ def __init__(self, env, alpha=0.1, gamma=0.9, epsilon=0.1):
40
+ self.env = env
41
+ self.alpha = alpha
42
+ self.gamma = gamma
43
+ self.epsilon = epsilon
44
+ self.q_table = np.zeros((env.size, env.size, 4))
45
+
46
+ def choose_action(self, state):
47
+ if np.random.uniform(0, 1) < self.epsilon:
48
+ return np.random.choice(4)
49
+ else:
50
+ return np.argmax(self.q_table[state[0], state[1]])
51
+
52
+ def learn(self, state, action, reward, next_state):
53
+ best_next_action = np.argmax(self.q_table[next_state[0], next_state[1]])
54
+ td_target = reward + self.gamma * self.q_table[next_state[0], next_state[1], best_next_action]
55
+ td_error = td_target - self.q_table[state[0], state[1], action]
56
+ self.q_table[state[0], state[1], action] += self.alpha * td_error
57
+
58
+ # Initialize the environment and agent
59
+ env = GridWorld()
60
+ agent = QLearningAgent(env)
61
+
62
+
63
+ def visualize_grid(agent_pos, goal_pos, obstacles):
64
+ grid = np.zeros((env.size, env.size), dtype=str)
65
+ grid[agent_pos[0], agent_pos[1]] = 'A'
66
+ grid[goal_pos[0], goal_pos[1]] = 'G'
67
+ for obstacle in obstacles:
68
+ grid[obstacle[0], obstacle[1]] = 'X'
69
+ return '\n'.join([' '.join(row) for row in grid])
70
+
71
+ def train_agent(steps=100):
72
+ state = env.reset()
73
+ for _ in range(steps):
74
+ action = agent.choose_action(state)
75
+ next_state, reward, done, _ = env.step(action)
76
+ agent.learn(state, action, reward, next_state)
77
+ state = next_state
78
+ if done:
79
+ break
80
+ return visualize_grid(env.agent_pos, env.goal_pos, env.obstacles)
81
+
82
+ # Create the Gradio interface
83
+ input_steps = gr.Slider(1, 1000, value=100, label="Number of Training Steps")
84
+ output_grid = gr.Textbox(label="Grid World")
85
+
86
+ # Define the Gradio interface function
87
+ def update_grid(steps):
88
+ return train_agent(steps)
89
+
90
+ # Create the Gradio interface
91
+ iface = gr.Interface(
92
+ fn=update_grid,
93
+ inputs=[input_steps],
94
+ outputs=[output_grid],
95
+ title="Reinforcement Learning with Grid World",
96
+ description="Train a Q-learning agent to navigate a grid world and visualize the results."
97
+ )
98
+
99
+ # Launch the interface
100
+ iface.launch()