|
import logging |
|
from typing import Union, List, Optional |
|
|
|
import torch |
|
from PIL import Image |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
T2V_CINEMATIC_PROMPT = """You are an expert cinematic director with many award winning movies, When writing prompts based on the user input, focus on detailed, chronological descriptions of actions and scenes. |
|
Include specific movements, appearances, camera angles, and environmental details - all in a single flowing paragraph. |
|
Start directly with the action, and keep descriptions literal and precise. |
|
Think like a cinematographer describing a shot list. |
|
Do not change the user input intent, just enhance it. |
|
Keep within 150 words. |
|
For best results, build your prompts using this structure: |
|
Start with main action in a single sentence |
|
Add specific details about movements and gestures |
|
Describe character/object appearances precisely |
|
Include background and environment details |
|
Specify camera angles and movements |
|
Describe lighting and colors |
|
Note any changes or sudden events |
|
Do not exceed the 150 word limit! |
|
Output the enhanced prompt only. |
|
""" |
|
|
|
I2V_CINEMATIC_PROMPT = """You are an expert cinematic director with many award winning movies, When writing prompts based on the user input, focus on detailed, chronological descriptions of actions and scenes. |
|
Include specific movements, appearances, camera angles, and environmental details - all in a single flowing paragraph. |
|
Start directly with the action, and keep descriptions literal and precise. |
|
Think like a cinematographer describing a shot list. |
|
Keep within 150 words. |
|
For best results, build your prompts using this structure: |
|
Describe the image first and then add the user input. Image description should be in first priority! Align to the image caption if it contradicts the user text input. |
|
Start with main action in a single sentence |
|
Add specific details about movements and gestures |
|
Describe character/object appearances precisely |
|
Include background and environment details |
|
Specify camera angles and movements |
|
Describe lighting and colors |
|
Note any changes or sudden events |
|
Align to the image caption if it contradicts the user text input. |
|
Do not exceed the 150 word limit! |
|
Output the enhanced prompt only. |
|
""" |
|
|
|
|
|
def tensor_to_pil(tensor): |
|
|
|
assert tensor.min() >= -1 and tensor.max() <= 1 |
|
|
|
|
|
tensor = (tensor + 1) / 2 |
|
|
|
|
|
tensor = tensor.permute(1, 2, 0) |
|
|
|
|
|
numpy_image = (tensor.cpu().numpy() * 255).astype("uint8") |
|
|
|
|
|
return Image.fromarray(numpy_image) |
|
|
|
|
|
def generate_cinematic_prompt( |
|
image_caption_model, |
|
image_caption_processor, |
|
prompt_enhancer_model, |
|
prompt_enhancer_tokenizer, |
|
prompt: Union[str, List[str]], |
|
conditioning_items: Optional[List] = None, |
|
max_new_tokens: int = 256, |
|
) -> List[str]: |
|
prompts = [prompt] if isinstance(prompt, str) else prompt |
|
|
|
if conditioning_items is None: |
|
prompts = _generate_t2v_prompt( |
|
prompt_enhancer_model, |
|
prompt_enhancer_tokenizer, |
|
prompts, |
|
max_new_tokens, |
|
T2V_CINEMATIC_PROMPT, |
|
) |
|
else: |
|
if len(conditioning_items) > 1 or conditioning_items[0].media_frame_number != 0: |
|
logger.warning( |
|
"prompt enhancement does only support unconditional or first frame of conditioning items, returning original prompts" |
|
) |
|
return prompts |
|
|
|
first_frame_conditioning_item = conditioning_items[0] |
|
first_frames = _get_first_frames_from_conditioning_item( |
|
first_frame_conditioning_item |
|
) |
|
|
|
assert len(first_frames) == len( |
|
prompts |
|
), "Number of conditioning frames must match number of prompts" |
|
|
|
prompts = _generate_i2v_prompt( |
|
image_caption_model, |
|
image_caption_processor, |
|
prompt_enhancer_model, |
|
prompt_enhancer_tokenizer, |
|
prompts, |
|
first_frames, |
|
max_new_tokens, |
|
I2V_CINEMATIC_PROMPT, |
|
) |
|
|
|
return prompts |
|
|
|
|
|
def _get_first_frames_from_conditioning_item(conditioning_item) -> List[Image.Image]: |
|
frames_tensor = conditioning_item.media_item |
|
return [ |
|
tensor_to_pil(frames_tensor[i, :, 0, :, :]) |
|
for i in range(frames_tensor.shape[0]) |
|
] |
|
|
|
|
|
def _generate_t2v_prompt( |
|
prompt_enhancer_model, |
|
prompt_enhancer_tokenizer, |
|
prompts: List[str], |
|
max_new_tokens: int, |
|
system_prompt: str, |
|
) -> List[str]: |
|
messages = [ |
|
[ |
|
{"role": "system", "content": system_prompt}, |
|
{"role": "user", "content": f"user_prompt: {p}"}, |
|
] |
|
for p in prompts |
|
] |
|
|
|
texts = [ |
|
prompt_enhancer_tokenizer.apply_chat_template( |
|
m, tokenize=False, add_generation_prompt=True |
|
) |
|
for m in messages |
|
] |
|
model_inputs = prompt_enhancer_tokenizer(texts, return_tensors="pt").to( |
|
prompt_enhancer_model.device |
|
) |
|
|
|
return _generate_and_decode_prompts( |
|
prompt_enhancer_model, prompt_enhancer_tokenizer, model_inputs, max_new_tokens |
|
) |
|
|
|
|
|
def _generate_i2v_prompt( |
|
image_caption_model, |
|
image_caption_processor, |
|
prompt_enhancer_model, |
|
prompt_enhancer_tokenizer, |
|
prompts: List[str], |
|
first_frames: List[Image.Image], |
|
max_new_tokens: int, |
|
system_prompt: str, |
|
) -> List[str]: |
|
image_captions = _generate_image_captions( |
|
image_caption_model, image_caption_processor, first_frames |
|
) |
|
|
|
messages = [ |
|
[ |
|
{"role": "system", "content": system_prompt}, |
|
{"role": "user", "content": f"user_prompt: {p}\nimage_caption: {c}"}, |
|
] |
|
for p, c in zip(prompts, image_captions) |
|
] |
|
|
|
texts = [ |
|
prompt_enhancer_tokenizer.apply_chat_template( |
|
m, tokenize=False, add_generation_prompt=True |
|
) |
|
for m in messages |
|
] |
|
model_inputs = prompt_enhancer_tokenizer(texts, return_tensors="pt").to( |
|
prompt_enhancer_model.device |
|
) |
|
|
|
return _generate_and_decode_prompts( |
|
prompt_enhancer_model, prompt_enhancer_tokenizer, model_inputs, max_new_tokens |
|
) |
|
|
|
|
|
def _generate_image_captions( |
|
image_caption_model, |
|
image_caption_processor, |
|
images: List[Image.Image], |
|
system_prompt: str = "<DETAILED_CAPTION>", |
|
) -> List[str]: |
|
image_caption_prompts = [system_prompt] * len(images) |
|
inputs = image_caption_processor( |
|
image_caption_prompts, images, return_tensors="pt" |
|
).to(image_caption_model.device) |
|
|
|
with torch.inference_mode(): |
|
generated_ids = image_caption_model.generate( |
|
input_ids=inputs["input_ids"], |
|
pixel_values=inputs["pixel_values"], |
|
max_new_tokens=1024, |
|
do_sample=False, |
|
num_beams=3, |
|
) |
|
|
|
return image_caption_processor.batch_decode(generated_ids, skip_special_tokens=True) |
|
|
|
|
|
def _generate_and_decode_prompts( |
|
prompt_enhancer_model, prompt_enhancer_tokenizer, model_inputs, max_new_tokens: int |
|
) -> List[str]: |
|
with torch.inference_mode(): |
|
outputs = prompt_enhancer_model.generate( |
|
**model_inputs, max_new_tokens=max_new_tokens |
|
) |
|
generated_ids = [ |
|
output_ids[len(input_ids) :] |
|
for input_ids, output_ids in zip(model_inputs.input_ids, outputs) |
|
] |
|
decoded_prompts = prompt_enhancer_tokenizer.batch_decode( |
|
generated_ids, skip_special_tokens=True |
|
) |
|
|
|
return decoded_prompts |
|
|