File size: 3,108 Bytes
6e0397b
 
 
 
 
 
 
 
 
498ae97
6e0397b
 
 
 
 
 
 
3b63b4a
 
 
 
 
 
 
 
 
 
 
 
 
6e0397b
3b63b4a
 
6e0397b
3b63b4a
 
 
6e0397b
3b63b4a
 
6e0397b
3b63b4a
6e0397b
3b63b4a
 
 
 
6e0397b
3b63b4a
6e0397b
498ae97
6e0397b
 
da78b12
 
 
 
 
 
 
3b63b4a
6e0397b
 
 
 
 
 
 
 
 
da78b12
6e0397b
da78b12
3b63b4a
da78b12
 
 
 
6e0397b
 
3b63b4a
6e0397b
 
da78b12
6e0397b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
from huggingface_hub import snapshot_download
from safetensors.torch import load_file

class ModelInput(BaseModel):
    prompt: str
    max_new_tokens: int = 2048

app = FastAPI()

# Define model paths
base_model_path = "HuggingFaceTB/SmolLM2-135M-Instruct"
adapter_path = "khurrameycon/SmolLM-135M-Instruct-qa_pairs_converted.json-25epochs"

# Load the model and tokenizer
def load_model_and_tokenizer():
    try:
        print("Loading base model...")
        model = AutoModelForCausalLM.from_pretrained(
            base_model_path,
            torch_dtype=torch.float16,
            trust_remote_code=True,
            device_map="auto"
        )

        print("Loading tokenizer...")
        tokenizer = AutoTokenizer.from_pretrained(base_model_path)

        print("Downloading adapter weights...")
        adapter_path_local = snapshot_download(repo_id=adapter_path)

        print("Loading adapter weights...")
        adapter_file = f"{adapter_path_local}/adapter_model.safetensors"
        state_dict = load_file(adapter_file)

        print("Applying adapter weights...")
        model.load_state_dict(state_dict, strict=False)

        print("Model and adapter loaded successfully!")

        return model, tokenizer
    except Exception as e:
        print(f"Error during model loading: {e}")
        raise

model, tokenizer = load_model_and_tokenizer()

def generate_response(model, tokenizer, instruction, max_new_tokens=2048):
    """Generate a response from the model based on an instruction."""
    try:
        # Encode input with truncation
        inputs = tokenizer.encode(
            instruction,
            return_tensors="pt",
            truncation=True,
            max_length=tokenizer.model_max_length
        ).to(model.device)

        # Generate response
        outputs = model.generate(
            inputs,
            max_new_tokens=max_new_tokens,
            temperature=0.7,
            top_p=0.9,
            do_sample=True,
        )

        # Decode and strip input prompt from response
        response = tokenizer.decode(outputs[0], skip_special_tokens=True)
        generated_text = response[len(instruction):].strip()

        print(f"Instruction: {instruction}")  # Debugging line
        print(f"Generated Response: {generated_text}")  # Debugging line

        return generated_text

    except Exception as e:
        print(f"Error generating response: {e}")
        raise ValueError(f"Error generating response: {e}")


@app.post("/generate")
async def generate_text(input: ModelInput):
    try:
        response = generate_response(
            model=model,
            tokenizer=tokenizer,
            instruction=input.prompt,
            max_new_tokens=input.max_new_tokens
        )
        return {"generated_text": response}

    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

@app.get("/")
async def root():
    return {"message": "Welcome to the Model API!"}