File size: 6,553 Bytes
a04b12b
fbf5fda
 
 
6e0397b
 
 
 
fbf5fda
 
 
 
 
6e0397b
 
fbf5fda
 
6e0397b
 
 
fbf5fda
 
 
 
 
 
 
 
 
6e0397b
8faa1c2
 
6e0397b
a04b12b
 
 
 
 
 
 
 
3b63b4a
8faa1c2
3b63b4a
fbf5fda
3b63b4a
8faa1c2
3b63b4a
 
a04b12b
 
3b63b4a
 
fbf5fda
a04b12b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6e0397b
fbf5fda
8faa1c2
6e0397b
fbf5fda
3b63b4a
 
6e0397b
fbf5fda
3b63b4a
fbf5fda
6e0397b
3b63b4a
 
fbf5fda
3b63b4a
6e0397b
8faa1c2
fbf5fda
 
 
 
 
 
6e0397b
498ae97
6e0397b
 
a04b12b
 
 
fbf5fda
8faa1c2
a04b12b
 
da78b12
 
a04b12b
 
 
da78b12
3b63b4a
a04b12b
fbf5fda
6e0397b
a04b12b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6e0397b
fbf5fda
 
a04b12b
 
 
 
 
 
3b63b4a
a04b12b
 
 
 
 
 
 
 
 
6e0397b
fbf5fda
6e0397b
 
 
fbf5fda
8faa1c2
6e0397b
fbf5fda
 
 
 
 
 
6e0397b
 
 
 
 
 
fbf5fda
6e0397b
 
fbf5fda
6e0397b
 
 
 
8faa1c2
fbf5fda
 
 
 
 
 
 
 
a04b12b
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
# server.py
from fastapi import FastAPI, HTTPException, Request
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
from huggingface_hub import snapshot_download
from safetensors.torch import load_file
import logging

# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

class ModelInput(BaseModel):
    prompt: str = Field(..., description="The input prompt for text generation")
    max_new_tokens: int = Field(default=2048, gt=0, le=4096, description="Maximum number of tokens to generate")

app = FastAPI()

# Add CORS middleware
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# Define model paths
BASE_MODEL_PATH = "HuggingFaceTB/SmolLM2-135M-Instruct"
ADAPTER_PATH = "khurrameycon/SmolLM-135M-Instruct-qa_pairs_converted.json-25epochs"

def format_prompt(instruction):
    """Format the prompt according to the model's expected format."""
    return f"""### Instruction:
{instruction}

### Response:
"""

def load_model_and_tokenizer():
    """Load the model, tokenizer, and adapter weights."""
    try:
        logger.info("Loading base model...")
        model = AutoModelForCausalLM.from_pretrained(
            BASE_MODEL_PATH,
            torch_dtype=torch.float16,
            trust_remote_code=True,
            device_map="auto",
            use_cache=True
        )

        logger.info("Loading tokenizer...")
        tokenizer = AutoTokenizer.from_pretrained(
            BASE_MODEL_PATH,
            padding_side="left",
            truncation_side="left"
        )
        
        # Ensure the tokenizer has the necessary special tokens
        special_tokens = {
            "pad_token": "<|padding|>",
            "eos_token": "</s>",
            "bos_token": "<s>",
            "unk_token": "<|unknown|>"
        }
        tokenizer.add_special_tokens(special_tokens)
        
        # Resize the model embeddings to match the new tokenizer size
        model.resize_token_embeddings(len(tokenizer))

        logger.info("Downloading adapter weights...")
        adapter_path_local = snapshot_download(repo_id=ADAPTER_PATH)

        logger.info("Loading adapter weights...")
        adapter_file = f"{adapter_path_local}/adapter_model.safetensors"
        state_dict = load_file(adapter_file)

        logger.info("Applying adapter weights...")
        model.load_state_dict(state_dict, strict=False)
        logger.info("Model and adapter loaded successfully!")

        return model, tokenizer
    except Exception as e:
        logger.error(f"Error during model loading: {e}", exc_info=True)
        raise

# Load model and tokenizer at startup
try:
    model, tokenizer = load_model_and_tokenizer()
except Exception as e:
    logger.error(f"Failed to load model at startup: {e}", exc_info=True)
    model = None
    tokenizer = None

def generate_response(model, tokenizer, instruction, max_new_tokens=2048):
    """Generate a response from the model based on an instruction."""
    try:
        # Format the prompt
        formatted_prompt = format_prompt(instruction)
        logger.info(f"Formatted prompt: {formatted_prompt}")
        
        # Encode input with truncation
        inputs = tokenizer(
            formatted_prompt,
            return_tensors="pt",
            truncation=True,
            max_length=tokenizer.model_max_length,
            padding=True,
            add_special_tokens=True
        ).to(model.device)

        logger.info(f"Input shape: {inputs.input_ids.shape}")
        
        # Generate response
        with torch.inference_mode():
            outputs = model.generate(
                input_ids=inputs.input_ids,
                attention_mask=inputs.attention_mask,
                max_new_tokens=max_new_tokens,
                temperature=0.7,
                top_p=0.9,
                top_k=50,
                do_sample=True,
                num_return_sequences=1,
                pad_token_id=tokenizer.pad_token_id,
                eos_token_id=tokenizer.eos_token_id,
                repetition_penalty=1.1,
                length_penalty=1.0,
                no_repeat_ngram_size=3
            )

        logger.info(f"Output shape: {outputs.shape}")

        # Decode the response
        response = tokenizer.decode(
            outputs[0, inputs.input_ids.shape[1]:],
            skip_special_tokens=True,
            clean_up_tokenization_spaces=True
        )

        response = response.strip()
        logger.info(f"Generated text length: {len(response)}")
        logger.info(f"Generated text preview: {response[:100]}...")
        
        if not response:
            logger.warning("Empty response generated")
            raise ValueError("Model generated an empty response")
            
        return response
    except Exception as e:
        logger.error(f"Error generating response: {e}", exc_info=True)
        raise ValueError(f"Error generating response: {e}")

@app.post("/generate")
async def generate_text(input: ModelInput, request: Request):
    """Generate text based on the input prompt."""
    try:
        if model is None or tokenizer is None:
            raise HTTPException(status_code=503, detail="Model not loaded")

        logger.info(f"Received request from {request.client.host}")
        logger.info(f"Prompt: {input.prompt[:100]}...")
        
        response = generate_response(
            model=model,
            tokenizer=tokenizer,
            instruction=input.prompt,
            max_new_tokens=input.max_new_tokens
        )
        
        return {"generated_text": response}
    except Exception as e:
        logger.error(f"Error in generate_text endpoint: {e}", exc_info=True)
        raise HTTPException(status_code=500, detail=str(e))

@app.get("/")
async def root():
    """Root endpoint that returns a welcome message."""
    return {"message": "Welcome to the Model API!", "status": "running"}

@app.get("/health")
async def health_check():
    """Health check endpoint."""
    return {
        "status": "healthy",
        "model_loaded": model is not None and tokenizer is not None,
        "model_device": str(next(model.parameters()).device) if model else None,
        "tokenizer_vocab_size": len(tokenizer) if tokenizer else None
    }

if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=8000, log_level="info")