Reality123b commited on
Commit
fbf5fda
·
verified ·
1 Parent(s): 8faa1c2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +66 -18
app.py CHANGED
@@ -1,16 +1,31 @@
1
- from fastapi import FastAPI, HTTPException
2
- from pydantic import BaseModel
 
3
  from transformers import AutoModelForCausalLM, AutoTokenizer
4
  import torch
5
  from huggingface_hub import snapshot_download
6
  from safetensors.torch import load_file
 
 
 
 
 
7
 
8
  class ModelInput(BaseModel):
9
- prompt: str
10
- max_new_tokens: int = 2048
11
 
12
  app = FastAPI()
13
 
 
 
 
 
 
 
 
 
 
14
  # Define model paths
15
  BASE_MODEL_PATH = "HuggingFaceTB/SmolLM2-135M-Instruct"
16
  ADAPTER_PATH = "khurrameycon/SmolLM-135M-Instruct-qa_pairs_converted.json-25epochs"
@@ -18,7 +33,7 @@ ADAPTER_PATH = "khurrameycon/SmolLM-135M-Instruct-qa_pairs_converted.json-25epoc
18
  def load_model_and_tokenizer():
19
  """Load the model, tokenizer, and adapter weights."""
20
  try:
21
- print("Loading base model...")
22
  model = AutoModelForCausalLM.from_pretrained(
23
  BASE_MODEL_PATH,
24
  torch_dtype=torch.float16,
@@ -26,31 +41,38 @@ def load_model_and_tokenizer():
26
  device_map="auto"
27
  )
28
 
29
- print("Loading tokenizer...")
30
  tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_PATH)
31
 
32
- print("Downloading adapter weights...")
33
  adapter_path_local = snapshot_download(repo_id=ADAPTER_PATH)
34
 
35
- print("Loading adapter weights...")
36
  adapter_file = f"{adapter_path_local}/adapter_model.safetensors"
37
  state_dict = load_file(adapter_file)
38
 
39
- print("Applying adapter weights...")
40
  model.load_state_dict(state_dict, strict=False)
41
- print("Model and adapter loaded successfully!")
42
 
43
  return model, tokenizer
44
  except Exception as e:
45
- print(f"Error during model loading: {e}")
46
  raise
47
 
48
  # Load model and tokenizer at startup
49
- model, tokenizer = load_model_and_tokenizer()
 
 
 
 
 
50
 
51
  def generate_response(model, tokenizer, instruction, max_new_tokens=2048):
52
  """Generate a response from the model based on an instruction."""
53
  try:
 
 
54
  # Encode input with truncation
55
  inputs = tokenizer.encode(
56
  instruction,
@@ -59,6 +81,8 @@ def generate_response(model, tokenizer, instruction, max_new_tokens=2048):
59
  max_length=tokenizer.model_max_length
60
  ).to(model.device)
61
 
 
 
62
  # Create attention mask
63
  attention_mask = torch.ones(inputs.shape, device=model.device)
64
 
@@ -70,35 +94,59 @@ def generate_response(model, tokenizer, instruction, max_new_tokens=2048):
70
  temperature=0.7,
71
  top_p=0.9,
72
  do_sample=True,
 
 
73
  )
74
 
 
 
75
  # Decode and strip input prompt from response
76
  response = tokenizer.decode(outputs[0], skip_special_tokens=True)
77
  generated_text = response[len(instruction):].strip()
78
 
 
79
  return generated_text
80
  except Exception as e:
81
- print(f"Error generating response: {e}")
82
  raise ValueError(f"Error generating response: {e}")
83
 
84
  @app.post("/generate")
85
- async def generate_text(input: ModelInput):
86
  """Generate text based on the input prompt."""
87
  try:
88
- print(f"Received prompt: {input.prompt}")
 
 
 
 
 
89
  response = generate_response(
90
  model=model,
91
  tokenizer=tokenizer,
92
  instruction=input.prompt,
93
  max_new_tokens=input.max_new_tokens
94
  )
95
- print(f"Generated response: {response}")
 
 
 
 
 
96
  return {"generated_text": response}
97
  except Exception as e:
98
- print(f"Error: {str(e)}")
99
  raise HTTPException(status_code=500, detail=str(e))
100
 
101
  @app.get("/")
102
  async def root():
103
  """Root endpoint that returns a welcome message."""
104
- return {"message": "Welcome to the Model API!"}
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, HTTPException, Request
2
+ from fastapi.middleware.cors import CORSMiddleware
3
+ from pydantic import BaseModel, Field
4
  from transformers import AutoModelForCausalLM, AutoTokenizer
5
  import torch
6
  from huggingface_hub import snapshot_download
7
  from safetensors.torch import load_file
8
+ import logging
9
+
10
+ # Set up logging
11
+ logging.basicConfig(level=logging.INFO)
12
+ logger = logging.getLogger(__name__)
13
 
14
  class ModelInput(BaseModel):
15
+ prompt: str = Field(..., description="The input prompt for text generation")
16
+ max_new_tokens: int = Field(default=2048, gt=0, le=4096, description="Maximum number of tokens to generate")
17
 
18
  app = FastAPI()
19
 
20
+ # Add CORS middleware
21
+ app.add_middleware(
22
+ CORSMiddleware,
23
+ allow_origins=["*"],
24
+ allow_credentials=True,
25
+ allow_methods=["*"],
26
+ allow_headers=["*"],
27
+ )
28
+
29
  # Define model paths
30
  BASE_MODEL_PATH = "HuggingFaceTB/SmolLM2-135M-Instruct"
31
  ADAPTER_PATH = "khurrameycon/SmolLM-135M-Instruct-qa_pairs_converted.json-25epochs"
 
33
  def load_model_and_tokenizer():
34
  """Load the model, tokenizer, and adapter weights."""
35
  try:
36
+ logger.info("Loading base model...")
37
  model = AutoModelForCausalLM.from_pretrained(
38
  BASE_MODEL_PATH,
39
  torch_dtype=torch.float16,
 
41
  device_map="auto"
42
  )
43
 
44
+ logger.info("Loading tokenizer...")
45
  tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_PATH)
46
 
47
+ logger.info("Downloading adapter weights...")
48
  adapter_path_local = snapshot_download(repo_id=ADAPTER_PATH)
49
 
50
+ logger.info("Loading adapter weights...")
51
  adapter_file = f"{adapter_path_local}/adapter_model.safetensors"
52
  state_dict = load_file(adapter_file)
53
 
54
+ logger.info("Applying adapter weights...")
55
  model.load_state_dict(state_dict, strict=False)
56
+ logger.info("Model and adapter loaded successfully!")
57
 
58
  return model, tokenizer
59
  except Exception as e:
60
+ logger.error(f"Error during model loading: {e}", exc_info=True)
61
  raise
62
 
63
  # Load model and tokenizer at startup
64
+ try:
65
+ model, tokenizer = load_model_and_tokenizer()
66
+ except Exception as e:
67
+ logger.error(f"Failed to load model at startup: {e}", exc_info=True)
68
+ model = None
69
+ tokenizer = None
70
 
71
  def generate_response(model, tokenizer, instruction, max_new_tokens=2048):
72
  """Generate a response from the model based on an instruction."""
73
  try:
74
+ logger.info(f"Generating response for instruction: {instruction[:100]}...")
75
+
76
  # Encode input with truncation
77
  inputs = tokenizer.encode(
78
  instruction,
 
81
  max_length=tokenizer.model_max_length
82
  ).to(model.device)
83
 
84
+ logger.info(f"Input shape: {inputs.shape}")
85
+
86
  # Create attention mask
87
  attention_mask = torch.ones(inputs.shape, device=model.device)
88
 
 
94
  temperature=0.7,
95
  top_p=0.9,
96
  do_sample=True,
97
+ pad_token_id=tokenizer.pad_token_id,
98
+ eos_token_id=tokenizer.eos_token_id,
99
  )
100
 
101
+ logger.info(f"Output shape: {outputs.shape}")
102
+
103
  # Decode and strip input prompt from response
104
  response = tokenizer.decode(outputs[0], skip_special_tokens=True)
105
  generated_text = response[len(instruction):].strip()
106
 
107
+ logger.info(f"Generated text length: {len(generated_text)}")
108
  return generated_text
109
  except Exception as e:
110
+ logger.error(f"Error generating response: {e}", exc_info=True)
111
  raise ValueError(f"Error generating response: {e}")
112
 
113
  @app.post("/generate")
114
+ async def generate_text(input: ModelInput, request: Request):
115
  """Generate text based on the input prompt."""
116
  try:
117
+ if model is None or tokenizer is None:
118
+ raise HTTPException(status_code=503, detail="Model not loaded")
119
+
120
+ logger.info(f"Received request from {request.client.host}")
121
+ logger.info(f"Prompt: {input.prompt[:100]}...")
122
+
123
  response = generate_response(
124
  model=model,
125
  tokenizer=tokenizer,
126
  instruction=input.prompt,
127
  max_new_tokens=input.max_new_tokens
128
  )
129
+
130
+ if not response:
131
+ logger.warning("Generated empty response")
132
+ return {"generated_text": "", "warning": "Empty response generated"}
133
+
134
+ logger.info(f"Generated response length: {len(response)}")
135
  return {"generated_text": response}
136
  except Exception as e:
137
+ logger.error(f"Error in generate_text endpoint: {e}", exc_info=True)
138
  raise HTTPException(status_code=500, detail=str(e))
139
 
140
  @app.get("/")
141
  async def root():
142
  """Root endpoint that returns a welcome message."""
143
+ return {"message": "Welcome to the Model API!", "status": "running"}
144
+
145
+ @app.get("/health")
146
+ async def health_check():
147
+ """Health check endpoint."""
148
+ return {
149
+ "status": "healthy",
150
+ "model_loaded": model is not None and tokenizer is not None,
151
+ "model_device": str(next(model.parameters()).device) if model else None
152
+ }