zizo66 commited on
Commit
6905010
·
verified ·
1 Parent(s): 75da7c7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +115 -58
app.py CHANGED
@@ -1,71 +1,128 @@
 
 
 
 
1
  import gradio as gr
2
- from huggingface_hub import InferenceClient
 
 
3
 
4
- # تحميل نموذج LLaMA من Hugging Face
5
- client = InferenceClient("meta-llama/Llama-2-7b-chat-hf")
 
 
 
 
6
 
7
- # قائمة السيناريوهات المتاحة
8
- scenarios = {
9
- "restaurant": "You are in a restaurant. Help the user order food in English.",
10
- "airport": "You are at an airport. Help the user check in and find their gate.",
11
- "hotel": "You are in a hotel. Help the user book a room.",
12
- "shopping": "You are in a store. Help the user ask for prices and sizes.",
13
- }
14
 
15
- # دالة لاختيار الرسالة المناسبة للسيناريو
16
- def scenario_prompt(choice):
17
- return scenarios.get(choice, "You are a language tutor AI. Help users practice real-life conversations.")
18
 
19
- # دالة المحادثة مع الذكاء الاصطناعي
20
- def respond(message, history, scenario, system_message, max_tokens, temperature, top_p):
21
- # التأكد من أن history عبارة عن قائمة، وإذا لم توجد بيانات سابقة، اجعلها قائمة فارغة.
22
- if history is None:
23
- history = []
24
-
25
- # استبدال system_message بالرسالة الخاصة بالسيناريو
26
- sys_msg = scenario_prompt(scenario)
27
-
28
- # إعداد قائمة الرسائل بما يتوافق مع تنسيق Chat Completion
29
- messages = [{"role": "system", "content": sys_msg}]
30
-
31
- # التأكد من أن history عبارة عن قائمة من الـ tuples
32
- for pair in history:
33
- if isinstance(pair, (list, tuple)) and len(pair) == 2:
34
- user_msg, assistant_msg = pair
35
- if user_msg:
36
- messages.append({"role": "user", "content": user_msg})
37
- if assistant_msg:
38
- messages.append({"role": "assistant", "content": assistant_msg})
39
-
40
- # إضافة الرسالة الحالية للمستخدم
41
- messages.append({"role": "user", "content": message})
42
-
43
- response = ""
44
- # الحصول على الرد مع البث (streaming)
45
- for m in client.chat_completion(
46
- messages,
47
- max_tokens=max_tokens,
48
- stream=True,
49
- temperature=temperature,
 
 
 
50
  top_p=top_p,
51
- ):
52
- # استخراج النص من كل توكن متدفق
53
- token = m.choices[0].delta.content if m.choices[0].delta.content is not None else ""
54
- response += token
55
- yield response
 
 
 
 
 
 
 
 
56
 
57
- # إنشاء واجهة Gradio للتفاعل مع المستخدم
58
  demo = gr.ChatInterface(
59
- fn=respond,
60
- chatbot=gr.Chatbot(type="messages"),
61
  additional_inputs=[
62
- gr.Dropdown(choices=list(scenarios.keys()), label="Choose a scenario", value="restaurant"),
63
- gr.Textbox(value=scenario_prompt("restaurant"), label="System message"),
64
- gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
65
- gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
66
- gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)"),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
  ],
 
 
 
 
 
68
  )
69
 
 
70
  if __name__ == "__main__":
71
- demo.launch()
 
1
+ import os
2
+ from collections.abc import Iterator
3
+ from threading import Thread
4
+
5
  import gradio as gr
6
+ import spaces
7
+ import torch
8
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
9
 
10
+ DESCRIPTION = """\
11
+ # Llama 3.2 3B Instruct
12
+ Llama 3.2 3B is Meta's latest iteration of open LLMs.
13
+ This is a demo of [`meta-llama/Llama-3.2-3B-Instruct`](https://huggingface.co/meta-llama/Llama-3.2-3B-Instruct), fine-tuned for instruction following.
14
+ For more details, please check [our post](https://huggingface.co/blog/llama32).
15
+ """
16
 
17
+ MAX_MAX_NEW_TOKENS = 2048
18
+ DEFAULT_MAX_NEW_TOKENS = 1024
19
+ MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
 
 
 
 
20
 
21
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
 
 
22
 
23
+ model_id = "meta-llama/Llama-3.2-3B-Instruct"
24
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
25
+ model = AutoModelForCausalLM.from_pretrained(
26
+ model_id,
27
+ device_map="auto",
28
+ torch_dtype=torch.bfloat16,
29
+ )
30
+ model.eval()
31
+
32
+
33
+ @spaces.GPU(duration=90)
34
+ def generate(
35
+ message: str,
36
+ chat_history: list[dict],
37
+ max_new_tokens: int = 1024,
38
+ temperature: float = 0.6,
39
+ top_p: float = 0.9,
40
+ top_k: int = 50,
41
+ repetition_penalty: float = 1.2,
42
+ ) -> Iterator[str]:
43
+ conversation = [*chat_history, {"role": "user", "content": message}]
44
+
45
+ input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
46
+ if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
47
+ input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
48
+ gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
49
+ input_ids = input_ids.to(model.device)
50
+
51
+ streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
52
+ generate_kwargs = dict(
53
+ {"input_ids": input_ids},
54
+ streamer=streamer,
55
+ max_new_tokens=max_new_tokens,
56
+ do_sample=True,
57
  top_p=top_p,
58
+ top_k=top_k,
59
+ temperature=temperature,
60
+ num_beams=1,
61
+ repetition_penalty=repetition_penalty,
62
+ )
63
+ t = Thread(target=model.generate, kwargs=generate_kwargs)
64
+ t.start()
65
+
66
+ outputs = []
67
+ for text in streamer:
68
+ outputs.append(text)
69
+ yield "".join(outputs)
70
+
71
 
 
72
  demo = gr.ChatInterface(
73
+ fn=generate,
 
74
  additional_inputs=[
75
+ gr.Slider(
76
+ label="Max new tokens",
77
+ minimum=1,
78
+ maximum=MAX_MAX_NEW_TOKENS,
79
+ step=1,
80
+ value=DEFAULT_MAX_NEW_TOKENS,
81
+ ),
82
+ gr.Slider(
83
+ label="Temperature",
84
+ minimum=0.1,
85
+ maximum=4.0,
86
+ step=0.1,
87
+ value=0.6,
88
+ ),
89
+ gr.Slider(
90
+ label="Top-p (nucleus sampling)",
91
+ minimum=0.05,
92
+ maximum=1.0,
93
+ step=0.05,
94
+ value=0.9,
95
+ ),
96
+ gr.Slider(
97
+ label="Top-k",
98
+ minimum=1,
99
+ maximum=1000,
100
+ step=1,
101
+ value=50,
102
+ ),
103
+ gr.Slider(
104
+ label="Repetition penalty",
105
+ minimum=1.0,
106
+ maximum=2.0,
107
+ step=0.05,
108
+ value=1.2,
109
+ ),
110
+ ],
111
+ stop_btn=None,
112
+ examples=[
113
+ ["Hello there! How are you doing?"],
114
+ ["Can you explain briefly to me what is the Python programming language?"],
115
+ ["Explain the plot of Cinderella in a sentence."],
116
+ ["How many hours does it take a man to eat a Helicopter?"],
117
+ ["Write a 100-word article on 'Benefits of Open-Source in AI research'"],
118
  ],
119
+ cache_examples=False,
120
+ type="messages",
121
+ description=DESCRIPTION,
122
+ css_paths="style.css",
123
+ fill_height=True,
124
  )
125
 
126
+
127
  if __name__ == "__main__":
128
+ demo.queue(max_size=20).launch()