雷娃 commited on
Commit
7685489
·
1 Parent(s): 628c773

modify app.py

Browse files
Files changed (1) hide show
  1. app.py +52 -28
app.py CHANGED
@@ -1,49 +1,73 @@
1
- # app.py
2
- from transformers import AutoTokenizer, AutoModelForCausalLM
3
  import gradio as gr
4
  import torch
5
 
6
- # load model and tokenizer
7
  model_name = "inclusionAI/Ling-lite-1.5"
8
  tokenizer = AutoTokenizer.from_pretrained(model_name)
9
  model = AutoModelForCausalLM.from_pretrained(
10
  model_name,
11
  torch_dtype="auto",
12
- device_map="auto",
13
  trust_remote_code=True
14
  ).eval()
15
 
16
- # define chat function
17
- def chat(user_input, max_new_tokens=512):
18
- # chat history
19
- messages = [
20
- {"role": "system", "content": "You are Ling, an assistant created by inclusionAI"},
21
- {"role": "user", "content": user_input}
22
- ]
23
- prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
24
-
25
- # encode the input prompt
26
- inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
27
-
28
- # generate response
29
- with torch.no_grad():
30
- outputs = model.generate(
31
- **inputs,
32
- max_new_tokens=max_new_tokens,
33
- pad_token_id=tokenizer.eos_token_id
34
- )
35
- response = tokenizer.decode(outputs[0][inputs['input_ids'].shape[-1]:], skip_special_tokens=True)
36
- return response
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
 
38
  # Construct Gradio Interface
39
  interface = gr.Interface(
40
- fn=chat,
41
  inputs=[
42
  gr.Textbox(lines=5, label="输入你的问题"),
43
  gr.Slider(minimum=100, maximum=1024, step=50, label="生成长度")
44
  ],
45
  outputs=gr.Textbox(label="模型回复"),
46
- title="Ling-lite-1.5 MoE 模型 Demo",
47
  description="基于 [inclusionAI/Ling-lite-1.5](https://huggingface.co/inclusionAI/Ling-lite-1.5) 的对话式文本生成演示。",
48
  examples=[
49
  ["介绍大型语言模型的基本概念", 512],
@@ -52,4 +76,4 @@ interface = gr.Interface(
52
  )
53
 
54
  # launch Gradion Service
55
- interface.launch()
 
1
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
2
+ from threading import Thread
3
  import gradio as gr
4
  import torch
5
 
6
+ # 加载模型和分词器
7
  model_name = "inclusionAI/Ling-lite-1.5"
8
  tokenizer = AutoTokenizer.from_pretrained(model_name)
9
  model = AutoModelForCausalLM.from_pretrained(
10
  model_name,
11
  torch_dtype="auto",
12
+ device_map="auto",
13
  trust_remote_code=True
14
  ).eval()
15
 
16
+
17
+ # 支持流式输出的聊天函数
18
+ def chat_stream(message, history):
19
+ system_prompt = {"role": "system", "content": "You are Ling, an assistant created by inclusionAI"}
20
+ user_message = {"role": "user", "content": message}
21
+
22
+ # 构建消息历史
23
+ messages = [system_prompt] + history + [user_message]
24
+
25
+ # 应用 chat template
26
+ text = tokenizer.apply_chat_template(
27
+ messages,
28
+ tokenize=False,
29
+ add_generation_prompt=True
30
+ )
31
+
32
+ # 编码输入
33
+ inputs = tokenizer([text], return_tensors="pt").to(model.device)
34
+
35
+ # 设置 streamer
36
+ streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
37
+
38
+ # 生成参数
39
+ generate_kwargs = dict(
40
+ input_ids=inputs["input_ids"],
41
+ attention_mask=inputs["attention_mask"],
42
+ streamer=streamer,
43
+ max_new_tokens=512,
44
+ do_sample=True,
45
+ temperature=0.7,
46
+ pad_token_id=tokenizer.eos_token_id
47
+ )
48
+
49
+ # 在后台线程中启动生成
50
+ def generate():
51
+ model.generate(**generate_kwargs)
52
+
53
+ thread = Thread(target=generate)
54
+ thread.start()
55
+
56
+ # 逐步读取生成的内容
57
+ response = ""
58
+ for new_text in streamer:
59
+ response += new_text
60
+ yield response.strip()
61
 
62
  # Construct Gradio Interface
63
  interface = gr.Interface(
64
+ fn=chat_stream,
65
  inputs=[
66
  gr.Textbox(lines=5, label="输入你的问题"),
67
  gr.Slider(minimum=100, maximum=1024, step=50, label="生成长度")
68
  ],
69
  outputs=gr.Textbox(label="模型回复"),
70
+ title="Ling-lite-1.5 MoE AI助手",
71
  description="基于 [inclusionAI/Ling-lite-1.5](https://huggingface.co/inclusionAI/Ling-lite-1.5) 的对话式文本生成演示。",
72
  examples=[
73
  ["介绍大型语言模型的基本概念", 512],
 
76
  )
77
 
78
  # launch Gradion Service
79
+ interface.launch()