Spaces:
Build error
Build error
File size: 10,415 Bytes
3382f47 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 |
import 'package:auto_gpt_flutter_client/models/benchmark/benchmark_run.dart';
import 'package:auto_gpt_flutter_client/models/benchmark/benchmark_step_request_body.dart';
import 'package:auto_gpt_flutter_client/models/benchmark/benchmark_task_request_body.dart';
import 'package:auto_gpt_flutter_client/models/benchmark/benchmark_task_status.dart';
import 'package:auto_gpt_flutter_client/models/skill_tree/skill_tree_edge.dart';
import 'package:auto_gpt_flutter_client/models/skill_tree/skill_tree_node.dart';
import 'package:auto_gpt_flutter_client/models/step.dart';
import 'package:auto_gpt_flutter_client/models/task.dart';
import 'package:auto_gpt_flutter_client/models/test_option.dart';
import 'package:auto_gpt_flutter_client/models/test_suite.dart';
import 'package:auto_gpt_flutter_client/services/benchmark_service.dart';
import 'package:auto_gpt_flutter_client/services/leaderboard_service.dart';
import 'package:auto_gpt_flutter_client/services/shared_preferences_service.dart';
import 'package:auto_gpt_flutter_client/viewmodels/chat_viewmodel.dart';
import 'package:auto_gpt_flutter_client/viewmodels/task_viewmodel.dart';
import 'package:collection/collection.dart';
import 'package:flutter/foundation.dart';
import 'package:uuid/uuid.dart';
import 'package:auto_gpt_flutter_client/utils/stack.dart';
class TaskQueueViewModel extends ChangeNotifier {
final BenchmarkService benchmarkService;
final LeaderboardService leaderboardService;
final SharedPreferencesService prefsService;
bool isBenchmarkRunning = false;
Map<SkillTreeNode, BenchmarkTaskStatus> benchmarkStatusMap = {};
List<BenchmarkRun> currentBenchmarkRuns = [];
List<SkillTreeNode>? _selectedNodeHierarchy;
TestOption _selectedOption = TestOption.runSingleTest;
TestOption get selectedOption => _selectedOption;
List<SkillTreeNode>? get selectedNodeHierarchy => _selectedNodeHierarchy;
TaskQueueViewModel(
this.benchmarkService, this.leaderboardService, this.prefsService);
void updateSelectedNodeHierarchyBasedOnOption(
TestOption selectedOption,
SkillTreeNode? selectedNode,
List<SkillTreeNode> nodes,
List<SkillTreeEdge> edges) {
_selectedOption = selectedOption;
switch (selectedOption) {
case TestOption.runSingleTest:
_selectedNodeHierarchy = selectedNode != null ? [selectedNode] : [];
break;
case TestOption.runTestSuiteIncludingSelectedNodeAndAncestors:
if (selectedNode != null) {
populateSelectedNodeHierarchy(selectedNode.id, nodes, edges);
}
break;
case TestOption.runAllTestsInCategory:
if (selectedNode != null) {
_getAllNodesInDepthFirstOrderEnsuringParents(nodes, edges);
}
break;
}
notifyListeners();
}
void _getAllNodesInDepthFirstOrderEnsuringParents(
List<SkillTreeNode> skillTreeNodes, List<SkillTreeEdge> skillTreeEdges) {
var nodes = <SkillTreeNode>[];
var stack = Stack<SkillTreeNode>();
var visited = <String>{};
// Identify the root node by its label
var root = skillTreeNodes.firstWhere((node) => node.label == "WriteFile");
stack.push(root);
visited.add(root.id);
while (stack.isNotEmpty) {
var node = stack.peek(); // Peek the top node, but do not remove it yet
var parents =
_getParentsOfNodeUsingEdges(node.id, skillTreeNodes, skillTreeEdges);
// Check if all parents are visited
if (parents.every((parent) => visited.contains(parent.id))) {
nodes.add(node);
stack.pop(); // Remove the node only when all its parents are visited
// Get the children of the current node using edges
var children = _getChildrenOfNodeUsingEdges(
node.id, skillTreeNodes, skillTreeEdges)
.where((child) => !visited.contains(child.id));
children.forEach((child) {
visited.add(child.id);
stack.push(child);
});
} else {
stack
.pop(); // Remove the node if not all parents are visited, it will be re-added when its parents are visited
}
}
_selectedNodeHierarchy = nodes;
}
List<SkillTreeNode> _getParentsOfNodeUsingEdges(
String nodeId, List<SkillTreeNode> nodes, List<SkillTreeEdge> edges) {
var parents = <SkillTreeNode>[];
for (var edge in edges) {
if (edge.to == nodeId) {
parents.add(nodes.firstWhere((node) => node.id == edge.from));
}
}
return parents;
}
List<SkillTreeNode> _getChildrenOfNodeUsingEdges(
String nodeId, List<SkillTreeNode> nodes, List<SkillTreeEdge> edges) {
var children = <SkillTreeNode>[];
for (var edge in edges) {
if (edge.from == nodeId) {
children.add(nodes.firstWhere((node) => node.id == edge.to));
}
}
return children;
}
// TODO: Do we want to continue testing other branches of tree if one branch side fails benchmarking?
void populateSelectedNodeHierarchy(String startNodeId,
List<SkillTreeNode> nodes, List<SkillTreeEdge> edges) {
_selectedNodeHierarchy = <SkillTreeNode>[];
final addedNodes = <String>{};
recursivePopulateHierarchy(startNodeId, addedNodes, nodes, edges);
notifyListeners();
}
void recursivePopulateHierarchy(String nodeId, Set<String> addedNodes,
List<SkillTreeNode> nodes, List<SkillTreeEdge> edges) {
// Find the current node in the skill tree nodes list.
final currentNode = nodes.firstWhereOrNull((node) => node.id == nodeId);
// If the node is found and it hasn't been added yet, proceed with the population.
if (currentNode != null && addedNodes.add(currentNode.id)) {
// Find all parent edges for the current node.
final parentEdges = edges.where((edge) => edge.to == currentNode.id);
// For each parent edge found, recurse to the parent node.
for (final parentEdge in parentEdges) {
// Recurse to the parent node identified by the 'from' field of the edge.
recursivePopulateHierarchy(parentEdge.from, addedNodes, nodes, edges);
}
// After processing all parent nodes, add the current node to the list.
_selectedNodeHierarchy!.add(currentNode);
}
}
Future<void> runBenchmark(
ChatViewModel chatViewModel, TaskViewModel taskViewModel) async {
// Clear the benchmarkStatusList
benchmarkStatusMap.clear();
// Reset the current benchmark runs list to be empty at the start of a new benchmark
currentBenchmarkRuns = [];
// Create a new TestSuite object with the current timestamp
final testSuite =
TestSuite(timestamp: DateTime.now().toIso8601String(), tests: []);
// Set the benchmark running flag to true
isBenchmarkRunning = true;
// Notify listeners
notifyListeners();
// Populate benchmarkStatusList with node hierarchy
for (var node in _selectedNodeHierarchy!) {
benchmarkStatusMap[node] = BenchmarkTaskStatus.notStarted;
}
try {
// Loop through the nodes in the hierarchy
for (var node in _selectedNodeHierarchy!) {
benchmarkStatusMap[node] = BenchmarkTaskStatus.inProgress;
notifyListeners();
// Create a BenchmarkTaskRequestBody
final benchmarkTaskRequestBody = BenchmarkTaskRequestBody(
input: node.data.task, evalId: node.data.evalId);
// Create a new benchmark task
final createdTask = await benchmarkService
.createBenchmarkTask(benchmarkTaskRequestBody);
// Create a new Task object
final task =
Task(id: createdTask['task_id'], title: createdTask['input']);
// Update the current task ID in ChatViewModel
chatViewModel.setCurrentTaskId(task.id);
// Execute the first step and initialize the Step object
Map<String, dynamic> stepResponse =
await benchmarkService.executeBenchmarkStep(
task.id, BenchmarkStepRequestBody(input: node.data.task));
Step step = Step.fromMap(stepResponse);
chatViewModel.fetchChatsForTask();
// Check if it's the last step
while (!step.isLast) {
// Execute next step and update the Step object
stepResponse = await benchmarkService.executeBenchmarkStep(
task.id, BenchmarkStepRequestBody(input: null));
step = Step.fromMap(stepResponse);
// Fetch chats for the task
chatViewModel.fetchChatsForTask();
}
// Trigger the evaluation
final evaluationResponse =
await benchmarkService.triggerEvaluation(task.id);
// Decode the evaluationResponse into a BenchmarkRun object
BenchmarkRun benchmarkRun = BenchmarkRun.fromJson(evaluationResponse);
// Add the benchmark run object to the list of current benchmark runs
currentBenchmarkRuns.add(benchmarkRun);
// Update the benchmarkStatusList based on the evaluation response
bool successStatus = benchmarkRun.metrics.success;
benchmarkStatusMap[node] = successStatus
? BenchmarkTaskStatus.success
: BenchmarkTaskStatus.failure;
await Future.delayed(Duration(seconds: 1));
notifyListeners();
testSuite.tests.add(task);
// If successStatus is false, break out of the loop
if (!successStatus) {
print(
"Benchmark for node ${node.id} failed. Stopping all benchmarks.");
break;
}
}
// Add the TestSuite to the TaskViewModel
taskViewModel.addTestSuite(testSuite);
} catch (e) {
print("Error while running benchmark: $e");
}
// Reset the benchmark running flag
isBenchmarkRunning = false;
notifyListeners();
}
Future<void> submitToLeaderboard(
String teamName, String repoUrl, String agentGitCommitSha) async {
// Create a UUID.v4 for our unique run ID
String uuid = const Uuid().v4();
for (var run in currentBenchmarkRuns) {
run.repositoryInfo.teamName = teamName;
run.repositoryInfo.repoUrl = repoUrl;
run.repositoryInfo.agentGitCommitSha = agentGitCommitSha;
run.runDetails.runId = uuid;
await leaderboardService.submitReport(run);
print('Completed submission to leaderboard!');
}
// Clear the currentBenchmarkRuns list after submitting to the leaderboard
currentBenchmarkRuns.clear();
}
}
|