fthor commited on
Commit
31d8777
·
1 Parent(s): 032b71e

Fixed missing output for last prediction

Browse files
Files changed (1) hide show
  1. app.py +14 -3
app.py CHANGED
@@ -7,6 +7,8 @@ from transformers import BitsAndBytesConfig
7
 
8
  from sentence_transformers import SentenceTransformer, util
9
 
 
 
10
  quantization_config = BitsAndBytesConfig(
11
  load_in_4bit=True,
12
  bnb_4bit_compute_dtype=torch.float16
@@ -20,7 +22,8 @@ model = LlavaForConditionalGeneration.from_pretrained(
20
  quantization_config=quantization_config,
21
  device_map="auto",
22
  # use_flash_attention_2=True,
23
- low_cpu_mem_usage=True
 
24
  )
25
 
26
  MAXIMUM_PIXEL_VALUES = 3725568
@@ -63,6 +66,14 @@ def text_to_image(image, prompt, duplications: float):
63
  batch = dict(input_ids=list(), attention_mask=list(), pixel_values=list())
64
  else:
65
  i += 1
 
 
 
 
 
 
 
 
66
  else:
67
  batched_inputs.append(inputs)
68
 
@@ -73,8 +84,8 @@ def text_to_image(image, prompt, duplications: float):
73
  batch['input_ids'] = batch['input_ids'].to(model.device)
74
  batch['attention_mask'] = batch['attention_mask'].to(model.device)
75
  batch['pixel_values'] = batch['pixel_values'].to(model.device)
76
- output = model.generate(**batch, max_new_tokens=500, temperature=0.3)
77
-
78
  # Unload GPU
79
  batch['input_ids'].to('cpu')
80
  batch['attention_mask'].to('cpu')
 
7
 
8
  from sentence_transformers import SentenceTransformer, util
9
 
10
+ from transformers import PretrainedConfig
11
+
12
  quantization_config = BitsAndBytesConfig(
13
  load_in_4bit=True,
14
  bnb_4bit_compute_dtype=torch.float16
 
22
  quantization_config=quantization_config,
23
  device_map="auto",
24
  # use_flash_attention_2=True,
25
+ low_cpu_mem_usage=True,
26
+ # config=PretrainedConfig(do_sample=True)
27
  )
28
 
29
  MAXIMUM_PIXEL_VALUES = 3725568
 
66
  batch = dict(input_ids=list(), attention_mask=list(), pixel_values=list())
67
  else:
68
  i += 1
69
+ if i >= len(inputs['pixel_values']) and len(batch['input_ids']) > 0:
70
+ batch['input_ids'] = torch.stack(batch['input_ids'], dim=0)
71
+ batch['attention_mask'] = torch.stack(batch['attention_mask'], dim=0)
72
+ batch['pixel_values'] = torch.stack(batch['pixel_values'], dim=0)
73
+
74
+ # Add to the batched_inputs
75
+ batched_inputs.append(batch)
76
+ batch = dict(input_ids=list(), attention_mask=list(), pixel_values=list())
77
  else:
78
  batched_inputs.append(inputs)
79
 
 
84
  batch['input_ids'] = batch['input_ids'].to(model.device)
85
  batch['attention_mask'] = batch['attention_mask'].to(model.device)
86
  batch['pixel_values'] = batch['pixel_values'].to(model.device)
87
+ # output = model.generate(**batch, max_new_tokens=500, temperature=0.3)
88
+ output = model.generate(**batch, max_new_tokens=500)
89
  # Unload GPU
90
  batch['input_ids'].to('cpu')
91
  batch['attention_mask'].to('cpu')