lionelgarnier commited on
Commit
f22117a
·
1 Parent(s): 74f1ee7

switch to deepseek 8B

Browse files
Files changed (1) hide show
  1. app.py +8 -9
app.py CHANGED
@@ -68,12 +68,7 @@ def preprocess_image(image: Image.Image) -> Image.Image:
68
  if trellis is None:
69
  # If the pipeline is not loaded, just return the original image
70
  return image
71
-
72
- # Check if image is a numpy array and convert to PIL Image if needed
73
- if isinstance(image, np.ndarray):
74
- image = Image.fromarray(image.astype('uint8'))
75
 
76
- # trellis.cuda()
77
  processed_image = trellis.preprocess_image(image)
78
  return processed_image
79
 
@@ -104,8 +99,8 @@ def get_text_gen_pipeline():
104
  try:
105
  device = "cuda" if torch.cuda.is_available() else "cpu"
106
  tokenizer = AutoTokenizer.from_pretrained(
107
- # "mistralai/Mistral-7B-Instruct-v0.3",
108
- "deepseek-ai/DeepSeek-R1-Distill-Qwen-32B",
109
  use_fast=True
110
  )
111
  tokenizer.pad_token = tokenizer.pad_token or tokenizer.eos_token
@@ -113,7 +108,7 @@ def get_text_gen_pipeline():
113
  _text_gen_pipeline = pipeline(
114
  # "text-generation",
115
  # model="mistralai/Mistral-7B-Instruct-v0.3",
116
- model="deepseek-ai/DeepSeek-R1-Distill-Qwen-32B",
117
  tokenizer=tokenizer,
118
  max_new_tokens=2048,
119
  device=device,
@@ -138,7 +133,11 @@ def get_trellis_pipeline():
138
 
139
 
140
  @spaces.GPU()
141
- def refine_prompt(prompt, system_prompt=DEFAULT_SYSTEM_PROMPT, progress=gr.Progress()):
 
 
 
 
142
  text_gen = get_text_gen_pipeline()
143
  if text_gen is None:
144
  return "", "Text generation model is unavailable."
 
68
  if trellis is None:
69
  # If the pipeline is not loaded, just return the original image
70
  return image
 
 
 
 
71
 
 
72
  processed_image = trellis.preprocess_image(image)
73
  return processed_image
74
 
 
99
  try:
100
  device = "cuda" if torch.cuda.is_available() else "cpu"
101
  tokenizer = AutoTokenizer.from_pretrained(
102
+ "mistralai/Mistral-7B-Instruct-v0.3",
103
+ "deepseek-ai/DeepSeek-R1-Distill-Llama-8B",
104
  use_fast=True
105
  )
106
  tokenizer.pad_token = tokenizer.pad_token or tokenizer.eos_token
 
108
  _text_gen_pipeline = pipeline(
109
  # "text-generation",
110
  # model="mistralai/Mistral-7B-Instruct-v0.3",
111
+ model="deepseek-ai/DeepSeek-R1-Distill-Llama-8B",
112
  tokenizer=tokenizer,
113
  max_new_tokens=2048,
114
  device=device,
 
133
 
134
 
135
  @spaces.GPU()
136
+ def refine_prompt(
137
+ prompt,
138
+ system_prompt=DEFAULT_SYSTEM_PROMPT,
139
+ progress=gr.Progress(track_tqdm=True)
140
+ ):
141
  text_gen = get_text_gen_pipeline()
142
  if text_gen is None:
143
  return "", "Text generation model is unavailable."