lukecq commited on
Commit
73933cb
·
verified ·
1 Parent(s): ce92757

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +176 -128
  2. requirements.txt +178 -31
app.py CHANGED
@@ -1,133 +1,181 @@
1
- # Copyright: DAMO Academy, Alibaba Group
2
- # By Xuan Phi Nguyen at DAMO Academy, Alibaba Group
3
-
4
- # Description:
5
- """
6
- Demo script to launch Language chat model
7
- """
8
-
9
- import spaces
10
- import os
11
- from gradio.themes import ThemeClass as Theme
12
- import numpy as np
13
- import argparse
14
- # import torch
15
  import gradio as gr
16
- from typing import Any, Iterator
17
- from typing import Iterator, List, Optional, Tuple
18
- import filelock
19
- import glob
20
- import json
21
  import time
22
- from gradio.routes import Request
23
- from gradio.utils import SyncToAsyncIterator, async_iteration
24
- from gradio.helpers import special_args
25
- import anyio
26
- from typing import AsyncGenerator, Callable, Literal, Union, cast
27
-
28
- from gradio_client.documentation import document, set_documentation_group
29
-
30
- from typing import List, Optional, Union, Dict, Tuple
31
- from tqdm.auto import tqdm
32
- from huggingface_hub import snapshot_download
33
- from langchain_community.embeddings import HuggingFaceEmbeddings, HuggingFaceBgeEmbeddings
34
- from gradio.components import Button, Component
35
- from gradio.events import Dependency, EventListenerMethod
36
-
37
- from multipurpose_chatbot.demos.base_demo import CustomTabbedInterface
38
-
39
- from multipurpose_chatbot.configs import (
40
- MODEL_TITLE,
41
- MODEL_DESC,
42
- MODEL_INFO,
43
- CITE_MARKDOWN,
44
- ALLOWED_PATHS,
45
- PROXY,
46
- PORT,
47
- MODEL_PATH,
48
- MODEL_NAME,
49
- BACKEND,
50
- DEMOS,
51
- DELETE_FOLDER,
52
- )
53
-
54
-
55
- demo = None
56
-
57
-
58
-
59
- if DELETE_FOLDER is not None and os.path.exists(DELETE_FOLDER):
60
- print(F'WARNING deleting folder: {DELETE_FOLDER}')
61
- import shutil
62
- print(f'DELETE ALL FILES IN {DELETE_FOLDER}')
63
- for filename in os.listdir(DELETE_FOLDER):
64
- file_path = os.path.join(DELETE_FOLDER, filename)
65
- try:
66
- if os.path.isfile(file_path) or os.path.islink(file_path):
67
- os.unlink(file_path)
68
- elif os.path.isdir(file_path):
69
- shutil.rmtree(file_path)
70
- print(f'deleted: {file_path}')
71
- except Exception as e:
72
- print('Failed to delete %s. Reason: %s' % (file_path, e))
73
-
74
-
75
- def launch_demo():
76
- global demo, MODEL_ENGINE
77
- model_desc = MODEL_DESC
78
- model_path = MODEL_PATH
79
-
80
- print(f'Begin importing models')
81
- from multipurpose_chatbot.demos import get_demo_class
82
-
83
- # demos = {
84
- # k: get_demo_class(k)().create_demo()
85
- # for k in demo_and_tab_names.keys()
86
- # }
87
- print(f'{DEMOS=}')
88
- demo_class_objects = {
89
- k: get_demo_class(k)()
90
- for k in DEMOS
91
- }
92
- demos = {
93
- k: get_demo_class(k)().create_demo()
94
- for k in DEMOS
95
- }
96
- demos_names = [x.tab_name for x in demo_class_objects.values()]
97
-
98
- descriptions = model_desc
99
- if MODEL_INFO is not None and MODEL_INFO != "":
100
- descriptions += (
101
- f"<br>" +
102
- MODEL_INFO.format(model_path=model_path)
103
- )
104
- if len(demos) == 1:
105
- demo = demos[DEMOS[0]]
106
  else:
107
- demo = CustomTabbedInterface(
108
- interface_list=list(demos.values()),
109
- tab_names=demos_names,
110
- title=f"{MODEL_TITLE}",
111
- description=descriptions,
112
- )
113
-
114
- demo.title = MODEL_NAME
 
 
 
 
 
 
 
 
 
 
 
 
 
 
115
 
116
- # with demo:
117
- # gr.Markdown(CITE_MARKDOWN)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
118
 
119
- demo.queue(api_open=False)
120
- return demo
121
-
122
-
123
-
124
- if __name__ == "__main__":
125
- demo = launch_demo()
126
- if PROXY is not None and PROXY != "":
127
- print(f'{PROXY=} {PORT=}')
128
- print(f"{ALLOWED_PATHS=}")
129
- demo.launch(server_port=PORT, root_path=PROXY, show_api=False, allowed_paths=ALLOWED_PATHS)
130
- else:
131
- demo.launch(server_port=PORT, show_api=False, allowed_paths=ALLOWED_PATHS)
132
-
133
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
 
 
 
 
 
2
  import time
3
+ from transformers import Qwen2AudioForConditionalGeneration, AutoProcessor
4
+ from io import BytesIO
5
+ from urllib.request import urlopen
6
+ import librosa
7
+ import os, json
8
+ from sys import argv
9
+ from vllm import LLM, SamplingParams
10
+
11
+ def load_model_processor(model_path):
12
+ processor = AutoProcessor.from_pretrained(model_path)
13
+ llm = LLM(
14
+ model=model_path, trust_remote_code=True, gpu_memory_utilization=0.8,
15
+ enforce_eager=True, device = "cuda",
16
+ limit_mm_per_prompt={"audio": 5},
17
+ )
18
+ return llm, processor
19
+
20
+ model_path1 = "Qwen/Qwen2-Audio-7B-Instruct" #argv[1]
21
+ model1, processor1 = load_model_processor(model_path1)
22
+
23
+ def response_to_audio_conv(conversation, model=None, processor=None, temperature = 0.1,repetition_penalty=1.1, top_p = 0.9,
24
+ max_new_tokens = 2048):
25
+ text = processor.apply_chat_template(conversation, add_generation_prompt=True, tokenize=False)
26
+ audios = []
27
+ for message in conversation:
28
+ if isinstance(message["content"], list):
29
+ for ele in message["content"]:
30
+ if ele["type"] == "audio":
31
+ if ele['audio_url'] != None:
32
+ audios.append(librosa.load(
33
+ ele['audio_url'],
34
+ sr=processor.feature_extractor.sampling_rate)[0]
35
+ )
36
+
37
+ sampling_params = SamplingParams(
38
+ temperature=temperature, max_tokens=max_new_tokens, repetition_penalty=repetition_penalty, top_p=top_p, top_k=20,
39
+ stop_token_ids=[],
40
+ )
41
+
42
+ input = {
43
+ 'prompt': text,
44
+ 'multi_modal_data': {
45
+ 'audio': [(audio, 16000) for audio in audios]
46
+ }
47
+ }
48
+
49
+ output = model.generate([input], sampling_params=sampling_params)[0]
50
+ response = output.outputs[0].text
51
+ return response
52
+
53
+ def print_like_dislike(x: gr.LikeData):
54
+ print(x.index, x.value, x.liked)
55
+
56
+ def add_message(history, message):
57
+ paths = []
58
+ for turn in history:
59
+ if turn['role'] == "user" and type(turn['content']) != str:
60
+ paths.append(turn['content'][0])
61
+ for x in message["files"]:
62
+ if x not in paths:
63
+ history.append({"role": "user", "content": {"path": x}})
64
+ if message["text"] is not None:
65
+ history.append({"role": "user", "content": message["text"]})
66
+ return history, gr.MultimodalTextbox(value=None, interactive=False)
67
+
68
+ def format_user_messgae(message):
69
+ if type(message['content']) == str:
70
+ return {"role": "user", "content": [{"type": "text", "text": message['content']}]}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
  else:
72
+ return {"role": "user", "content": [{"type": "audio", "audio_url": message['content'][0]}]}
73
+
74
+ def history_to_conversation(history):
75
+ conversation = []
76
+ audio_paths = []
77
+ for turn in history:
78
+ if turn['role'] == "user":
79
+ if not turn['content']:
80
+ continue
81
+ turn = format_user_messgae(turn)
82
+ if turn['content'][0]['type'] == 'audio':
83
+ if turn['content'][0]['audio_url'] in audio_paths:
84
+ continue
85
+ else:
86
+ audio_paths.append(turn['content'][0]['audio_url'])
87
+
88
+ if len(conversation) > 0 and conversation[-1]["role"] == "user":
89
+ conversation[-1]['content'].append(turn['content'][0])
90
+ else:
91
+ conversation.append(turn)
92
+ else:
93
+ conversation.append(turn)
94
 
95
+ print(json.dumps(conversation, indent=4, ensure_ascii=False))
96
+ return conversation
97
+
98
+ def bot(history: list, temperature = 0.1,repetition_penalty=1.1, top_p = 0.9,
99
+ max_new_tokens = 2048):
100
+ conversation = history_to_conversation(history)
101
+ response = response_to_audio_conv(conversation, model=model1, processor=processor1, temperature = temperature,repetition_penalty=repetition_penalty, top_p = top_p, max_new_tokens = max_new_tokens)
102
+ # response = "Nice to meet you!"
103
+ print("Bot:",response)
104
+
105
+ history.append({"role": "assistant", "content": ""})
106
+ for character in response:
107
+ history[-1]["content"] += character
108
+ time.sleep(0.01)
109
+ yield history
110
+
111
+ insturctions = """**Instruction**: there are three input format:
112
+ 1. text: input text message only
113
+ 2. audio: upload audio file or record a voice message
114
+ 3. audio + text: record a voice message and input text message"""
115
+
116
+ with gr.Blocks() as demo:
117
+ # gr.Markdown("""<p align="center"><img src="images/seal_logo.png" style="height: 80px"/><p>""")
118
+ # gr.Image("images/seal_logo.png", elem_id="seal_logo", show_label=False,height=80,show_fullscreen_button=False)
119
+ gr.Markdown(
120
+ """<div style="text-align: center; font-size: 32px; font-weight: bold;">SeaLLMs-Audio ChatBot</div>""",
121
+ )
122
+
123
+ # Description text
124
+ gr.Markdown(
125
+ """<div style="text-align: center; font-size: 16px;">
126
+ This WebUI is based on SeaLLMs-Audio-7B-Chat, developed by Alibaba DAMO Academy.<br>
127
+ You can interact with the chatbot in <b>English, Chinese, Indonesian, Thai, or Vietnamese</b>.<br>
128
+ For each round, you can input <b>audio and/or text</b>.
129
+ </div>""",
130
+ )
131
+
132
+ # Links with proper formatting
133
+ gr.Markdown(
134
+ """<div style="text-align: center; font-size: 16px;">
135
+ <a href="https://huggingface.co/SeaLLMs/SeaLLMs-v3-7B-Chat">[Website]</a> &nbsp;
136
+ <a href="https://huggingface.co/SeaLLMs/SeaLLMs-v3-7B-Chat">[Model🤗]</a> &nbsp;
137
+ <a href="https://github.com/liuchaoqun/SeaLLMs-Audio">[Github]</a>
138
+ </div>""",
139
+ )
140
+
141
+ # gr.Markdown(insturctions)
142
+ # with gr.Row():
143
+ # with gr.Column():
144
+ # temperature = gr.Slider(minimum=0, maximum=1, value=0.3, step=0.1, label="Temperature")
145
+ # with gr.Column():
146
+ # top_p = gr.Slider(minimum=0.1, maximum=1, value=0.5, step=0.1, label="Top P")
147
+ # with gr.Column():
148
+ # repetition_penalty = gr.Slider(minimum=0, maximum=2, value=1.1, step=0.1, label="Repetition Penalty")
149
+ chatbot = gr.Chatbot(elem_id="chatbot", bubble_full_width=False, type="messages")
150
+
151
+ chat_input = gr.MultimodalTextbox(
152
+ interactive=True,
153
+ file_count="single",
154
+ file_types=['.wav'],
155
+ placeholder="Enter message (optional) ...",
156
+ show_label=False,
157
+ sources=["microphone", "upload"],
158
+ )
159
+
160
+ chat_msg = chat_input.submit(
161
+ add_message, [chatbot, chat_input], [chatbot, chat_input]
162
+ )
163
+ bot_msg = chat_msg.then(bot, chatbot, chatbot, api_name="bot_response")
164
+ # bot_msg = chat_msg.then(bot, [chatbot, temperature, repetition_penalty, top_p], chatbot, api_name="bot_response")
165
+ bot_msg.then(lambda: gr.MultimodalTextbox(interactive=True), None, [chat_input])
166
+
167
+ # chatbot.like(print_like_dislike, None, None, like_user_message=True)
168
+
169
+ clear_button = gr.ClearButton([chatbot, chat_input])
170
 
171
+ # PORT = 7950
172
+ # demo.launch(server_port=PORT, show_api = True, allowed_paths = [],
173
+ # root_path = f"https://dsw-gateway.alibaba-inc.com/dsw81322/proxy/{PORT}/")
174
+
175
+ demo.launch(
176
+ share=False,
177
+ inbrowser=True,
178
+ server_port=7950,
179
+ server_name="0.0.0.0",
180
+ max_threads=40
181
+ )
 
 
 
 
requirements.txt CHANGED
@@ -1,31 +1,178 @@
1
- spaces
2
- torch
3
- gradio
4
- tiktoken
5
- openai
6
- transformers==4.38
7
- langchain
8
- langchain-community
9
- langchain-core
10
- chromadb
11
- pypdf
12
- docx2txt
13
- sentencepiece
14
- accelerate
15
- evaluate
16
- datasets
17
- sacrebleu
18
- websockets
19
- omegaconf
20
- scikit-learn
21
- jiwer
22
- tenacity
23
- pynvml
24
- ninja
25
- fastapi
26
- geomloss
27
- einops
28
- langdetect
29
- plotly
30
- faiss-cpu
31
- sentence-transformers
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ accelerate==1.2.1
2
+ aiofiles==23.2.1
3
+ aiohappyeyeballs==2.4.4
4
+ aiohttp==3.11.11
5
+ aiohttp-cors==0.7.0
6
+ aiosignal==1.3.2
7
+ airportsdata==20241001
8
+ annotated-types==0.7.0
9
+ anyio==4.7.0
10
+ astor==0.8.1
11
+ async-timeout==5.0.1
12
+ attrs==24.3.0
13
+ audioread==3.0.1
14
+ blake3==1.0.4
15
+ cachetools==5.5.1
16
+ certifi==2024.12.14
17
+ cffi==1.17.1
18
+ charset-normalizer==3.4.0
19
+ click==8.1.7
20
+ cloudpickle==3.1.1
21
+ colorful==0.5.6
22
+ compressed-tensors==0.9.1
23
+ cupy-cuda12x==13.3.0
24
+ dashscope==1.20.14
25
+ datasets==3.3.2
26
+ depyf==0.18.0
27
+ dill==0.3.8
28
+ diskcache==5.6.3
29
+ distlib==0.3.9
30
+ distro==1.9.0
31
+ dnspython==2.7.0
32
+ einops==0.8.1
33
+ email_validator==2.2.0
34
+ fastapi==0.115.6
35
+ fastapi-cli==0.0.7
36
+ fastrlock==0.8.3
37
+ ffmpy==0.4.0
38
+ filelock==3.16.1
39
+ frozenlist==1.5.0
40
+ fsspec==2024.10.0
41
+ gguf==0.10.0
42
+ google-api-core==2.24.1
43
+ google-auth==2.38.0
44
+ googleapis-common-protos==1.67.0
45
+ gradio==5.10.0
46
+ gradio_client==1.5.3
47
+ gradio_leaderboard==0.0.13
48
+ grpcio==1.70.0
49
+ h11==0.14.0
50
+ httpcore==1.0.7
51
+ httptools==0.6.4
52
+ httpx==0.28.1
53
+ huggingface-hub==0.27.0
54
+ idna==3.10
55
+ importlib_metadata==8.6.1
56
+ iniconfig==2.0.0
57
+ interegular==0.3.3
58
+ Jinja2==3.1.4
59
+ jiter==0.8.2
60
+ joblib==1.4.2
61
+ jsonschema==4.23.0
62
+ jsonschema-specifications==2024.10.1
63
+ lark==1.2.2
64
+ lazy_loader==0.4
65
+ librosa==0.10.2.post1
66
+ llvmlite==0.43.0
67
+ lm-format-enforcer==0.10.10
68
+ markdown-it-py==3.0.0
69
+ MarkupSafe==2.1.5
70
+ mdurl==0.1.2
71
+ mistral_common==1.5.3
72
+ modelscope_studio==1.0.2
73
+ mpmath==1.3.0
74
+ msgpack==1.1.0
75
+ msgspec==0.19.0
76
+ multidict==6.1.0
77
+ multiprocess==0.70.16
78
+ networkx==3.4.2
79
+ numba==0.60.0
80
+ numpy==1.26.4
81
+ nvidia-cublas-cu12==12.4.5.8
82
+ nvidia-cuda-cupti-cu12==12.4.127
83
+ nvidia-cuda-nvrtc-cu12==12.4.127
84
+ nvidia-cuda-runtime-cu12==12.4.127
85
+ nvidia-cudnn-cu12==9.1.0.70
86
+ nvidia-cufft-cu12==11.2.1.3
87
+ nvidia-curand-cu12==10.3.5.147
88
+ nvidia-cusolver-cu12==11.6.1.9
89
+ nvidia-cusparse-cu12==12.3.1.170
90
+ nvidia-ml-py==12.570.86
91
+ nvidia-nccl-cu12==2.21.5
92
+ nvidia-nvjitlink-cu12==12.4.127
93
+ nvidia-nvtx-cu12==12.4.127
94
+ openai==1.63.0
95
+ opencensus==0.11.4
96
+ opencensus-context==0.1.3
97
+ opencv-python-headless==4.11.0.86
98
+ orjson==3.10.12
99
+ outlines==0.1.11
100
+ outlines_core==0.1.26
101
+ pandas==2.2.3
102
+ partial-json-parser==0.2.1.1.post5
103
+ peft==0.14.0
104
+ pillow==11.0.0
105
+ pluggy==1.5.0
106
+ pooch==1.8.2
107
+ prometheus-fastapi-instrumentator==7.0.2
108
+ prometheus_client==0.21.1
109
+ propcache==0.2.1
110
+ proto-plus==1.26.0
111
+ protobuf==5.29.3
112
+ py-cpuinfo==9.0.0
113
+ py-spy==0.4.0
114
+ pyarrow==19.0.1
115
+ pyasn1==0.6.1
116
+ pyasn1_modules==0.4.1
117
+ pybind11==2.13.6
118
+ pycountry==24.6.1
119
+ pycparser==2.22
120
+ pydantic==2.10.3
121
+ pydantic_core==2.27.1
122
+ pydub==0.25.1
123
+ Pygments==2.18.0
124
+ pytest==8.3.4
125
+ python-dotenv==1.0.1
126
+ python-multipart==0.0.20
127
+ pytz==2024.2
128
+ PyYAML==6.0.2
129
+ ray==2.40.0
130
+ referencing==0.36.2
131
+ regex==2024.11.6
132
+ requests==2.32.3
133
+ rich==13.9.4
134
+ rich-toolkit==0.13.2
135
+ rpds-py==0.22.3
136
+ rsa==4.9
137
+ ruff==0.8.4
138
+ safehttpx==0.1.6
139
+ safetensors==0.4.5
140
+ scikit-learn==1.6.0
141
+ scipy==1.14.1
142
+ semantic-version==2.10.0
143
+ sentencepiece==0.2.0
144
+ shellingham==1.5.4
145
+ smart-open==7.1.0
146
+ sniffio==1.3.1
147
+ soundfile==0.12.1
148
+ soxr==0.5.0.post1
149
+ starlette==0.41.3
150
+ sympy==1.13.1
151
+ tenacity==9.0.0
152
+ threadpoolctl==3.5.0
153
+ tiktoken==0.9.0
154
+ tokenizers==0.21.0
155
+ tomli==2.2.1
156
+ tomlkit==0.13.2
157
+ torch==2.5.1
158
+ torchaudio==2.5.1
159
+ torchvision==0.20.1
160
+ tqdm==4.67.1
161
+ transformers==4.48.3
162
+ triton==3.1.0
163
+ typer==0.15.1
164
+ tzdata==2024.2
165
+ urllib3==2.2.3
166
+ uvicorn==0.34.0
167
+ uvloop==0.21.0
168
+ virtualenv==20.29.2
169
+ vllm==0.7.3
170
+ watchfiles==1.0.4
171
+ websocket-client==1.8.0
172
+ websockets==14.1
173
+ wrapt==1.17.2
174
+ xformers==0.0.28.post3
175
+ xgrammar==0.1.11
176
+ xxhash==3.5.0
177
+ yarl==1.18.3
178
+ zipp==3.21.0