nguyenlam0306 commited on
Commit
016478f
·
1 Parent(s): c60b90b

update for error of early stopping

Browse files
Files changed (1) hide show
  1. app.py +18 -4
app.py CHANGED
@@ -1,12 +1,26 @@
1
  import gradio as gr
2
- from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer
3
  from diffusers import StableDiffusionPipeline
4
  import torch
5
  import io
6
  from PIL import Image
7
 
8
- # Load the summarization model (lacos03/bart-base-finetuned-xsum)
9
- summarizer = pipeline("summarization", model="lacos03/bart-base-finetuned-xsum")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
  # Load the prompt generation model (microsoft/Promptist)
12
  promptist_model = AutoModelForCausalLM.from_pretrained("microsoft/Promptist")
@@ -20,7 +34,7 @@ device = "cuda" if torch.cuda.is_available() else "cpu"
20
  image_generator = StableDiffusionPipeline.from_pretrained(
21
  model_id,
22
  torch_dtype=torch.float16 if device == "cuda" else torch.float32,
23
- use_auth_token=False
24
  ).to(device)
25
 
26
  # Load LoRA weights
 
1
  import gradio as gr
2
+ from transformers import pipeline, AutoModelForSeq2SeqLM, AutoTokenizer, GenerationConfig
3
  from diffusers import StableDiffusionPipeline
4
  import torch
5
  import io
6
  from PIL import Image
7
 
8
+ # Load the summarization model (lacos03/bart-base-finetuned-xsum) manually
9
+ model_name = "lacos03/bart-base-finetuned-xsum"
10
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
11
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
12
+
13
+ # Create a custom GenerationConfig to fix early_stopping issue
14
+ generation_config = GenerationConfig.from_pretrained(model_name)
15
+ generation_config.early_stopping = True # Set to True to fix the error
16
+
17
+ # Initialize the summarization pipeline
18
+ summarizer = pipeline(
19
+ "summarization",
20
+ model=model,
21
+ tokenizer=tokenizer,
22
+ generation_config=generation_config
23
+ )
24
 
25
  # Load the prompt generation model (microsoft/Promptist)
26
  promptist_model = AutoModelForCausalLM.from_pretrained("microsoft/Promptist")
 
34
  image_generator = StableDiffusionPipeline.from_pretrained(
35
  model_id,
36
  torch_dtype=torch.float16 if device == "cuda" else torch.float32,
37
+ use_safetensors=True
38
  ).to(device)
39
 
40
  # Load LoRA weights