Spaces:
Paused
Paused
feat(endpoints): Add conv ID to headers passed to TGI (#1511)
Browse files
src/lib/server/endpoints/endpoints.ts
CHANGED
|
@@ -29,6 +29,7 @@ import endpointLangserve, {
|
|
| 29 |
} from "./langserve/endpointLangserve";
|
| 30 |
|
| 31 |
import type { Tool, ToolCall, ToolResult } from "$lib/types/Tool";
|
|
|
|
| 32 |
|
| 33 |
export type EndpointMessage = Omit<Message, "id">;
|
| 34 |
|
|
@@ -41,6 +42,7 @@ export interface EndpointParameters {
|
|
| 41 |
tools?: Tool[];
|
| 42 |
toolResults?: ToolResult[];
|
| 43 |
isMultimodal?: boolean;
|
|
|
|
| 44 |
}
|
| 45 |
|
| 46 |
interface CommonEndpoint {
|
|
|
|
| 29 |
} from "./langserve/endpointLangserve";
|
| 30 |
|
| 31 |
import type { Tool, ToolCall, ToolResult } from "$lib/types/Tool";
|
| 32 |
+
import type { ObjectId } from "mongodb";
|
| 33 |
|
| 34 |
export type EndpointMessage = Omit<Message, "id">;
|
| 35 |
|
|
|
|
| 42 |
tools?: Tool[];
|
| 43 |
toolResults?: ToolResult[];
|
| 44 |
isMultimodal?: boolean;
|
| 45 |
+
conversationId?: ObjectId;
|
| 46 |
}
|
| 47 |
|
| 48 |
interface CommonEndpoint {
|
src/lib/server/endpoints/openai/endpointOai.ts
CHANGED
|
@@ -149,7 +149,7 @@ export async function endpointOai(
|
|
| 149 |
"Tools are not supported for 'completions' mode, switch to 'chat_completions' instead"
|
| 150 |
);
|
| 151 |
}
|
| 152 |
-
return async ({ messages, preprompt, continueMessage, generateSettings }) => {
|
| 153 |
const prompt = await buildPrompt({
|
| 154 |
messages,
|
| 155 |
continueMessage,
|
|
@@ -171,12 +171,22 @@ export async function endpointOai(
|
|
| 171 |
|
| 172 |
const openAICompletion = await openai.completions.create(body, {
|
| 173 |
body: { ...body, ...extraBody },
|
|
|
|
|
|
|
|
|
|
| 174 |
});
|
| 175 |
|
| 176 |
return openAICompletionToTextGenerationStream(openAICompletion);
|
| 177 |
};
|
| 178 |
} else if (completion === "chat_completions") {
|
| 179 |
-
return async ({
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 180 |
let messagesOpenAI: OpenAI.Chat.Completions.ChatCompletionMessageParam[] =
|
| 181 |
await prepareMessages(messages, imageProcessor, !model.tools && model.multimodal);
|
| 182 |
|
|
@@ -240,6 +250,9 @@ export async function endpointOai(
|
|
| 240 |
|
| 241 |
const openChatAICompletion = await openai.chat.completions.create(body, {
|
| 242 |
body: { ...body, ...extraBody },
|
|
|
|
|
|
|
|
|
|
| 243 |
});
|
| 244 |
|
| 245 |
return openAIChatToTextGenerationStream(openChatAICompletion);
|
|
|
|
| 149 |
"Tools are not supported for 'completions' mode, switch to 'chat_completions' instead"
|
| 150 |
);
|
| 151 |
}
|
| 152 |
+
return async ({ messages, preprompt, continueMessage, generateSettings, conversationId }) => {
|
| 153 |
const prompt = await buildPrompt({
|
| 154 |
messages,
|
| 155 |
continueMessage,
|
|
|
|
| 171 |
|
| 172 |
const openAICompletion = await openai.completions.create(body, {
|
| 173 |
body: { ...body, ...extraBody },
|
| 174 |
+
headers: {
|
| 175 |
+
"ChatUI-Conversation-ID": conversationId?.toString() ?? "",
|
| 176 |
+
},
|
| 177 |
});
|
| 178 |
|
| 179 |
return openAICompletionToTextGenerationStream(openAICompletion);
|
| 180 |
};
|
| 181 |
} else if (completion === "chat_completions") {
|
| 182 |
+
return async ({
|
| 183 |
+
messages,
|
| 184 |
+
preprompt,
|
| 185 |
+
generateSettings,
|
| 186 |
+
tools,
|
| 187 |
+
toolResults,
|
| 188 |
+
conversationId,
|
| 189 |
+
}) => {
|
| 190 |
let messagesOpenAI: OpenAI.Chat.Completions.ChatCompletionMessageParam[] =
|
| 191 |
await prepareMessages(messages, imageProcessor, !model.tools && model.multimodal);
|
| 192 |
|
|
|
|
| 250 |
|
| 251 |
const openChatAICompletion = await openai.chat.completions.create(body, {
|
| 252 |
body: { ...body, ...extraBody },
|
| 253 |
+
headers: {
|
| 254 |
+
"ChatUI-Conversation-ID": conversationId?.toString() ?? "",
|
| 255 |
+
},
|
| 256 |
});
|
| 257 |
|
| 258 |
return openAIChatToTextGenerationStream(openChatAICompletion);
|
src/lib/server/endpoints/tgi/endpointTgi.ts
CHANGED
|
@@ -43,6 +43,7 @@ export function endpointTgi(input: z.input<typeof endpointTgiParametersSchema>):
|
|
| 43 |
tools,
|
| 44 |
toolResults,
|
| 45 |
isMultimodal,
|
|
|
|
| 46 |
}) => {
|
| 47 |
const messagesWithResizedFiles = await Promise.all(
|
| 48 |
messages.map((message) => prepareMessage(Boolean(isMultimodal), message, imageProcessor))
|
|
@@ -72,6 +73,7 @@ export function endpointTgi(input: z.input<typeof endpointTgiParametersSchema>):
|
|
| 72 |
info.headers = {
|
| 73 |
...info.headers,
|
| 74 |
Authorization: authorization,
|
|
|
|
| 75 |
};
|
| 76 |
}
|
| 77 |
return fetch(endpointUrl, info);
|
|
|
|
| 43 |
tools,
|
| 44 |
toolResults,
|
| 45 |
isMultimodal,
|
| 46 |
+
conversationId,
|
| 47 |
}) => {
|
| 48 |
const messagesWithResizedFiles = await Promise.all(
|
| 49 |
messages.map((message) => prepareMessage(Boolean(isMultimodal), message, imageProcessor))
|
|
|
|
| 73 |
info.headers = {
|
| 74 |
...info.headers,
|
| 75 |
Authorization: authorization,
|
| 76 |
+
"ChatUI-Conversation-ID": conversationId?.toString() ?? "",
|
| 77 |
};
|
| 78 |
}
|
| 79 |
return fetch(endpointUrl, info);
|
src/lib/server/textGeneration/generate.ts
CHANGED
|
@@ -18,6 +18,7 @@ export async function* generate(
|
|
| 18 |
generateSettings: assistant?.generateSettings,
|
| 19 |
toolResults,
|
| 20 |
isMultimodal: model.multimodal,
|
|
|
|
| 21 |
})) {
|
| 22 |
// text generation completed
|
| 23 |
if (output.generated_text) {
|
|
|
|
| 18 |
generateSettings: assistant?.generateSettings,
|
| 19 |
toolResults,
|
| 20 |
isMultimodal: model.multimodal,
|
| 21 |
+
conversationId: conv._id,
|
| 22 |
})) {
|
| 23 |
// text generation completed
|
| 24 |
if (output.generated_text) {
|
src/lib/server/textGeneration/tools.ts
CHANGED
|
@@ -196,6 +196,7 @@ export async function* runTools(
|
|
| 196 |
type: input.type === "file" ? "str" : input.type,
|
| 197 |
})),
|
| 198 |
})),
|
|
|
|
| 199 |
})) {
|
| 200 |
// model natively supports tool calls
|
| 201 |
if (output.token.toolCalls) {
|
|
|
|
| 196 |
type: input.type === "file" ? "str" : input.type,
|
| 197 |
})),
|
| 198 |
})),
|
| 199 |
+
conversationId: conv._id,
|
| 200 |
})) {
|
| 201 |
// model natively supports tool calls
|
| 202 |
if (output.token.toolCalls) {
|