mjpsm commited on
Commit
9dd2063
·
verified ·
1 Parent(s): ba959d7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -3
app.py CHANGED
@@ -10,16 +10,30 @@ model = AutoModelForCausalLM.from_pretrained(model_name)
10
  # Generation function
11
  def generate_affirmation(prompt):
12
  inputs = tokenizer(prompt, return_tensors="pt")
 
 
13
  with torch.no_grad():
14
  outputs = model.generate(
15
- inputs["input_ids"],
16
  max_new_tokens=100,
17
  temperature=0.7,
18
  top_k=50,
19
  top_p=0.95,
20
- do_sample=True
 
21
  )
22
- return tokenizer.decode(outputs[0], skip_special_tokens=True)
 
 
 
 
 
 
 
 
 
 
 
23
 
24
  # Gradio interface
25
  demo = gr.Interface(
 
10
  # Generation function
11
  def generate_affirmation(prompt):
12
  inputs = tokenizer(prompt, return_tensors="pt")
13
+ input_ids = inputs["input_ids"]
14
+
15
  with torch.no_grad():
16
  outputs = model.generate(
17
+ input_ids,
18
  max_new_tokens=100,
19
  temperature=0.7,
20
  top_k=50,
21
  top_p=0.95,
22
+ do_sample=True,
23
+ pad_token_id=tokenizer.eos_token_id
24
  )
25
+
26
+ # Decode full output
27
+ full_output = tokenizer.decode(outputs[0], skip_special_tokens=True)
28
+
29
+ # Remove the input prompt from the output to isolate the generated part
30
+ if full_output.startswith(prompt):
31
+ affirmation = full_output[len(prompt):].strip()
32
+ else:
33
+ affirmation = full_output.strip()
34
+
35
+ return affirmation
36
+
37
 
38
  # Gradio interface
39
  demo = gr.Interface(