Spaces:
Sleeping
Sleeping
from fastapi import FastAPI, HTTPException | |
from pydantic import BaseModel | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
import torch | |
from huggingface_hub import snapshot_download | |
from safetensors.torch import load_file | |
class ModelInput(BaseModel): | |
prompt: str | |
max_new_tokens: int = 2048 | |
app = FastAPI() | |
# Define model paths | |
BASE_MODEL_PATH = "HuggingFaceTB/SmolLM2-135M-Instruct" | |
ADAPTER_PATH = "khurrameycon/SmolLM-135M-Instruct-qa_pairs_converted.json-25epochs" | |
def load_model_and_tokenizer(): | |
"""Load the model, tokenizer, and adapter weights.""" | |
try: | |
print("Loading base model...") | |
model = AutoModelForCausalLM.from_pretrained( | |
BASE_MODEL_PATH, | |
torch_dtype=torch.float16, | |
trust_remote_code=True, | |
device_map="auto" | |
) | |
print("Loading tokenizer...") | |
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_PATH) | |
print("Downloading adapter weights...") | |
adapter_path_local = snapshot_download(repo_id=ADAPTER_PATH) | |
print("Loading adapter weights...") | |
adapter_file = f"{adapter_path_local}/adapter_model.safetensors" | |
state_dict = load_file(adapter_file) | |
print("Applying adapter weights...") | |
model.load_state_dict(state_dict, strict=False) | |
print("Model and adapter loaded successfully!") | |
return model, tokenizer | |
except Exception as e: | |
print(f"Error during model loading: {e}") | |
raise | |
# Load model and tokenizer at startup | |
model, tokenizer = load_model_and_tokenizer() | |
def generate_response(model, tokenizer, instruction, max_new_tokens=2048): | |
"""Generate a response from the model based on an instruction.""" | |
try: | |
# Encode input with truncation | |
inputs = tokenizer.encode( | |
instruction, | |
return_tensors="pt", | |
truncation=True, | |
max_length=tokenizer.model_max_length | |
).to(model.device) | |
# Create attention mask | |
attention_mask = torch.ones(inputs.shape, device=model.device) | |
# Generate response | |
outputs = model.generate( | |
inputs, | |
attention_mask=attention_mask, | |
max_new_tokens=max_new_tokens, | |
temperature=0.7, | |
top_p=0.9, | |
do_sample=True, | |
) | |
# Decode and strip input prompt from response | |
response = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
generated_text = response[len(instruction):].strip() | |
return generated_text | |
except Exception as e: | |
print(f"Error generating response: {e}") | |
raise ValueError(f"Error generating response: {e}") | |
async def generate_text(input: ModelInput): | |
"""Generate text based on the input prompt.""" | |
try: | |
print(f"Received prompt: {input.prompt}") | |
response = generate_response( | |
model=model, | |
tokenizer=tokenizer, | |
instruction=input.prompt, | |
max_new_tokens=input.max_new_tokens | |
) | |
print(f"Generated response: {response}") | |
return {"generated_text": response} | |
except Exception as e: | |
print(f"Error: {str(e)}") | |
raise HTTPException(status_code=500, detail=str(e)) | |
async def root(): | |
"""Root endpoint that returns a welcome message.""" | |
return {"message": "Welcome to the Model API!"} |