multimodalart HF Staff commited on
Commit
5cda1cc
·
1 Parent(s): a8a9fb2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +30 -42
app.py CHANGED
@@ -1,6 +1,6 @@
1
  import gradio as gr
2
  import torch
3
- from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline, StoppingCriteria, StoppingCriteriaList
4
  import time
5
  import numpy as np
6
  from torch.nn import functional as F
@@ -28,43 +28,11 @@ class StopOnTokens(StoppingCriteria):
28
  return True
29
  return False
30
 
31
-
32
- def contrastive_generate(text, bad_text):
33
- with torch.no_grad():
34
- tokens = tok(text, return_tensors="pt")[
35
- 'input_ids'].cuda()[:, :4096-1024]
36
- bad_tokens = tok(bad_text, return_tensors="pt")[
37
- 'input_ids'].cuda()[:, :4096-1024]
38
- history = None
39
- bad_history = None
40
- curr_output = list()
41
- for i in range(1024):
42
- out = m(tokens, past_key_values=history, use_cache=True)
43
- logits = out.logits
44
- history = out.past_key_values
45
- bad_out = m(bad_tokens, past_key_values=bad_history,
46
- use_cache=True)
47
- bad_logits = bad_out.logits
48
- bad_history = bad_out.past_key_values
49
- probs = F.softmax(logits.float(), dim=-1)[0][-1].cpu()
50
- bad_probs = F.softmax(bad_logits.float(), dim=-1)[0][-1].cpu()
51
- logits = torch.log(probs)
52
- bad_logits = torch.log(bad_probs)
53
- logits[probs > 0.1] = logits[probs > 0.1] - bad_logits[probs > 0.1]
54
- probs = F.softmax(logits)
55
- out = int(torch.multinomial(probs, 1))
56
- if out in [50278, 50279, 50277, 1, 0]:
57
- break
58
- else:
59
- curr_output.append(out)
60
- out = np.array([out])
61
- tokens = torch.from_numpy(np.array([out])).to(
62
- tokens.device)
63
- bad_tokens = torch.from_numpy(np.array([out])).to(
64
- tokens.device)
65
- return tok.decode(curr_output)
66
-
67
-
68
  def generate(text, bad_text=None):
69
  stop = StopOnTokens()
70
  result = generator(text, max_new_tokens=1024, num_return_sequences=1, num_beams=1, do_sample=True,
@@ -81,9 +49,29 @@ def bot(history, curr_system_message):
81
  messages = curr_system_message + \
82
  "".join(["".join(["<|USER|>"+item[0], "<|ASSISTANT|>"+item[1]])
83
  for item in history])
84
- output = generate(messages)
85
- history[-1][1] = output
86
- time.sleep(1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
  return history, history
88
 
89
 
@@ -107,5 +95,5 @@ with gr.Blocks() as demo:
107
  submit.click(fn=user, inputs=[msg, history], outputs=[msg, chatbot, history], queue=False).then(
108
  fn=bot, inputs=[chatbot, system_msg], outputs=[chatbot, history], queue=True)
109
  clear.click(lambda: [None, []], None, [chatbot, history], queue=False)
110
- demo.queue(concurrency_count=5)
111
  demo.launch()
 
1
  import gradio as gr
2
  import torch
3
+ from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline, StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer
4
  import time
5
  import numpy as np
6
  from torch.nn import functional as F
 
28
  return True
29
  return False
30
 
31
+ def streaming_generate(text, bad_text=None):
32
+
33
+ return model_output
34
+
35
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  def generate(text, bad_text=None):
37
  stop = StopOnTokens()
38
  result = generator(text, max_new_tokens=1024, num_return_sequences=1, num_beams=1, do_sample=True,
 
49
  messages = curr_system_message + \
50
  "".join(["".join(["<|USER|>"+item[0], "<|ASSISTANT|>"+item[1]])
51
  for item in history])
52
+
53
+ model_inputs = tok(messages, return_tensors="pt")['input_ids'].cuda()[:, :4096-1024]
54
+
55
+ streamer = TextIteratorStreamer(tok, timeout=10., skip_prompt=True, skip_special_tokens=True)
56
+ generate_kwargs = dict(
57
+ model_inputs,
58
+ streamer=streamer,
59
+ max_new_tokens=1024,
60
+ do_sample=True,
61
+ top_p=0.95,
62
+ top_k=1000,
63
+ temperature=1.0,
64
+ num_beams=1,
65
+ stopping_criteria=StoppingCriteriaList([stop])
66
+ )
67
+ t = Thread(target=m.generate, kwargs=generate_kwargs)
68
+ t.start()
69
+
70
+ model_output = ""
71
+ for new_text in streamer:
72
+ history[-1][1] += new_text
73
+ yield history
74
+
75
  return history, history
76
 
77
 
 
95
  submit.click(fn=user, inputs=[msg, history], outputs=[msg, chatbot, history], queue=False).then(
96
  fn=bot, inputs=[chatbot, system_msg], outputs=[chatbot, history], queue=True)
97
  clear.click(lambda: [None, []], None, [chatbot, history], queue=False)
98
+ demo.queue(concurrency_count=1)
99
  demo.launch()