雷娃 commited on
Commit
26ca9d4
·
1 Parent(s): 7685489

replace app.py

Browse files
Files changed (1) hide show
  1. app.py +28 -52
app.py CHANGED
@@ -1,73 +1,49 @@
1
- from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
2
- from threading import Thread
3
  import gradio as gr
4
  import torch
5
 
6
- # 加载模型和分词器
7
  model_name = "inclusionAI/Ling-lite-1.5"
8
  tokenizer = AutoTokenizer.from_pretrained(model_name)
9
  model = AutoModelForCausalLM.from_pretrained(
10
  model_name,
11
  torch_dtype="auto",
12
- device_map="auto",
13
  trust_remote_code=True
14
  ).eval()
15
 
16
-
17
- # 支持流式输出的聊天函数
18
- def chat_stream(message, history):
19
- system_prompt = {"role": "system", "content": "You are Ling, an assistant created by inclusionAI"}
20
- user_message = {"role": "user", "content": message}
21
-
22
- # 构建消息历史
23
- messages = [system_prompt] + history + [user_message]
24
-
25
- # 应用 chat template
26
- text = tokenizer.apply_chat_template(
27
- messages,
28
- tokenize=False,
29
- add_generation_prompt=True
30
- )
31
-
32
- # 编码输入
33
- inputs = tokenizer([text], return_tensors="pt").to(model.device)
34
-
35
- # 设置 streamer
36
- streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
37
-
38
- # 生成参数
39
- generate_kwargs = dict(
40
- input_ids=inputs["input_ids"],
41
- attention_mask=inputs["attention_mask"],
42
- streamer=streamer,
43
- max_new_tokens=512,
44
- do_sample=True,
45
- temperature=0.7,
46
- pad_token_id=tokenizer.eos_token_id
47
- )
48
-
49
- # 在后台线程中启动生成
50
- def generate():
51
- model.generate(**generate_kwargs)
52
-
53
- thread = Thread(target=generate)
54
- thread.start()
55
-
56
- # 逐步读取生成的内容
57
- response = ""
58
- for new_text in streamer:
59
- response += new_text
60
- yield response.strip()
61
 
62
  # Construct Gradio Interface
63
  interface = gr.Interface(
64
- fn=chat_stream,
65
  inputs=[
66
  gr.Textbox(lines=5, label="输入你的问题"),
67
  gr.Slider(minimum=100, maximum=1024, step=50, label="生成长度")
68
  ],
69
  outputs=gr.Textbox(label="模型回复"),
70
- title="Ling-lite-1.5 MoE AI助手",
71
  description="基于 [inclusionAI/Ling-lite-1.5](https://huggingface.co/inclusionAI/Ling-lite-1.5) 的对话式文本生成演示。",
72
  examples=[
73
  ["介绍大型语言模型的基本概念", 512],
@@ -76,4 +52,4 @@ interface = gr.Interface(
76
  )
77
 
78
  # launch Gradion Service
79
- interface.launch()
 
1
+ # app.py
2
+ from transformers import AutoTokenizer, AutoModelForCausalLM
3
  import gradio as gr
4
  import torch
5
 
6
+ # load model and tokenizer
7
  model_name = "inclusionAI/Ling-lite-1.5"
8
  tokenizer = AutoTokenizer.from_pretrained(model_name)
9
  model = AutoModelForCausalLM.from_pretrained(
10
  model_name,
11
  torch_dtype="auto",
12
+ device_map="auto",
13
  trust_remote_code=True
14
  ).eval()
15
 
16
+ # define chat function
17
+ def chat(user_input, max_new_tokens=512):
18
+ # chat history
19
+ messages = [
20
+ {"role": "system", "content": "You are Ling, an assistant created by inclusionAI"},
21
+ {"role": "user", "content": user_input}
22
+ ]
23
+ prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
24
+
25
+ # encode the input prompt
26
+ inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
27
+
28
+ # generate response
29
+ with torch.no_grad():
30
+ outputs = model.generate(
31
+ **inputs,
32
+ max_new_tokens=max_new_tokens,
33
+ pad_token_id=tokenizer.eos_token_id
34
+ )
35
+ response = tokenizer.decode(outputs[0][inputs['input_ids'].shape[-1]:], skip_special_tokens=True)
36
+ return response
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
 
38
  # Construct Gradio Interface
39
  interface = gr.Interface(
40
+ fn=chat,
41
  inputs=[
42
  gr.Textbox(lines=5, label="输入你的问题"),
43
  gr.Slider(minimum=100, maximum=1024, step=50, label="生成长度")
44
  ],
45
  outputs=gr.Textbox(label="模型回复"),
46
+ title="Ling-lite-1.5 MoE 模型 Demo",
47
  description="基于 [inclusionAI/Ling-lite-1.5](https://huggingface.co/inclusionAI/Ling-lite-1.5) 的对话式文本生成演示。",
48
  examples=[
49
  ["介绍大型语言模型的基本概念", 512],
 
52
  )
53
 
54
  # launch Gradion Service
55
+ interface.launch()