File size: 4,541 Bytes
e730afb
 
 
 
 
 
 
4fe7de8
 
39eaec8
 
 
 
6485df8
87fa838
39eaec8
 
0afdc15
 
 
e730afb
 
 
 
 
 
 
4fe7de8
e730afb
1136305
49e0d0c
4fe7de8
e730afb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f141d7a
21e5df5
 
 
 
 
 
 
 
 
 
 
 
f141d7a
 
21e5df5
 
 
 
f141d7a
21e5df5
 
 
 
f141d7a
6f6099a
e730afb
 
 
6f6099a
e730afb
 
 
 
 
 
6f6099a
e730afb
 
 
 
 
 
 
 
 
 
 
c6820b5
 
e730afb
 
6f6099a
f141d7a
 
 
 
e730afb
 
 
 
 
f141d7a
e730afb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49e0d0c
e730afb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
"""
File: vlm.py
Description: Vision language model utility functions.
Author: Didier Guillevic
Date: 2025-05-08
"""

from transformers import AutoProcessor
from transformers import Mistral3ForConditionalGeneration
from transformers import TextIteratorStreamer
from threading import Thread
import re
import time
import torch
import base64
import spaces

import logging
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)

#
# Load the model: OPEA/Mistral-Small-3.1-24B-Instruct-2503-int4-AutoRound-awq-sym
#
model_id = "OPEA/Mistral-Small-3.1-24B-Instruct-2503-int4-AutoRound-awq-sym"
device = 'cuda' if torch.cuda.is_available() else 'cpu'
processor = AutoProcessor.from_pretrained(model_id)
model = Mistral3ForConditionalGeneration.from_pretrained(
    model_id, 
    #_attn_implementation="flash_attention_2",
    torch_dtype=torch.float16
).eval().to(device)

#
# Encode images as base64
#
def encode_image(image_path):
    """Encode the image to base64."""
    try:
        with open(image_path, "rb") as image_file:
            return base64.b64encode(image_file.read()).decode('utf-8')
    except FileNotFoundError:
        print(f"Error: The file {image_path} was not found.")
        return None
    except Exception as e:  # Added general exception handling
        print(f"Error: {e}")
        return None


#
# Build messages
#
def normalize_message_content(msg: dict) -> dict:
    content = msg.get("content")

    # Case 1: Already in expected format
    if isinstance(content, list) and all(isinstance(item, dict) for item in content):
        return {"role": msg["role"], "content": content}

    # Case 2: String (assume text)
    if isinstance(content, str):
        return {"role": msg["role"], "content": [{"type": "text", "text": content}]}

    # Case 3: Tuple with image path(s)
    if isinstance(content, tuple):
        return {
            "role": msg["role"],
            "content": [
                {"type": "image", "image": encode_image(path)}  # your `encode_image()` function
                for path in content if isinstance(path, str)
            ]
        }

    logger.warning(f"Unexpected content format in message: {msg}")
    return {"role": msg["role"], "content": [{"type": "text", "text": str(content)}]}


def build_messages(message: dict, history: list[dict]):
    """Build messages given message & history from a **multimodal** chat interface.
    Args:
        message: dictionary with keys: 'text', 'files'
        history: list of dictionaries
    
    Returns:
        list of messages (to be sent to the model)
    """
    logger.info(f"{message=}")
    logger.info(f"{history=}")
    
    # Get the user's text and list of images
    user_text = message.get("text", "")
    user_images = message.get("files", [])  # List of images
    
    # Build the user message's content from the provided message
    user_content = []
    if user_text:
        user_content.append({"type": "text", "text": user_text})
    for image in user_images:
        user_content.append(
            {
                "type": "image",
                "image": f"data:image/jpeg;base64,{encode_image(image)}"
            }
        )

    # Normalize existing history content
    messages = [normalize_message_content(msg) for msg in history]
    
    # Append new user message
    messages.append({'role': 'user', 'content': user_content})
    logger.info(f"{messages=}")

    return messages


#
# stream response
#
@spaces.GPU
@torch.inference_mode()
def stream_response(
        messages: list[dict],
        max_new_tokens: int=1_024,
        temperature: float=0.15
    ):
    """Stream the model's response to the chat interface.
    
    Args:
        messages: list of messages to send to the model
    """
    # Generate model's response
    inputs = processor.apply_chat_template(
        messages,
        add_generation_prompt=True,
        tokenize=True,
        return_dict=True,
        return_tensors="pt",
    ).to(model.device, dtype=torch.float16)
    
    # Generate
    streamer = TextIteratorStreamer(
        processor, skip_prompt=True, skip_special_tokens=True)
    generation_args = dict(
        inputs,
        streamer=streamer,
        max_new_tokens=max_new_tokens,
        temperature=temperature,
        top_p=0.9,
        do_sample=True
    )

    thread = Thread(target=model.generate, kwargs=generation_args)
    thread.start()

    partial_message = ""
    for new_text in streamer:
        partial_message += new_text
        yield partial_message