Spaces:
Paused
Paused
Add support for Anthropic models via AWS Bedrock (#1413)
Browse files* Add support for Anthropic models via AWS Bedrock
* deps
* Fixed type errors
* Temporary fix for continue button showing up on Claude
* Fix continue button issue by setting the last message token's special to true
---------
Co-authored-by: Nathan Sarrazin <[email protected]>
- package-lock.json +0 -0
- package.json +1 -0
- src/lib/server/endpoints/aws/endpointBedrock.ts +150 -0
- src/lib/server/endpoints/endpoints.ts +3 -0
- src/lib/server/models.ts +2 -0
package-lock.json
CHANGED
|
The diff for this file is too large to render.
See raw diff
|
|
|
package.json
CHANGED
|
@@ -108,6 +108,7 @@
|
|
| 108 |
"zod": "^3.22.3"
|
| 109 |
},
|
| 110 |
"optionalDependencies": {
|
|
|
|
| 111 |
"@anthropic-ai/sdk": "^0.25.0",
|
| 112 |
"@anthropic-ai/vertex-sdk": "^0.4.1",
|
| 113 |
"@google-cloud/vertexai": "^1.1.0",
|
|
|
|
| 108 |
"zod": "^3.22.3"
|
| 109 |
},
|
| 110 |
"optionalDependencies": {
|
| 111 |
+
"@aws-sdk/client-bedrock-runtime": "^3.631.0",
|
| 112 |
"@anthropic-ai/sdk": "^0.25.0",
|
| 113 |
"@anthropic-ai/vertex-sdk": "^0.4.1",
|
| 114 |
"@google-cloud/vertexai": "^1.1.0",
|
src/lib/server/endpoints/aws/endpointBedrock.ts
ADDED
|
@@ -0,0 +1,150 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import { z } from "zod";
|
| 2 |
+
import type { Endpoint } from "../endpoints";
|
| 3 |
+
import type { TextGenerationStreamOutput } from "@huggingface/inference";
|
| 4 |
+
import {
|
| 5 |
+
BedrockRuntimeClient,
|
| 6 |
+
InvokeModelWithResponseStreamCommand,
|
| 7 |
+
} from "@aws-sdk/client-bedrock-runtime";
|
| 8 |
+
import { createImageProcessorOptionsValidator, makeImageProcessor } from "../images";
|
| 9 |
+
import type { EndpointMessage } from "../endpoints";
|
| 10 |
+
import type { MessageFile } from "$lib/types/Message";
|
| 11 |
+
|
| 12 |
+
export const endpointBedrockParametersSchema = z.object({
|
| 13 |
+
weight: z.number().int().positive().default(1),
|
| 14 |
+
type: z.literal("bedrock"),
|
| 15 |
+
region: z.string().default("us-east-1"),
|
| 16 |
+
model: z.any(),
|
| 17 |
+
anthropicVersion: z.string().default("bedrock-2023-05-31"),
|
| 18 |
+
multimodal: z
|
| 19 |
+
.object({
|
| 20 |
+
image: createImageProcessorOptionsValidator({
|
| 21 |
+
supportedMimeTypes: [
|
| 22 |
+
"image/png",
|
| 23 |
+
"image/jpeg",
|
| 24 |
+
"image/webp",
|
| 25 |
+
"image/avif",
|
| 26 |
+
"image/tiff",
|
| 27 |
+
"image/gif",
|
| 28 |
+
],
|
| 29 |
+
preferredMimeType: "image/webp",
|
| 30 |
+
maxSizeInMB: Infinity,
|
| 31 |
+
maxWidth: 4096,
|
| 32 |
+
maxHeight: 4096,
|
| 33 |
+
}),
|
| 34 |
+
})
|
| 35 |
+
.default({}),
|
| 36 |
+
});
|
| 37 |
+
|
| 38 |
+
export async function endpointBedrock(
|
| 39 |
+
input: z.input<typeof endpointBedrockParametersSchema>
|
| 40 |
+
): Promise<Endpoint> {
|
| 41 |
+
const { region, model, anthropicVersion, multimodal } =
|
| 42 |
+
endpointBedrockParametersSchema.parse(input);
|
| 43 |
+
const client = new BedrockRuntimeClient({
|
| 44 |
+
region,
|
| 45 |
+
});
|
| 46 |
+
const imageProcessor = makeImageProcessor(multimodal.image);
|
| 47 |
+
|
| 48 |
+
return async ({ messages, preprompt, generateSettings }) => {
|
| 49 |
+
let system = preprompt;
|
| 50 |
+
// Use the first message as the system prompt if it's of type "system"
|
| 51 |
+
if (messages?.[0]?.from === "system") {
|
| 52 |
+
system = messages[0].content;
|
| 53 |
+
messages = messages.slice(1); // Remove the first system message from the array
|
| 54 |
+
}
|
| 55 |
+
|
| 56 |
+
const formattedMessages = await prepareMessages(messages, imageProcessor);
|
| 57 |
+
|
| 58 |
+
let tokenId = 0;
|
| 59 |
+
const parameters = { ...model.parameters, ...generateSettings };
|
| 60 |
+
return (async function* () {
|
| 61 |
+
const command = new InvokeModelWithResponseStreamCommand({
|
| 62 |
+
body: Buffer.from(
|
| 63 |
+
JSON.stringify({
|
| 64 |
+
anthropic_version: anthropicVersion,
|
| 65 |
+
max_tokens: parameters.max_new_tokens ? parameters.max_new_tokens : 4096,
|
| 66 |
+
messages: formattedMessages,
|
| 67 |
+
system,
|
| 68 |
+
}),
|
| 69 |
+
"utf-8"
|
| 70 |
+
),
|
| 71 |
+
contentType: "application/json",
|
| 72 |
+
accept: "application/json",
|
| 73 |
+
modelId: model.id,
|
| 74 |
+
trace: "DISABLED",
|
| 75 |
+
});
|
| 76 |
+
|
| 77 |
+
const response = await client.send(command);
|
| 78 |
+
|
| 79 |
+
let text = "";
|
| 80 |
+
|
| 81 |
+
for await (const item of response.body ?? []) {
|
| 82 |
+
const chunk = JSON.parse(new TextDecoder().decode(item.chunk?.bytes));
|
| 83 |
+
const chunk_type = chunk.type;
|
| 84 |
+
|
| 85 |
+
if (chunk_type === "content_block_delta") {
|
| 86 |
+
text += chunk.delta.text;
|
| 87 |
+
yield {
|
| 88 |
+
token: {
|
| 89 |
+
id: tokenId++,
|
| 90 |
+
text: chunk.delta.text,
|
| 91 |
+
logprob: 0,
|
| 92 |
+
special: false,
|
| 93 |
+
},
|
| 94 |
+
generated_text: null,
|
| 95 |
+
details: null,
|
| 96 |
+
} satisfies TextGenerationStreamOutput;
|
| 97 |
+
} else if (chunk_type === "message_stop") {
|
| 98 |
+
yield {
|
| 99 |
+
token: {
|
| 100 |
+
id: tokenId++,
|
| 101 |
+
text: "",
|
| 102 |
+
logprob: 0,
|
| 103 |
+
special: true,
|
| 104 |
+
},
|
| 105 |
+
generated_text: text,
|
| 106 |
+
details: null,
|
| 107 |
+
} satisfies TextGenerationStreamOutput;
|
| 108 |
+
}
|
| 109 |
+
}
|
| 110 |
+
})();
|
| 111 |
+
};
|
| 112 |
+
}
|
| 113 |
+
|
| 114 |
+
// Prepare the messages excluding system prompts
|
| 115 |
+
async function prepareMessages(
|
| 116 |
+
messages: EndpointMessage[],
|
| 117 |
+
imageProcessor: ReturnType<typeof makeImageProcessor>
|
| 118 |
+
) {
|
| 119 |
+
const formattedMessages = [];
|
| 120 |
+
|
| 121 |
+
for (const message of messages) {
|
| 122 |
+
const content = [];
|
| 123 |
+
|
| 124 |
+
if (message.files?.length) {
|
| 125 |
+
content.push(...(await prepareFiles(imageProcessor, message.files)));
|
| 126 |
+
}
|
| 127 |
+
content.push({ type: "text", text: message.content });
|
| 128 |
+
|
| 129 |
+
const lastMessage = formattedMessages[formattedMessages.length - 1];
|
| 130 |
+
if (lastMessage && lastMessage.role === message.from) {
|
| 131 |
+
// If the last message has the same role, merge the content
|
| 132 |
+
lastMessage.content.push(...content);
|
| 133 |
+
} else {
|
| 134 |
+
formattedMessages.push({ role: message.from, content });
|
| 135 |
+
}
|
| 136 |
+
}
|
| 137 |
+
return formattedMessages;
|
| 138 |
+
}
|
| 139 |
+
|
| 140 |
+
// Process files and convert them to base64 encoded strings
|
| 141 |
+
async function prepareFiles(
|
| 142 |
+
imageProcessor: ReturnType<typeof makeImageProcessor>,
|
| 143 |
+
files: MessageFile[]
|
| 144 |
+
) {
|
| 145 |
+
const processedFiles = await Promise.all(files.map(imageProcessor));
|
| 146 |
+
return processedFiles.map((file) => ({
|
| 147 |
+
type: "image",
|
| 148 |
+
source: { type: "base64", media_type: "image/jpeg", data: file.image.toString("base64") },
|
| 149 |
+
}));
|
| 150 |
+
}
|
src/lib/server/endpoints/endpoints.ts
CHANGED
|
@@ -9,6 +9,7 @@ import endpointLlamacpp, { endpointLlamacppParametersSchema } from "./llamacpp/e
|
|
| 9 |
import endpointOllama, { endpointOllamaParametersSchema } from "./ollama/endpointOllama";
|
| 10 |
import endpointVertex, { endpointVertexParametersSchema } from "./google/endpointVertex";
|
| 11 |
import endpointGenAI, { endpointGenAIParametersSchema } from "./google/endpointGenAI";
|
|
|
|
| 12 |
|
| 13 |
import {
|
| 14 |
endpointAnthropic,
|
|
@@ -61,6 +62,7 @@ export const endpoints = {
|
|
| 61 |
tgi: endpointTgi,
|
| 62 |
anthropic: endpointAnthropic,
|
| 63 |
anthropicvertex: endpointAnthropicVertex,
|
|
|
|
| 64 |
aws: endpointAws,
|
| 65 |
openai: endpointOai,
|
| 66 |
llamacpp: endpointLlamacpp,
|
|
@@ -76,6 +78,7 @@ export const endpointSchema = z.discriminatedUnion("type", [
|
|
| 76 |
endpointAnthropicParametersSchema,
|
| 77 |
endpointAnthropicVertexParametersSchema,
|
| 78 |
endpointAwsParametersSchema,
|
|
|
|
| 79 |
endpointOAIParametersSchema,
|
| 80 |
endpointTgiParametersSchema,
|
| 81 |
endpointLlamacppParametersSchema,
|
|
|
|
| 9 |
import endpointOllama, { endpointOllamaParametersSchema } from "./ollama/endpointOllama";
|
| 10 |
import endpointVertex, { endpointVertexParametersSchema } from "./google/endpointVertex";
|
| 11 |
import endpointGenAI, { endpointGenAIParametersSchema } from "./google/endpointGenAI";
|
| 12 |
+
import { endpointBedrock, endpointBedrockParametersSchema } from "./aws/endpointBedrock";
|
| 13 |
|
| 14 |
import {
|
| 15 |
endpointAnthropic,
|
|
|
|
| 62 |
tgi: endpointTgi,
|
| 63 |
anthropic: endpointAnthropic,
|
| 64 |
anthropicvertex: endpointAnthropicVertex,
|
| 65 |
+
bedrock: endpointBedrock,
|
| 66 |
aws: endpointAws,
|
| 67 |
openai: endpointOai,
|
| 68 |
llamacpp: endpointLlamacpp,
|
|
|
|
| 78 |
endpointAnthropicParametersSchema,
|
| 79 |
endpointAnthropicVertexParametersSchema,
|
| 80 |
endpointAwsParametersSchema,
|
| 81 |
+
endpointBedrockParametersSchema,
|
| 82 |
endpointOAIParametersSchema,
|
| 83 |
endpointTgiParametersSchema,
|
| 84 |
endpointLlamacppParametersSchema,
|
src/lib/server/models.ts
CHANGED
|
@@ -280,6 +280,8 @@ const addEndpoint = (m: Awaited<ReturnType<typeof processModel>>) => ({
|
|
| 280 |
return endpoints.anthropic(args);
|
| 281 |
case "anthropic-vertex":
|
| 282 |
return endpoints.anthropicvertex(args);
|
|
|
|
|
|
|
| 283 |
case "aws":
|
| 284 |
return await endpoints.aws(args);
|
| 285 |
case "openai":
|
|
|
|
| 280 |
return endpoints.anthropic(args);
|
| 281 |
case "anthropic-vertex":
|
| 282 |
return endpoints.anthropicvertex(args);
|
| 283 |
+
case "bedrock":
|
| 284 |
+
return endpoints.bedrock(args);
|
| 285 |
case "aws":
|
| 286 |
return await endpoints.aws(args);
|
| 287 |
case "openai":
|