TuringsSolutions commited on
Commit
7f11fb0
·
verified ·
1 Parent(s): 654c6e1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +44 -29
app.py CHANGED
@@ -1,35 +1,50 @@
1
- def load_model():
2
- return pipeline("text-generation", model="TinyLlama/TinyLlama-1.1B-Chat-v1.0", torch_dtype=torch.bfloat16, device_map="auto")
 
3
 
4
- models = [load_model() for _ in range(3)]
 
 
 
 
 
 
 
 
 
 
 
5
  tokenizer = models[0].tokenizer
6
 
7
- # Enhanced prompt engineering (unchanged)
8
- messages = [
9
- {
10
- "role": "system",
11
- "content": "You are a friendly chatbot who always responds in the style of a pirate. Use pirate vocabulary and mannerisms in your replies.",
12
- },
13
- {"role": "user", "content": "How many helicopters can a human eat in one sitting, matey?"},
14
- ]
15
- prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
 
 
 
16
 
17
- # Ensemble generation with averaging (corrected)
18
- responses = []
19
- for model in models:
20
- outputs = model(prompt, max_new_tokens=256, do_sample=True, temperature=0.7, top_k=50, top_p=0.95)
21
- response = outputs[0]['generated_text']
22
- responses.append(response)
 
 
 
23
 
24
- # Average the generated text directly
25
- averaged_text = ""
26
- for i in range(min(len(response) for response in responses)):
27
- token_counts = {}
28
- for response in responses:
29
- token = response[i]
30
- token_counts[token] = token_counts.get(token, 0) + 1
31
- most_frequent_tokens = sorted(token_counts.items(), key=lambda x: x[1], reverse=True)
32
- averaged_token = most_frequent_tokens[0][0] # Choose the most frequent token
33
- averaged_text += averaged_token
34
 
35
- print(averaged_text)
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import pipeline, AutoTokenizer
3
+ import gradio as gr
4
 
5
+ def load_models():
6
+ return [
7
+ pipeline(
8
+ "text-generation",
9
+ model="TinyLlama/TinyLlama-1.1B-Chat-v1.0",
10
+ torch_dtype=torch.bfloat16,
11
+ device_map="auto",
12
+ )
13
+ for _ in range(3)
14
+ ]
15
+
16
+ models = load_models()
17
  tokenizer = models[0].tokenizer
18
 
19
+ def generate_text(prompt):
20
+ messages = [
21
+ {"role": "system", "content": "You are a friendly chatbot who always responds in the style of a pirate. Use pirate vocabulary and mannerisms in your replies."},
22
+ {"role": "user", "content": prompt},
23
+ ]
24
+ prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
25
+
26
+ responses = []
27
+ for model in models:
28
+ outputs = model(prompt, max_new_tokens=256, do_sample=True, temperature=0.7, top_k=50, top_p=0.95)
29
+ response = outputs[0]["generated_text"]
30
+ responses.append(response)
31
 
32
+ averaged_text = ""
33
+ for i in range(min(len(response) for response in responses)):
34
+ token_counts = {}
35
+ for response in responses:
36
+ token = response[i]
37
+ token_counts[token] = token_counts.get(token, 0) + 1
38
+ most_frequent_tokens = sorted(token_counts.items(), key=lambda x: x[1], reverse=True)
39
+ averaged_token = most_frequent_tokens[0][0] # Choose the most frequent token
40
+ averaged_text += averaged_token
41
 
42
+ return averaged_text
 
 
 
 
 
 
 
 
 
43
 
44
+ iface = gr.Interface(
45
+ generate_text,
46
+ [gr.Textbox(lines=2, label="Enter your prompt")],
47
+ "textbox",
48
+ title="Pirate Chatbot",
49
+ )
50
+ iface.launch()