pratikshahp commited on
Commit
c00c190
·
verified ·
1 Parent(s): 3e6b5bf

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -3
app.py CHANGED
@@ -8,8 +8,11 @@ import os
8
  load_dotenv()
9
 
10
  # Load the model and tokenizer
11
- tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b")
12
- model = AutoModelForCausalLM.from_pretrained("google/gemma-2b")
 
 
 
13
 
14
  # Function to generate blog content
15
  def generate_blog(topic, keywords):
@@ -26,7 +29,7 @@ def generate_blog(topic, keywords):
26
  Blog:
27
  """
28
  input_ids = tokenizer(prompt_template, return_tensors="pt", max_length=512, truncation=True)
29
- outputs = model.generate(input_ids["input_ids"], max_length=800, num_return_sequences=1, temperature=0.7)
30
  blog_content = tokenizer.decode(outputs[0], skip_special_tokens=True)
31
 
32
  return blog_content
 
8
  load_dotenv()
9
 
10
  # Load the model and tokenizer
11
+ #tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b")
12
+ #model = AutoModelForCausalLM.from_pretrained("google/gemma-2b")
13
+
14
+ tokenizer = AutoTokenizer.from_pretrained("tiiuae/falcon-rw-1b", trust_remote_code=True)
15
+ model = AutoModelForCausalLM.from_pretrained("tiiuae/falcon-rw-1b", trust_remote_code=True)
16
 
17
  # Function to generate blog content
18
  def generate_blog(topic, keywords):
 
29
  Blog:
30
  """
31
  input_ids = tokenizer(prompt_template, return_tensors="pt", max_length=512, truncation=True)
32
+ outputs = model.generate(input_ids["input_ids"], max_length=800, num_return_sequences=1)
33
  blog_content = tokenizer.decode(outputs[0], skip_special_tokens=True)
34
 
35
  return blog_content