Spaces:
Sleeping
Sleeping
File size: 3,418 Bytes
6e0397b 498ae97 6e0397b 8faa1c2 6e0397b 3b63b4a 8faa1c2 3b63b4a 8faa1c2 3b63b4a 8faa1c2 6e0397b 3b63b4a 8faa1c2 6e0397b 3b63b4a 6e0397b 3b63b4a 6e0397b 3b63b4a 6e0397b 8faa1c2 3b63b4a 6e0397b 498ae97 6e0397b 8faa1c2 da78b12 3b63b4a 8faa1c2 fa36528 6e0397b 8faa1c2 6e0397b da78b12 6e0397b da78b12 3b63b4a da78b12 6e0397b 3b63b4a 6e0397b 8faa1c2 6e0397b 8faa1c2 6e0397b 8faa1c2 6e0397b 8faa1c2 6e0397b 8faa1c2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 |
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
from huggingface_hub import snapshot_download
from safetensors.torch import load_file
class ModelInput(BaseModel):
prompt: str
max_new_tokens: int = 2048
app = FastAPI()
# Define model paths
BASE_MODEL_PATH = "HuggingFaceTB/SmolLM2-135M-Instruct"
ADAPTER_PATH = "khurrameycon/SmolLM-135M-Instruct-qa_pairs_converted.json-25epochs"
def load_model_and_tokenizer():
"""Load the model, tokenizer, and adapter weights."""
try:
print("Loading base model...")
model = AutoModelForCausalLM.from_pretrained(
BASE_MODEL_PATH,
torch_dtype=torch.float16,
trust_remote_code=True,
device_map="auto"
)
print("Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_PATH)
print("Downloading adapter weights...")
adapter_path_local = snapshot_download(repo_id=ADAPTER_PATH)
print("Loading adapter weights...")
adapter_file = f"{adapter_path_local}/adapter_model.safetensors"
state_dict = load_file(adapter_file)
print("Applying adapter weights...")
model.load_state_dict(state_dict, strict=False)
print("Model and adapter loaded successfully!")
return model, tokenizer
except Exception as e:
print(f"Error during model loading: {e}")
raise
# Load model and tokenizer at startup
model, tokenizer = load_model_and_tokenizer()
def generate_response(model, tokenizer, instruction, max_new_tokens=2048):
"""Generate a response from the model based on an instruction."""
try:
# Encode input with truncation
inputs = tokenizer.encode(
instruction,
return_tensors="pt",
truncation=True,
max_length=tokenizer.model_max_length
).to(model.device)
# Create attention mask
attention_mask = torch.ones(inputs.shape, device=model.device)
# Generate response
outputs = model.generate(
inputs,
attention_mask=attention_mask,
max_new_tokens=max_new_tokens,
temperature=0.7,
top_p=0.9,
do_sample=True,
)
# Decode and strip input prompt from response
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
generated_text = response[len(instruction):].strip()
return generated_text
except Exception as e:
print(f"Error generating response: {e}")
raise ValueError(f"Error generating response: {e}")
@app.post("/generate")
async def generate_text(input: ModelInput):
"""Generate text based on the input prompt."""
try:
print(f"Received prompt: {input.prompt}")
response = generate_response(
model=model,
tokenizer=tokenizer,
instruction=input.prompt,
max_new_tokens=input.max_new_tokens
)
print(f"Generated response: {response}")
return {"generated_text": response}
except Exception as e:
print(f"Error: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
@app.get("/")
async def root():
"""Root endpoint that returns a welcome message."""
return {"message": "Welcome to the Model API!"} |