Spaces:
				
			
			
	
			
			
		Paused
		
	
	
	
			
			
	
	
	
	
		
		
		Paused
		
	
		Antonio Ramos
		
		antoniora
		
	commited on
		
		
					Add langserve endpoint (#1009)
Browse files* Add support for langserve endpoints
* Add support for langserve endpoints
* Fix linting
* Fix linting issues
* Fix issue import
---------
Co-authored-by: antoniora <[email protected]>
    	
        README.md
    CHANGED
    
    | 
         @@ -618,6 +618,24 @@ MODELS=`[ 
     | 
|
| 618 | 
         | 
| 619 | 
         
             
            ```
         
     | 
| 620 | 
         | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 621 | 
         
             
            ### Custom endpoint authorization
         
     | 
| 622 | 
         | 
| 623 | 
         
             
            #### Basic and Bearer
         
     | 
| 
         | 
|
| 618 | 
         | 
| 619 | 
         
             
            ```
         
     | 
| 620 | 
         | 
| 621 | 
         
            +
            ##### LangServe
         
     | 
| 622 | 
         
            +
             
     | 
| 623 | 
         
            +
            LangChain applications that are deployed using LangServe can be called with the following config:
         
     | 
| 624 | 
         
            +
             
     | 
| 625 | 
         
            +
            ```
         
     | 
| 626 | 
         
            +
            MODELS=`[
         
     | 
| 627 | 
         
            +
            //...
         
     | 
| 628 | 
         
            +
                {
         
     | 
| 629 | 
         
            +
                   "name": "summarization-chain", //model-name
         
     | 
| 630 | 
         
            +
                   "endpoints" : [{
         
     | 
| 631 | 
         
            +
                     "type": "langserve",
         
     | 
| 632 | 
         
            +
                     "url" : "http://127.0.0.1:8100",
         
     | 
| 633 | 
         
            +
                   }]
         
     | 
| 634 | 
         
            +
                 },
         
     | 
| 635 | 
         
            +
            ]`
         
     | 
| 636 | 
         
            +
             
     | 
| 637 | 
         
            +
            ```
         
     | 
| 638 | 
         
            +
             
     | 
| 639 | 
         
             
            ### Custom endpoint authorization
         
     | 
| 640 | 
         | 
| 641 | 
         
             
            #### Basic and Bearer
         
     | 
    	
        src/lib/server/endpoints/endpoints.ts
    CHANGED
    
    | 
         @@ -17,6 +17,9 @@ import endpointCloudflare, { 
     | 
|
| 17 | 
         
             
            	endpointCloudflareParametersSchema,
         
     | 
| 18 | 
         
             
            } from "./cloudflare/endpointCloudflare";
         
     | 
| 19 | 
         
             
            import { endpointCohere, endpointCohereParametersSchema } from "./cohere/endpointCohere";
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 20 | 
         | 
| 21 | 
         
             
            // parameters passed when generating text
         
     | 
| 22 | 
         
             
            export interface EndpointParameters {
         
     | 
| 
         @@ -48,6 +51,7 @@ export const endpoints = { 
     | 
|
| 48 | 
         
             
            	vertex: endpointVertex,
         
     | 
| 49 | 
         
             
            	cloudflare: endpointCloudflare,
         
     | 
| 50 | 
         
             
            	cohere: endpointCohere,
         
     | 
| 
         | 
|
| 51 | 
         
             
            };
         
     | 
| 52 | 
         | 
| 53 | 
         
             
            export const endpointSchema = z.discriminatedUnion("type", [
         
     | 
| 
         @@ -60,5 +64,6 @@ export const endpointSchema = z.discriminatedUnion("type", [ 
     | 
|
| 60 | 
         
             
            	endpointVertexParametersSchema,
         
     | 
| 61 | 
         
             
            	endpointCloudflareParametersSchema,
         
     | 
| 62 | 
         
             
            	endpointCohereParametersSchema,
         
     | 
| 
         | 
|
| 63 | 
         
             
            ]);
         
     | 
| 64 | 
         
             
            export default endpoints;
         
     | 
| 
         | 
|
| 17 | 
         
             
            	endpointCloudflareParametersSchema,
         
     | 
| 18 | 
         
             
            } from "./cloudflare/endpointCloudflare";
         
     | 
| 19 | 
         
             
            import { endpointCohere, endpointCohereParametersSchema } from "./cohere/endpointCohere";
         
     | 
| 20 | 
         
            +
            import endpointLangserve, {
         
     | 
| 21 | 
         
            +
            	endpointLangserveParametersSchema,
         
     | 
| 22 | 
         
            +
            } from "./langserve/endpointLangserve";
         
     | 
| 23 | 
         | 
| 24 | 
         
             
            // parameters passed when generating text
         
     | 
| 25 | 
         
             
            export interface EndpointParameters {
         
     | 
| 
         | 
|
| 51 | 
         
             
            	vertex: endpointVertex,
         
     | 
| 52 | 
         
             
            	cloudflare: endpointCloudflare,
         
     | 
| 53 | 
         
             
            	cohere: endpointCohere,
         
     | 
| 54 | 
         
            +
            	langserve: endpointLangserve,
         
     | 
| 55 | 
         
             
            };
         
     | 
| 56 | 
         | 
| 57 | 
         
             
            export const endpointSchema = z.discriminatedUnion("type", [
         
     | 
| 
         | 
|
| 64 | 
         
             
            	endpointVertexParametersSchema,
         
     | 
| 65 | 
         
             
            	endpointCloudflareParametersSchema,
         
     | 
| 66 | 
         
             
            	endpointCohereParametersSchema,
         
     | 
| 67 | 
         
            +
            	endpointLangserveParametersSchema,
         
     | 
| 68 | 
         
             
            ]);
         
     | 
| 69 | 
         
             
            export default endpoints;
         
     | 
    	
        src/lib/server/endpoints/langserve/endpointLangserve.ts
    ADDED
    
    | 
         @@ -0,0 +1,128 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import { buildPrompt } from "$lib/buildPrompt";
         
     | 
| 2 | 
         
            +
            import { z } from "zod";
         
     | 
| 3 | 
         
            +
            import type { Endpoint } from "../endpoints";
         
     | 
| 4 | 
         
            +
            import type { TextGenerationStreamOutput } from "@huggingface/inference";
         
     | 
| 5 | 
         
            +
             
     | 
| 6 | 
         
            +
            export const endpointLangserveParametersSchema = z.object({
         
     | 
| 7 | 
         
            +
            	weight: z.number().int().positive().default(1),
         
     | 
| 8 | 
         
            +
            	model: z.any(),
         
     | 
| 9 | 
         
            +
            	type: z.literal("langserve"),
         
     | 
| 10 | 
         
            +
            	url: z.string().url(),
         
     | 
| 11 | 
         
            +
            });
         
     | 
| 12 | 
         
            +
             
     | 
| 13 | 
         
            +
            export function endpointLangserve(
         
     | 
| 14 | 
         
            +
            	input: z.input<typeof endpointLangserveParametersSchema>
         
     | 
| 15 | 
         
            +
            ): Endpoint {
         
     | 
| 16 | 
         
            +
            	const { url, model } = endpointLangserveParametersSchema.parse(input);
         
     | 
| 17 | 
         
            +
             
     | 
| 18 | 
         
            +
            	return async ({ messages, preprompt, continueMessage }) => {
         
     | 
| 19 | 
         
            +
            		const prompt = await buildPrompt({
         
     | 
| 20 | 
         
            +
            			messages,
         
     | 
| 21 | 
         
            +
            			continueMessage,
         
     | 
| 22 | 
         
            +
            			preprompt,
         
     | 
| 23 | 
         
            +
            			model,
         
     | 
| 24 | 
         
            +
            		});
         
     | 
| 25 | 
         
            +
             
     | 
| 26 | 
         
            +
            		const r = await fetch(`${url}/stream`, {
         
     | 
| 27 | 
         
            +
            			method: "POST",
         
     | 
| 28 | 
         
            +
            			headers: {
         
     | 
| 29 | 
         
            +
            				"Content-Type": "application/json",
         
     | 
| 30 | 
         
            +
            			},
         
     | 
| 31 | 
         
            +
            			body: JSON.stringify({
         
     | 
| 32 | 
         
            +
            				input: { text: prompt },
         
     | 
| 33 | 
         
            +
            			}),
         
     | 
| 34 | 
         
            +
            		});
         
     | 
| 35 | 
         
            +
             
     | 
| 36 | 
         
            +
            		if (!r.ok) {
         
     | 
| 37 | 
         
            +
            			throw new Error(`Failed to generate text: ${await r.text()}`);
         
     | 
| 38 | 
         
            +
            		}
         
     | 
| 39 | 
         
            +
             
     | 
| 40 | 
         
            +
            		const encoder = new TextDecoderStream();
         
     | 
| 41 | 
         
            +
            		const reader = r.body?.pipeThrough(encoder).getReader();
         
     | 
| 42 | 
         
            +
             
     | 
| 43 | 
         
            +
            		return (async function* () {
         
     | 
| 44 | 
         
            +
            			let stop = false;
         
     | 
| 45 | 
         
            +
            			let generatedText = "";
         
     | 
| 46 | 
         
            +
            			let tokenId = 0;
         
     | 
| 47 | 
         
            +
            			let accumulatedData = ""; // Buffer to accumulate data chunks
         
     | 
| 48 | 
         
            +
             
     | 
| 49 | 
         
            +
            			while (!stop) {
         
     | 
| 50 | 
         
            +
            				// Read the stream and log the outputs to console
         
     | 
| 51 | 
         
            +
            				const out = (await reader?.read()) ?? { done: false, value: undefined };
         
     | 
| 52 | 
         
            +
             
     | 
| 53 | 
         
            +
            				// If it's done, we cancel
         
     | 
| 54 | 
         
            +
            				if (out.done) {
         
     | 
| 55 | 
         
            +
            					reader?.cancel();
         
     | 
| 56 | 
         
            +
            					return;
         
     | 
| 57 | 
         
            +
            				}
         
     | 
| 58 | 
         
            +
             
     | 
| 59 | 
         
            +
            				if (!out.value) {
         
     | 
| 60 | 
         
            +
            					return;
         
     | 
| 61 | 
         
            +
            				}
         
     | 
| 62 | 
         
            +
             
     | 
| 63 | 
         
            +
            				// Accumulate the data chunk
         
     | 
| 64 | 
         
            +
            				accumulatedData += out.value;
         
     | 
| 65 | 
         
            +
            				// Keep read data to check event type
         
     | 
| 66 | 
         
            +
            				const eventData = out.value;
         
     | 
| 67 | 
         
            +
             
     | 
| 68 | 
         
            +
            				// Process each complete JSON object in the accumulated data
         
     | 
| 69 | 
         
            +
            				while (accumulatedData.includes("\n")) {
         
     | 
| 70 | 
         
            +
            					// Assuming each JSON object ends with a newline
         
     | 
| 71 | 
         
            +
            					const endIndex = accumulatedData.indexOf("\n");
         
     | 
| 72 | 
         
            +
            					let jsonString = accumulatedData.substring(0, endIndex).trim();
         
     | 
| 73 | 
         
            +
            					// Remove the processed part from the buffer
         
     | 
| 74 | 
         
            +
             
     | 
| 75 | 
         
            +
            					accumulatedData = accumulatedData.substring(endIndex + 1);
         
     | 
| 76 | 
         
            +
             
     | 
| 77 | 
         
            +
            					// Stopping with end event
         
     | 
| 78 | 
         
            +
            					if (eventData.startsWith("event: end")) {
         
     | 
| 79 | 
         
            +
            						stop = true;
         
     | 
| 80 | 
         
            +
            						yield {
         
     | 
| 81 | 
         
            +
            							token: {
         
     | 
| 82 | 
         
            +
            								id: tokenId++,
         
     | 
| 83 | 
         
            +
            								text: "",
         
     | 
| 84 | 
         
            +
            								logprob: 0,
         
     | 
| 85 | 
         
            +
            								special: true,
         
     | 
| 86 | 
         
            +
            							},
         
     | 
| 87 | 
         
            +
            							generated_text: generatedText,
         
     | 
| 88 | 
         
            +
            							details: null,
         
     | 
| 89 | 
         
            +
            						} satisfies TextGenerationStreamOutput;
         
     | 
| 90 | 
         
            +
            						reader?.cancel();
         
     | 
| 91 | 
         
            +
            						continue;
         
     | 
| 92 | 
         
            +
            					}
         
     | 
| 93 | 
         
            +
             
     | 
| 94 | 
         
            +
            					if (eventData.startsWith("event: data") && jsonString.startsWith("data: ")) {
         
     | 
| 95 | 
         
            +
            						jsonString = jsonString.slice(6);
         
     | 
| 96 | 
         
            +
            						let data = null;
         
     | 
| 97 | 
         
            +
             
     | 
| 98 | 
         
            +
            						// Handle the parsed data
         
     | 
| 99 | 
         
            +
            						try {
         
     | 
| 100 | 
         
            +
            							data = JSON.parse(jsonString);
         
     | 
| 101 | 
         
            +
            						} catch (e) {
         
     | 
| 102 | 
         
            +
            							console.error("Failed to parse JSON", e);
         
     | 
| 103 | 
         
            +
            							console.error("Problematic JSON string:", jsonString);
         
     | 
| 104 | 
         
            +
            							continue; // Skip this iteration and try the next chunk
         
     | 
| 105 | 
         
            +
            						}
         
     | 
| 106 | 
         
            +
            						// Assuming content within data is a plain string
         
     | 
| 107 | 
         
            +
            						if (data) {
         
     | 
| 108 | 
         
            +
            							generatedText += data;
         
     | 
| 109 | 
         
            +
            							const output: TextGenerationStreamOutput = {
         
     | 
| 110 | 
         
            +
            								token: {
         
     | 
| 111 | 
         
            +
            									id: tokenId++,
         
     | 
| 112 | 
         
            +
            									text: data,
         
     | 
| 113 | 
         
            +
            									logprob: 0,
         
     | 
| 114 | 
         
            +
            									special: false,
         
     | 
| 115 | 
         
            +
            								},
         
     | 
| 116 | 
         
            +
            								generated_text: null,
         
     | 
| 117 | 
         
            +
            								details: null,
         
     | 
| 118 | 
         
            +
            							};
         
     | 
| 119 | 
         
            +
            							yield output;
         
     | 
| 120 | 
         
            +
            						}
         
     | 
| 121 | 
         
            +
            					}
         
     | 
| 122 | 
         
            +
            				}
         
     | 
| 123 | 
         
            +
            			}
         
     | 
| 124 | 
         
            +
            		})();
         
     | 
| 125 | 
         
            +
            	};
         
     | 
| 126 | 
         
            +
            }
         
     | 
| 127 | 
         
            +
             
     | 
| 128 | 
         
            +
            export default endpointLangserve;
         
     | 
    	
        src/lib/server/models.ts
    CHANGED
    
    | 
         @@ -177,6 +177,8 @@ const addEndpoint = (m: Awaited<ReturnType<typeof processModel>>) => ({ 
     | 
|
| 177 | 
         
             
            						return await endpoints.cloudflare(args);
         
     | 
| 178 | 
         
             
            					case "cohere":
         
     | 
| 179 | 
         
             
            						return await endpoints.cohere(args);
         
     | 
| 
         | 
|
| 
         | 
|
| 180 | 
         
             
            					default:
         
     | 
| 181 | 
         
             
            						// for legacy reason
         
     | 
| 182 | 
         
             
            						return endpoints.tgi(args);
         
     | 
| 
         | 
|
| 177 | 
         
             
            						return await endpoints.cloudflare(args);
         
     | 
| 178 | 
         
             
            					case "cohere":
         
     | 
| 179 | 
         
             
            						return await endpoints.cohere(args);
         
     | 
| 180 | 
         
            +
            					case "langserve":
         
     | 
| 181 | 
         
            +
            						return await endpoints.langserve(args);
         
     | 
| 182 | 
         
             
            					default:
         
     | 
| 183 | 
         
             
            						// for legacy reason
         
     | 
| 184 | 
         
             
            						return endpoints.tgi(args);
         
     |