Spaces:
Running
Running
File size: 3,464 Bytes
b5ae065 ec2a4ed b5ae065 a8a9533 b5ae065 a8a9533 b5ae065 4ef3bdf d8a16f3 ec2a4ed b5ae065 51b0991 ec2a4ed 4ef3bdf b5ae065 347b211 4ef3bdf d8a16f3 b5ae065 4e43408 2a808d7 4e43408 ec2a4ed 4e43408 ec2a4ed b5ae065 4e43408 2a808d7 b5ae065 2a808d7 c78cb53 2a808d7 4e43408 ec2a4ed 4e43408 ec2a4ed b5ae065 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 |
import { z } from "zod";
import { openAICompletionToTextGenerationStream } from "./openAICompletionToTextGenerationStream";
import { openAIChatToTextGenerationStream } from "./openAIChatToTextGenerationStream";
import type { CompletionCreateParamsStreaming } from "openai/resources/completions";
import type { ChatCompletionCreateParamsStreaming } from "openai/resources/chat/completions";
import { buildPrompt } from "$lib/buildPrompt";
import { env } from "$env/dynamic/private";
import type { Endpoint } from "../endpoints";
export const endpointOAIParametersSchema = z.object({
weight: z.number().int().positive().default(1),
model: z.any(),
type: z.literal("openai"),
baseURL: z.string().url().default("https://api.openai.com/v1"),
apiKey: z.string().default(env.OPENAI_API_KEY ?? "sk-"),
completion: z
.union([z.literal("completions"), z.literal("chat_completions")])
.default("chat_completions"),
defaultHeaders: z.record(z.string()).optional(),
defaultQuery: z.record(z.string()).optional(),
extraBody: z.record(z.any()).optional(),
});
export async function endpointOai(
input: z.input<typeof endpointOAIParametersSchema>
): Promise<Endpoint> {
const { baseURL, apiKey, completion, model, defaultHeaders, defaultQuery, extraBody } =
endpointOAIParametersSchema.parse(input);
let OpenAI;
try {
OpenAI = (await import("openai")).OpenAI;
} catch (e) {
throw new Error("Failed to import OpenAI", { cause: e });
}
const openai = new OpenAI({
apiKey: apiKey ?? "sk-",
baseURL,
defaultHeaders,
defaultQuery,
});
if (completion === "completions") {
return async ({ messages, preprompt, continueMessage, generateSettings }) => {
const prompt = await buildPrompt({
messages,
continueMessage,
preprompt,
model,
});
const parameters = { ...model.parameters, ...generateSettings };
const body: CompletionCreateParamsStreaming = {
model: model.id ?? model.name,
prompt,
stream: true,
max_tokens: parameters?.max_new_tokens,
stop: parameters?.stop,
temperature: parameters?.temperature,
top_p: parameters?.top_p,
frequency_penalty: parameters?.repetition_penalty,
};
const openAICompletion = await openai.completions.create(body, {
body: { ...body, ...extraBody },
});
return openAICompletionToTextGenerationStream(openAICompletion);
};
} else if (completion === "chat_completions") {
return async ({ messages, preprompt, generateSettings }) => {
let messagesOpenAI = messages.map((message) => ({
role: message.from,
content: message.content,
}));
if (messagesOpenAI?.[0]?.role !== "system") {
messagesOpenAI = [{ role: "system", content: "" }, ...messagesOpenAI];
}
if (messagesOpenAI?.[0]) {
messagesOpenAI[0].content = preprompt ?? "";
}
const parameters = { ...model.parameters, ...generateSettings };
const body: ChatCompletionCreateParamsStreaming = {
model: model.id ?? model.name,
messages: messagesOpenAI,
stream: true,
max_tokens: parameters?.max_new_tokens,
stop: parameters?.stop,
temperature: parameters?.temperature,
top_p: parameters?.top_p,
frequency_penalty: parameters?.repetition_penalty,
};
const openChatAICompletion = await openai.chat.completions.create(body, {
body: { ...body, ...extraBody },
});
return openAIChatToTextGenerationStream(openChatAICompletion);
};
} else {
throw new Error("Invalid completion type");
}
}
|