genai_service / models.py
ahmed-eisa's picture
added image endpoint
ec10d0e
# models.py
import torch
from transformers import Pipeline, pipeline,AutoProcessor, AutoModel, BarkProcessor, BarkModel
from schemas import VoicePresets
from diffusers import DiffusionPipeline, StableDiffusionInpaintPipelineLegacy
from PIL import Image
import numpy as np
prompt = "How to set up a FastAPI project?"
system_prompt = """
Your name is FastAPI bot and you are a helpful
chatbot responsible for teaching FastAPI to your users.
Always respond in markdown.
"""
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def load_audio_model() -> tuple[BarkProcessor, BarkModel]:
processor = AutoProcessor.from_pretrained("suno/bark-small", device=device)
model = AutoModel.from_pretrained("suno/bark-small")
model.to(device)
return processor, model
def generate_audio(
processor: BarkProcessor,
model: BarkModel,
prompt: str,
preset: VoicePresets,
) -> tuple[np.array, int]:
inputs = processor(text=[prompt], return_tensors="pt", voice_preset=preset)
output = model.generate(**inputs, do_sample=True).cpu().numpy().squeeze()
sample_rate = model.generation_config.sample_rate
return output, sample_rate
def load_text_model():
pipe = pipeline(
"text-generation",
model="TinyLlama/TinyLlama-1.1B-Chat-v1.0",
torch_dtype=torch.bfloat16,
device=device
)
return pipe
def generate_text(pipe: Pipeline, prompt: str, temperature: float = 0.7) -> str:
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": prompt},
]
prompt = pipe.tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
predictions = pipe(
prompt,
temperature=temperature,
max_new_tokens=256,
do_sample=True,
top_k=50,
top_p=0.95,
)
output = predictions[0]["generated_text"].split("</s>\n<|assistant|>\n")[-1]
return output
def load_image_model() -> StableDiffusionInpaintPipelineLegacy:
pipe = DiffusionPipeline.from_pretrained(
"segmind/tiny-sd", torch_dtype=torch.float32,
device=device
)
return pipe
def generate_image(
pipe: StableDiffusionInpaintPipelineLegacy, prompt: str
) -> Image.Image:
output = pipe(prompt, num_inference_steps=10).images[0]
return output