nsarrazin HF Staff commited on
Commit
b6d5d03
·
unverified ·
1 Parent(s): 2353ab2

Add SD3 as a tool (#1279)

Browse files

* Add SD3 as a tool

* add comments

src/lib/server/tools/images/generation.ts CHANGED
@@ -3,14 +3,8 @@ import { uploadFile } from "../../files/uploadFile";
3
  import { MessageUpdateType } from "$lib/types/MessageUpdate";
4
  import { callSpace, getIpToken, type GradioImage } from "../utils";
5
 
6
- type ImageGenerationInput = [
7
- number /* number (numeric value between 1 and 8) in 'Number of Images' Slider component */,
8
- number /* number in 'Image Height' Number component */,
9
- number /* number in 'Image Width' Number component */,
10
- string /* prompt */,
11
- number /* seed random */
12
- ];
13
- type ImageGenerationOutput = [{ image: GradioImage }[]];
14
 
15
  const imageGeneration: BackendTool = {
16
  name: "image_generation",
@@ -24,11 +18,12 @@ const imageGeneration: BackendTool = {
24
  type: "string",
25
  required: true,
26
  },
27
- numberOfImages: {
28
- description: "Number of images to generate, between 1 and 8.",
29
- type: "number",
 
30
  required: false,
31
- default: 1,
32
  },
33
  width: {
34
  description: "Width of the generated image.",
@@ -43,41 +38,38 @@ const imageGeneration: BackendTool = {
43
  default: 1024,
44
  },
45
  },
46
- async *call({ prompt, numberOfImages, width, height }, { conv, ip, username }) {
47
  const ipToken = await getIpToken(ip, username);
48
 
49
  const outputs = await callSpace<ImageGenerationInput, ImageGenerationOutput>(
50
- "ByteDance/Hyper-SDXL-1Step-T2I",
51
- "/process_image",
52
  [
53
- Number(numberOfImages), // number (numeric value between 1 and 8) in 'Number of Images' Slider component
54
- Number(height), // number in 'Image Height' Number component
55
- Number(width), // number in 'Image Width' Number component
56
  String(prompt), // prompt
 
57
  Math.floor(Math.random() * 1000), // seed random
 
 
 
 
 
58
  ],
59
  ipToken
60
  );
61
- const imageBlobs = await Promise.all(
62
- outputs[0].map((output) =>
63
- fetch(output.image.url)
64
- .then((res) => res.blob())
65
- .then(
66
- (blob) =>
67
- new File([blob], `${prompt}.${blob.type.split("/")[1] ?? "png"}`, { type: blob.type })
68
- )
69
- .then((file) => uploadFile(file, conv))
70
  )
71
- );
72
 
73
- for (const image of imageBlobs) {
74
- yield {
75
- type: MessageUpdateType.File,
76
- name: image.name,
77
- sha: image.value,
78
- mime: image.mime,
79
- };
80
- }
81
 
82
  return {
83
  outputs: [
 
3
  import { MessageUpdateType } from "$lib/types/MessageUpdate";
4
  import { callSpace, getIpToken, type GradioImage } from "../utils";
5
 
6
+ type ImageGenerationInput = [string, string, number, boolean, number, number, number, number];
7
+ type ImageGenerationOutput = [GradioImage, unknown];
 
 
 
 
 
 
8
 
9
  const imageGeneration: BackendTool = {
10
  name: "image_generation",
 
18
  type: "string",
19
  required: true,
20
  },
21
+ negativePrompt: {
22
+ description:
23
+ "A prompt for things that should not be in the image. Simple terms, separate terms with a comma.",
24
+ type: "string",
25
  required: false,
26
+ default: "",
27
  },
28
  width: {
29
  description: "Width of the generated image.",
 
38
  default: 1024,
39
  },
40
  },
41
+ async *call({ prompt, negativePrompt, width, height }, { conv, ip, username }) {
42
  const ipToken = await getIpToken(ip, username);
43
 
44
  const outputs = await callSpace<ImageGenerationInput, ImageGenerationOutput>(
45
+ "stabilityai/stable-diffusion-3-medium",
46
+ "/infer",
47
  [
 
 
 
48
  String(prompt), // prompt
49
+ String(negativePrompt), // negative prompt
50
  Math.floor(Math.random() * 1000), // seed random
51
+ true, // randomize seed
52
+ Number(width), // number in 'Image Width' Number component
53
+ Number(height), // number in 'Image Height' Number component
54
+ 5, // guidance scale
55
+ 28, // steps
56
  ],
57
  ipToken
58
  );
59
+ const image = await fetch(outputs[0].url)
60
+ .then((res) => res.blob())
61
+ .then(
62
+ (blob) =>
63
+ new File([blob], `${prompt}.${blob.type.split("/")[1] ?? "png"}`, { type: blob.type })
 
 
 
 
64
  )
65
+ .then((file) => uploadFile(file, conv));
66
 
67
+ yield {
68
+ type: MessageUpdateType.File,
69
+ name: image.name,
70
+ sha: image.value,
71
+ mime: image.mime,
72
+ };
 
 
73
 
74
  return {
75
  outputs: [