prithivMLmods commited on
Commit
69471fa
Β·
verified Β·
1 Parent(s): bef0395

Delete app.py

Browse files
Files changed (1) hide show
  1. app.py +0 -236
app.py DELETED
@@ -1,236 +0,0 @@
1
- import os
2
- from collections.abc import Iterator
3
- from threading import Thread
4
- import gradio as gr
5
- import spaces
6
- import torch
7
- from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
8
- from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
9
- from PIL import Image
10
- import uuid
11
- import io
12
- import re
13
- import time
14
-
15
- # Text-only model setup
16
- DESCRIPTION = """
17
- # GWQ PREV
18
- """
19
-
20
- MAX_MAX_NEW_TOKENS = 2048
21
- DEFAULT_MAX_NEW_TOKENS = 1024
22
- MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
23
-
24
- device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
25
-
26
- model_id = "prithivMLmods/GWQ2b"
27
- tokenizer = AutoTokenizer.from_pretrained(model_id)
28
- model = AutoModelForCausalLM.from_pretrained(
29
- model_id,
30
- device_map="auto",
31
- torch_dtype=torch.bfloat16,
32
- )
33
- model.config.sliding_window = 4096
34
- model.eval()
35
-
36
- # Multimodal model setup
37
- MULTIMODAL_MODEL_ID = "Qwen/Qwen2-VL-2B-Instruct"
38
- multimodal_model = Qwen2VLForConditionalGeneration.from_pretrained(
39
- MULTIMODAL_MODEL_ID,
40
- trust_remote_code=True,
41
- torch_dtype=torch.float16
42
- ).to("cuda").eval()
43
- multimodal_processor = AutoProcessor.from_pretrained(MULTIMODAL_MODEL_ID, trust_remote_code=True)
44
-
45
- image_extensions = Image.registered_extensions()
46
-
47
- def identify_and_save_image(blob_path):
48
- """Identifies if the blob is an image and saves it accordingly."""
49
- try:
50
- with open(blob_path, 'rb') as file:
51
- blob_content = file.read()
52
-
53
- # Try to identify if it's an image
54
- try:
55
- Image.open(io.BytesIO(blob_content)).verify() # Check if it's a valid image
56
- extension = ".png" # Default to PNG for saving
57
- media_type = "image"
58
- except (IOError, SyntaxError):
59
- raise ValueError("Unsupported media type. Please upload an image.")
60
-
61
- # Create a unique filename
62
- filename = f"temp_{uuid.uuid4()}_media{extension}"
63
- with open(filename, "wb") as f:
64
- f.write(blob_content)
65
-
66
- return filename, media_type
67
-
68
- except FileNotFoundError:
69
- raise ValueError(f"The file {blob_path} was not found.")
70
- except Exception as e:
71
- raise ValueError(f"An error occurred while processing the file: {e}")
72
-
73
- @spaces.GPU()
74
- def generate(
75
- message: str,
76
- chat_history: list[dict],
77
- max_new_tokens: int = 1024,
78
- temperature: float = 0.6,
79
- top_p: float = 0.9,
80
- top_k: int = 50,
81
- repetition_penalty: float = 1.2,
82
- files: list = None,
83
- ) -> Iterator[str]:
84
- if files and len(files) > 0:
85
- # Multimodal input (image only)
86
- media_path = files[0]
87
- if media_path.endswith(tuple([i for i, f in image_extensions.items()])):
88
- media_type = "image"
89
- else:
90
- try:
91
- media_path, media_type = identify_and_save_image(media_path)
92
- except Exception as e:
93
- raise ValueError("Unsupported media type. Please upload an image.")
94
-
95
- # Load the image
96
- image = Image.open(media_path).convert("RGB")
97
-
98
- # Prepare the input for the multimodal model
99
- messages = [
100
- {
101
- "role": "user",
102
- "content": [
103
- {"image": media_path}, # Pass the image path
104
- {"text": message}, # Pass the text prompt
105
- ],
106
- }
107
- ]
108
-
109
- # Process the input
110
- inputs = multimodal_processor(
111
- messages,
112
- return_tensors="pt",
113
- padding=True,
114
- ).to("cuda")
115
-
116
- # Stream the output
117
- streamer = TextIteratorStreamer(
118
- multimodal_processor, skip_prompt=True, skip_special_tokens=True
119
- )
120
- generation_kwargs = dict(
121
- inputs,
122
- streamer=streamer,
123
- max_new_tokens=max_new_tokens,
124
- do_sample=True,
125
- temperature=temperature,
126
- top_p=top_p,
127
- top_k=top_k,
128
- repetition_penalty=repetition_penalty,
129
- )
130
-
131
- # Start the generation in a separate thread
132
- thread = Thread(target=multimodal_model.generate, kwargs=generation_kwargs)
133
- thread.start()
134
-
135
- # Stream the output token by token
136
- buffer = ""
137
- for new_text in streamer:
138
- buffer += new_text
139
- yield buffer
140
- else:
141
- # Text-only input
142
- # Ensure the chat history alternates between user and assistant roles
143
- conversation = []
144
- for i, entry in enumerate(chat_history):
145
- if i % 2 == 0:
146
- conversation.append({"role": "user", "content": entry["content"]})
147
- else:
148
- conversation.append({"role": "assistant", "content": entry["content"]})
149
- conversation.append({"role": "user", "content": message})
150
-
151
- # Apply the chat template
152
- input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
153
- if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
154
- input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
155
- gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
156
- input_ids = input_ids.to(model.device)
157
-
158
- # Stream the output
159
- streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
160
- generate_kwargs = dict(
161
- {"input_ids": input_ids},
162
- streamer=streamer,
163
- max_new_tokens=max_new_tokens,
164
- do_sample=True,
165
- top_p=top_p,
166
- top_k=top_k,
167
- temperature=temperature,
168
- num_beams=1,
169
- repetition_penalty=repetition_penalty,
170
- )
171
- t = Thread(target=model.generate, kwargs=generate_kwargs)
172
- t.start()
173
-
174
- outputs = []
175
- for text in streamer:
176
- outputs.append(text)
177
- yield "".join(outputs)
178
-
179
- demo = gr.ChatInterface(
180
- fn=generate,
181
- additional_inputs=[
182
- gr.Slider(
183
- label="Max new tokens",
184
- minimum=1,
185
- maximum=MAX_MAX_NEW_TOKENS,
186
- step=1,
187
- value=DEFAULT_MAX_NEW_TOKENS,
188
- ),
189
- gr.Slider(
190
- label="Temperature",
191
- minimum=0.1,
192
- maximum=4.0,
193
- step=0.1,
194
- value=0.6,
195
- ),
196
- gr.Slider(
197
- label="Top-p (nucleus sampling)",
198
- minimum=0.05,
199
- maximum=1.0,
200
- step=0.05,
201
- value=0.9,
202
- ),
203
- gr.Slider(
204
- label="Top-k",
205
- minimum=1,
206
- maximum=1000,
207
- step=1,
208
- value=50,
209
- ),
210
- gr.Slider(
211
- label="Repetition penalty",
212
- minimum=1.0,
213
- maximum=2.0,
214
- step=0.05,
215
- value=1.2,
216
- ),
217
- ],
218
- stop_btn=None,
219
- examples=[
220
- ["Hello there! How are you doing?"],
221
- ["Can you explain briefly to me what is the Python programming language?"],
222
- ["Explain the plot of Cinderella in a sentence."],
223
- ["How many hours does it take a man to eat a Helicopter?"],
224
- ["Write a 100-word article on 'Benefits of Open-Source in AI research'"],
225
- ],
226
- cache_examples=False,
227
- type="messages",
228
- description=DESCRIPTION,
229
- css_paths="style.css",
230
- fill_height=True,
231
- multimodal=True,
232
- textbox=gr.MultimodalTextbox(),
233
- )
234
-
235
- if __name__ == "__main__":
236
- demo.queue(max_size=20).launch(share=True) # Set share=True for a public link