Spaces:
Running
on
A100
Running
on
A100
# Copyright (c) 2025 NVIDIA CORPORATION. | |
# Licensed under the MIT license. | |
# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license. | |
# LICENSE is in incl_licenses directory. | |
# Copyright 2024 NVIDIA CORPORATION & AFFILIATES | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# | |
# SPDX-License-Identifier: Apache-2.0 | |
# This file is modified from https://github.com/haotian-liu/LLaVA/ | |
import dataclasses | |
from enum import Enum, auto | |
from typing import List | |
from llava.utils.logging import logger | |
class SeparatorStyle(Enum): | |
"""Different separator style.""" | |
AUTO = auto() | |
TWO = auto() | |
MPT = auto() | |
PLAIN = auto() | |
LLAMA_3 = auto() | |
class Conversation: | |
"""A class that keeps all conversation history.""" | |
system: str | |
roles: List[str] | |
messages: List[List[str]] | |
sep_style: SeparatorStyle = SeparatorStyle.AUTO | |
sep: str = "###" | |
sep2: str = None | |
version: str = "Unknown" | |
def get_prompt(self): | |
messages = self.messages | |
if len(messages) > 0 and type(messages[0][1]) is tuple: | |
messages = self.messages.copy() | |
init_role, init_msg = messages[0].copy() | |
init_msg = init_msg[0].replace("<image>", "").strip() | |
messages[0] = (init_role, "<image>\n" + init_msg) | |
if self.sep_style == SeparatorStyle.TWO: | |
seps = [self.sep, self.sep2] | |
ret = self.system + seps[0] | |
for i, (role, message) in enumerate(messages): | |
if message: | |
if type(message) is tuple: | |
message, _, _ = message | |
ret += role + ": " + message + seps[i % 2] | |
else: | |
ret += role + ":" | |
elif self.sep_style == SeparatorStyle.LLAMA_3: | |
ret = self.system + self.sep | |
for rid, (role, message) in enumerate(messages): | |
if message: | |
if type(message) is tuple: | |
message = message[0] | |
sep = self.sep if rid < len(messages) - 1 else self.sep2 | |
ret += role + message + sep | |
else: | |
ret += role | |
elif self.sep_style == SeparatorStyle.MPT: | |
ret = self.system + self.sep | |
for role, message in messages: | |
if message: | |
if type(message) is tuple: | |
message, _, _ = message | |
ret += role + message + self.sep | |
else: | |
ret += role | |
elif self.sep_style == SeparatorStyle.PLAIN: | |
seps = [self.sep, self.sep2] | |
ret = self.system | |
for i, (role, message) in enumerate(messages): | |
if message: | |
if type(message) is tuple: | |
message, _, _ = message | |
ret += message + seps[i % 2] | |
else: | |
ret += "" | |
else: | |
raise ValueError(f"Invalid style: {self.sep_style}") | |
return ret | |
def append_message(self, role, message): | |
self.messages.append([role, message]) | |
def copy(self): | |
return Conversation( | |
system=self.system, | |
roles=self.roles, | |
messages=[[x, y] for x, y in self.messages], | |
sep_style=self.sep_style, | |
sep=self.sep, | |
sep2=self.sep2, | |
version=self.version, | |
) | |
conv_auto = Conversation( | |
system="", | |
roles=("", ""), | |
messages=(), | |
sep_style=SeparatorStyle.AUTO, | |
sep="\n", | |
) | |
conv_vicuna_v1 = Conversation( | |
system="A chat between a curious user and an artificial intelligence assistant. " | |
"The assistant gives helpful, detailed, and polite answers to the user's questions.", | |
roles=("USER", "ASSISTANT"), | |
version="v1", | |
messages=(), | |
sep_style=SeparatorStyle.TWO, | |
sep=" ", | |
sep2="</s>", | |
) | |
conv_llava_plain = Conversation( | |
system="", | |
roles=("", ""), | |
messages=(), | |
sep_style=SeparatorStyle.PLAIN, | |
sep="\n", | |
) | |
hermes_2 = Conversation( | |
system="<|im_start|>system\nAnswer the questions.", | |
roles=("<|im_start|>user\n", "<|im_start|>assistant\n"), | |
sep_style=SeparatorStyle.MPT, | |
sep="<|im_end|>", | |
messages=(), | |
version="hermes-2", | |
) | |
# Template added by Yukang. Note (kentang-mit@): sep is <|eot_id|> for official template. | |
llama_3_chat = Conversation( | |
system="<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\nYou are a helpful language and vision assistant. " | |
"You are able to understand the visual content that the user provides, " | |
"and assist the user with a variety of tasks using natural language.", | |
roles=("<|start_header_id|>user<|end_header_id|>\n\n", "<|start_header_id|>assistant<|end_header_id|>\n\n"), | |
version="llama_v3", | |
messages=(), | |
sep_style=SeparatorStyle.LLAMA_3, | |
sep="<|eot_id|>", | |
sep2="<|end_of_text|>", | |
) | |
default_conversation = conv_auto | |
conv_templates = { | |
"auto": conv_auto, | |
"hermes-2": hermes_2, | |
"llama_3": llama_3_chat, | |
"v1": conv_vicuna_v1, | |
"vicuna_v1": conv_vicuna_v1, | |
"plain": conv_llava_plain, | |
} | |
CONVERSATION_MODE_MAPPING = { | |
"vila1.5-3b": "vicuna_v1", | |
"vila1.5-8b": "llama_3", | |
"vila1.5-13b": "vicuna_v1", | |
"vila1.5-40b": "hermes-2", | |
"llama-3": "llama_3", | |
"llama3": "llama_3", | |
} | |
def auto_set_conversation_mode(model_name_or_path: str) -> str: | |
global default_conversation | |
for k, v in CONVERSATION_MODE_MAPPING.items(): | |
if k in model_name_or_path.lower(): | |
logger.info(f"Setting conversation mode to `{v}` based on model name/path `{model_name_or_path}`.") | |
default_conversation = conv_templates[v] | |
return | |