david-thrower commited on
Commit
a16489c
·
verified ·
1 Parent(s): 2013b5a

Update app.py

Browse files

Nested generation in inference mode.

Files changed (1) hide show
  1. app.py +12 -12
app.py CHANGED
@@ -69,18 +69,18 @@ def chat_fn(history, enable_thinking, temperature, top_p, top_k, repetition_pena
69
  # xml_tools=TOOLS
70
  )
71
  inputs = tokenizer(text, return_tensors="pt").to(DEVICE)
72
-
73
- streamer = model.generate(
74
- **inputs,
75
- max_new_tokens=max_new_tokens,
76
- do_sample=True,
77
- temperature=temperature,
78
- top_p=top_p,
79
- top_k=top_k,
80
- repetition_penalty=repetition_penalty,
81
- pad_token_id=tokenizer.eos_token_id,
82
- streamer=None # we'll yield manually
83
- )
84
  output_ids = streamer[0][len(inputs.input_ids[0]):]
85
  response = tokenizer.decode(output_ids, skip_special_tokens=True)
86
  if isinstance(response, str):
 
69
  # xml_tools=TOOLS
70
  )
71
  inputs = tokenizer(text, return_tensors="pt").to(DEVICE)
72
+ with torch.inference_mode():
73
+ streamer = model.generate(
74
+ **inputs,
75
+ max_new_tokens=max_new_tokens,
76
+ do_sample=True,
77
+ temperature=temperature,
78
+ top_p=top_p,
79
+ top_k=top_k,
80
+ repetition_penalty=repetition_penalty,
81
+ pad_token_id=tokenizer.eos_token_id,
82
+ streamer=None # we'll yield manually
83
+ )
84
  output_ids = streamer[0][len(inputs.input_ids[0]):]
85
  response = tokenizer.decode(output_ids, skip_special_tokens=True)
86
  if isinstance(response, str):