Hjgugugjhuhjggg commited on
Commit
64be9ea
·
verified ·
1 Parent(s): 458211a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -4
app.py CHANGED
@@ -14,6 +14,7 @@ from duckduckgo_search import DDGS
14
  MODEL_NAME = "lilmeaty/my_xdd"
15
  global_model = None
16
  global_tokenizer = None
 
17
 
18
  async def cleanup_memory(device: str):
19
  gc.collect()
@@ -41,7 +42,7 @@ app = FastAPI()
41
 
42
  @app.on_event("startup")
43
  async def load_global_model():
44
- global global_model, global_tokenizer
45
  config = AutoConfig.from_pretrained(MODEL_NAME)
46
  global_tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, config=config)
47
  global_model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, config=config, torch_dtype=torch.float16)
@@ -49,6 +50,8 @@ async def load_global_model():
49
  global_tokenizer.pad_token_id = config.pad_token_id or global_tokenizer.eos_token_id
50
  device = "cuda" if torch.cuda.is_available() else "cpu"
51
  global_model.to(device)
 
 
52
  print(f"Modelo {MODEL_NAME} cargado correctamente en {device}.")
53
 
54
  async def perform_duckasgo_search(query: str, max_results: int = 3) -> str:
@@ -68,7 +71,7 @@ async def perform_duckasgo_search(query: str, max_results: int = 3) -> str:
68
  return summary
69
 
70
  async def stream_text(request: GenerateRequest, device: str):
71
- global global_model, global_tokenizer
72
  encoded_input = global_tokenizer(request.input_text, return_tensors="pt").to(device)
73
  input_ids = encoded_input.input_ids
74
  accumulated_text = ""
@@ -119,7 +122,7 @@ async def stream_text(request: GenerateRequest, device: str):
119
  chunk_token_count = 0
120
  await asyncio.sleep(0)
121
  input_ids = next_token
122
- if token_id == global_tokenizer.eos_token_id:
123
  break
124
  if current_chunk:
125
  yield current_chunk
@@ -130,7 +133,7 @@ async def stream_text(request: GenerateRequest, device: str):
130
 
131
  @app.post("/generate")
132
  async def generate_text(request: GenerateRequest):
133
- global global_model, global_tokenizer
134
  if global_model is None or global_tokenizer is None:
135
  raise HTTPException(status_code=500, detail="El modelo no se ha cargado correctamente.")
136
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
14
  MODEL_NAME = "lilmeaty/my_xdd"
15
  global_model = None
16
  global_tokenizer = None
17
+ global_tokens = {}
18
 
19
  async def cleanup_memory(device: str):
20
  gc.collect()
 
42
 
43
  @app.on_event("startup")
44
  async def load_global_model():
45
+ global global_model, global_tokenizer, global_tokens
46
  config = AutoConfig.from_pretrained(MODEL_NAME)
47
  global_tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, config=config)
48
  global_model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, config=config, torch_dtype=torch.float16)
 
50
  global_tokenizer.pad_token_id = config.pad_token_id or global_tokenizer.eos_token_id
51
  device = "cuda" if torch.cuda.is_available() else "cpu"
52
  global_model.to(device)
53
+ global_tokens["eos_token_id"] = global_tokenizer.eos_token_id
54
+ global_tokens["pad_token_id"] = global_tokenizer.pad_token_id
55
  print(f"Modelo {MODEL_NAME} cargado correctamente en {device}.")
56
 
57
  async def perform_duckasgo_search(query: str, max_results: int = 3) -> str:
 
71
  return summary
72
 
73
  async def stream_text(request: GenerateRequest, device: str):
74
+ global global_model, global_tokenizer, global_tokens
75
  encoded_input = global_tokenizer(request.input_text, return_tensors="pt").to(device)
76
  input_ids = encoded_input.input_ids
77
  accumulated_text = ""
 
122
  chunk_token_count = 0
123
  await asyncio.sleep(0)
124
  input_ids = next_token
125
+ if token_id == global_tokens["eos_token_id"]:
126
  break
127
  if current_chunk:
128
  yield current_chunk
 
133
 
134
  @app.post("/generate")
135
  async def generate_text(request: GenerateRequest):
136
+ global global_model, global_tokenizer, global_tokens
137
  if global_model is None or global_tokenizer is None:
138
  raise HTTPException(status_code=500, detail="El modelo no se ha cargado correctamente.")
139
  device = "cuda" if torch.cuda.is_available() else "cpu"