File size: 4,756 Bytes
4bebcaf
 
 
 
 
 
 
41f2871
 
4bebcaf
 
 
f136d79
a6bc163
 
 
 
 
41f2871
 
2c2417f
e487146
a6bc163
 
4bebcaf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import os
import logging
from datetime import datetime

import gradio as gr
from PIL import Image

# from lmdeploy import pipeline, TurbomindEngineConfig, GenerationConfig, ChatTemplateConfig
from lmdeploy import pipeline, GenerationConfig, ChatTemplateConfig
from lmdeploy.vl import load_image

class ConversationalAgent:
    def __init__(self, model_path, outputs_dir, device='cpu'):
        # 传入 device 参数,并设置 pipeline 时指定设备
        self.device = device
        self.pipe = pipeline(
            model_path,
            chat_template_config=ChatTemplateConfig(model_name='internvl2-internlm2'),
            # backend_config=TurbomindEngineConfig(session_len=8192),
            
            # backend_config=PytorchEngineConfig(max_length=8192),
            device=self.device
        )
        
        self.uploaded_images_storage = os.path.join(outputs_dir, "uploaded")
        self.uploaded_images_storage = os.path.abspath(self.uploaded_images_storage)
        os.makedirs(self.uploaded_images_storage, exist_ok=True)
        self.sess = None
        
    def start_chat(self, chat_state):
        self.sess = None
        self.context = ""
        self.current_image_id = -1
        self.image_list = []
        self.pixel_values_list = []
        self.seen_image_idx = []
        logging.info("=" * 30 + "Start Chat" + "=" * 30)
        
        return (
            #gr.update(interactive=False),  # [image] Image
            gr.update(interactive=True, placeholder='input the text.'),  # [input_text] Textbox
            gr.update(interactive=False),  # [start_btn] Button
            gr.update(interactive=True),  # [clear_btn] Button
            gr.update(interactive=True),  # [image] Image
            gr.update(interactive=True),  # [upload_btn] Button
            chat_state  # [chat_state] State
        )
        
    def restart_chat(self, chat_state):
        self.sess = None
        self.context = ""
        self.current_image_id = -1
        self.image_list = []
        self.pixel_values_list = []
        self.seen_image_idx = []
        
        logging.info("=" * 30 + "End Chat" + "=" * 30)
        
        return (
            None,  # [chatbot] Chatbot
            #gr.update(value=None, interactive=True),  # [image] Image
            gr.update(interactive=False, placeholder="Please click the <Start Chat> button to start chat!"),  # [input_text] Textbox
            gr.update(interactive=True),  # [start] Button
            gr.update(interactive=False),  # [clear] Button
            gr.update(value=None, interactive=False),  # [image] Image
            gr.update(interactive=False),  # [upload_btn] Button
            chat_state  # [chat_state] State
        )
        
    def upload_image(self, image: Image.Image, chat_history: gr.Chatbot, chat_state: gr.State):
        logging.info(f"type(image): {type(image)}")
        
        self.image_list.append(image)        
        save_image_path = os.path.join(self.uploaded_images_storage, "{}.jpg".format(len(os.listdir(self.uploaded_images_storage))))
        image.save(save_image_path)
        logging.info(f"image save path: {save_image_path}")
        chat_history.append((gr.HTML(f'<img src="./file={save_image_path}" style="width: 200px; height: auto; display: inline-block;">'), "Received."))
        
        return None, chat_history, chat_state
    
    def respond(
        self,
        message,
        image,
        chat_history: gr.Chatbot,
        top_p,
        temperature,
        chat_state,
    ):
        current_time = datetime.now().strftime("%b%d-%H:%M:%S")
        logging.info(f"Time: {current_time}")
        logging.info(f"User: {message}")
        gen_config = GenerationConfig(top_p=top_p, temperature=temperature)
        chat_input = message
        if image is not None:
            save_image_path = os.path.join(self.uploaded_images_storage, "{}.jpg".format(len(os.listdir(self.uploaded_images_storage))))
            image.save(save_image_path)
            logging.info(f"image save path: {save_image_path}")
            chat_input = (message, image)
        if self.sess is None:
            self.sess = self.pipe.chat(chat_input, gen_config=gen_config)
        else:
            self.sess = self.pipe.chat(chat_input, session=self.sess, gen_config=gen_config)
        response = self.sess.response.text
        if image is not None:
            chat_history.append((gr.HTML(f'{message}\n\n<img src="./file={save_image_path}" style="width: 200px; height: auto; display: inline-block;">'), response))
        else:
            chat_history.append((message, response))
        
        logging.info(f"generated text = \n{response}")        
        
        return "", None, chat_history, chat_state