jtordable commited on
Commit
85d0b5f
·
verified ·
1 Parent(s): b893fc1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +34 -18
app.py CHANGED
@@ -3,6 +3,8 @@ import gradio as gr
3
  import spaces
4
  import torch
5
  import logging
 
 
6
  from transformers import AutoTokenizer, AutoModelForCausalLM
7
  from transformers.utils import logging as hf_logging
8
 
@@ -31,25 +33,39 @@ model = AutoModelForCausalLM.from_pretrained(
31
 
32
  @spaces.GPU
33
  def chat_fn(prompt, max_tokens=512):
 
34
  max_tokens = min(int(max_tokens), 32_000)
35
- messages = [{"role": "user", "content": prompt}]
36
- chat_prompt = tokenizer.apply_chat_template(
37
- messages, tokenize=False, add_generation_prompt=True
38
- )
39
- inputs = tokenizer(chat_prompt, return_tensors="pt").to(model.device)
40
-
41
- # Generate with proper parameters
42
- outputs = model.generate(
43
- **inputs,
44
- max_new_tokens=max_tokens,
45
- do_sample=True,
46
- temperature=0.1,
47
- pad_token_id=tokenizer.eos_token_id
48
- )
49
-
50
- # Decode only the new tokens (not the input)
51
- generated_text = tokenizer.decode(outputs[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True)
52
- return generated_text
 
 
 
 
 
 
 
 
 
 
 
 
 
53
 
54
  gr.Interface(
55
  fn=chat_fn,
 
3
  import spaces
4
  import torch
5
  import logging
6
+ import time
7
+
8
  from transformers import AutoTokenizer, AutoModelForCausalLM
9
  from transformers.utils import logging as hf_logging
10
 
 
33
 
34
  @spaces.GPU
35
  def chat_fn(prompt, max_tokens=512):
36
+ t0 = time.time()
37
  max_tokens = min(int(max_tokens), 32_000)
38
+
39
+ try:
40
+ messages = [{"role": "user", "content": prompt}]
41
+ chat_prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
42
+
43
+ inputs = tokenizer(chat_prompt, return_tensors="pt").to(model.device)
44
+ t1 = time.time()
45
+ logging.info(f"🧠 Tokenization complete in {t1 - t0:.2f}s")
46
+
47
+ outputs = model.generate(
48
+ **inputs,
49
+ max_new_tokens=max_tokens,
50
+ do_sample=True,
51
+ temperature=0.1,
52
+ pad_token_id=tokenizer.eos_token_id
53
+ )
54
+ t2 = time.time()
55
+ logging.info(f"⚡️ Generation complete in {t2 - t1:.2f}s (max_tokens={max_tokens})")
56
+
57
+ generated_text = tokenizer.decode(
58
+ outputs[0][inputs['input_ids'].shape[1]:],
59
+ skip_special_tokens=True
60
+ )
61
+ t3 = time.time()
62
+ logging.info(f"🔓 Decoding complete in {t3 - t2:.2f}s (output length: {len(generated_text)})")
63
+
64
+ return generated_text
65
+
66
+ except Exception:
67
+ logging.exception("❌ Exception during generation")
68
+ return "⚠️ Generation failed"
69
 
70
  gr.Interface(
71
  fn=chat_fn,