Spaces:
Paused
Paused
Extend endpointOai.ts to allow usage of extra sampling parameters (#1032)
Browse files* Extend endpointOai.ts to allow usage of extra sampling parameters when calling vllm as an OpenAI compatible
* refactor : prettier endpointOai.ts
* Fix: Corrected type imports in endpointOai.ts
* Simplifies code a bit and adds `extraBody` to open ai endpooint
* Update zod schema to allow any type in extraBody
---------
Co-authored-by: Nathan Sarrazin <[email protected]>
src/lib/server/endpoints/openai/endpointOai.ts
CHANGED
|
@@ -1,6 +1,8 @@
|
|
| 1 |
import { z } from "zod";
|
| 2 |
import { openAICompletionToTextGenerationStream } from "./openAICompletionToTextGenerationStream";
|
| 3 |
import { openAIChatToTextGenerationStream } from "./openAIChatToTextGenerationStream";
|
|
|
|
|
|
|
| 4 |
import { buildPrompt } from "$lib/buildPrompt";
|
| 5 |
import { env } from "$env/dynamic/private";
|
| 6 |
import type { Endpoint } from "../endpoints";
|
|
@@ -16,12 +18,13 @@ export const endpointOAIParametersSchema = z.object({
|
|
| 16 |
.default("chat_completions"),
|
| 17 |
defaultHeaders: z.record(z.string()).optional(),
|
| 18 |
defaultQuery: z.record(z.string()).optional(),
|
|
|
|
| 19 |
});
|
| 20 |
|
| 21 |
export async function endpointOai(
|
| 22 |
input: z.input<typeof endpointOAIParametersSchema>
|
| 23 |
): Promise<Endpoint> {
|
| 24 |
-
const { baseURL, apiKey, completion, model, defaultHeaders, defaultQuery } =
|
| 25 |
endpointOAIParametersSchema.parse(input);
|
| 26 |
let OpenAI;
|
| 27 |
try {
|
|
@@ -47,19 +50,22 @@ export async function endpointOai(
|
|
| 47 |
});
|
| 48 |
|
| 49 |
const parameters = { ...model.parameters, ...generateSettings };
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 50 |
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
max_tokens: parameters?.max_new_tokens,
|
| 57 |
-
stop: parameters?.stop,
|
| 58 |
-
temperature: parameters?.temperature,
|
| 59 |
-
top_p: parameters?.top_p,
|
| 60 |
-
frequency_penalty: parameters?.repetition_penalty,
|
| 61 |
-
})
|
| 62 |
-
);
|
| 63 |
};
|
| 64 |
} else if (completion === "chat_completions") {
|
| 65 |
return async ({ messages, preprompt, generateSettings }) => {
|
|
@@ -77,19 +83,22 @@ export async function endpointOai(
|
|
| 77 |
}
|
| 78 |
|
| 79 |
const parameters = { ...model.parameters, ...generateSettings };
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 80 |
|
| 81 |
-
return openAIChatToTextGenerationStream(
|
| 82 |
-
await openai.chat.completions.create({
|
| 83 |
-
model: model.id ?? model.name,
|
| 84 |
-
messages: messagesOpenAI,
|
| 85 |
-
stream: true,
|
| 86 |
-
max_tokens: parameters?.max_new_tokens,
|
| 87 |
-
stop: parameters?.stop,
|
| 88 |
-
temperature: parameters?.temperature,
|
| 89 |
-
top_p: parameters?.top_p,
|
| 90 |
-
frequency_penalty: parameters?.repetition_penalty,
|
| 91 |
-
})
|
| 92 |
-
);
|
| 93 |
};
|
| 94 |
} else {
|
| 95 |
throw new Error("Invalid completion type");
|
|
|
|
| 1 |
import { z } from "zod";
|
| 2 |
import { openAICompletionToTextGenerationStream } from "./openAICompletionToTextGenerationStream";
|
| 3 |
import { openAIChatToTextGenerationStream } from "./openAIChatToTextGenerationStream";
|
| 4 |
+
import type { CompletionCreateParamsStreaming } from "openai/resources/completions";
|
| 5 |
+
import type { ChatCompletionCreateParamsStreaming } from "openai/resources/chat/completions";
|
| 6 |
import { buildPrompt } from "$lib/buildPrompt";
|
| 7 |
import { env } from "$env/dynamic/private";
|
| 8 |
import type { Endpoint } from "../endpoints";
|
|
|
|
| 18 |
.default("chat_completions"),
|
| 19 |
defaultHeaders: z.record(z.string()).optional(),
|
| 20 |
defaultQuery: z.record(z.string()).optional(),
|
| 21 |
+
extraBody: z.record(z.any()).optional(),
|
| 22 |
});
|
| 23 |
|
| 24 |
export async function endpointOai(
|
| 25 |
input: z.input<typeof endpointOAIParametersSchema>
|
| 26 |
): Promise<Endpoint> {
|
| 27 |
+
const { baseURL, apiKey, completion, model, defaultHeaders, defaultQuery, extraBody } =
|
| 28 |
endpointOAIParametersSchema.parse(input);
|
| 29 |
let OpenAI;
|
| 30 |
try {
|
|
|
|
| 50 |
});
|
| 51 |
|
| 52 |
const parameters = { ...model.parameters, ...generateSettings };
|
| 53 |
+
const body: CompletionCreateParamsStreaming = {
|
| 54 |
+
model: model.id ?? model.name,
|
| 55 |
+
prompt,
|
| 56 |
+
stream: true,
|
| 57 |
+
max_tokens: parameters?.max_new_tokens,
|
| 58 |
+
stop: parameters?.stop,
|
| 59 |
+
temperature: parameters?.temperature,
|
| 60 |
+
top_p: parameters?.top_p,
|
| 61 |
+
frequency_penalty: parameters?.repetition_penalty,
|
| 62 |
+
};
|
| 63 |
|
| 64 |
+
const openAICompletion = await openai.completions.create(body, {
|
| 65 |
+
body: { ...body, ...extraBody },
|
| 66 |
+
});
|
| 67 |
+
|
| 68 |
+
return openAICompletionToTextGenerationStream(openAICompletion);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 69 |
};
|
| 70 |
} else if (completion === "chat_completions") {
|
| 71 |
return async ({ messages, preprompt, generateSettings }) => {
|
|
|
|
| 83 |
}
|
| 84 |
|
| 85 |
const parameters = { ...model.parameters, ...generateSettings };
|
| 86 |
+
const body: ChatCompletionCreateParamsStreaming = {
|
| 87 |
+
model: model.id ?? model.name,
|
| 88 |
+
messages: messagesOpenAI,
|
| 89 |
+
stream: true,
|
| 90 |
+
max_tokens: parameters?.max_new_tokens,
|
| 91 |
+
stop: parameters?.stop,
|
| 92 |
+
temperature: parameters?.temperature,
|
| 93 |
+
top_p: parameters?.top_p,
|
| 94 |
+
frequency_penalty: parameters?.repetition_penalty,
|
| 95 |
+
};
|
| 96 |
+
|
| 97 |
+
const openChatAICompletion = await openai.chat.completions.create(body, {
|
| 98 |
+
body: { ...body, ...extraBody },
|
| 99 |
+
});
|
| 100 |
|
| 101 |
+
return openAIChatToTextGenerationStream(openChatAICompletion);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 102 |
};
|
| 103 |
} else {
|
| 104 |
throw new Error("Invalid completion type");
|