import argparse
import time
from PIL import Image
import torch
import numpy as np
from transformers import AutoTokenizer, AutoModelForCausalLM, LlamaTokenizer
from transformers import StoppingCriteria, StoppingCriteriaList
import dataclasses
from enum import auto, Enum
from typing import List, Tuple, Any
import string
import cv2
import gradio as gr
from huggingface_hub import hf_hub_download, login
from open_flamingo.src.factory import create_model_and_transforms
class SeparatorStyle(Enum):
    """Different separator style."""
    SINGLE = auto()
    TWO = auto()
@dataclasses.dataclass
class Conversation:
    """A class that keeps all conversation history."""
    system: str
    roles: List[str]
    messages: List[List[str]]
    offset: int
    # system_img: List[Image.Image] = []
    sep_style: SeparatorStyle = SeparatorStyle.SINGLE
    sep: str = "###"
    sep2: str = None
    skip_next: bool = False
    conv_id: Any = None
    def get_prompt(self):
        if self.sep_style == SeparatorStyle.SINGLE:
            ret = self.system + self.sep
            for role, message in self.messages:
                if message:
                    ret += role + ": " + message + self.sep
                else:
                    ret += role + ":"
            return ret
        elif self.sep_style == SeparatorStyle.TWO:
            seps = [self.sep, self.sep2]
            ret = self.system + seps[0]
            for i, (role, message) in enumerate(self.messages):
                if message:
                    ret += role + ": " + message + seps[i % 2]
                else:
                    ret += role + ":"
            return ret
        else:
            raise ValueError(f"Invalid style: {self.sep_style}")
    def append_message(self, role, message):
        self.messages.append([role, message])
    def to_gradio_chatbot(self):
        ret = []
        for i, (role, msg) in enumerate(self.messages[self.offset:]):
            if i % 2 == 0:
                ret.append([msg, None])
            else:
                ret[-1][-1] = msg
        return ret
    def copy(self):
        return Conversation(
            system=self.system,
            # system_img=self.system_img,
            roles=self.roles,
            messages=[[x, y] for x, y in self.messages],
            offset=self.offset,
            sep_style=self.sep_style,
            sep=self.sep,
            sep2=self.sep2,
            conv_id=self.conv_id)
    def dict(self):
        return {
            "system": self.system,
            # "system_img": self.system_img,
            "roles": self.roles,
            "messages": self.messages,
            "offset": self.offset,
            "sep": self.sep,
            "sep2": self.sep2,
            "conv_id": self.conv_id,
        }
class StoppingCriteriaSub(StoppingCriteria):
    def __init__(self, stops=[], encounters=1):
        super().__init__()
        self.stops = stops
    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
        for stop in self.stops:
            if torch.all((stop == input_ids[0][-len(stop):])).item():
                return True
        return False
CONV_VISION = Conversation(
    system="Give the following image: ![]() ImageContent. "
           "You will be able to see the image once I provide it to you. Please answer my questions.",
    roles=("Human", "Assistant"),
    messages=[],
    offset=2,
    sep_style=SeparatorStyle.SINGLE,
    sep="###",
)
def get_outputs(
    model,
    batch_images,
    attention_mask,
    max_generation_length,
    min_generation_length,
    num_beams,
    length_penalty,
    input_ids,
    image_start_index_list=None,
    image_nums=None,
    bad_words_ids=None,
):
    #  and torch.cuda.amp.autocast(dtype=torch.float16)
    with torch.inference_mode():
        outputs = model(
            vision_x=batch_images,
            lang_x=input_ids,
            attention_mask=attention_mask,
            labels=None,
            image_nums=image_nums,
            image_start_index_list=image_start_index_list,
            added_bbox_list=None,
            add_box=False,
        )
        # outputs = model.generate(
        #     batch_images,
        #     input_ids,
        #     attention_mask=attention_mask,
        #     max_new_tokens=max_generation_length,
        #     min_length=min_generation_length,
        #     num_beams=num_beams,
        #     length_penalty=length_penalty,
        #     image_start_index_list=image_start_index_list,
        #     image_nums=image_nums,
        #     bad_words_ids=bad_words_ids,
        # )
    return outputs
def generate(
    idx,
    image,
    text,
    image_processor,
    tokenizer,
    flamingo,
    vis_embed_size=256,
    rank=0,
    world_size=1,
):
    if image is None:
        raise gr.Error("Please upload an image.")
    flamingo.eval()
    loc_token_ids = []
    for i in range(1000):
        loc_token_ids.append(int(tokenizer(f"", add_special_tokens=False)["input_ids"][-1]))
    media_token_id = tokenizer("<|#image#|>", add_special_tokens=False)["input_ids"][-1]
    endofmedia_token_id = tokenizer("<|#endofimage#|>", add_special_tokens=False)["input_ids"][-1]
    pad_token_id = tokenizer(tokenizer.pad_token, add_special_tokens=False)["input_ids"][-1]
    bos_token_id = tokenizer(tokenizer.bos_token, add_special_tokens=False)["input_ids"][-1]
    prebox_token_id = tokenizer("<|#prebox#|>", add_special_tokens=False)["input_ids"][-1]
    image_ori = image
    image = image.convert("RGB")
    width = image.width
    height = image.height
    image = image.resize((224, 224))
    batch_images = image_processor(image).unsqueeze(0).unsqueeze(1).unsqueeze(0)
    if idx == 1:
        prompt = [f"{tokenizer.bos_token}<|#image#|>{tokenizer.pad_token * vis_embed_size}<|#endofimage#|><|#object#|> {text.rstrip('.').strip()}<|#endofobject#|><|#visual#|>"]
        bad_words_ids = None
        max_generation_length = 5
    else:
        prompt = [f"<|#image#|>{tokenizer.pad_token * vis_embed_size}<|#endofimage#|>{text.rstrip('.')}"]
        bad_words_ids = loc_word_ids
        max_generation_length = 300
    encodings = tokenizer(
        prompt,
        padding="longest",
        truncation=True,
        return_tensors="pt",
        max_length=2000,
    )
    input_ids = encodings["input_ids"]
    attention_mask = encodings["attention_mask"]
    image_start_index_list = ((input_ids == media_token_id).nonzero(as_tuple=True)[-1] + 1).tolist()
    image_start_index_list = [[x] for x in image_start_index_list]
    image_nums = [1] * len(input_ids)
    outputs = get_outputs(
        model=flamingo,
        batch_images=batch_images,
        attention_mask=attention_mask,
        max_generation_length=max_generation_length,
        min_generation_length=4,
        num_beams=1,
        length_penalty=1.0,
        input_ids=input_ids,
        bad_words_ids=bad_words_ids,
        image_start_index_list=image_start_index_list,
        image_nums=image_nums,
    )
    boxes = outputs["boxes"]
    scores = outputs["scores"]
    if len(scores) > 0:
        box = boxes[scores.argmax()]/224
    print(f"{box}")
    
    if len(boxes)>0:
        open_cv_image = np.array(image_ori)
        # Convert RGB to BGR
        open_cv_image = open_cv_image[:, :, ::-1].copy()
        box = box*[width,height,width,height]
        # for box in boxes:
        open_cv_image = cv2.rectangle(open_cv_image, box[:2].astype(int), box[2:].astype(int), (255, 0, 0), 2)
        out_image = Image.fromarray(cv2.cvtColor(open_cv_image, cv2.COLOR_BGR2RGB))
        return f"Output:{box}", out_image
    else:
        gen_text = tokenizer.batch_decode(outputs)
        return (f"{gen_text}")
def preprocess_conv(data):
    conversation = ""
    BEGIN_SIGNAL = "### "
    END_SIGNAL = "\n"
    for idx, d in enumerate(data):
        from_str = d["from"]
        if from_str.lower() == "human":
            from_str = "Human"
        elif from_str.lower() == "gpt":
            from_str = "Assistant"
        else:
            from_str = 'unknown'
        conversation += (BEGIN_SIGNAL + from_str + ": " + d["value"] + END_SIGNAL)
    return conversation
class Chat:
    def __init__(self, model, vis_processor, tokenizer, vis_embed_size ):
        self.device = device
        self.model = model
        self.vis_processor = vis_processor
        self.tokenizer = tokenizer
        self.vis_embed_size = vis_embed_size
        self.conv = []
        # stop_words_ids = [torch.tensor([835]).to(self.device),
        #                   torch.tensor([2277, 29937]).to(self.device)]  # '###' can be encoded in two different ways.
        # self.stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)])
    def ask(self, text, conv):
        conv.append(({
            "from": "human",
            "value": text,
        }))
        # if len(conv.messages) > 0 and conv.messages[-1][0] == conv.roles[0] \
        #         and conv.messages[-1][1][-6:] == '':  # last message is image.
        #     conv.messages[-1][1] = ' '.join([conv.messages[-1][1], text])
        # else:
        #     conv.append_message(conv.roles[0], text)
    def answer(self, conv, img_list, max_new_tokens=200, num_beams=5, min_length=1, top_p=0.9,
               repetition_penalty=1.0, length_penalty=1, temperature=1, max_length=2000):
        # conv.append_message(conv.roles[1], None)
        # embs = self.get_context_emb(conv, img_list)
        # 
        # # current_max_len = embs.shape[1] + max_new_tokens + 100
        # # begin_idx = max(0, current_max_len - max_length)
        # # embs = embs[:, begin_idx:]
        # outputs = self.model.llama_model.generate(
        #     inputs_embeds=embs,
        #     max_new_tokens=max_new_tokens,
        #     stopping_criteria=self.stopping_criteria,
        #     num_beams=num_beams,
        #     min_length=min_length,
        #     top_p=top_p,
        #     repetition_penalty=repetition_penalty,
        #     length_penalty=length_penalty,
        #     temperature=temperature,
        # )
        # output_token = outputs[0]
        # if output_token[0] == 0:
        #     output_token = output_token[1:]
        # output_text = self.model.llama_tokenizer.decode(output_token, add_special_tokens=False)
        # output_text = output_text.split('###')[0]  # remove the stop sign '###'
        # output_text = output_text.split('Assistant:')[-1].strip()
        # conv.messages[-1][1] = output_text
        media_token_id = self.tokenizer("<|#image#|>", add_special_tokens=False)["input_ids"][-1]
        box_token_id = self.tokenizer("<|#box#|>", add_special_tokens=False)["input_ids"][-1]
        endofobject_token_id = self.tokenizer("<|#endofobject#|>", add_special_tokens=False)["input_ids"][-1]
        endofattr_token_id = self.tokenizer("<|#endofattr#|>", add_special_tokens=False)["input_ids"][-1]
        endofmedia_token_id = self.tokenizer("<|#endofimage#|>", add_special_tokens=False)["input_ids"][-1]
        visual_token_id = self.tokenizer("<|#visual#|>", add_special_tokens=False)["input_ids"][-1]
        previsual_token_id = self.tokenizer("<|#previsual#|>", add_special_tokens=False)["input_ids"][-1]
        prebox_token_id = self.tokenizer("<|#prebox#|>", add_special_tokens=False)["input_ids"][-1]
        size = self.vis_processor.size["shortest_edge"]
        model.eval()
        # "/gpfs/u/home/LMCG/LMCGljnn/scratch-shared/cdl/tmp_img/chat_vis/chat19.png"
        image_path = input("Please enter the image path: ")
        image = Image.open(image_path).convert("RGB")
        image = image.resize((size, size))
        print(f"image size: {image.size}")
        batch_images = preprocess_image(img_list[0], self.vis_processor).unsqueeze(0).unsqueeze(1).unsqueeze(0).to("cuda")
        # conversation = []
        human_sentence = None
        conv.append({
                    "from": "gpt",
                    "value": "",
                })
        # while True:
        #     human_sentence = input("### Human: ")
        #     if human_sentence == "#end#":
        #         break
        #     conversation.append({
        #         "from": "human",
        #         "value": human_sentence,
        #     })
        #     conversation.append({
        #         "from": "gpt",
        #         "value": "",
        #     })
        text = preprocess_conv(conv).strip()
        caption = f"<|#image#|>{tokenizer.pad_token * self.vis_embed_size}<|#endofimage#|>{text}"
        encodings = tokenizer(
            caption,
            padding="longest",
            truncation=True,
            return_tensors="pt",
            max_length=2000,
        )
        input_ids = encodings["input_ids"].to("cuda")
        attention_mask = encodings["attention_mask"].to("cuda")
        image_start_index_list = ((input_ids == media_token_id).nonzero(as_tuple=True)[-1] + 1).tolist()
        image_start_index_list = [[x] for x in image_start_index_list]
        image_nums = [1] * len(input_ids)
        with torch.no_grad() and torch.cuda.amp.autocast(dtype=torch.float16):
            outputs = model.generate(
                batch_images,
                input_ids,
                attention_mask=attention_mask,
                max_new_tokens=100,
                # min_new_tokens=8,
                num_beams=1,
                image_start_index_list=image_start_index_list,
                image_nums=image_nums,
            )
        output_token = outputs[0, input_ids.shape[1]:]
        output_text = tokenizer.decode(output_token, skip_special_tokens=True).strip()
        conv[-1]["value"] = output_text
        # conv.messages[-1][1] = output_text
        print(
            f"### Assistant: {tokenizer.decode(outputs[0, input_ids.shape[1]:], skip_special_tokens=True).strip()}")
            
        return output_text, output_token.cpu().numpy()
    def upload_img(self, image, conv, img_list):
        img_list.append(image)
        # if isinstance(image, str):  # is a image path
        #     raw_image = Image.open(image).convert('RGB')
        #     image = image.resize((224, 224))
        #     image = self.vis_processor(raw_image).unsqueeze(0).unsqueeze(1).unsqueeze(0)
        # elif isinstance(image, Image.Image):
        #     raw_image = image
        #     image = image.resize((224, 224))
        #     image = self.vis_processor(raw_image).unsqueeze(0).unsqueeze(1).unsqueeze(0)
        # elif isinstance(image, torch.Tensor):
        #     if len(image.shape) == 3:
        #         image = image.unsqueeze(0)
        #     # image = image.to(self.device)
        # 
        # # image_emb, _ = self.model.encode_img(image)
        # img_list.append(image_emb)
        conv.append_message(conv.roles[0], "
ImageContent. "
           "You will be able to see the image once I provide it to you. Please answer my questions.",
    roles=("Human", "Assistant"),
    messages=[],
    offset=2,
    sep_style=SeparatorStyle.SINGLE,
    sep="###",
)
def get_outputs(
    model,
    batch_images,
    attention_mask,
    max_generation_length,
    min_generation_length,
    num_beams,
    length_penalty,
    input_ids,
    image_start_index_list=None,
    image_nums=None,
    bad_words_ids=None,
):
    #  and torch.cuda.amp.autocast(dtype=torch.float16)
    with torch.inference_mode():
        outputs = model(
            vision_x=batch_images,
            lang_x=input_ids,
            attention_mask=attention_mask,
            labels=None,
            image_nums=image_nums,
            image_start_index_list=image_start_index_list,
            added_bbox_list=None,
            add_box=False,
        )
        # outputs = model.generate(
        #     batch_images,
        #     input_ids,
        #     attention_mask=attention_mask,
        #     max_new_tokens=max_generation_length,
        #     min_length=min_generation_length,
        #     num_beams=num_beams,
        #     length_penalty=length_penalty,
        #     image_start_index_list=image_start_index_list,
        #     image_nums=image_nums,
        #     bad_words_ids=bad_words_ids,
        # )
    return outputs
def generate(
    idx,
    image,
    text,
    image_processor,
    tokenizer,
    flamingo,
    vis_embed_size=256,
    rank=0,
    world_size=1,
):
    if image is None:
        raise gr.Error("Please upload an image.")
    flamingo.eval()
    loc_token_ids = []
    for i in range(1000):
        loc_token_ids.append(int(tokenizer(f"", add_special_tokens=False)["input_ids"][-1]))
    media_token_id = tokenizer("<|#image#|>", add_special_tokens=False)["input_ids"][-1]
    endofmedia_token_id = tokenizer("<|#endofimage#|>", add_special_tokens=False)["input_ids"][-1]
    pad_token_id = tokenizer(tokenizer.pad_token, add_special_tokens=False)["input_ids"][-1]
    bos_token_id = tokenizer(tokenizer.bos_token, add_special_tokens=False)["input_ids"][-1]
    prebox_token_id = tokenizer("<|#prebox#|>", add_special_tokens=False)["input_ids"][-1]
    image_ori = image
    image = image.convert("RGB")
    width = image.width
    height = image.height
    image = image.resize((224, 224))
    batch_images = image_processor(image).unsqueeze(0).unsqueeze(1).unsqueeze(0)
    if idx == 1:
        prompt = [f"{tokenizer.bos_token}<|#image#|>{tokenizer.pad_token * vis_embed_size}<|#endofimage#|><|#object#|> {text.rstrip('.').strip()}<|#endofobject#|><|#visual#|>"]
        bad_words_ids = None
        max_generation_length = 5
    else:
        prompt = [f"<|#image#|>{tokenizer.pad_token * vis_embed_size}<|#endofimage#|>{text.rstrip('.')}"]
        bad_words_ids = loc_word_ids
        max_generation_length = 300
    encodings = tokenizer(
        prompt,
        padding="longest",
        truncation=True,
        return_tensors="pt",
        max_length=2000,
    )
    input_ids = encodings["input_ids"]
    attention_mask = encodings["attention_mask"]
    image_start_index_list = ((input_ids == media_token_id).nonzero(as_tuple=True)[-1] + 1).tolist()
    image_start_index_list = [[x] for x in image_start_index_list]
    image_nums = [1] * len(input_ids)
    outputs = get_outputs(
        model=flamingo,
        batch_images=batch_images,
        attention_mask=attention_mask,
        max_generation_length=max_generation_length,
        min_generation_length=4,
        num_beams=1,
        length_penalty=1.0,
        input_ids=input_ids,
        bad_words_ids=bad_words_ids,
        image_start_index_list=image_start_index_list,
        image_nums=image_nums,
    )
    boxes = outputs["boxes"]
    scores = outputs["scores"]
    if len(scores) > 0:
        box = boxes[scores.argmax()]/224
    print(f"{box}")
    
    if len(boxes)>0:
        open_cv_image = np.array(image_ori)
        # Convert RGB to BGR
        open_cv_image = open_cv_image[:, :, ::-1].copy()
        box = box*[width,height,width,height]
        # for box in boxes:
        open_cv_image = cv2.rectangle(open_cv_image, box[:2].astype(int), box[2:].astype(int), (255, 0, 0), 2)
        out_image = Image.fromarray(cv2.cvtColor(open_cv_image, cv2.COLOR_BGR2RGB))
        return f"Output:{box}", out_image
    else:
        gen_text = tokenizer.batch_decode(outputs)
        return (f"{gen_text}")
def preprocess_conv(data):
    conversation = ""
    BEGIN_SIGNAL = "### "
    END_SIGNAL = "\n"
    for idx, d in enumerate(data):
        from_str = d["from"]
        if from_str.lower() == "human":
            from_str = "Human"
        elif from_str.lower() == "gpt":
            from_str = "Assistant"
        else:
            from_str = 'unknown'
        conversation += (BEGIN_SIGNAL + from_str + ": " + d["value"] + END_SIGNAL)
    return conversation
class Chat:
    def __init__(self, model, vis_processor, tokenizer, vis_embed_size ):
        self.device = device
        self.model = model
        self.vis_processor = vis_processor
        self.tokenizer = tokenizer
        self.vis_embed_size = vis_embed_size
        self.conv = []
        # stop_words_ids = [torch.tensor([835]).to(self.device),
        #                   torch.tensor([2277, 29937]).to(self.device)]  # '###' can be encoded in two different ways.
        # self.stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)])
    def ask(self, text, conv):
        conv.append(({
            "from": "human",
            "value": text,
        }))
        # if len(conv.messages) > 0 and conv.messages[-1][0] == conv.roles[0] \
        #         and conv.messages[-1][1][-6:] == '':  # last message is image.
        #     conv.messages[-1][1] = ' '.join([conv.messages[-1][1], text])
        # else:
        #     conv.append_message(conv.roles[0], text)
    def answer(self, conv, img_list, max_new_tokens=200, num_beams=5, min_length=1, top_p=0.9,
               repetition_penalty=1.0, length_penalty=1, temperature=1, max_length=2000):
        # conv.append_message(conv.roles[1], None)
        # embs = self.get_context_emb(conv, img_list)
        # 
        # # current_max_len = embs.shape[1] + max_new_tokens + 100
        # # begin_idx = max(0, current_max_len - max_length)
        # # embs = embs[:, begin_idx:]
        # outputs = self.model.llama_model.generate(
        #     inputs_embeds=embs,
        #     max_new_tokens=max_new_tokens,
        #     stopping_criteria=self.stopping_criteria,
        #     num_beams=num_beams,
        #     min_length=min_length,
        #     top_p=top_p,
        #     repetition_penalty=repetition_penalty,
        #     length_penalty=length_penalty,
        #     temperature=temperature,
        # )
        # output_token = outputs[0]
        # if output_token[0] == 0:
        #     output_token = output_token[1:]
        # output_text = self.model.llama_tokenizer.decode(output_token, add_special_tokens=False)
        # output_text = output_text.split('###')[0]  # remove the stop sign '###'
        # output_text = output_text.split('Assistant:')[-1].strip()
        # conv.messages[-1][1] = output_text
        media_token_id = self.tokenizer("<|#image#|>", add_special_tokens=False)["input_ids"][-1]
        box_token_id = self.tokenizer("<|#box#|>", add_special_tokens=False)["input_ids"][-1]
        endofobject_token_id = self.tokenizer("<|#endofobject#|>", add_special_tokens=False)["input_ids"][-1]
        endofattr_token_id = self.tokenizer("<|#endofattr#|>", add_special_tokens=False)["input_ids"][-1]
        endofmedia_token_id = self.tokenizer("<|#endofimage#|>", add_special_tokens=False)["input_ids"][-1]
        visual_token_id = self.tokenizer("<|#visual#|>", add_special_tokens=False)["input_ids"][-1]
        previsual_token_id = self.tokenizer("<|#previsual#|>", add_special_tokens=False)["input_ids"][-1]
        prebox_token_id = self.tokenizer("<|#prebox#|>", add_special_tokens=False)["input_ids"][-1]
        size = self.vis_processor.size["shortest_edge"]
        model.eval()
        # "/gpfs/u/home/LMCG/LMCGljnn/scratch-shared/cdl/tmp_img/chat_vis/chat19.png"
        image_path = input("Please enter the image path: ")
        image = Image.open(image_path).convert("RGB")
        image = image.resize((size, size))
        print(f"image size: {image.size}")
        batch_images = preprocess_image(img_list[0], self.vis_processor).unsqueeze(0).unsqueeze(1).unsqueeze(0).to("cuda")
        # conversation = []
        human_sentence = None
        conv.append({
                    "from": "gpt",
                    "value": "",
                })
        # while True:
        #     human_sentence = input("### Human: ")
        #     if human_sentence == "#end#":
        #         break
        #     conversation.append({
        #         "from": "human",
        #         "value": human_sentence,
        #     })
        #     conversation.append({
        #         "from": "gpt",
        #         "value": "",
        #     })
        text = preprocess_conv(conv).strip()
        caption = f"<|#image#|>{tokenizer.pad_token * self.vis_embed_size}<|#endofimage#|>{text}"
        encodings = tokenizer(
            caption,
            padding="longest",
            truncation=True,
            return_tensors="pt",
            max_length=2000,
        )
        input_ids = encodings["input_ids"].to("cuda")
        attention_mask = encodings["attention_mask"].to("cuda")
        image_start_index_list = ((input_ids == media_token_id).nonzero(as_tuple=True)[-1] + 1).tolist()
        image_start_index_list = [[x] for x in image_start_index_list]
        image_nums = [1] * len(input_ids)
        with torch.no_grad() and torch.cuda.amp.autocast(dtype=torch.float16):
            outputs = model.generate(
                batch_images,
                input_ids,
                attention_mask=attention_mask,
                max_new_tokens=100,
                # min_new_tokens=8,
                num_beams=1,
                image_start_index_list=image_start_index_list,
                image_nums=image_nums,
            )
        output_token = outputs[0, input_ids.shape[1]:]
        output_text = tokenizer.decode(output_token, skip_special_tokens=True).strip()
        conv[-1]["value"] = output_text
        # conv.messages[-1][1] = output_text
        print(
            f"### Assistant: {tokenizer.decode(outputs[0, input_ids.shape[1]:], skip_special_tokens=True).strip()}")
            
        return output_text, output_token.cpu().numpy()
    def upload_img(self, image, conv, img_list):
        img_list.append(image)
        # if isinstance(image, str):  # is a image path
        #     raw_image = Image.open(image).convert('RGB')
        #     image = image.resize((224, 224))
        #     image = self.vis_processor(raw_image).unsqueeze(0).unsqueeze(1).unsqueeze(0)
        # elif isinstance(image, Image.Image):
        #     raw_image = image
        #     image = image.resize((224, 224))
        #     image = self.vis_processor(raw_image).unsqueeze(0).unsqueeze(1).unsqueeze(0)
        # elif isinstance(image, torch.Tensor):
        #     if len(image.shape) == 3:
        #         image = image.unsqueeze(0)
        #     # image = image.to(self.device)
        # 
        # # image_emb, _ = self.model.encode_img(image)
        # img_list.append(image_emb)
        conv.append_message(conv.roles[0], "![]() ")
        msg = "Received."
        # self.conv.append_message(self.conv.roles[1], msg)
        return msg
    def get_context_emb(self, conv, img_list):
        prompt = conv.get_prompt()
        prompt_segs = prompt.split('')
        assert len(prompt_segs) == len(img_list) + 1, "Unmatched numbers of image placeholders and images."
        seg_tokens = [
            self.model.llama_tokenizer(
                seg, return_tensors="pt", add_special_tokens=i == 0).to(self.device).input_ids
            # only add bos to the first seg
            for i, seg in enumerate(prompt_segs)
        ]
        seg_embs = [self.model.llama_model.model.embed_tokens(seg_t) for seg_t in seg_tokens]
        mixed_embs = [emb for pair in zip(seg_embs[:-1], img_list) for emb in pair] + [seg_embs[-1]]
        mixed_embs = torch.cat(mixed_embs, dim=1)
        return mixed_embs
    
def evaluate_exp(
    model,
    tokenizer,
    image_processor,
    vis_embed_size=None,
    rank=0,
    world_size=1,
    id=0,
    add_visual=True,
):
    media_token_id = tokenizer("<|#image#|>", add_special_tokens=False)["input_ids"][-1]
    box_token_id = tokenizer("<|#box#|>", add_special_tokens=False)["input_ids"][-1]
    endofobject_token_id = tokenizer("<|#endofobject#|>", add_special_tokens=False)["input_ids"][-1]
    endofattr_token_id = tokenizer("<|#endofattr#|>", add_special_tokens=False)["input_ids"][-1]
    endofmedia_token_id = tokenizer("<|#endofimage#|>", add_special_tokens=False)["input_ids"][-1]
    visual_token_id = tokenizer("<|#visual#|>", add_special_tokens=False)["input_ids"][-1]
    previsual_token_id = tokenizer("<|#previsual#|>", add_special_tokens=False)["input_ids"][-1]
    prebox_token_id = tokenizer("<|#prebox#|>", add_special_tokens=False)["input_ids"][-1]
    size = image_processor.size["shortest_edge"]
    model.eval()
    # "/gpfs/u/home/LMCG/LMCGljnn/scratch-shared/cdl/tmp_img/chat_vis/chat19.png"
    image_path = input("Please enter the image path: ")
    image = Image.open(image_path).convert("RGB")
    image = image.resize((size, size))
    print(f"image size: {image.size}")
    batch_images = preprocess_image(image, image_processor).unsqueeze(0).unsqueeze(1).unsqueeze(0).to("cuda")
    conversation = []
    human_sentence = None
    while True:
        human_sentence = input("### Human: ")
        if human_sentence == "#end#":
            break
        conversation.append({
            "from": "human",
            "value": human_sentence,
        })
        conversation.append({
            "from": "gpt",
            "value": "",
        })
        text = preprocess_conv(conversation).strip()
        caption = f"<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|>{text}"
        encodings = tokenizer(
            caption,
            padding="longest",
            truncation=True,
            return_tensors="pt",
            max_length=2000,
        )
        input_ids = encodings["input_ids"].to("cuda")
        attention_mask = encodings["attention_mask"].to("cuda")
        image_start_index_list = ((input_ids == media_token_id).nonzero(as_tuple=True)[-1] + 1).tolist()
        image_start_index_list = [[x] for x in image_start_index_list]
        image_nums = [1] * len(input_ids)
        with torch.no_grad() and torch.cuda.amp.autocast(dtype=torch.float16):
            outputs = model.generate(
                batch_images,
                input_ids,
                attention_mask=attention_mask,
                max_new_tokens=100,
                # min_new_tokens=8,
                num_beams=1,
                image_start_index_list=image_start_index_list,
                image_nums=image_nums,
            )
        print(f"### Assistant: {tokenizer.decode(outputs[0, input_ids.shape[1]:], skip_special_tokens=True).strip()}")
")
        msg = "Received."
        # self.conv.append_message(self.conv.roles[1], msg)
        return msg
    def get_context_emb(self, conv, img_list):
        prompt = conv.get_prompt()
        prompt_segs = prompt.split('')
        assert len(prompt_segs) == len(img_list) + 1, "Unmatched numbers of image placeholders and images."
        seg_tokens = [
            self.model.llama_tokenizer(
                seg, return_tensors="pt", add_special_tokens=i == 0).to(self.device).input_ids
            # only add bos to the first seg
            for i, seg in enumerate(prompt_segs)
        ]
        seg_embs = [self.model.llama_model.model.embed_tokens(seg_t) for seg_t in seg_tokens]
        mixed_embs = [emb for pair in zip(seg_embs[:-1], img_list) for emb in pair] + [seg_embs[-1]]
        mixed_embs = torch.cat(mixed_embs, dim=1)
        return mixed_embs
    
def evaluate_exp(
    model,
    tokenizer,
    image_processor,
    vis_embed_size=None,
    rank=0,
    world_size=1,
    id=0,
    add_visual=True,
):
    media_token_id = tokenizer("<|#image#|>", add_special_tokens=False)["input_ids"][-1]
    box_token_id = tokenizer("<|#box#|>", add_special_tokens=False)["input_ids"][-1]
    endofobject_token_id = tokenizer("<|#endofobject#|>", add_special_tokens=False)["input_ids"][-1]
    endofattr_token_id = tokenizer("<|#endofattr#|>", add_special_tokens=False)["input_ids"][-1]
    endofmedia_token_id = tokenizer("<|#endofimage#|>", add_special_tokens=False)["input_ids"][-1]
    visual_token_id = tokenizer("<|#visual#|>", add_special_tokens=False)["input_ids"][-1]
    previsual_token_id = tokenizer("<|#previsual#|>", add_special_tokens=False)["input_ids"][-1]
    prebox_token_id = tokenizer("<|#prebox#|>", add_special_tokens=False)["input_ids"][-1]
    size = image_processor.size["shortest_edge"]
    model.eval()
    # "/gpfs/u/home/LMCG/LMCGljnn/scratch-shared/cdl/tmp_img/chat_vis/chat19.png"
    image_path = input("Please enter the image path: ")
    image = Image.open(image_path).convert("RGB")
    image = image.resize((size, size))
    print(f"image size: {image.size}")
    batch_images = preprocess_image(image, image_processor).unsqueeze(0).unsqueeze(1).unsqueeze(0).to("cuda")
    conversation = []
    human_sentence = None
    while True:
        human_sentence = input("### Human: ")
        if human_sentence == "#end#":
            break
        conversation.append({
            "from": "human",
            "value": human_sentence,
        })
        conversation.append({
            "from": "gpt",
            "value": "",
        })
        text = preprocess_conv(conversation).strip()
        caption = f"<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|>{text}"
        encodings = tokenizer(
            caption,
            padding="longest",
            truncation=True,
            return_tensors="pt",
            max_length=2000,
        )
        input_ids = encodings["input_ids"].to("cuda")
        attention_mask = encodings["attention_mask"].to("cuda")
        image_start_index_list = ((input_ids == media_token_id).nonzero(as_tuple=True)[-1] + 1).tolist()
        image_start_index_list = [[x] for x in image_start_index_list]
        image_nums = [1] * len(input_ids)
        with torch.no_grad() and torch.cuda.amp.autocast(dtype=torch.float16):
            outputs = model.generate(
                batch_images,
                input_ids,
                attention_mask=attention_mask,
                max_new_tokens=100,
                # min_new_tokens=8,
                num_beams=1,
                image_start_index_list=image_start_index_list,
                image_nums=image_nums,
            )
        print(f"### Assistant: {tokenizer.decode(outputs[0, input_ids.shape[1]:], skip_special_tokens=True).strip()}")