Update app.py
Browse files
app.py
CHANGED
@@ -14,6 +14,7 @@ from duckduckgo_search import DDGS
|
|
14 |
MODEL_NAME = "lilmeaty/my_xdd"
|
15 |
global_model = None
|
16 |
global_tokenizer = None
|
|
|
17 |
|
18 |
async def cleanup_memory(device: str):
|
19 |
gc.collect()
|
@@ -41,7 +42,7 @@ app = FastAPI()
|
|
41 |
|
42 |
@app.on_event("startup")
|
43 |
async def load_global_model():
|
44 |
-
global global_model, global_tokenizer
|
45 |
config = AutoConfig.from_pretrained(MODEL_NAME)
|
46 |
global_tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, config=config)
|
47 |
global_model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, config=config, torch_dtype=torch.float16)
|
@@ -49,6 +50,8 @@ async def load_global_model():
|
|
49 |
global_tokenizer.pad_token_id = config.pad_token_id or global_tokenizer.eos_token_id
|
50 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
51 |
global_model.to(device)
|
|
|
|
|
52 |
print(f"Modelo {MODEL_NAME} cargado correctamente en {device}.")
|
53 |
|
54 |
async def perform_duckasgo_search(query: str, max_results: int = 3) -> str:
|
@@ -68,7 +71,7 @@ async def perform_duckasgo_search(query: str, max_results: int = 3) -> str:
|
|
68 |
return summary
|
69 |
|
70 |
async def stream_text(request: GenerateRequest, device: str):
|
71 |
-
global global_model, global_tokenizer
|
72 |
encoded_input = global_tokenizer(request.input_text, return_tensors="pt").to(device)
|
73 |
input_ids = encoded_input.input_ids
|
74 |
accumulated_text = ""
|
@@ -119,7 +122,7 @@ async def stream_text(request: GenerateRequest, device: str):
|
|
119 |
chunk_token_count = 0
|
120 |
await asyncio.sleep(0)
|
121 |
input_ids = next_token
|
122 |
-
if token_id ==
|
123 |
break
|
124 |
if current_chunk:
|
125 |
yield current_chunk
|
@@ -130,7 +133,7 @@ async def stream_text(request: GenerateRequest, device: str):
|
|
130 |
|
131 |
@app.post("/generate")
|
132 |
async def generate_text(request: GenerateRequest):
|
133 |
-
global global_model, global_tokenizer
|
134 |
if global_model is None or global_tokenizer is None:
|
135 |
raise HTTPException(status_code=500, detail="El modelo no se ha cargado correctamente.")
|
136 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
14 |
MODEL_NAME = "lilmeaty/my_xdd"
|
15 |
global_model = None
|
16 |
global_tokenizer = None
|
17 |
+
global_tokens = {}
|
18 |
|
19 |
async def cleanup_memory(device: str):
|
20 |
gc.collect()
|
|
|
42 |
|
43 |
@app.on_event("startup")
|
44 |
async def load_global_model():
|
45 |
+
global global_model, global_tokenizer, global_tokens
|
46 |
config = AutoConfig.from_pretrained(MODEL_NAME)
|
47 |
global_tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, config=config)
|
48 |
global_model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, config=config, torch_dtype=torch.float16)
|
|
|
50 |
global_tokenizer.pad_token_id = config.pad_token_id or global_tokenizer.eos_token_id
|
51 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
52 |
global_model.to(device)
|
53 |
+
global_tokens["eos_token_id"] = global_tokenizer.eos_token_id
|
54 |
+
global_tokens["pad_token_id"] = global_tokenizer.pad_token_id
|
55 |
print(f"Modelo {MODEL_NAME} cargado correctamente en {device}.")
|
56 |
|
57 |
async def perform_duckasgo_search(query: str, max_results: int = 3) -> str:
|
|
|
71 |
return summary
|
72 |
|
73 |
async def stream_text(request: GenerateRequest, device: str):
|
74 |
+
global global_model, global_tokenizer, global_tokens
|
75 |
encoded_input = global_tokenizer(request.input_text, return_tensors="pt").to(device)
|
76 |
input_ids = encoded_input.input_ids
|
77 |
accumulated_text = ""
|
|
|
122 |
chunk_token_count = 0
|
123 |
await asyncio.sleep(0)
|
124 |
input_ids = next_token
|
125 |
+
if token_id == global_tokens["eos_token_id"]:
|
126 |
break
|
127 |
if current_chunk:
|
128 |
yield current_chunk
|
|
|
133 |
|
134 |
@app.post("/generate")
|
135 |
async def generate_text(request: GenerateRequest):
|
136 |
+
global global_model, global_tokenizer, global_tokens
|
137 |
if global_model is None or global_tokenizer is None:
|
138 |
raise HTTPException(status_code=500, detail="El modelo no se ha cargado correctamente.")
|
139 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|