Reality123b commited on
Commit
ee9527e
·
verified ·
1 Parent(s): a04b12b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +63 -144
app.py CHANGED
@@ -1,20 +1,25 @@
1
- # server.py
2
  from fastapi import FastAPI, HTTPException, Request
3
  from fastapi.middleware.cors import CORSMiddleware
 
4
  from pydantic import BaseModel, Field
5
- from transformers import AutoModelForCausalLM, AutoTokenizer
6
- import torch
7
- from huggingface_hub import snapshot_download
8
- from safetensors.torch import load_file
9
  import logging
10
 
11
  # Set up logging
12
  logging.basicConfig(level=logging.INFO)
13
  logger = logging.getLogger(__name__)
14
 
15
- class ModelInput(BaseModel):
16
- prompt: str = Field(..., description="The input prompt for text generation")
17
- max_new_tokens: int = Field(default=2048, gt=0, le=4096, description="Maximum number of tokens to generate")
 
 
 
 
 
 
18
 
19
  app = FastAPI()
20
 
@@ -27,167 +32,81 @@ app.add_middleware(
27
  allow_headers=["*"],
28
  )
29
 
30
- # Define model paths
31
- BASE_MODEL_PATH = "HuggingFaceTB/SmolLM2-135M-Instruct"
32
- ADAPTER_PATH = "khurrameycon/SmolLM-135M-Instruct-qa_pairs_converted.json-25epochs"
33
-
34
- def format_prompt(instruction):
35
- """Format the prompt according to the model's expected format."""
36
- return f"""### Instruction:
37
- {instruction}
38
-
39
- ### Response:
40
- """
41
-
42
- def load_model_and_tokenizer():
43
- """Load the model, tokenizer, and adapter weights."""
44
- try:
45
- logger.info("Loading base model...")
46
- model = AutoModelForCausalLM.from_pretrained(
47
- BASE_MODEL_PATH,
48
- torch_dtype=torch.float16,
49
- trust_remote_code=True,
50
- device_map="auto",
51
- use_cache=True
52
- )
53
-
54
- logger.info("Loading tokenizer...")
55
- tokenizer = AutoTokenizer.from_pretrained(
56
- BASE_MODEL_PATH,
57
- padding_side="left",
58
- truncation_side="left"
59
- )
60
-
61
- # Ensure the tokenizer has the necessary special tokens
62
- special_tokens = {
63
- "pad_token": "<|padding|>",
64
- "eos_token": "</s>",
65
- "bos_token": "<s>",
66
- "unk_token": "<|unknown|>"
67
- }
68
- tokenizer.add_special_tokens(special_tokens)
69
-
70
- # Resize the model embeddings to match the new tokenizer size
71
- model.resize_token_embeddings(len(tokenizer))
72
-
73
- logger.info("Downloading adapter weights...")
74
- adapter_path_local = snapshot_download(repo_id=ADAPTER_PATH)
75
-
76
- logger.info("Loading adapter weights...")
77
- adapter_file = f"{adapter_path_local}/adapter_model.safetensors"
78
- state_dict = load_file(adapter_file)
79
-
80
- logger.info("Applying adapter weights...")
81
- model.load_state_dict(state_dict, strict=False)
82
- logger.info("Model and adapter loaded successfully!")
83
-
84
- return model, tokenizer
85
- except Exception as e:
86
- logger.error(f"Error during model loading: {e}", exc_info=True)
87
- raise
88
 
89
- # Load model and tokenizer at startup
90
- try:
91
- model, tokenizer = load_model_and_tokenizer()
92
- except Exception as e:
93
- logger.error(f"Failed to load model at startup: {e}", exc_info=True)
94
- model = None
95
- tokenizer = None
96
 
97
- def generate_response(model, tokenizer, instruction, max_new_tokens=2048):
98
- """Generate a response from the model based on an instruction."""
99
  try:
100
- # Format the prompt
101
- formatted_prompt = format_prompt(instruction)
102
- logger.info(f"Formatted prompt: {formatted_prompt}")
103
-
104
- # Encode input with truncation
105
- inputs = tokenizer(
106
- formatted_prompt,
107
- return_tensors="pt",
108
- truncation=True,
109
- max_length=tokenizer.model_max_length,
110
- padding=True,
111
- add_special_tokens=True
112
- ).to(model.device)
113
-
114
- logger.info(f"Input shape: {inputs.input_ids.shape}")
115
 
116
- # Generate response
117
- with torch.inference_mode():
118
- outputs = model.generate(
119
- input_ids=inputs.input_ids,
120
- attention_mask=inputs.attention_mask,
121
- max_new_tokens=max_new_tokens,
122
- temperature=0.7,
123
- top_p=0.9,
124
- top_k=50,
125
- do_sample=True,
126
- num_return_sequences=1,
127
- pad_token_id=tokenizer.pad_token_id,
128
- eos_token_id=tokenizer.eos_token_id,
129
- repetition_penalty=1.1,
130
- length_penalty=1.0,
131
- no_repeat_ngram_size=3
132
- )
133
-
134
- logger.info(f"Output shape: {outputs.shape}")
135
-
136
- # Decode the response
137
- response = tokenizer.decode(
138
- outputs[0, inputs.input_ids.shape[1]:],
139
- skip_special_tokens=True,
140
- clean_up_tokenization_spaces=True
141
  )
142
 
143
- response = response.strip()
144
- logger.info(f"Generated text length: {len(response)}")
145
- logger.info(f"Generated text preview: {response[:100]}...")
146
-
147
- if not response:
148
- logger.warning("Empty response generated")
149
- raise ValueError("Model generated an empty response")
150
-
151
- return response
152
  except Exception as e:
153
- logger.error(f"Error generating response: {e}", exc_info=True)
154
  raise ValueError(f"Error generating response: {e}")
155
 
156
- @app.post("/generate")
157
- async def generate_text(input: ModelInput, request: Request):
158
- """Generate text based on the input prompt."""
159
  try:
160
- if model is None or tokenizer is None:
161
- raise HTTPException(status_code=503, detail="Model not loaded")
 
 
 
162
 
163
- logger.info(f"Received request from {request.client.host}")
164
- logger.info(f"Prompt: {input.prompt[:100]}...")
165
 
166
- response = generate_response(
167
- model=model,
168
- tokenizer=tokenizer,
169
- instruction=input.prompt,
170
- max_new_tokens=input.max_new_tokens
 
 
 
171
  )
172
-
173
- return {"generated_text": response}
174
  except Exception as e:
175
- logger.error(f"Error in generate_text endpoint: {e}", exc_info=True)
176
  raise HTTPException(status_code=500, detail=str(e))
177
 
178
  @app.get("/")
179
  async def root():
180
  """Root endpoint that returns a welcome message."""
181
- return {"message": "Welcome to the Model API!", "status": "running"}
 
 
 
 
182
 
183
  @app.get("/health")
184
  async def health_check():
185
  """Health check endpoint."""
186
  return {
187
  "status": "healthy",
188
- "model_loaded": model is not None and tokenizer is not None,
189
- "model_device": str(next(model.parameters()).device) if model else None,
190
- "tokenizer_vocab_size": len(tokenizer) if tokenizer else None
191
  }
192
 
193
  if __name__ == "__main__":
 
 
1
  from fastapi import FastAPI, HTTPException, Request
2
  from fastapi.middleware.cors import CORSMiddleware
3
+ from fastapi.responses import StreamingResponse
4
  from pydantic import BaseModel, Field
5
+ from typing import List
6
+ import os
7
+ from huggingface_hub import InferenceClient
 
8
  import logging
9
 
10
  # Set up logging
11
  logging.basicConfig(level=logging.INFO)
12
  logger = logging.getLogger(__name__)
13
 
14
+ class Message(BaseModel):
15
+ role: str = Field(..., description="Role of the message sender (system/user/assistant)")
16
+ content: str = Field(..., description="Content of the message")
17
+
18
+ class ChatInput(BaseModel):
19
+ messages: List[Message] = Field(..., description="List of conversation messages")
20
+ max_tokens: int = Field(default=2048, gt=0, le=4096, description="Maximum number of tokens to generate")
21
+ temperature: float = Field(default=0.5, gt=0, le=2.0, description="Temperature for sampling")
22
+ top_p: float = Field(default=0.7, gt=0, le=1.0, description="Top-p sampling parameter")
23
 
24
  app = FastAPI()
25
 
 
32
  allow_headers=["*"],
33
  )
34
 
35
+ # Initialize Hugging Face client
36
+ hf_client = InferenceClient(
37
+ api_key=os.getenv("HF_TOKEN"),
38
+ timeout=30
39
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
 
41
+ MODEL_ID = "mistralai/Mistral-Nemo-Instruct-2407"
 
 
 
 
 
 
42
 
43
+ async def generate_stream(messages: List[Message], max_tokens: int, temperature: float, top_p: float):
44
+ """Generate streaming response using Hugging Face Inference API."""
45
  try:
46
+ # Convert messages to the format expected by the API
47
+ formatted_messages = [{"role": msg.role, "content": msg.content} for msg in messages]
 
 
 
 
 
 
 
 
 
 
 
 
 
48
 
49
+ # Create the streaming completion
50
+ stream = hf_client.chat.completions.create(
51
+ model=MODEL_ID,
52
+ messages=formatted_messages,
53
+ temperature=temperature,
54
+ max_tokens=max_tokens,
55
+ top_p=top_p,
56
+ stream=True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
  )
58
 
59
+ # Stream the response chunks
60
+ for chunk in stream:
61
+ if chunk.choices[0].delta.content is not None:
62
+ yield chunk.choices[0].delta.content
63
+
 
 
 
 
64
  except Exception as e:
65
+ logger.error(f"Error in generate_stream: {e}", exc_info=True)
66
  raise ValueError(f"Error generating response: {e}")
67
 
68
+ @app.post("/chat")
69
+ async def chat_stream(input: ChatInput, request: Request):
70
+ """Stream chat completions based on the input messages."""
71
  try:
72
+ if not os.getenv("HF_TOKEN"):
73
+ raise HTTPException(
74
+ status_code=500,
75
+ detail="HF_TOKEN environment variable not set"
76
+ )
77
 
78
+ logger.info(f"Received chat request from {request.client.host}")
79
+ logger.info(f"Number of messages: {len(input.messages)}")
80
 
81
+ return StreamingResponse(
82
+ generate_stream(
83
+ messages=input.messages,
84
+ max_tokens=input.max_tokens,
85
+ temperature=input.temperature,
86
+ top_p=input.top_p
87
+ ),
88
+ media_type="text/event-stream"
89
  )
 
 
90
  except Exception as e:
91
+ logger.error(f"Error in chat_stream endpoint: {e}", exc_info=True)
92
  raise HTTPException(status_code=500, detail=str(e))
93
 
94
  @app.get("/")
95
  async def root():
96
  """Root endpoint that returns a welcome message."""
97
+ return {
98
+ "message": "Welcome to the Hugging Face Inference API Streaming Chat!",
99
+ "status": "running",
100
+ "model": MODEL_ID
101
+ }
102
 
103
  @app.get("/health")
104
  async def health_check():
105
  """Health check endpoint."""
106
  return {
107
  "status": "healthy",
108
+ "model": MODEL_ID,
109
+ "hf_token_set": bool(os.getenv("HF_TOKEN"))
 
110
  }
111
 
112
  if __name__ == "__main__":