Spaces:
				
			
			
	
			
			
		Paused
		
	
	
	
			
			
	
	
	
	
		
		
		Paused
		
	Implement Cloudflare Workers AI endpoint (#907) (#972)
Browse files* Implement Cloudflare Workers AI endpoint (#907)
* Renamed to Cloudflare Workers AI in docs
* Add note about sampling parameters
* clean up env example
- .env +3 -0
- README.md +32 -0
- src/lib/server/endpoints/cloudflare/endpointCloudflare.ts +134 -0
- src/lib/server/endpoints/endpoints.ts +5 -0
- src/lib/server/models.ts +2 -0
    	
        .env
    CHANGED
    
    | @@ -8,8 +8,11 @@ MONGODB_DIRECT_CONNECTION=false | |
| 8 | 
             
            COOKIE_NAME=hf-chat
         | 
| 9 | 
             
            HF_TOKEN=#hf_<token> from https://huggingface.co/settings/token
         | 
| 10 | 
             
            HF_API_ROOT=https://api-inference.huggingface.co/models
         | 
|  | |
| 11 | 
             
            OPENAI_API_KEY=#your openai api key here
         | 
| 12 | 
             
            ANTHROPIC_API_KEY=#your anthropic api key here
         | 
|  | |
|  | |
| 13 |  | 
| 14 | 
             
            HF_ACCESS_TOKEN=#LEGACY! Use HF_TOKEN instead
         | 
| 15 |  | 
|  | |
| 8 | 
             
            COOKIE_NAME=hf-chat
         | 
| 9 | 
             
            HF_TOKEN=#hf_<token> from https://huggingface.co/settings/token
         | 
| 10 | 
             
            HF_API_ROOT=https://api-inference.huggingface.co/models
         | 
| 11 | 
            +
             | 
| 12 | 
             
            OPENAI_API_KEY=#your openai api key here
         | 
| 13 | 
             
            ANTHROPIC_API_KEY=#your anthropic api key here
         | 
| 14 | 
            +
            CLOUDFLARE_ACCOUNT_ID=#your cloudflare account id here
         | 
| 15 | 
            +
            CLOUDFLARE_API_TOKEN=#your cloudflare api token here
         | 
| 16 |  | 
| 17 | 
             
            HF_ACCESS_TOKEN=#LEGACY! Use HF_TOKEN instead
         | 
| 18 |  | 
    	
        README.md
    CHANGED
    
    | @@ -528,6 +528,38 @@ You can also set `"service" : "lambda"` to use a lambda instance. | |
| 528 |  | 
| 529 | 
             
            You can get the `accessKey` and `secretKey` from your AWS user, under programmatic access.
         | 
| 530 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 531 | 
             
            ##### Google Vertex models
         | 
| 532 |  | 
| 533 | 
             
            Chat UI can connect to the google Vertex API endpoints ([List of supported models](https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models)).
         | 
|  | |
| 528 |  | 
| 529 | 
             
            You can get the `accessKey` and `secretKey` from your AWS user, under programmatic access.
         | 
| 530 |  | 
| 531 | 
            +
            #### Cloudflare Workers AI
         | 
| 532 | 
            +
             | 
| 533 | 
            +
            You can also use Cloudflare Workers AI to run your own models with serverless inference.
         | 
| 534 | 
            +
             | 
| 535 | 
            +
            You will need to have a Cloudflare account, then get your [account ID](https://developers.cloudflare.com/fundamentals/setup/find-account-and-zone-ids/) as well as your [API token](https://developers.cloudflare.com/workers-ai/get-started/rest-api/#1-get-an-api-token) for Workers AI.
         | 
| 536 | 
            +
             | 
| 537 | 
            +
            You can either specify them directly in your `.env.local` using the `CLOUDFLARE_ACCOUNT_ID` and `CLOUDFLARE_API_TOKEN` variables, or you can set them directly in the endpoint config.
         | 
| 538 | 
            +
             | 
| 539 | 
            +
            You can find the list of models available on Cloudflare [here](https://developers.cloudflare.com/workers-ai/models/#text-generation).
         | 
| 540 | 
            +
             | 
| 541 | 
            +
            ```env
         | 
| 542 | 
            +
              {
         | 
| 543 | 
            +
              "name" : "nousresearch/hermes-2-pro-mistral-7b",
         | 
| 544 | 
            +
              "tokenizer": "nousresearch/hermes-2-pro-mistral-7b",
         | 
| 545 | 
            +
              "parameters": {
         | 
| 546 | 
            +
                "stop": ["<|im_end|>"]
         | 
| 547 | 
            +
              },
         | 
| 548 | 
            +
              "endpoints" : [
         | 
| 549 | 
            +
                {
         | 
| 550 | 
            +
                  "type" : "cloudflare"
         | 
| 551 | 
            +
                  <!-- optionally specify these
         | 
| 552 | 
            +
                  "accountId": "your-account-id",
         | 
| 553 | 
            +
                  "authToken": "your-api-token"
         | 
| 554 | 
            +
                  -->
         | 
| 555 | 
            +
                }
         | 
| 556 | 
            +
              ]
         | 
| 557 | 
            +
            }
         | 
| 558 | 
            +
            ```
         | 
| 559 | 
            +
             | 
| 560 | 
            +
            > [!NOTE]  
         | 
| 561 | 
            +
            > Cloudlare Workers AI currently do not support custom sampling parameters like temperature, top_p, etc.
         | 
| 562 | 
            +
             | 
| 563 | 
             
            ##### Google Vertex models
         | 
| 564 |  | 
| 565 | 
             
            Chat UI can connect to the google Vertex API endpoints ([List of supported models](https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models)).
         | 
    	
        src/lib/server/endpoints/cloudflare/endpointCloudflare.ts
    ADDED
    
    | @@ -0,0 +1,134 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import { z } from "zod";
         | 
| 2 | 
            +
            import type { Endpoint } from "../endpoints";
         | 
| 3 | 
            +
            import type { TextGenerationStreamOutput } from "@huggingface/inference";
         | 
| 4 | 
            +
            import { CLOUDFLARE_ACCOUNT_ID, CLOUDFLARE_API_TOKEN } from "$env/static/private";
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            export const endpointCloudflareParametersSchema = z.object({
         | 
| 7 | 
            +
            	weight: z.number().int().positive().default(1),
         | 
| 8 | 
            +
            	model: z.any(),
         | 
| 9 | 
            +
            	type: z.literal("cloudflare"),
         | 
| 10 | 
            +
            	accountId: z.string().default(CLOUDFLARE_ACCOUNT_ID),
         | 
| 11 | 
            +
            	apiToken: z.string().default(CLOUDFLARE_API_TOKEN),
         | 
| 12 | 
            +
            });
         | 
| 13 | 
            +
             | 
| 14 | 
            +
            export async function endpointCloudflare(
         | 
| 15 | 
            +
            	input: z.input<typeof endpointCloudflareParametersSchema>
         | 
| 16 | 
            +
            ): Promise<Endpoint> {
         | 
| 17 | 
            +
            	const { accountId, apiToken, model } = endpointCloudflareParametersSchema.parse(input);
         | 
| 18 | 
            +
            	const apiURL = `https://api.cloudflare.com/client/v4/accounts/${accountId}/ai/run/@hf/${model.id}`;
         | 
| 19 | 
            +
             | 
| 20 | 
            +
            	return async ({ messages, preprompt }) => {
         | 
| 21 | 
            +
            		let messagesFormatted = messages.map((message) => ({
         | 
| 22 | 
            +
            			role: message.from,
         | 
| 23 | 
            +
            			content: message.content,
         | 
| 24 | 
            +
            		}));
         | 
| 25 | 
            +
             | 
| 26 | 
            +
            		if (messagesFormatted?.[0]?.role !== "system") {
         | 
| 27 | 
            +
            			messagesFormatted = [{ role: "system", content: preprompt ?? "" }, ...messagesFormatted];
         | 
| 28 | 
            +
            		}
         | 
| 29 | 
            +
             | 
| 30 | 
            +
            		const payload = JSON.stringify({
         | 
| 31 | 
            +
            			messages: messagesFormatted,
         | 
| 32 | 
            +
            			stream: true,
         | 
| 33 | 
            +
            		});
         | 
| 34 | 
            +
             | 
| 35 | 
            +
            		const res = await fetch(apiURL, {
         | 
| 36 | 
            +
            			method: "POST",
         | 
| 37 | 
            +
            			headers: {
         | 
| 38 | 
            +
            				Authorization: `Bearer ${apiToken}`,
         | 
| 39 | 
            +
            				"Content-Type": "application/json",
         | 
| 40 | 
            +
            			},
         | 
| 41 | 
            +
            			body: payload,
         | 
| 42 | 
            +
            		});
         | 
| 43 | 
            +
             | 
| 44 | 
            +
            		if (!res.ok) {
         | 
| 45 | 
            +
            			throw new Error(`Failed to generate text: ${await res.text()}`);
         | 
| 46 | 
            +
            		}
         | 
| 47 | 
            +
             | 
| 48 | 
            +
            		const encoder = new TextDecoderStream();
         | 
| 49 | 
            +
            		const reader = res.body?.pipeThrough(encoder).getReader();
         | 
| 50 | 
            +
             | 
| 51 | 
            +
            		return (async function* () {
         | 
| 52 | 
            +
            			let stop = false;
         | 
| 53 | 
            +
            			let generatedText = "";
         | 
| 54 | 
            +
            			let tokenId = 0;
         | 
| 55 | 
            +
            			let accumulatedData = ""; // Buffer to accumulate data chunks
         | 
| 56 | 
            +
             | 
| 57 | 
            +
            			while (!stop) {
         | 
| 58 | 
            +
            				const out = await reader?.read();
         | 
| 59 | 
            +
             | 
| 60 | 
            +
            				// If it's done, we cancel
         | 
| 61 | 
            +
            				if (out?.done) {
         | 
| 62 | 
            +
            					reader?.cancel();
         | 
| 63 | 
            +
            					return;
         | 
| 64 | 
            +
            				}
         | 
| 65 | 
            +
             | 
| 66 | 
            +
            				if (!out?.value) {
         | 
| 67 | 
            +
            					return;
         | 
| 68 | 
            +
            				}
         | 
| 69 | 
            +
             | 
| 70 | 
            +
            				// Accumulate the data chunk
         | 
| 71 | 
            +
            				accumulatedData += out.value;
         | 
| 72 | 
            +
             | 
| 73 | 
            +
            				// Process each complete JSON object in the accumulated data
         | 
| 74 | 
            +
            				while (accumulatedData.includes("\n")) {
         | 
| 75 | 
            +
            					// Assuming each JSON object ends with a newline
         | 
| 76 | 
            +
            					const endIndex = accumulatedData.indexOf("\n");
         | 
| 77 | 
            +
            					let jsonString = accumulatedData.substring(0, endIndex).trim();
         | 
| 78 | 
            +
             | 
| 79 | 
            +
            					// Remove the processed part from the buffer
         | 
| 80 | 
            +
            					accumulatedData = accumulatedData.substring(endIndex + 1);
         | 
| 81 | 
            +
             | 
| 82 | 
            +
            					if (jsonString.startsWith("data: ")) {
         | 
| 83 | 
            +
            						jsonString = jsonString.slice(6);
         | 
| 84 | 
            +
            						let data = null;
         | 
| 85 | 
            +
             | 
| 86 | 
            +
            						if (jsonString === "[DONE]") {
         | 
| 87 | 
            +
            							stop = true;
         | 
| 88 | 
            +
             | 
| 89 | 
            +
            							yield {
         | 
| 90 | 
            +
            								token: {
         | 
| 91 | 
            +
            									id: tokenId++,
         | 
| 92 | 
            +
            									text: "",
         | 
| 93 | 
            +
            									logprob: 0,
         | 
| 94 | 
            +
            									special: true,
         | 
| 95 | 
            +
            								},
         | 
| 96 | 
            +
            								generated_text: generatedText,
         | 
| 97 | 
            +
            								details: null,
         | 
| 98 | 
            +
            							} satisfies TextGenerationStreamOutput;
         | 
| 99 | 
            +
            							reader?.cancel();
         | 
| 100 | 
            +
             | 
| 101 | 
            +
            							continue;
         | 
| 102 | 
            +
            						}
         | 
| 103 | 
            +
             | 
| 104 | 
            +
            						try {
         | 
| 105 | 
            +
            							data = JSON.parse(jsonString);
         | 
| 106 | 
            +
            						} catch (e) {
         | 
| 107 | 
            +
            							console.error("Failed to parse JSON", e);
         | 
| 108 | 
            +
            							console.error("Problematic JSON string:", jsonString);
         | 
| 109 | 
            +
            							continue; // Skip this iteration and try the next chunk
         | 
| 110 | 
            +
            						}
         | 
| 111 | 
            +
             | 
| 112 | 
            +
            						// Handle the parsed data
         | 
| 113 | 
            +
            						if (data.response) {
         | 
| 114 | 
            +
            							generatedText += data.response ?? "";
         | 
| 115 | 
            +
            							const output: TextGenerationStreamOutput = {
         | 
| 116 | 
            +
            								token: {
         | 
| 117 | 
            +
            									id: tokenId++,
         | 
| 118 | 
            +
            									text: data.response ?? "",
         | 
| 119 | 
            +
            									logprob: 0,
         | 
| 120 | 
            +
            									special: false,
         | 
| 121 | 
            +
            								},
         | 
| 122 | 
            +
            								generated_text: null,
         | 
| 123 | 
            +
            								details: null,
         | 
| 124 | 
            +
            							};
         | 
| 125 | 
            +
            							yield output;
         | 
| 126 | 
            +
            						}
         | 
| 127 | 
            +
            					}
         | 
| 128 | 
            +
            				}
         | 
| 129 | 
            +
            			}
         | 
| 130 | 
            +
            		})();
         | 
| 131 | 
            +
            	};
         | 
| 132 | 
            +
            }
         | 
| 133 | 
            +
             | 
| 134 | 
            +
            export default endpointCloudflare;
         | 
    	
        src/lib/server/endpoints/endpoints.ts
    CHANGED
    
    | @@ -13,6 +13,9 @@ import { | |
| 13 | 
             
            	endpointAnthropicParametersSchema,
         | 
| 14 | 
             
            } from "./anthropic/endpointAnthropic";
         | 
| 15 | 
             
            import type { Model } from "$lib/types/Model";
         | 
|  | |
|  | |
|  | |
| 16 |  | 
| 17 | 
             
            // parameters passed when generating text
         | 
| 18 | 
             
            export interface EndpointParameters {
         | 
| @@ -42,6 +45,7 @@ export const endpoints = { | |
| 42 | 
             
            	llamacpp: endpointLlamacpp,
         | 
| 43 | 
             
            	ollama: endpointOllama,
         | 
| 44 | 
             
            	vertex: endpointVertex,
         | 
|  | |
| 45 | 
             
            };
         | 
| 46 |  | 
| 47 | 
             
            export const endpointSchema = z.discriminatedUnion("type", [
         | 
| @@ -52,5 +56,6 @@ export const endpointSchema = z.discriminatedUnion("type", [ | |
| 52 | 
             
            	endpointLlamacppParametersSchema,
         | 
| 53 | 
             
            	endpointOllamaParametersSchema,
         | 
| 54 | 
             
            	endpointVertexParametersSchema,
         | 
|  | |
| 55 | 
             
            ]);
         | 
| 56 | 
             
            export default endpoints;
         | 
|  | |
| 13 | 
             
            	endpointAnthropicParametersSchema,
         | 
| 14 | 
             
            } from "./anthropic/endpointAnthropic";
         | 
| 15 | 
             
            import type { Model } from "$lib/types/Model";
         | 
| 16 | 
            +
            import endpointCloudflare, {
         | 
| 17 | 
            +
            	endpointCloudflareParametersSchema,
         | 
| 18 | 
            +
            } from "./cloudflare/endpointCloudflare";
         | 
| 19 |  | 
| 20 | 
             
            // parameters passed when generating text
         | 
| 21 | 
             
            export interface EndpointParameters {
         | 
|  | |
| 45 | 
             
            	llamacpp: endpointLlamacpp,
         | 
| 46 | 
             
            	ollama: endpointOllama,
         | 
| 47 | 
             
            	vertex: endpointVertex,
         | 
| 48 | 
            +
            	cloudflare: endpointCloudflare,
         | 
| 49 | 
             
            };
         | 
| 50 |  | 
| 51 | 
             
            export const endpointSchema = z.discriminatedUnion("type", [
         | 
|  | |
| 56 | 
             
            	endpointLlamacppParametersSchema,
         | 
| 57 | 
             
            	endpointOllamaParametersSchema,
         | 
| 58 | 
             
            	endpointVertexParametersSchema,
         | 
| 59 | 
            +
            	endpointCloudflareParametersSchema,
         | 
| 60 | 
             
            ]);
         | 
| 61 | 
             
            export default endpoints;
         | 
    	
        src/lib/server/models.ts
    CHANGED
    
    | @@ -130,6 +130,8 @@ const addEndpoint = (m: Awaited<ReturnType<typeof processModel>>) => ({ | |
| 130 | 
             
            						return endpoints.ollama(args);
         | 
| 131 | 
             
            					case "vertex":
         | 
| 132 | 
             
            						return await endpoints.vertex(args);
         | 
|  | |
|  | |
| 133 | 
             
            					default:
         | 
| 134 | 
             
            						// for legacy reason
         | 
| 135 | 
             
            						return endpoints.tgi(args);
         | 
|  | |
| 130 | 
             
            						return endpoints.ollama(args);
         | 
| 131 | 
             
            					case "vertex":
         | 
| 132 | 
             
            						return await endpoints.vertex(args);
         | 
| 133 | 
            +
            					case "cloudflare":
         | 
| 134 | 
            +
            						return await endpoints.cloudflare(args);
         | 
| 135 | 
             
            					default:
         | 
| 136 | 
             
            						// for legacy reason
         | 
| 137 | 
             
            						return endpoints.tgi(args);
         | 
 
			

