goupilew nsarrazin HF Staff commited on
Commit
61d56f7
·
unverified ·
1 Parent(s): ba87c5b

[VertexAI] Add support for tools parameter (#1065)

Browse files

* [VertexAI] Add support for tools parameter

* Simplify tools parameter parsing and add support for passing parameters in model

---------

Co-authored-by: Nathan Sarrazin <[email protected]>

README.md CHANGED
@@ -619,7 +619,12 @@ MODELS=`[
619
 
620
  // Optional
621
  "safetyThreshold": "BLOCK_MEDIUM_AND_ABOVE",
622
- "apiEndpoint": "", // alternative api endpoint url
 
 
 
 
 
623
  }]
624
  },
625
  ]`
 
619
 
620
  // Optional
621
  "safetyThreshold": "BLOCK_MEDIUM_AND_ABOVE",
622
+ "apiEndpoint": "", // alternative api endpoint url,
623
+ "tools": [{
624
+ "googleSearchRetrieval": {
625
+ "disableAttribution": true
626
+ }
627
+ }]
628
  }]
629
  },
630
  ]`
src/lib/server/endpoints/google/endpointVertex.ts CHANGED
@@ -26,10 +26,11 @@ export const endpointVertexParametersSchema = z.object({
26
  HarmBlockThreshold.BLOCK_ONLY_HIGH,
27
  ])
28
  .optional(),
 
29
  });
30
 
31
  export function endpointVertex(input: z.input<typeof endpointVertexParametersSchema>): Endpoint {
32
- const { project, location, model, apiEndpoint, safetyThreshold } =
33
  endpointVertexParametersSchema.parse(input);
34
 
35
  const vertex_ai = new VertexAI({
@@ -39,6 +40,8 @@ export function endpointVertex(input: z.input<typeof endpointVertexParametersSch
39
  });
40
 
41
  return async ({ messages, preprompt, generateSettings }) => {
 
 
42
  const generativeModel = vertex_ai.getGenerativeModel({
43
  model: model.id ?? model.name,
44
  safetySettings: safetyThreshold
@@ -66,10 +69,11 @@ export function endpointVertex(input: z.input<typeof endpointVertexParametersSch
66
  ]
67
  : undefined,
68
  generationConfig: {
69
- maxOutputTokens: generateSettings?.max_new_tokens ?? 4096,
70
- stopSequences: generateSettings?.stop,
71
- temperature: generateSettings?.temperature ?? 1,
72
  },
 
73
  });
74
 
75
  // Preprompt is the same as the first system message.
 
26
  HarmBlockThreshold.BLOCK_ONLY_HIGH,
27
  ])
28
  .optional(),
29
+ tools: z.array(z.any()),
30
  });
31
 
32
  export function endpointVertex(input: z.input<typeof endpointVertexParametersSchema>): Endpoint {
33
+ const { project, location, model, apiEndpoint, safetyThreshold, tools } =
34
  endpointVertexParametersSchema.parse(input);
35
 
36
  const vertex_ai = new VertexAI({
 
40
  });
41
 
42
  return async ({ messages, preprompt, generateSettings }) => {
43
+ const parameters = { ...model.parameters, ...generateSettings };
44
+
45
  const generativeModel = vertex_ai.getGenerativeModel({
46
  model: model.id ?? model.name,
47
  safetySettings: safetyThreshold
 
69
  ]
70
  : undefined,
71
  generationConfig: {
72
+ maxOutputTokens: parameters?.max_new_tokens ?? 4096,
73
+ stopSequences: parameters?.stop,
74
+ temperature: parameters?.temperature ?? 1,
75
  },
76
+ tools,
77
  });
78
 
79
  // Preprompt is the same as the first system message.