Spaces:
Paused
Paused
feat(tools): use gradio ETA for tools (#1411)
Browse files
src/lib/components/chat/ToolUpdate.svelte
CHANGED
|
@@ -9,7 +9,7 @@
|
|
| 9 |
import CarbonTools from "~icons/carbon/tools";
|
| 10 |
import { ToolResultStatus, type ToolFront } from "$lib/types/Tool";
|
| 11 |
import { page } from "$app/stores";
|
| 12 |
-
import {
|
| 13 |
import { browser } from "$app/environment";
|
| 14 |
|
| 15 |
export let tool: MessageToolUpdate[];
|
|
@@ -19,6 +19,8 @@
|
|
| 19 |
$: toolError = tool.some(isMessageToolErrorUpdate);
|
| 20 |
$: toolDone = tool.some(isMessageToolResultUpdate);
|
| 21 |
|
|
|
|
|
|
|
| 22 |
const availableTools: ToolFront[] = $page.data.tools;
|
| 23 |
|
| 24 |
let loadingBarEl: HTMLDivElement;
|
|
@@ -26,16 +28,24 @@
|
|
| 26 |
|
| 27 |
let isShowingLoadingBar = false;
|
| 28 |
|
| 29 |
-
|
| 30 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 31 |
loadingBarEl.classList.remove("hidden");
|
| 32 |
isShowingLoadingBar = true;
|
| 33 |
animation = loadingBarEl.animate([{ width: "0%" }, { width: "calc(100%+1rem)" }], {
|
| 34 |
-
duration:
|
| 35 |
fill: "forwards",
|
| 36 |
});
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 37 |
}
|
| 38 |
-
return () => animation?.cancel();
|
| 39 |
});
|
| 40 |
|
| 41 |
// go to 100% quickly if loading is done
|
|
|
|
| 9 |
import CarbonTools from "~icons/carbon/tools";
|
| 10 |
import { ToolResultStatus, type ToolFront } from "$lib/types/Tool";
|
| 11 |
import { page } from "$app/stores";
|
| 12 |
+
import { onDestroy } from "svelte";
|
| 13 |
import { browser } from "$app/environment";
|
| 14 |
|
| 15 |
export let tool: MessageToolUpdate[];
|
|
|
|
| 19 |
$: toolError = tool.some(isMessageToolErrorUpdate);
|
| 20 |
$: toolDone = tool.some(isMessageToolResultUpdate);
|
| 21 |
|
| 22 |
+
$: eta = tool.find((el) => el.subtype === MessageToolUpdateType.ETA)?.eta;
|
| 23 |
+
|
| 24 |
const availableTools: ToolFront[] = $page.data.tools;
|
| 25 |
|
| 26 |
let loadingBarEl: HTMLDivElement;
|
|
|
|
| 28 |
|
| 29 |
let isShowingLoadingBar = false;
|
| 30 |
|
| 31 |
+
$: !toolError &&
|
| 32 |
+
!toolDone &&
|
| 33 |
+
loading &&
|
| 34 |
+
loadingBarEl &&
|
| 35 |
+
eta &&
|
| 36 |
+
(() => {
|
| 37 |
loadingBarEl.classList.remove("hidden");
|
| 38 |
isShowingLoadingBar = true;
|
| 39 |
animation = loadingBarEl.animate([{ width: "0%" }, { width: "calc(100%+1rem)" }], {
|
| 40 |
+
duration: eta * 1000,
|
| 41 |
fill: "forwards",
|
| 42 |
});
|
| 43 |
+
})();
|
| 44 |
+
|
| 45 |
+
onDestroy(() => {
|
| 46 |
+
if (animation) {
|
| 47 |
+
animation.cancel();
|
| 48 |
}
|
|
|
|
| 49 |
});
|
| 50 |
|
| 51 |
// go to 100% quickly if loading is done
|
src/lib/server/textGeneration/tools.ts
CHANGED
|
@@ -70,7 +70,7 @@ async function* callTool(
|
|
| 70 |
};
|
| 71 |
|
| 72 |
try {
|
| 73 |
-
const toolResult = yield* tool.call(call.parameters, ctx);
|
| 74 |
|
| 75 |
yield {
|
| 76 |
type: MessageUpdateType.Tool,
|
|
|
|
| 70 |
};
|
| 71 |
|
| 72 |
try {
|
| 73 |
+
const toolResult = yield* tool.call(call.parameters, ctx, uuid);
|
| 74 |
|
| 75 |
yield {
|
| 76 |
type: MessageUpdateType.Tool,
|
src/lib/server/tools/index.ts
CHANGED
|
@@ -119,7 +119,7 @@ export const configTools = z
|
|
| 119 |
.transform((val) => [...val, calculator, directlyAnswer, fetchUrl, websearch]);
|
| 120 |
|
| 121 |
export function getCallMethod(tool: Omit<BaseTool, "call">): BackendCall {
|
| 122 |
-
return async function* (params, ctx) {
|
| 123 |
if (
|
| 124 |
tool.endpoint === null ||
|
| 125 |
!tool.baseUrl ||
|
|
@@ -203,11 +203,12 @@ export function getCallMethod(tool: Omit<BaseTool, "call">): BackendCall {
|
|
| 203 |
}
|
| 204 |
});
|
| 205 |
|
| 206 |
-
const outputs =
|
| 207 |
tool.baseUrl,
|
| 208 |
tool.endpoint,
|
| 209 |
await Promise.all(inputs),
|
| 210 |
-
ipToken
|
|
|
|
| 211 |
);
|
| 212 |
|
| 213 |
if (!isValidOutputComponent(tool.outputComponent)) {
|
|
|
|
| 119 |
.transform((val) => [...val, calculator, directlyAnswer, fetchUrl, websearch]);
|
| 120 |
|
| 121 |
export function getCallMethod(tool: Omit<BaseTool, "call">): BackendCall {
|
| 122 |
+
return async function* (params, ctx, uuid) {
|
| 123 |
if (
|
| 124 |
tool.endpoint === null ||
|
| 125 |
!tool.baseUrl ||
|
|
|
|
| 203 |
}
|
| 204 |
});
|
| 205 |
|
| 206 |
+
const outputs = yield* callSpace(
|
| 207 |
tool.baseUrl,
|
| 208 |
tool.endpoint,
|
| 209 |
await Promise.all(inputs),
|
| 210 |
+
ipToken,
|
| 211 |
+
uuid
|
| 212 |
);
|
| 213 |
|
| 214 |
if (!isValidOutputComponent(tool.outputComponent)) {
|
src/lib/server/tools/utils.ts
CHANGED
|
@@ -1,27 +1,20 @@
|
|
| 1 |
import { env } from "$env/dynamic/private";
|
| 2 |
import { Client } from "@gradio/client";
|
| 3 |
import { SignJWT } from "jose";
|
| 4 |
-
import { logger } from "../logger";
|
| 5 |
import JSON5 from "json5";
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
|
| 7 |
-
export
|
| 8 |
-
path: string;
|
| 9 |
-
url: string;
|
| 10 |
-
orig_name: string;
|
| 11 |
-
is_stream: boolean;
|
| 12 |
-
meta: Record<string, unknown>;
|
| 13 |
-
};
|
| 14 |
-
|
| 15 |
-
type GradioResponse = {
|
| 16 |
-
data: unknown[];
|
| 17 |
-
};
|
| 18 |
-
|
| 19 |
-
export async function callSpace<TInput extends unknown[], TOutput extends unknown[]>(
|
| 20 |
name: string,
|
| 21 |
func: string,
|
| 22 |
parameters: TInput,
|
| 23 |
-
ipToken: string | undefined
|
| 24 |
-
|
|
|
|
| 25 |
class CustomClient extends Client {
|
| 26 |
fetch(input: RequestInfo | URL, init?: RequestInit): Promise<Response> {
|
| 27 |
init = init || {};
|
|
@@ -34,15 +27,32 @@ export async function callSpace<TInput extends unknown[], TOutput extends unknow
|
|
| 34 |
}
|
| 35 |
const client = await CustomClient.connect(name, {
|
| 36 |
hf_token: (env.HF_TOKEN ?? env.HF_ACCESS_TOKEN) as unknown as `hf_${string}`,
|
|
|
|
| 37 |
});
|
| 38 |
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 46 |
}
|
| 47 |
|
| 48 |
export async function getIpToken(ip: string, username?: string) {
|
|
|
|
| 1 |
import { env } from "$env/dynamic/private";
|
| 2 |
import { Client } from "@gradio/client";
|
| 3 |
import { SignJWT } from "jose";
|
|
|
|
| 4 |
import JSON5 from "json5";
|
| 5 |
+
import {
|
| 6 |
+
MessageToolUpdateType,
|
| 7 |
+
MessageUpdateType,
|
| 8 |
+
type MessageToolUpdate,
|
| 9 |
+
} from "$lib/types/MessageUpdate";
|
| 10 |
|
| 11 |
+
export async function* callSpace<TInput extends unknown[], TOutput extends unknown[]>(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
name: string,
|
| 13 |
func: string,
|
| 14 |
parameters: TInput,
|
| 15 |
+
ipToken: string | undefined,
|
| 16 |
+
uuid: string
|
| 17 |
+
): AsyncGenerator<MessageToolUpdate, TOutput, undefined> {
|
| 18 |
class CustomClient extends Client {
|
| 19 |
fetch(input: RequestInfo | URL, init?: RequestInit): Promise<Response> {
|
| 20 |
init = init || {};
|
|
|
|
| 27 |
}
|
| 28 |
const client = await CustomClient.connect(name, {
|
| 29 |
hf_token: (env.HF_TOKEN ?? env.HF_ACCESS_TOKEN) as unknown as `hf_${string}`,
|
| 30 |
+
events: ["status", "data"],
|
| 31 |
});
|
| 32 |
|
| 33 |
+
const job = client.submit(func, parameters);
|
| 34 |
+
|
| 35 |
+
let data;
|
| 36 |
+
for await (const output of job) {
|
| 37 |
+
console.log({ output });
|
| 38 |
+
if (output.type === "data") {
|
| 39 |
+
data = output.data as TOutput;
|
| 40 |
+
}
|
| 41 |
+
if (output.type === "status" && output.eta) {
|
| 42 |
+
yield {
|
| 43 |
+
type: MessageUpdateType.Tool,
|
| 44 |
+
subtype: MessageToolUpdateType.ETA,
|
| 45 |
+
eta: output.eta,
|
| 46 |
+
uuid,
|
| 47 |
+
};
|
| 48 |
+
}
|
| 49 |
+
}
|
| 50 |
+
|
| 51 |
+
if (!data) {
|
| 52 |
+
throw new Error("No data found in tool call");
|
| 53 |
+
}
|
| 54 |
+
|
| 55 |
+
return data;
|
| 56 |
}
|
| 57 |
|
| 58 |
export async function getIpToken(ip: string, username?: string) {
|
src/lib/types/MessageUpdate.ts
CHANGED
|
@@ -75,7 +75,10 @@ export enum MessageToolUpdateType {
|
|
| 75 |
Result = "result",
|
| 76 |
/** Error while running tool */
|
| 77 |
Error = "error",
|
|
|
|
|
|
|
| 78 |
}
|
|
|
|
| 79 |
interface MessageToolBaseUpdate<TSubType extends MessageToolUpdateType> {
|
| 80 |
type: MessageUpdateType.Tool;
|
| 81 |
subtype: TSubType;
|
|
@@ -91,10 +94,16 @@ export interface MessageToolResultUpdate
|
|
| 91 |
export interface MessageToolErrorUpdate extends MessageToolBaseUpdate<MessageToolUpdateType.Error> {
|
| 92 |
message: string;
|
| 93 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 94 |
export type MessageToolUpdate =
|
| 95 |
| MessageToolCallUpdate
|
| 96 |
| MessageToolResultUpdate
|
| 97 |
-
| MessageToolErrorUpdate
|
|
|
|
| 98 |
|
| 99 |
// Everything else
|
| 100 |
export interface MessageTitleUpdate {
|
|
|
|
| 75 |
Result = "result",
|
| 76 |
/** Error while running tool */
|
| 77 |
Error = "error",
|
| 78 |
+
/** ETA update */
|
| 79 |
+
ETA = "eta",
|
| 80 |
}
|
| 81 |
+
|
| 82 |
interface MessageToolBaseUpdate<TSubType extends MessageToolUpdateType> {
|
| 83 |
type: MessageUpdateType.Tool;
|
| 84 |
subtype: TSubType;
|
|
|
|
| 94 |
export interface MessageToolErrorUpdate extends MessageToolBaseUpdate<MessageToolUpdateType.Error> {
|
| 95 |
message: string;
|
| 96 |
}
|
| 97 |
+
|
| 98 |
+
export interface MessageToolETAUpdate extends MessageToolBaseUpdate<MessageToolUpdateType.ETA> {
|
| 99 |
+
eta: number;
|
| 100 |
+
}
|
| 101 |
+
|
| 102 |
export type MessageToolUpdate =
|
| 103 |
| MessageToolCallUpdate
|
| 104 |
| MessageToolResultUpdate
|
| 105 |
+
| MessageToolErrorUpdate
|
| 106 |
+
| MessageToolETAUpdate;
|
| 107 |
|
| 108 |
// Everything else
|
| 109 |
export interface MessageTitleUpdate {
|
src/lib/types/Tool.ts
CHANGED
|
@@ -177,5 +177,6 @@ export interface ToolCall {
|
|
| 177 |
|
| 178 |
export type BackendCall = (
|
| 179 |
params: Record<string, string | number | boolean>,
|
| 180 |
-
context: BackendToolContext
|
|
|
|
| 181 |
) => AsyncGenerator<MessageUpdate, Omit<ToolResultSuccess, "status" | "call" | "type">, undefined>;
|
|
|
|
| 177 |
|
| 178 |
export type BackendCall = (
|
| 179 |
params: Record<string, string | number | boolean>,
|
| 180 |
+
context: BackendToolContext,
|
| 181 |
+
uuid: string
|
| 182 |
) => AsyncGenerator<MessageUpdate, Omit<ToolResultSuccess, "status" | "call" | "type">, undefined>;
|