Didier commited on
Commit
e730afb
·
verified ·
1 Parent(s): 6dc7ece

Create vlm.py

Browse files
Files changed (1) hide show
  1. vlm.py +135 -0
vlm.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ File: vlm.py
3
+ Description: Vision language model utility functions.
4
+ Author: Didier Guillevic
5
+ Date: 2025-05-08
6
+ """
7
+
8
+ from transformers import AutoProcessor, AutoModelForImageTextToText
9
+
10
+ #
11
+ # Load the model: OPEA/Mistral-Small-3.1-24B-Instruct-2503-int4-AutoRound-awq-sym
12
+ #
13
+ model_id = "OPEA/Mistral-Small-3.1-24B-Instruct-2503-int4-AutoRound-awq-sym"
14
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
15
+ processor = AutoProcessor.from_pretrained(model_id)
16
+ model = AutoModelForImageTextToText.from_pretrained(
17
+ model_id,
18
+ _attn_implementation="flash_attention_2",
19
+ torch_dtype=torch.bfloat16
20
+ ).to(device)
21
+
22
+ #
23
+ # Encode images as base64
24
+ #
25
+ def encode_image(image_path):
26
+ """Encode the image to base64."""
27
+ try:
28
+ with open(image_path, "rb") as image_file:
29
+ return base64.b64encode(image_file.read()).decode('utf-8')
30
+ except FileNotFoundError:
31
+ print(f"Error: The file {image_path} was not found.")
32
+ return None
33
+ except Exception as e: # Added general exception handling
34
+ print(f"Error: {e}")
35
+ return None
36
+
37
+
38
+ #
39
+ # Build messages
40
+ #
41
+ def build_messages(message: dict, history: list[tuple]):
42
+ """Build messages given message & history from a **multimodal** chat interface.
43
+ Args:
44
+ message: dictionary with keys: 'text', 'files'
45
+ history: list of tuples with (message, response)
46
+
47
+ Returns:
48
+ list of messages (to be sent to the model)
49
+ """
50
+ logger.info(f"{message=}")
51
+ logger.info(f"{history=}")
52
+ # Get the user's text and list of images
53
+ user_text = message.get("text", "")
54
+ user_images = message.get("files", []) # List of images
55
+
56
+ # Build the message list including history
57
+ messages = []
58
+ combined_user_input = [] # Combine images and text if found in same turn.
59
+ for user_turn, bot_turn in history:
60
+ if isinstance(user_turn, tuple): # Image input
61
+ image_content = [
62
+ {
63
+ "type": "image_url",
64
+ "image_url": f"data:image/jpeg;base64,{encode_image(image)}"
65
+ } for image in user_turn
66
+ ]
67
+ combined_user_input.extend(image_content)
68
+ elif isinstance(user_turn, str): # Text input
69
+ combined_user_input.append({"type": "text", "text": user_turn})
70
+ if combined_user_input and bot_turn:
71
+ messages.append({'role': 'user', 'content': combined_user_input})
72
+ messages.append({'role': 'assistant', 'content': [{"type": "text", "text": bot_turn}]})
73
+ combined_user_input = [] #reset the combined user input.
74
+
75
+ # Build the user message's content from the provided message
76
+ user_content = []
77
+ if user_text:
78
+ user_content.append({"type": "text", "text": user_text})
79
+ for image in user_images:
80
+ user_content.append(
81
+ {
82
+ "type": "image_url",
83
+ "image_url": f"data:image/jpeg;base64,{encode_image(image)}"
84
+ }
85
+ )
86
+
87
+ messages.append({'role': 'user', 'content': user_content})
88
+ logger.info(f"{messages=}")
89
+
90
+ return messages
91
+
92
+ #
93
+ # stream response
94
+ #
95
+ @spaces.GPU
96
+ @torch.inference_mode()
97
+ def stream_response(
98
+ messages: list[dict],
99
+ max_new_tokens: int=1_024,
100
+ temperature: float=0.15
101
+ ):
102
+ """Stream the model's response to the chat interface.
103
+
104
+ Args:
105
+ messages: list of messages to send to the model
106
+ """
107
+ # Generate model's response
108
+ inputs = processor.apply_chat_template(
109
+ messages,
110
+ add_generation_prompt=True,
111
+ tokenize=True,
112
+ return_dict=True,
113
+ return_tensors="pt",
114
+ ).to(model.device, dtype=torch.bfloat16)
115
+
116
+ # Generate
117
+ streamer = TextIteratorStreamer(
118
+ processor, skip_prompt=True, skip_special_tokens=True)
119
+ generation_args = dict(
120
+ inputs,
121
+ streamer=streamer,
122
+ max_new_tokens=max_new_tokens,
123
+ temperature=temperature,
124
+ top_p=0.9,
125
+ do_sample=True
126
+ )
127
+
128
+ thread = Thread(target=model.generate, kwargs=generation_args)
129
+ thread.start()
130
+
131
+ partial_message = ""
132
+ for new_text in streamer:
133
+ partial_message += new_text
134
+ yield partial_message
135
+