lionelgarnier commited on
Commit
07c838c
·
1 Parent(s): 48cc831

fix image generation

Browse files
Files changed (1) hide show
  1. app.py +17 -12
app.py CHANGED
@@ -5,7 +5,7 @@ import os
5
  import spaces
6
  import torch
7
  from diffusers import DiffusionPipeline
8
- from transformers import pipeline
9
  from huggingface_hub import login
10
 
11
  hf_token = os.getenv("hf_token")
@@ -37,17 +37,22 @@ def get_image_gen_pipeline():
37
  def get_text_gen_pipeline():
38
  global _text_gen_pipeline
39
  if _text_gen_pipeline is None:
40
- try:
41
- device = "cuda" if torch.cuda.is_available() else "cpu"
42
- _text_gen_pipeline = pipeline(
43
- "text-generation",
44
- model="mistralai/Mistral-7B-Instruct-v0.3",
45
- max_new_tokens=2048,
46
- device=device,
47
- )
48
- except Exception as e:
49
- print(f"Error loading text generation model: {e}")
50
- return None
 
 
 
 
 
51
  return _text_gen_pipeline
52
 
53
  @spaces.GPU()
 
5
  import spaces
6
  import torch
7
  from diffusers import DiffusionPipeline
8
+ from transformers import pipeline, AutoTokenizer
9
  from huggingface_hub import login
10
 
11
  hf_token = os.getenv("hf_token")
 
37
  def get_text_gen_pipeline():
38
  global _text_gen_pipeline
39
  if _text_gen_pipeline is None:
40
+ try:
41
+ device = "cuda" if torch.cuda.is_available() else "cpu"
42
+ tokenizer = AutoTokenizer.from_pretrained(
43
+ "mistralai/Mistral-7B-Instruct-v0.3",
44
+ use_fast=True # Ensure fast tokenizer
45
+ )
46
+ _text_gen_pipeline = pipeline(
47
+ "text-generation",
48
+ model="mistralai/Mistral-7B-Instruct-v0.3",
49
+ tokenizer=tokenizer,
50
+ max_new_tokens=2048,
51
+ device=device,
52
+ )
53
+ except Exception as e:
54
+ print(f"Error loading text generation model: {e}")
55
+ return None
56
  return _text_gen_pipeline
57
 
58
  @spaces.GPU()