Spaces:
Running
on
Zero
Running
on
Zero
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. | |
import json | |
import logging | |
import math | |
import os | |
import random | |
import sys | |
import tempfile | |
from dataclasses import dataclass | |
from http import HTTPStatus | |
from typing import Optional, Union | |
import dashscope | |
import torch | |
from PIL import Image | |
try: | |
from flash_attn import flash_attn_varlen_func | |
FLASH_VER = 2 | |
except ModuleNotFoundError: | |
flash_attn_varlen_func = None # in compatible with CPU machines | |
FLASH_VER = None | |
from .system_prompt import * | |
DEFAULT_SYS_PROMPTS = { | |
"t2v-A14B": { | |
"zh": T2V_A14B_ZH_SYS_PROMPT, | |
"en": T2V_A14B_EN_SYS_PROMPT, | |
}, | |
"i2v-A14B": { | |
"zh": I2V_A14B_ZH_SYS_PROMPT, | |
"en": I2V_A14B_EN_SYS_PROMPT, | |
"empty": { | |
"zh": I2V_A14B_EMPTY_ZH_SYS_PROMPT, | |
"en": I2V_A14B_EMPTY_EN_SYS_PROMPT, | |
} | |
}, | |
"ti2v-5B": { | |
"t2v": { | |
"zh": T2V_A14B_ZH_SYS_PROMPT, | |
"en": T2V_A14B_EN_SYS_PROMPT, | |
}, | |
"i2v": { | |
"zh": I2V_A14B_ZH_SYS_PROMPT, | |
"en": I2V_A14B_EN_SYS_PROMPT, | |
} | |
}, | |
} | |
class PromptOutput(object): | |
status: bool | |
prompt: str | |
seed: int | |
system_prompt: str | |
message: str | |
def add_custom_field(self, key: str, value) -> None: | |
self.__setattr__(key, value) | |
class PromptExpander: | |
def __init__(self, model_name, task, is_vl=False, device=0, **kwargs): | |
self.model_name = model_name | |
self.task = task | |
self.is_vl = is_vl | |
self.device = device | |
def extend_with_img(self, | |
prompt, | |
system_prompt, | |
image=None, | |
seed=-1, | |
*args, | |
**kwargs): | |
pass | |
def extend(self, prompt, system_prompt, seed=-1, *args, **kwargs): | |
pass | |
def decide_system_prompt(self, tar_lang="zh", prompt=None): | |
assert self.task is not None | |
if "ti2v" in self.task: | |
if self.is_vl: | |
return DEFAULT_SYS_PROMPTS[self.task]["i2v"][tar_lang] | |
else: | |
return DEFAULT_SYS_PROMPTS[self.task]["t2v"][tar_lang] | |
if "i2v" in self.task and len(prompt) == 0: | |
return DEFAULT_SYS_PROMPTS[self.task]["empty"][tar_lang] | |
return DEFAULT_SYS_PROMPTS[self.task][tar_lang] | |
def __call__(self, | |
prompt, | |
system_prompt=None, | |
tar_lang="zh", | |
image=None, | |
seed=-1, | |
*args, | |
**kwargs): | |
if system_prompt is None: | |
system_prompt = self.decide_system_prompt( | |
tar_lang=tar_lang, prompt=prompt) | |
if seed < 0: | |
seed = random.randint(0, sys.maxsize) | |
if image is not None and self.is_vl: | |
return self.extend_with_img( | |
prompt, system_prompt, image=image, seed=seed, *args, **kwargs) | |
elif not self.is_vl: | |
return self.extend(prompt, system_prompt, seed, *args, **kwargs) | |
else: | |
raise NotImplementedError | |
class DashScopePromptExpander(PromptExpander): | |
def __init__(self, | |
api_key=None, | |
model_name=None, | |
task=None, | |
max_image_size=512 * 512, | |
retry_times=4, | |
is_vl=False, | |
**kwargs): | |
''' | |
Args: | |
api_key: The API key for Dash Scope authentication and access to related services. | |
model_name: Model name, 'qwen-plus' for extending prompts, 'qwen-vl-max' for extending prompt-images. | |
task: Task name. This is required to determine the default system prompt. | |
max_image_size: The maximum size of the image; unit unspecified (e.g., pixels, KB). Please specify the unit based on actual usage. | |
retry_times: Number of retry attempts in case of request failure. | |
is_vl: A flag indicating whether the task involves visual-language processing. | |
**kwargs: Additional keyword arguments that can be passed to the function or method. | |
''' | |
if model_name is None: | |
model_name = 'qwen-plus' if not is_vl else 'qwen-vl-max' | |
super().__init__(model_name, task, is_vl, **kwargs) | |
if api_key is not None: | |
dashscope.api_key = api_key | |
elif 'DASH_API_KEY' in os.environ and os.environ[ | |
'DASH_API_KEY'] is not None: | |
dashscope.api_key = os.environ['DASH_API_KEY'] | |
else: | |
raise ValueError("DASH_API_KEY is not set") | |
if 'DASH_API_URL' in os.environ and os.environ[ | |
'DASH_API_URL'] is not None: | |
dashscope.base_http_api_url = os.environ['DASH_API_URL'] | |
else: | |
dashscope.base_http_api_url = 'https://dashscope.aliyuncs.com/api/v1' | |
self.api_key = api_key | |
self.max_image_size = max_image_size | |
self.model = model_name | |
self.retry_times = retry_times | |
def extend(self, prompt, system_prompt, seed=-1, *args, **kwargs): | |
messages = [{ | |
'role': 'system', | |
'content': system_prompt | |
}, { | |
'role': 'user', | |
'content': prompt | |
}] | |
exception = None | |
for _ in range(self.retry_times): | |
try: | |
response = dashscope.Generation.call( | |
self.model, | |
messages=messages, | |
seed=seed, | |
result_format='message', # set the result to be "message" format. | |
) | |
assert response.status_code == HTTPStatus.OK, response | |
expanded_prompt = response['output']['choices'][0]['message'][ | |
'content'] | |
return PromptOutput( | |
status=True, | |
prompt=expanded_prompt, | |
seed=seed, | |
system_prompt=system_prompt, | |
message=json.dumps(response, ensure_ascii=False)) | |
except Exception as e: | |
exception = e | |
return PromptOutput( | |
status=False, | |
prompt=prompt, | |
seed=seed, | |
system_prompt=system_prompt, | |
message=str(exception)) | |
def extend_with_img(self, | |
prompt, | |
system_prompt, | |
image: Union[Image.Image, str] = None, | |
seed=-1, | |
*args, | |
**kwargs): | |
if isinstance(image, str): | |
image = Image.open(image).convert('RGB') | |
w = image.width | |
h = image.height | |
area = min(w * h, self.max_image_size) | |
aspect_ratio = h / w | |
resized_h = round(math.sqrt(area * aspect_ratio)) | |
resized_w = round(math.sqrt(area / aspect_ratio)) | |
image = image.resize((resized_w, resized_h)) | |
with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as f: | |
image.save(f.name) | |
fname = f.name | |
image_path = f"file://{f.name}" | |
prompt = f"{prompt}" | |
messages = [ | |
{ | |
'role': 'system', | |
'content': [{ | |
"text": system_prompt | |
}] | |
}, | |
{ | |
'role': 'user', | |
'content': [{ | |
"text": prompt | |
}, { | |
"image": image_path | |
}] | |
}, | |
] | |
response = None | |
result_prompt = prompt | |
exception = None | |
status = False | |
for _ in range(self.retry_times): | |
try: | |
response = dashscope.MultiModalConversation.call( | |
self.model, | |
messages=messages, | |
seed=seed, | |
result_format='message', # set the result to be "message" format. | |
) | |
assert response.status_code == HTTPStatus.OK, response | |
result_prompt = response['output']['choices'][0]['message'][ | |
'content'][0]['text'].replace('\n', '\\n') | |
status = True | |
break | |
except Exception as e: | |
exception = e | |
result_prompt = result_prompt.replace('\n', '\\n') | |
os.remove(fname) | |
return PromptOutput( | |
status=status, | |
prompt=result_prompt, | |
seed=seed, | |
system_prompt=system_prompt, | |
message=str(exception) if not status else json.dumps( | |
response, ensure_ascii=False)) | |
class QwenPromptExpander(PromptExpander): | |
model_dict = { | |
"QwenVL2.5_3B": "Qwen/Qwen2.5-VL-3B-Instruct", | |
"QwenVL2.5_7B": "Qwen/Qwen2.5-VL-7B-Instruct", | |
"Qwen2.5_3B": "Qwen/Qwen2.5-3B-Instruct", | |
"Qwen2.5_7B": "Qwen/Qwen2.5-7B-Instruct", | |
"Qwen2.5_14B": "Qwen/Qwen2.5-14B-Instruct", | |
} | |
def __init__(self, | |
model_name=None, | |
task=None, | |
device=0, | |
is_vl=False, | |
**kwargs): | |
''' | |
Args: | |
model_name: Use predefined model names such as 'QwenVL2.5_7B' and 'Qwen2.5_14B', | |
which are specific versions of the Qwen model. Alternatively, you can use the | |
local path to a downloaded model or the model name from Hugging Face." | |
Detailed Breakdown: | |
Predefined Model Names: | |
* 'QwenVL2.5_7B' and 'Qwen2.5_14B' are specific versions of the Qwen model. | |
Local Path: | |
* You can provide the path to a model that you have downloaded locally. | |
Hugging Face Model Name: | |
* You can also specify the model name from Hugging Face's model hub. | |
task: Task name. This is required to determine the default system prompt. | |
is_vl: A flag indicating whether the task involves visual-language processing. | |
**kwargs: Additional keyword arguments that can be passed to the function or method. | |
''' | |
if model_name is None: | |
model_name = 'Qwen2.5_14B' if not is_vl else 'QwenVL2.5_7B' | |
super().__init__(model_name, task, is_vl, device, **kwargs) | |
if (not os.path.exists(self.model_name)) and (self.model_name | |
in self.model_dict): | |
self.model_name = self.model_dict[self.model_name] | |
if self.is_vl: | |
# default: Load the model on the available device(s) | |
from transformers import ( | |
AutoProcessor, | |
AutoTokenizer, | |
Qwen2_5_VLForConditionalGeneration, | |
) | |
try: | |
from .qwen_vl_utils import process_vision_info | |
except: | |
from qwen_vl_utils import process_vision_info | |
self.process_vision_info = process_vision_info | |
min_pixels = 256 * 28 * 28 | |
max_pixels = 1280 * 28 * 28 | |
self.processor = AutoProcessor.from_pretrained( | |
self.model_name, | |
min_pixels=min_pixels, | |
max_pixels=max_pixels, | |
use_fast=True) | |
self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained( | |
self.model_name, | |
torch_dtype=torch.bfloat16 if FLASH_VER == 2 else | |
torch.float16 if "AWQ" in self.model_name else "auto", | |
attn_implementation="flash_attention_2" | |
if FLASH_VER == 2 else None, | |
device_map="cpu") | |
else: | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
self.model = AutoModelForCausalLM.from_pretrained( | |
self.model_name, | |
torch_dtype=torch.float16 | |
if "AWQ" in self.model_name else "auto", | |
attn_implementation="flash_attention_2" | |
if FLASH_VER == 2 else None, | |
device_map="cpu") | |
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name) | |
def extend(self, prompt, system_prompt, seed=-1, *args, **kwargs): | |
self.model = self.model.to(self.device) | |
messages = [{ | |
"role": "system", | |
"content": system_prompt | |
}, { | |
"role": "user", | |
"content": prompt | |
}] | |
text = self.tokenizer.apply_chat_template( | |
messages, tokenize=False, add_generation_prompt=True) | |
model_inputs = self.tokenizer([text], | |
return_tensors="pt").to(self.model.device) | |
generated_ids = self.model.generate(**model_inputs, max_new_tokens=512) | |
generated_ids = [ | |
output_ids[len(input_ids):] for input_ids, output_ids in zip( | |
model_inputs.input_ids, generated_ids) | |
] | |
expanded_prompt = self.tokenizer.batch_decode( | |
generated_ids, skip_special_tokens=True)[0] | |
self.model = self.model.to("cpu") | |
return PromptOutput( | |
status=True, | |
prompt=expanded_prompt, | |
seed=seed, | |
system_prompt=system_prompt, | |
message=json.dumps({"content": expanded_prompt}, | |
ensure_ascii=False)) | |
def extend_with_img(self, | |
prompt, | |
system_prompt, | |
image: Union[Image.Image, str] = None, | |
seed=-1, | |
*args, | |
**kwargs): | |
self.model = self.model.to(self.device) | |
messages = [{ | |
'role': 'system', | |
'content': [{ | |
"type": "text", | |
"text": system_prompt | |
}] | |
}, { | |
"role": | |
"user", | |
"content": [ | |
{ | |
"type": "image", | |
"image": image, | |
}, | |
{ | |
"type": "text", | |
"text": prompt | |
}, | |
], | |
}] | |
# Preparation for inference | |
text = self.processor.apply_chat_template( | |
messages, tokenize=False, add_generation_prompt=True) | |
image_inputs, video_inputs = self.process_vision_info(messages) | |
inputs = self.processor( | |
text=[text], | |
images=image_inputs, | |
videos=video_inputs, | |
padding=True, | |
return_tensors="pt", | |
) | |
inputs = inputs.to(self.device) | |
# Inference: Generation of the output | |
generated_ids = self.model.generate(**inputs, max_new_tokens=512) | |
generated_ids_trimmed = [ | |
out_ids[len(in_ids):] | |
for in_ids, out_ids in zip(inputs.input_ids, generated_ids) | |
] | |
expanded_prompt = self.processor.batch_decode( | |
generated_ids_trimmed, | |
skip_special_tokens=True, | |
clean_up_tokenization_spaces=False)[0] | |
self.model = self.model.to("cpu") | |
return PromptOutput( | |
status=True, | |
prompt=expanded_prompt, | |
seed=seed, | |
system_prompt=system_prompt, | |
message=json.dumps({"content": expanded_prompt}, | |
ensure_ascii=False)) | |
if __name__ == "__main__": | |
logging.basicConfig( | |
level=logging.INFO, | |
format="[%(asctime)s] %(levelname)s: %(message)s", | |
handlers=[logging.StreamHandler(stream=sys.stdout)]) | |
seed = 100 | |
prompt = "夏日海滩度假风格,一只戴着墨镜的白色猫咪坐在冲浪板上。猫咪毛发蓬松,表情悠闲,直视镜头。背景是模糊的海滩景色,海水清澈,远处有绿色的山丘和蓝天白云。猫咪的姿态自然放松,仿佛在享受海风和阳光。近景特写,强调猫咪的细节和海滩的清新氛围。" | |
en_prompt = "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside." | |
image = "./examples/i2v_input.JPG" | |
def test(method, | |
prompt, | |
model_name, | |
task, | |
image=None, | |
en_prompt=None, | |
seed=None): | |
prompt_expander = method( | |
model_name=model_name, task=task, is_vl=image is not None) | |
result = prompt_expander(prompt, image=image, tar_lang="zh") | |
logging.info(f"zh prompt -> zh: {result.prompt}") | |
result = prompt_expander(prompt, image=image, tar_lang="en") | |
logging.info(f"zh prompt -> en: {result.prompt}") | |
if en_prompt is not None: | |
result = prompt_expander(en_prompt, image=image, tar_lang="zh") | |
logging.info(f"en prompt -> zh: {result.prompt}") | |
result = prompt_expander(en_prompt, image=image, tar_lang="en") | |
logging.info(f"en prompt -> en: {result.prompt}") | |
ds_model_name = None | |
ds_vl_model_name = None | |
qwen_model_name = None | |
qwen_vl_model_name = None | |
for task in ["t2v-A14B", "i2v-A14B", "ti2v-5B"]: | |
# test prompt extend | |
if "t2v" in task or "ti2v" in task: | |
# test dashscope api | |
logging.info(f"-" * 40) | |
logging.info(f"Testing {task} dashscope prompt extend") | |
test( | |
DashScopePromptExpander, | |
prompt, | |
ds_model_name, | |
task, | |
image=None, | |
en_prompt=en_prompt, | |
seed=seed) | |
# test qwen api | |
logging.info(f"-" * 40) | |
logging.info(f"Testing {task} qwen prompt extend") | |
test( | |
QwenPromptExpander, | |
prompt, | |
qwen_model_name, | |
task, | |
image=None, | |
en_prompt=en_prompt, | |
seed=seed) | |
# test prompt-image extend | |
if "i2v" in task: | |
# test dashscope api | |
logging.info(f"-" * 40) | |
logging.info(f"Testing {task} dashscope vl prompt extend") | |
test( | |
DashScopePromptExpander, | |
prompt, | |
ds_vl_model_name, | |
task, | |
image=image, | |
en_prompt=en_prompt, | |
seed=seed) | |
# test qwen api | |
logging.info(f"-" * 40) | |
logging.info(f"Testing {task} qwen vl prompt extend") | |
test( | |
QwenPromptExpander, | |
prompt, | |
qwen_vl_model_name, | |
task, | |
image=image, | |
en_prompt=en_prompt, | |
seed=seed) | |
# test empty prompt extend | |
if "i2v-A14B" in task: | |
# test dashscope api | |
logging.info(f"-" * 40) | |
logging.info(f"Testing {task} dashscope vl empty prompt extend") | |
test( | |
DashScopePromptExpander, | |
"", | |
ds_vl_model_name, | |
task, | |
image=image, | |
en_prompt=None, | |
seed=seed) | |
# test qwen api | |
logging.info(f"-" * 40) | |
logging.info(f"Testing {task} qwen vl empty prompt extend") | |
test( | |
QwenPromptExpander, | |
"", | |
qwen_vl_model_name, | |
task, | |
image=image, | |
en_prompt=None, | |
seed=seed) | |