Spaces:
Running
on
Zero
Running
on
Zero
import logging | |
from typing import Union, List, Optional | |
import torch | |
from PIL import Image | |
logger = logging.getLogger(__name__) # pylint: disable=invalid-name | |
T2V_CINEMATIC_PROMPT = """You are an expert cinematic director with many award winning movies, When writing prompts based on the user input, focus on detailed, chronological descriptions of actions and scenes. | |
Include specific movements, appearances, camera angles, and environmental details - all in a single flowing paragraph. | |
Start directly with the action, and keep descriptions literal and precise. | |
Think like a cinematographer describing a shot list. | |
Do not change the user input intent, just enhance it. | |
Keep within 150 words. | |
For best results, build your prompts using this structure: | |
Start with main action in a single sentence | |
Add specific details about movements and gestures | |
Describe character/object appearances precisely | |
Include background and environment details | |
Specify camera angles and movements | |
Describe lighting and colors | |
Note any changes or sudden events | |
Do not exceed the 150 word limit! | |
Output the enhanced prompt only. | |
""" | |
I2V_CINEMATIC_PROMPT = """You are an expert cinematic director with many award winning movies, When writing prompts based on the user input, focus on detailed, chronological descriptions of actions and scenes. | |
Include specific movements, appearances, camera angles, and environmental details - all in a single flowing paragraph. | |
Start directly with the action, and keep descriptions literal and precise. | |
Think like a cinematographer describing a shot list. | |
Keep within 150 words. | |
For best results, build your prompts using this structure: | |
Describe the image first and then add the user input. Image description should be in first priority! Align to the image caption if it contradicts the user text input. | |
Start with main action in a single sentence | |
Add specific details about movements and gestures | |
Describe character/object appearances precisely | |
Include background and environment details | |
Specify camera angles and movements | |
Describe lighting and colors | |
Note any changes or sudden events | |
Align to the image caption if it contradicts the user text input. | |
Do not exceed the 150 word limit! | |
Output the enhanced prompt only. | |
""" | |
def tensor_to_pil(tensor): | |
# Ensure tensor is in range [-1, 1] | |
assert tensor.min() >= -1 and tensor.max() <= 1 | |
# Convert from [-1, 1] to [0, 1] | |
tensor = (tensor + 1) / 2 | |
# Rearrange from [C, H, W] to [H, W, C] | |
tensor = tensor.permute(1, 2, 0) | |
# Convert to numpy array and then to uint8 range [0, 255] | |
numpy_image = (tensor.cpu().numpy() * 255).astype("uint8") | |
# Convert to PIL Image | |
return Image.fromarray(numpy_image) | |
def generate_cinematic_prompt( | |
image_caption_model, | |
image_caption_processor, | |
prompt_enhancer_model, | |
prompt_enhancer_tokenizer, | |
prompt: Union[str, List[str]], | |
conditioning_items: Optional[List] = None, | |
max_new_tokens: int = 256, | |
) -> List[str]: | |
prompts = [prompt] if isinstance(prompt, str) else prompt | |
if conditioning_items is None: | |
prompts = _generate_t2v_prompt( | |
prompt_enhancer_model, | |
prompt_enhancer_tokenizer, | |
prompts, | |
max_new_tokens, | |
T2V_CINEMATIC_PROMPT, | |
) | |
else: | |
if len(conditioning_items) > 1 or conditioning_items[0].media_frame_number != 0: | |
logger.warning( | |
"prompt enhancement does only support unconditional or first frame of conditioning items, returning original prompts" | |
) | |
return prompts | |
first_frame_conditioning_item = conditioning_items[0] | |
first_frames = _get_first_frames_from_conditioning_item( | |
first_frame_conditioning_item | |
) | |
assert len(first_frames) == len( | |
prompts | |
), "Number of conditioning frames must match number of prompts" | |
prompts = _generate_i2v_prompt( | |
image_caption_model, | |
image_caption_processor, | |
prompt_enhancer_model, | |
prompt_enhancer_tokenizer, | |
prompts, | |
first_frames, | |
max_new_tokens, | |
I2V_CINEMATIC_PROMPT, | |
) | |
return prompts | |
def _get_first_frames_from_conditioning_item(conditioning_item) -> List[Image.Image]: | |
frames_tensor = conditioning_item.media_item | |
return [ | |
tensor_to_pil(frames_tensor[i, :, 0, :, :]) | |
for i in range(frames_tensor.shape[0]) | |
] | |
def _generate_t2v_prompt( | |
prompt_enhancer_model, | |
prompt_enhancer_tokenizer, | |
prompts: List[str], | |
max_new_tokens: int, | |
system_prompt: str, | |
) -> List[str]: | |
messages = [ | |
[ | |
{"role": "system", "content": system_prompt}, | |
{"role": "user", "content": f"user_prompt: {p}"}, | |
] | |
for p in prompts | |
] | |
texts = [ | |
prompt_enhancer_tokenizer.apply_chat_template( | |
m, tokenize=False, add_generation_prompt=True | |
) | |
for m in messages | |
] | |
model_inputs = prompt_enhancer_tokenizer(texts, return_tensors="pt").to( | |
prompt_enhancer_model.device | |
) | |
return _generate_and_decode_prompts( | |
prompt_enhancer_model, prompt_enhancer_tokenizer, model_inputs, max_new_tokens | |
) | |
def _generate_i2v_prompt( | |
image_caption_model, | |
image_caption_processor, | |
prompt_enhancer_model, | |
prompt_enhancer_tokenizer, | |
prompts: List[str], | |
first_frames: List[Image.Image], | |
max_new_tokens: int, | |
system_prompt: str, | |
) -> List[str]: | |
image_captions = _generate_image_captions( | |
image_caption_model, image_caption_processor, first_frames | |
) | |
messages = [ | |
[ | |
{"role": "system", "content": system_prompt}, | |
{"role": "user", "content": f"user_prompt: {p}\nimage_caption: {c}"}, | |
] | |
for p, c in zip(prompts, image_captions) | |
] | |
texts = [ | |
prompt_enhancer_tokenizer.apply_chat_template( | |
m, tokenize=False, add_generation_prompt=True | |
) | |
for m in messages | |
] | |
model_inputs = prompt_enhancer_tokenizer(texts, return_tensors="pt").to( | |
prompt_enhancer_model.device | |
) | |
return _generate_and_decode_prompts( | |
prompt_enhancer_model, prompt_enhancer_tokenizer, model_inputs, max_new_tokens | |
) | |
def _generate_image_captions( | |
image_caption_model, | |
image_caption_processor, | |
images: List[Image.Image], | |
system_prompt: str = "<DETAILED_CAPTION>", | |
) -> List[str]: | |
image_caption_prompts = [system_prompt] * len(images) | |
inputs = image_caption_processor( | |
image_caption_prompts, images, return_tensors="pt" | |
).to(image_caption_model.device) | |
with torch.inference_mode(): | |
generated_ids = image_caption_model.generate( | |
input_ids=inputs["input_ids"], | |
pixel_values=inputs["pixel_values"], | |
max_new_tokens=1024, | |
do_sample=False, | |
num_beams=3, | |
) | |
return image_caption_processor.batch_decode(generated_ids, skip_special_tokens=True) | |
def _generate_and_decode_prompts( | |
prompt_enhancer_model, prompt_enhancer_tokenizer, model_inputs, max_new_tokens: int | |
) -> List[str]: | |
with torch.inference_mode(): | |
outputs = prompt_enhancer_model.generate( | |
**model_inputs, max_new_tokens=max_new_tokens | |
) | |
generated_ids = [ | |
output_ids[len(input_ids) :] | |
for input_ids, output_ids in zip(model_inputs.input_ids, outputs) | |
] | |
decoded_prompts = prompt_enhancer_tokenizer.batch_decode( | |
generated_ids, skip_special_tokens=True | |
) | |
return decoded_prompts | |