Thomas G. Lopes commited on
Commit
dab40ed
·
unverified ·
1 Parent(s): 7419732

branching wip (#98)

Browse files
src/lib/components/inference-playground/message.svelte CHANGED
@@ -22,6 +22,9 @@
22
  import { parseThinkingTokens } from "$lib/utils/thinking.js";
23
  import IconChevronDown from "~icons/carbon/chevron-down";
24
  import IconChevronRight from "~icons/carbon/chevron-right";
 
 
 
25
 
26
  type Props = {
27
  conversation: ConversationClass;
@@ -341,6 +344,34 @@
341
  </div>
342
  </Tooltip>
343
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
344
  <Tooltip>
345
  {#snippet trigger(tooltip)}
346
  <button
@@ -399,3 +430,23 @@
399
  {/each}
400
  </div>
401
  </div>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  import { parseThinkingTokens } from "$lib/utils/thinking.js";
23
  import IconChevronDown from "~icons/carbon/chevron-down";
24
  import IconChevronRight from "~icons/carbon/chevron-right";
25
+ import ArrowSplitRounded from "~icons/material-symbols/arrow-split-rounded";
26
+ import { addToast } from "$lib/components/toaster.svelte.js";
27
+ import { projects } from "$lib/state/projects.svelte";
28
 
29
  type Props = {
30
  conversation: ConversationClass;
 
344
  </div>
345
  </Tooltip>
346
 
347
+ <Tooltip>
348
+ {#snippet trigger(tooltip)}
349
+ <button
350
+ tabindex="0"
351
+ onclick={async () => {
352
+ try {
353
+ await projects.branch(projects.activeId, index);
354
+ } catch (error) {
355
+ addToast({
356
+ title: "Error",
357
+ description: error instanceof Error ? error.message : "Failed to create branch",
358
+ variant: "error",
359
+ });
360
+ }
361
+ }}
362
+ type="button"
363
+ class="grid size-7 place-items-center border border-gray-200 bg-white text-xs font-medium text-gray-900 hover:bg-gray-100
364
+ hover:text-blue-700 focus:z-10 focus:ring-4 focus:ring-gray-100
365
+ focus:outline-hidden dark:border-gray-600 dark:bg-gray-800
366
+ dark:text-gray-400 dark:hover:bg-gray-700 dark:hover:text-white dark:focus:ring-gray-700"
367
+ {...tooltip.trigger}
368
+ >
369
+ <ArrowSplitRounded />
370
+ </button>
371
+ {/snippet}
372
+ Branch from here
373
+ </Tooltip>
374
+
375
  <Tooltip>
376
  {#snippet trigger(tooltip)}
377
  <button
 
430
  {/each}
431
  </div>
432
  </div>
433
+
434
+ {#if projects.current?.branchedFromId && projects.current?.branchedFromMessageIndex === index}
435
+ <div class="mt-4 flex items-center justify-center">
436
+ <div
437
+ class="flex items-center gap-1 rounded-full bg-gray-100 px-3 py-1.5 text-sm text-gray-600 dark:bg-gray-800 dark:text-gray-400"
438
+ >
439
+ <ArrowSplitRounded class="mr-1 size-4" />
440
+ <span>Branched from</span>
441
+ <button
442
+ onclick={() => {
443
+ if (!projects.current?.branchedFromId) return;
444
+ projects.activeId = projects.current.branchedFromId;
445
+ }}
446
+ class="font-medium text-blue-600 underline hover:text-blue-800 dark:text-blue-400 dark:hover:text-blue-300"
447
+ >
448
+ {projects.getBranchedFromProject(projects.current.id)?.name || "original project"}
449
+ </button>
450
+ </div>
451
+ </div>
452
+ {/if}
src/lib/components/inference-playground/project-select.svelte CHANGED
@@ -10,6 +10,7 @@
10
  import IconHistory from "~icons/carbon/recently-viewed";
11
  import IconSave from "~icons/carbon/save";
12
  import IconDelete from "~icons/carbon/trash-can";
 
13
  import Dialog from "../dialog.svelte";
14
  import { prompt } from "../prompts.svelte";
15
  import Tooltip from "../tooltip.svelte";
@@ -114,7 +115,21 @@
114
  >
115
  <div class="flex items-center gap-2">
116
  {name}
117
- {#if hasCheckpoints}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
118
  <div
119
  class="text-3xs grid aspect-square place-items-center rounded bg-yellow-300 p-0.5 text-yellow-700 dark:bg-yellow-400/25 dark:text-yellow-400"
120
  aria-label="Project has checkpoints"
 
10
  import IconHistory from "~icons/carbon/recently-viewed";
11
  import IconSave from "~icons/carbon/save";
12
  import IconDelete from "~icons/carbon/trash-can";
13
+ import ArrowSplitRounded from "~icons/material-symbols/arrow-split-rounded";
14
  import Dialog from "../dialog.svelte";
15
  import { prompt } from "../prompts.svelte";
16
  import Tooltip from "../tooltip.svelte";
 
115
  >
116
  <div class="flex items-center gap-2">
117
  {name}
118
+ {#if projects.all.find(p => p.id === id)?.branchedFromId}
119
+ {@const originalProject = projects.getBranchedFromProject(id)}
120
+ <Tooltip>
121
+ {#snippet trigger(tooltip)}
122
+ <div
123
+ class="text-3xs grid aspect-square place-items-center rounded bg-blue-300 p-0.5 text-blue-700 dark:bg-blue-400/25 dark:text-blue-400"
124
+ aria-label="Branched project"
125
+ {...tooltip.trigger}
126
+ >
127
+ <ArrowSplitRounded />
128
+ </div>
129
+ {/snippet}
130
+ Branched from {originalProject?.name || "unknown project"}
131
+ </Tooltip>
132
+ {:else if hasCheckpoints}
133
  <div
134
  class="text-3xs grid aspect-square place-items-center rounded bg-yellow-300 p-0.5 text-yellow-700 dark:bg-yellow-400/25 dark:text-yellow-400"
135
  aria-label="Project has checkpoints"
src/lib/state/conversations.svelte.ts CHANGED
@@ -157,9 +157,24 @@ export class ConversationClass {
157
  });
158
  };
159
 
 
 
 
 
 
 
 
 
 
 
 
160
  deleteMessage = async (idx: number) => {
161
  if (!this.data.messages) return;
162
  const imgKeys = this.data.messages.flatMap(m => m.images).filter(isString);
 
 
 
 
163
  await Promise.all([
164
  ...imgKeys.map(k => images.delete(k)),
165
  this.update({
@@ -174,6 +189,9 @@ export class ConversationClass {
174
  const sliced = this.data.messages.slice(0, from);
175
  const notSliced = this.data.messages.slice(from);
176
 
 
 
 
177
  const imgKeys = notSliced.flatMap(m => m.images).filter(isString);
178
  await Promise.all([
179
  ...imgKeys.map(k => images.delete(k)),
@@ -380,6 +398,42 @@ class Conversations {
380
  );
381
  };
382
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
383
  genNextMessages = async (conv: "left" | "right" | "both" | ConversationClass = "both") => {
384
  if (!token.value) {
385
  token.showModal = true;
 
157
  });
158
  };
159
 
160
+ checkAndClearBranchStatus = async (deletionIndex: number) => {
161
+ const currentProject = projects.current;
162
+
163
+ if (!currentProject?.branchedFromId || typeof currentProject?.branchedFromMessageIndex !== "number") return;
164
+
165
+ // If we're deleting messages at or before the branch point, clear branch status
166
+ if (deletionIndex <= currentProject.branchedFromMessageIndex) {
167
+ await projects.clearBranchStatus(currentProject.id);
168
+ }
169
+ };
170
+
171
  deleteMessage = async (idx: number) => {
172
  if (!this.data.messages) return;
173
  const imgKeys = this.data.messages.flatMap(m => m.images).filter(isString);
174
+
175
+ // Check if we need to clear branch status
176
+ await this.checkAndClearBranchStatus(idx);
177
+
178
  await Promise.all([
179
  ...imgKeys.map(k => images.delete(k)),
180
  this.update({
 
189
  const sliced = this.data.messages.slice(0, from);
190
  const notSliced = this.data.messages.slice(from);
191
 
192
+ // Check if we need to clear branch status
193
+ await this.checkAndClearBranchStatus(from);
194
+
195
  const imgKeys = notSliced.flatMap(m => m.images).filter(isString);
196
  await Promise.all([
197
  ...imgKeys.map(k => images.delete(k)),
 
398
  );
399
  };
400
 
401
+ duplicateUpToMessage = async (from: ProjectEntity["id"], to: ProjectEntity["id"], messageIndex: number) => {
402
+ const fromArr = this.#conversations[from] ?? [];
403
+
404
+ // Clear any existing conversations for the target project first
405
+ this.#conversations[to] = [];
406
+
407
+ // Delete any existing conversations in the database for this project
408
+ const existingConversations = await conversationsRepo.find({ where: { projectId: to } });
409
+ await Promise.all(existingConversations.map(c => conversationsRepo.delete(c.id)));
410
+
411
+ const newConversations: ConversationClass[] = [];
412
+
413
+ for (const c of fromArr) {
414
+ // Copy only messages up to the specified index with deep clone
415
+ const truncatedMessages =
416
+ c.data.messages?.slice(0, messageIndex + 1).map(msg => ({
417
+ ...msg,
418
+ images: msg.images ? [...msg.images] : undefined,
419
+ })) || [];
420
+
421
+ const conversationData = {
422
+ ...snapshot(c.data),
423
+ projectId: to,
424
+ messages: truncatedMessages,
425
+ id: undefined, // Let the database generate a new ID
426
+ };
427
+
428
+ // Use conversationsRepo directly to avoid default conversation merging
429
+ const saved = await conversationsRepo.save(conversationData);
430
+ newConversations.push(new ConversationClass(saved));
431
+ }
432
+
433
+ // Update the in-memory cache
434
+ this.#conversations[to] = newConversations;
435
+ };
436
+
437
  genNextMessages = async (conv: "left" | "right" | "both" | ConversationClass = "both") => {
438
  if (!token.value) {
439
  token.showModal = true;
src/lib/state/projects.svelte.ts CHANGED
@@ -15,6 +15,12 @@ export class ProjectEntity {
15
 
16
  @Fields.string()
17
  systemMessage?: string;
 
 
 
 
 
 
18
  }
19
 
20
  export type ProjectEntityMembers = MembersOnly<ProjectEntity>;
@@ -72,14 +78,6 @@ class Projects {
72
  return id;
73
  };
74
 
75
- setCurrent = async (id: string) => {
76
- await checkpoints.migrate(id, this.activeId);
77
- conversations.migrate(this.activeId, id).then(() => {
78
- this.#activeId.current = id;
79
- });
80
- this.activeId = id;
81
- };
82
-
83
  get current() {
84
  return this.#projects[this.activeId];
85
  }
@@ -105,6 +103,46 @@ class Projects {
105
  this.activeId = DEFAULT_PROJECT_ID;
106
  }
107
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
108
  }
109
 
110
  export const projects = new Projects();
 
15
 
16
  @Fields.string()
17
  systemMessage?: string;
18
+
19
+ @Fields.string()
20
+ branchedFromId?: string | null;
21
+
22
+ @Fields.number()
23
+ branchedFromMessageIndex?: number | null;
24
  }
25
 
26
  export type ProjectEntityMembers = MembersOnly<ProjectEntity>;
 
78
  return id;
79
  };
80
 
 
 
 
 
 
 
 
 
81
  get current() {
82
  return this.#projects[this.activeId];
83
  }
 
103
  this.activeId = DEFAULT_PROJECT_ID;
104
  }
105
  }
106
+
107
+ branch = async (fromProjectId: string, messageIndex: number): Promise<string> => {
108
+ const fromProject = this.#projects[fromProjectId];
109
+ if (!fromProject) throw new Error("Source project not found");
110
+
111
+ // Create new project with branching info
112
+ const newProjectId = await this.create({
113
+ name: `${fromProject.name} (branch)`,
114
+ systemMessage: fromProject.systemMessage,
115
+ branchedFromId: fromProjectId,
116
+ branchedFromMessageIndex: messageIndex,
117
+ });
118
+
119
+ // Copy conversations up to the specified message index
120
+ await conversations.duplicateUpToMessage(fromProjectId, newProjectId, messageIndex);
121
+
122
+ // Switch to the new project
123
+ this.activeId = newProjectId;
124
+
125
+ return newProjectId;
126
+ };
127
+
128
+ getBranchedFromProject = (projectId: string) => {
129
+ const project = this.#projects[projectId];
130
+ if (!project?.branchedFromId) return null;
131
+
132
+ const originalProject = this.#projects[project.branchedFromId];
133
+ return originalProject;
134
+ };
135
+
136
+ clearBranchStatus = async (projectId: string) => {
137
+ const project = this.#projects[projectId];
138
+ if (!project?.branchedFromId) return;
139
+
140
+ await this.update({
141
+ ...project,
142
+ branchedFromId: null,
143
+ branchedFromMessageIndex: null,
144
+ });
145
+ };
146
  }
147
 
148
  export const projects = new Projects();