ahmed-eisa commited on
Commit
ec10d0e
·
1 Parent(s): be3cdfa

added image endpoint

Browse files
Files changed (4) hide show
  1. main.py +13 -5
  2. models.py +16 -0
  3. requirements.txt +2 -1
  4. utils.py +10 -1
main.py CHANGED
@@ -1,10 +1,9 @@
1
  # main.py
2
- from fastapi import FastAPI,status
3
  from fastapi.responses import StreamingResponse
4
-
5
- from models import load_text_model,generate_text,load_audio_model,generate_audio
6
  from schemas import VoicePresets
7
- from utils import audio_array_to_buffer
8
  app = FastAPI()
9
 
10
  @app.get("/")
@@ -31,4 +30,13 @@ def serve_text_to_audio_model_controller(
31
  output, sample_rate = generate_audio(processor, model, prompt, preset)
32
  return StreamingResponse(
33
  audio_array_to_buffer(output, sample_rate), media_type="audio/wav"
34
- )
 
 
 
 
 
 
 
 
 
 
1
  # main.py
2
+ from fastapi import FastAPI,status,Response
3
  from fastapi.responses import StreamingResponse
4
+ from models import load_text_model,generate_text,load_audio_model,generate_audio,load_image_model, generate_image
 
5
  from schemas import VoicePresets
6
+ from utils import audio_array_to_buffer,img_to_bytes
7
  app = FastAPI()
8
 
9
  @app.get("/")
 
30
  output, sample_rate = generate_audio(processor, model, prompt, preset)
31
  return StreamingResponse(
32
  audio_array_to_buffer(output, sample_rate), media_type="audio/wav"
33
+ )
34
+
35
+
36
+ @app.get("/generate/image",
37
+ responses={status.HTTP_200_OK: {"content": {"image/png": {}}}},
38
+ response_class=Response)
39
+ def serve_text_to_image_model_controller(prompt: str):
40
+ pipe = load_image_model()
41
+ output = generate_image(pipe, prompt)
42
+ return Response(content=img_to_bytes(output), media_type="image/png")
models.py CHANGED
@@ -3,6 +3,8 @@
3
  import torch
4
  from transformers import Pipeline, pipeline,AutoProcessor, AutoModel, BarkProcessor, BarkModel
5
  from schemas import VoicePresets
 
 
6
  import numpy as np
7
 
8
  prompt = "How to set up a FastAPI project?"
@@ -60,4 +62,18 @@ def generate_text(pipe: Pipeline, prompt: str, temperature: float = 0.7) -> str:
60
  top_p=0.95,
61
  )
62
  output = predictions[0]["generated_text"].split("</s>\n<|assistant|>\n")[-1]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
  return output
 
3
  import torch
4
  from transformers import Pipeline, pipeline,AutoProcessor, AutoModel, BarkProcessor, BarkModel
5
  from schemas import VoicePresets
6
+ from diffusers import DiffusionPipeline, StableDiffusionInpaintPipelineLegacy
7
+ from PIL import Image
8
  import numpy as np
9
 
10
  prompt = "How to set up a FastAPI project?"
 
62
  top_p=0.95,
63
  )
64
  output = predictions[0]["generated_text"].split("</s>\n<|assistant|>\n")[-1]
65
+ return output
66
+
67
+
68
+ def load_image_model() -> StableDiffusionInpaintPipelineLegacy:
69
+ pipe = DiffusionPipeline.from_pretrained(
70
+ "segmind/tiny-sd", torch_dtype=torch.float32,
71
+ device=device
72
+ )
73
+ return pipe
74
+
75
+ def generate_image(
76
+ pipe: StableDiffusionInpaintPipelineLegacy, prompt: str
77
+ ) -> Image.Image:
78
+ output = pipe(prompt, num_inference_steps=10).images[0]
79
  return output
requirements.txt CHANGED
@@ -4,4 +4,5 @@ transformers
4
  torch
5
  pydantic
6
  bitsandbytes
7
- soundfile
 
 
4
  torch
5
  pydantic
6
  bitsandbytes
7
+ soundfile
8
+ diffusers
utils.py CHANGED
@@ -1,9 +1,18 @@
1
  from io import BytesIO
2
  import soundfile
3
  import numpy as np
 
4
 
5
  def audio_array_to_buffer(audio_array: np.array, sample_rate: int) -> BytesIO:
6
  buffer = BytesIO()
7
  soundfile.write(buffer, audio_array, sample_rate, format="wav")
8
  buffer.seek(0)
9
- return buffer
 
 
 
 
 
 
 
 
 
1
  from io import BytesIO
2
  import soundfile
3
  import numpy as np
4
+ from PIL import Image
5
 
6
  def audio_array_to_buffer(audio_array: np.array, sample_rate: int) -> BytesIO:
7
  buffer = BytesIO()
8
  soundfile.write(buffer, audio_array, sample_rate, format="wav")
9
  buffer.seek(0)
10
+ return buffer
11
+
12
+
13
+ def img_to_bytes(
14
+ image: Image.Image, img_format: Literal["PNG", "JPEG"] = "PNG"
15
+ ) -> bytes:
16
+ buffer = BytesIO()
17
+ image.save(buffer, format=img_format)
18
+ return buffer.getvalue()