Manasa1 commited on
Commit
0e19546
·
verified ·
1 Parent(s): d29b4df

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -2
app.py CHANGED
@@ -10,18 +10,27 @@ tokenizer = AutoTokenizer.from_pretrained("Manasa1/gpt-finetuned-tweets")
10
 
11
  def generate_tweet():
12
  prompt = "Write a concise, creative tweet reflecting the style and personality in the fine-tuned dataset."
13
- inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=100)
 
 
 
 
 
 
14
  outputs = model.generate(
15
  inputs["input_ids"],
 
16
  max_length=140,
17
  num_return_sequences=1,
18
  top_p=0.8,
19
  temperature=0.6,
20
- repetition_penalty=1.2, # Penalizes repetitive tokens
21
  )
 
22
  generated_tweet = tokenizer.decode(outputs[0], skip_special_tokens=True)
23
  return generated_tweet.strip()
24
 
 
25
  # Gradio Interface
26
  with gr.Blocks() as app:
27
  gr.Markdown("# AI Tweet Generator")
 
10
 
11
  def generate_tweet():
12
  prompt = "Write a concise, creative tweet reflecting the style and personality in the fine-tuned dataset."
13
+ # Tokenize the input prompt
14
+ inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=100, padding=True)
15
+
16
+ # Explicitly set the pad_token_id
17
+ model.config.pad_token_id = model.config.eos_token_id
18
+
19
+ # Generate the tweet with the attention mask
20
  outputs = model.generate(
21
  inputs["input_ids"],
22
+ attention_mask=inputs["attention_mask"], # Pass attention_mask explicitly
23
  max_length=140,
24
  num_return_sequences=1,
25
  top_p=0.8,
26
  temperature=0.6,
27
+ repetition_penalty=1.2, # Penalize repetition
28
  )
29
+ # Decode and return the generated tweet
30
  generated_tweet = tokenizer.decode(outputs[0], skip_special_tokens=True)
31
  return generated_tweet.strip()
32
 
33
+
34
  # Gradio Interface
35
  with gr.Blocks() as app:
36
  gr.Markdown("# AI Tweet Generator")