multimodalart HF Staff commited on
Commit
22e6618
·
1 Parent(s): adae6e9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -17
app.py CHANGED
@@ -5,7 +5,8 @@ import time
5
  import numpy as np
6
  from torch.nn import functional as F
7
  import os
8
- # auth_key = os.environ["HF_ACCESS_TOKEN"]
 
9
  print(f"Starting to load the model to memory")
10
  m = AutoModelForCausalLM.from_pretrained(
11
  "stabilityai/stablelm-tuned-alpha-7b", torch_dtype=torch.float16).cuda()
@@ -28,18 +29,6 @@ class StopOnTokens(StoppingCriteria):
28
  return True
29
  return False
30
 
31
- def streaming_generate(text, bad_text=None):
32
-
33
- return model_output
34
-
35
-
36
- def generate(text, bad_text=None):
37
- stop = StopOnTokens()
38
- result = generator(text, max_new_tokens=1024, num_return_sequences=1, num_beams=1, do_sample=True,
39
- temperature=1.0, top_p=0.95, top_k=1000, stopping_criteria=StoppingCriteriaList([stop]))
40
- return result[0]["generated_text"].replace(text, "")
41
-
42
-
43
  def user(user_message, history):
44
  history = history + [[user_message, ""]]
45
  return "", history, history
@@ -51,8 +40,8 @@ def bot(history, curr_system_message):
51
  "".join(["".join(["<|USER|>"+item[0], "<|ASSISTANT|>"+item[1]])
52
  for item in history])
53
 
54
- model_inputs = tok(messages, return_tensors="pt")['input_ids'].cuda()[:, :4096-1024]
55
-
56
  streamer = TextIteratorStreamer(tok, timeout=10., skip_prompt=True, skip_special_tokens=True)
57
  generate_kwargs = dict(
58
  model_inputs,
@@ -68,10 +57,11 @@ def bot(history, curr_system_message):
68
  t = Thread(target=m.generate, kwargs=generate_kwargs)
69
  t.start()
70
 
71
- model_output = ""
72
  for new_text in streamer:
 
73
  history[-1][1] += new_text
74
- yield history
75
 
76
  return history, history
77
 
 
5
  import numpy as np
6
  from torch.nn import functional as F
7
  import os
8
+ from threading import Thread
9
+
10
  print(f"Starting to load the model to memory")
11
  m = AutoModelForCausalLM.from_pretrained(
12
  "stabilityai/stablelm-tuned-alpha-7b", torch_dtype=torch.float16).cuda()
 
29
  return True
30
  return False
31
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  def user(user_message, history):
33
  history = history + [[user_message, ""]]
34
  return "", history, history
 
40
  "".join(["".join(["<|USER|>"+item[0], "<|ASSISTANT|>"+item[1]])
41
  for item in history])
42
 
43
+ #model_inputs = tok([messages], return_tensors="pt")['input_ids'].cuda()[:, :4096-1024]
44
+ model_inputs = tok([messages], return_tensors="pt").to("cuda")
45
  streamer = TextIteratorStreamer(tok, timeout=10., skip_prompt=True, skip_special_tokens=True)
46
  generate_kwargs = dict(
47
  model_inputs,
 
57
  t = Thread(target=m.generate, kwargs=generate_kwargs)
58
  t.start()
59
 
60
+ print(history)
61
  for new_text in streamer:
62
+ print(new_text)
63
  history[-1][1] += new_text
64
+ yield history, history
65
 
66
  return history, history
67