File size: 3,418 Bytes
6e0397b
 
 
 
 
 
 
 
 
498ae97
6e0397b
 
 
 
8faa1c2
 
6e0397b
3b63b4a
8faa1c2
3b63b4a
 
 
8faa1c2
3b63b4a
 
 
 
 
 
8faa1c2
6e0397b
3b63b4a
8faa1c2
6e0397b
3b63b4a
 
 
6e0397b
3b63b4a
 
 
6e0397b
3b63b4a
 
 
 
6e0397b
8faa1c2
3b63b4a
6e0397b
498ae97
6e0397b
 
8faa1c2
da78b12
 
 
 
 
 
3b63b4a
8faa1c2
fa36528
 
6e0397b
 
 
8faa1c2
6e0397b
 
 
 
 
 
da78b12
6e0397b
da78b12
3b63b4a
da78b12
6e0397b
3b63b4a
6e0397b
 
 
 
8faa1c2
6e0397b
8faa1c2
6e0397b
 
 
 
 
 
8faa1c2
6e0397b
 
8faa1c2
6e0397b
 
 
 
8faa1c2
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
from huggingface_hub import snapshot_download
from safetensors.torch import load_file

class ModelInput(BaseModel):
    prompt: str
    max_new_tokens: int = 2048

app = FastAPI()

# Define model paths
BASE_MODEL_PATH = "HuggingFaceTB/SmolLM2-135M-Instruct"
ADAPTER_PATH = "khurrameycon/SmolLM-135M-Instruct-qa_pairs_converted.json-25epochs"

def load_model_and_tokenizer():
    """Load the model, tokenizer, and adapter weights."""
    try:
        print("Loading base model...")
        model = AutoModelForCausalLM.from_pretrained(
            BASE_MODEL_PATH,
            torch_dtype=torch.float16,
            trust_remote_code=True,
            device_map="auto"
        )

        print("Loading tokenizer...")
        tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_PATH)

        print("Downloading adapter weights...")
        adapter_path_local = snapshot_download(repo_id=ADAPTER_PATH)

        print("Loading adapter weights...")
        adapter_file = f"{adapter_path_local}/adapter_model.safetensors"
        state_dict = load_file(adapter_file)

        print("Applying adapter weights...")
        model.load_state_dict(state_dict, strict=False)
        print("Model and adapter loaded successfully!")

        return model, tokenizer
    except Exception as e:
        print(f"Error during model loading: {e}")
        raise

# Load model and tokenizer at startup
model, tokenizer = load_model_and_tokenizer()

def generate_response(model, tokenizer, instruction, max_new_tokens=2048):
    """Generate a response from the model based on an instruction."""
    try:
        # Encode input with truncation
        inputs = tokenizer.encode(
            instruction,
            return_tensors="pt",
            truncation=True,
            max_length=tokenizer.model_max_length
        ).to(model.device)

        # Create attention mask
        attention_mask = torch.ones(inputs.shape, device=model.device)

        # Generate response
        outputs = model.generate(
            inputs,
            attention_mask=attention_mask,
            max_new_tokens=max_new_tokens,
            temperature=0.7,
            top_p=0.9,
            do_sample=True,
        )

        # Decode and strip input prompt from response
        response = tokenizer.decode(outputs[0], skip_special_tokens=True)
        generated_text = response[len(instruction):].strip()

        return generated_text
    except Exception as e:
        print(f"Error generating response: {e}")
        raise ValueError(f"Error generating response: {e}")

@app.post("/generate")
async def generate_text(input: ModelInput):
    """Generate text based on the input prompt."""
    try:
        print(f"Received prompt: {input.prompt}")
        response = generate_response(
            model=model,
            tokenizer=tokenizer,
            instruction=input.prompt,
            max_new_tokens=input.max_new_tokens
        )
        print(f"Generated response: {response}")
        return {"generated_text": response}
    except Exception as e:
        print(f"Error: {str(e)}")
        raise HTTPException(status_code=500, detail=str(e))

@app.get("/")
async def root():
    """Root endpoint that returns a welcome message."""
    return {"message": "Welcome to the Model API!"}