Spaces:
Sleeping
Sleeping
File size: 6,553 Bytes
a04b12b fbf5fda 6e0397b fbf5fda 6e0397b fbf5fda 6e0397b fbf5fda 6e0397b 8faa1c2 6e0397b a04b12b 3b63b4a 8faa1c2 3b63b4a fbf5fda 3b63b4a 8faa1c2 3b63b4a a04b12b 3b63b4a fbf5fda a04b12b 6e0397b fbf5fda 8faa1c2 6e0397b fbf5fda 3b63b4a 6e0397b fbf5fda 3b63b4a fbf5fda 6e0397b 3b63b4a fbf5fda 3b63b4a 6e0397b 8faa1c2 fbf5fda 6e0397b 498ae97 6e0397b a04b12b fbf5fda 8faa1c2 a04b12b da78b12 a04b12b da78b12 3b63b4a a04b12b fbf5fda 6e0397b a04b12b 6e0397b fbf5fda a04b12b 3b63b4a a04b12b 6e0397b fbf5fda 6e0397b fbf5fda 8faa1c2 6e0397b fbf5fda 6e0397b fbf5fda 6e0397b fbf5fda 6e0397b 8faa1c2 fbf5fda a04b12b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 |
# server.py
from fastapi import FastAPI, HTTPException, Request
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
from huggingface_hub import snapshot_download
from safetensors.torch import load_file
import logging
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class ModelInput(BaseModel):
prompt: str = Field(..., description="The input prompt for text generation")
max_new_tokens: int = Field(default=2048, gt=0, le=4096, description="Maximum number of tokens to generate")
app = FastAPI()
# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Define model paths
BASE_MODEL_PATH = "HuggingFaceTB/SmolLM2-135M-Instruct"
ADAPTER_PATH = "khurrameycon/SmolLM-135M-Instruct-qa_pairs_converted.json-25epochs"
def format_prompt(instruction):
"""Format the prompt according to the model's expected format."""
return f"""### Instruction:
{instruction}
### Response:
"""
def load_model_and_tokenizer():
"""Load the model, tokenizer, and adapter weights."""
try:
logger.info("Loading base model...")
model = AutoModelForCausalLM.from_pretrained(
BASE_MODEL_PATH,
torch_dtype=torch.float16,
trust_remote_code=True,
device_map="auto",
use_cache=True
)
logger.info("Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(
BASE_MODEL_PATH,
padding_side="left",
truncation_side="left"
)
# Ensure the tokenizer has the necessary special tokens
special_tokens = {
"pad_token": "<|padding|>",
"eos_token": "</s>",
"bos_token": "<s>",
"unk_token": "<|unknown|>"
}
tokenizer.add_special_tokens(special_tokens)
# Resize the model embeddings to match the new tokenizer size
model.resize_token_embeddings(len(tokenizer))
logger.info("Downloading adapter weights...")
adapter_path_local = snapshot_download(repo_id=ADAPTER_PATH)
logger.info("Loading adapter weights...")
adapter_file = f"{adapter_path_local}/adapter_model.safetensors"
state_dict = load_file(adapter_file)
logger.info("Applying adapter weights...")
model.load_state_dict(state_dict, strict=False)
logger.info("Model and adapter loaded successfully!")
return model, tokenizer
except Exception as e:
logger.error(f"Error during model loading: {e}", exc_info=True)
raise
# Load model and tokenizer at startup
try:
model, tokenizer = load_model_and_tokenizer()
except Exception as e:
logger.error(f"Failed to load model at startup: {e}", exc_info=True)
model = None
tokenizer = None
def generate_response(model, tokenizer, instruction, max_new_tokens=2048):
"""Generate a response from the model based on an instruction."""
try:
# Format the prompt
formatted_prompt = format_prompt(instruction)
logger.info(f"Formatted prompt: {formatted_prompt}")
# Encode input with truncation
inputs = tokenizer(
formatted_prompt,
return_tensors="pt",
truncation=True,
max_length=tokenizer.model_max_length,
padding=True,
add_special_tokens=True
).to(model.device)
logger.info(f"Input shape: {inputs.input_ids.shape}")
# Generate response
with torch.inference_mode():
outputs = model.generate(
input_ids=inputs.input_ids,
attention_mask=inputs.attention_mask,
max_new_tokens=max_new_tokens,
temperature=0.7,
top_p=0.9,
top_k=50,
do_sample=True,
num_return_sequences=1,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
repetition_penalty=1.1,
length_penalty=1.0,
no_repeat_ngram_size=3
)
logger.info(f"Output shape: {outputs.shape}")
# Decode the response
response = tokenizer.decode(
outputs[0, inputs.input_ids.shape[1]:],
skip_special_tokens=True,
clean_up_tokenization_spaces=True
)
response = response.strip()
logger.info(f"Generated text length: {len(response)}")
logger.info(f"Generated text preview: {response[:100]}...")
if not response:
logger.warning("Empty response generated")
raise ValueError("Model generated an empty response")
return response
except Exception as e:
logger.error(f"Error generating response: {e}", exc_info=True)
raise ValueError(f"Error generating response: {e}")
@app.post("/generate")
async def generate_text(input: ModelInput, request: Request):
"""Generate text based on the input prompt."""
try:
if model is None or tokenizer is None:
raise HTTPException(status_code=503, detail="Model not loaded")
logger.info(f"Received request from {request.client.host}")
logger.info(f"Prompt: {input.prompt[:100]}...")
response = generate_response(
model=model,
tokenizer=tokenizer,
instruction=input.prompt,
max_new_tokens=input.max_new_tokens
)
return {"generated_text": response}
except Exception as e:
logger.error(f"Error in generate_text endpoint: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=str(e))
@app.get("/")
async def root():
"""Root endpoint that returns a welcome message."""
return {"message": "Welcome to the Model API!", "status": "running"}
@app.get("/health")
async def health_check():
"""Health check endpoint."""
return {
"status": "healthy",
"model_loaded": model is not None and tokenizer is not None,
"model_device": str(next(model.parameters()).device) if model else None,
"tokenizer_vocab_size": len(tokenizer) if tokenizer else None
}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000, log_level="info") |