TalHach61 commited on
Commit
1b8d605
·
verified ·
1 Parent(s): 6f20c60

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -2
app.py CHANGED
@@ -25,6 +25,17 @@ login(token=hf_token)
25
 
26
  MAX_SEED = np.iinfo(np.int32).max
27
 
 
 
 
 
 
 
 
 
 
 
 
28
  def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
29
  if randomize_seed:
30
  seed = random.randint(0, MAX_SEED)
@@ -67,10 +78,9 @@ import PIL.Image as Image
67
 
68
  base_model = 'briaai/BRIA-4B-Adapt'
69
  controlnet_model = 'briaai/BRIA-4B-Adapt-ControlNet-Union'
70
-
71
  controlnet = BriaControlNetModel.from_pretrained(controlnet_model, torch_dtype=torch.bfloat16)
72
  pipe = BriaControlNetPipeline.from_pretrained(base_model, controlnet=controlnet, torch_dtype=torch.bfloat16, trust_remote_code=True)
73
- pipe.to("cuda")
74
 
75
  mode_mapping = {
76
  "depth": 0,
@@ -172,6 +182,7 @@ def infer(cond_in, image_in, prompt, inference_steps, guidance_scale, control_mo
172
  guidance_scale=guidance_scale,
173
  generator=torch.manual_seed(seed),
174
  max_sequence_length=128,
 
175
  ).images[0]
176
 
177
  torch.cuda.empty_cache()
 
25
 
26
  MAX_SEED = np.iinfo(np.int32).max
27
 
28
+ try:
29
+ local_dir = os.path.dirname(__file__)
30
+ except:
31
+ local_dir = '.'
32
+
33
+ hf_hub_download(repo_id="briaai/BRIA-4B-Adapt", filename='pipeline_bria.py', local_dir=local_dir)
34
+ hf_hub_download(repo_id="briaai/BRIA-4B-Adapt", filename='transformer_bria.py', local_dir=local_dir)
35
+ hf_hub_download(repo_id="briaai/BRIA-4B-Adapt", filename='bria_utils.py', local_dir=local_dir)
36
+ hf_hub_download(repo_id="briaai/BRIA-4B-Adapt-ControlNet-Union", filename='pipeline_bria_controlnet.py', local_dir=local_dir)
37
+ hf_hub_download(repo_id="briaai/BRIA-4B-Adapt-ControlNet-Union", filename='controlnet_bria.py', local_dir=local_dir)
38
+
39
  def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
40
  if randomize_seed:
41
  seed = random.randint(0, MAX_SEED)
 
78
 
79
  base_model = 'briaai/BRIA-4B-Adapt'
80
  controlnet_model = 'briaai/BRIA-4B-Adapt-ControlNet-Union'
 
81
  controlnet = BriaControlNetModel.from_pretrained(controlnet_model, torch_dtype=torch.bfloat16)
82
  pipe = BriaControlNetPipeline.from_pretrained(base_model, controlnet=controlnet, torch_dtype=torch.bfloat16, trust_remote_code=True)
83
+ pipe = pipeline.to(device="cuda", dtype=torch.bfloat16)
84
 
85
  mode_mapping = {
86
  "depth": 0,
 
182
  guidance_scale=guidance_scale,
183
  generator=torch.manual_seed(seed),
184
  max_sequence_length=128,
185
+ negative_prompt="Logo,Watermark,Text,Ugly,Morbid,Extra fingers,Poorly drawn hands,Mutation,Blurry,Extra limbs,Gross proportions,Missing arms,Mutated hands,Long neck,Duplicate"
186
  ).images[0]
187
 
188
  torch.cuda.empty_cache()