Euryeth commited on
Commit
a5e8a2b
·
verified ·
1 Parent(s): c9c300e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -21
app.py CHANGED
@@ -11,19 +11,21 @@ import gradio as gr
11
  login(os.getenv("HUGGINGFACEHUB_API_TOKEN"))
12
 
13
  # Token authentication for requests
14
- API_TOKEN = os.getenv("HF_API_TOKEN") # You set this in Space secrets
15
 
16
  # Set up model loading and pipeline
17
  torch_dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
18
  os.environ['HF_HOME'] = '/tmp/cache'
19
-
20
  model_name = "cerebras/btlm-3b-8k-chat"
21
- tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
 
 
22
  model = AutoModelForCausalLM.from_pretrained(
23
  model_name,
24
  torch_dtype=torch_dtype,
25
  device_map="auto",
26
- trust_remote_code=True
 
27
  )
28
 
29
  generator = pipeline(
@@ -36,7 +38,6 @@ generator = pipeline(
36
  trust_remote_code=True
37
  )
38
 
39
- # Flask app
40
  app = Flask(__name__)
41
 
42
  @app.route("/")
@@ -45,7 +46,6 @@ def home():
45
 
46
  @app.route("/v1/chat/completions", methods=["POST"])
47
  def chat():
48
- # Token auth: require Bearer token
49
  auth_header = request.headers.get("Authorization", "")
50
  if not auth_header.startswith("Bearer ") or auth_header.split(" ")[1] != API_TOKEN:
51
  return jsonify({"error": "Unauthorized"}), 401
@@ -56,7 +56,6 @@ def chat():
56
  temperature = data.get("temperature", 0.7)
57
  stream = data.get("stream", False)
58
 
59
- # Build the prompt from chat history
60
  prompt = ""
61
  for msg in messages:
62
  role = msg.get("role", "user").capitalize()
@@ -64,7 +63,6 @@ def chat():
64
  prompt += f"{role}: {content}\n"
65
  prompt += "Assistant:"
66
 
67
- # If stream = True, stream response like OpenAI
68
  if stream:
69
  def generate_stream():
70
  output = generator(
@@ -97,7 +95,6 @@ def chat():
97
 
98
  return Response(generate_stream(), content_type="text/event-stream")
99
 
100
- # Non-streamed response
101
  output = generator(
102
  prompt,
103
  max_new_tokens=max_tokens,
@@ -109,24 +106,21 @@ def chat():
109
  reply = output[0]["generated_text"].replace(prompt, "").strip()
110
 
111
  return jsonify({
112
- "choices": [
113
- {
114
- "message": {
115
- "role": "assistant",
116
- "content": reply
117
- },
118
- "finish_reason": "stop",
119
- "index": 0
120
- }
121
- ]
122
  })
123
 
124
- # Optional Gradio frontend to keep Hugging Face Space active
125
  with gr.Blocks() as demo:
126
  gr.Markdown("### LLM backend is running and ready for API calls.")
127
 
128
  demo.launch()
129
 
130
  if __name__ == "__main__":
131
- # Listen on port 8080 as required by HF Spaces
132
  app.run(host="0.0.0.0", port=8080)
 
11
  login(os.getenv("HUGGINGFACEHUB_API_TOKEN"))
12
 
13
  # Token authentication for requests
14
+ API_TOKEN = os.getenv("HF_API_TOKEN")
15
 
16
  # Set up model loading and pipeline
17
  torch_dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
18
  os.environ['HF_HOME'] = '/tmp/cache'
 
19
  model_name = "cerebras/btlm-3b-8k-chat"
20
+ revision = "main" # Pin to stable revision
21
+
22
+ tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True, revision=revision)
23
  model = AutoModelForCausalLM.from_pretrained(
24
  model_name,
25
  torch_dtype=torch_dtype,
26
  device_map="auto",
27
+ trust_remote_code=True,
28
+ revision=revision
29
  )
30
 
31
  generator = pipeline(
 
38
  trust_remote_code=True
39
  )
40
 
 
41
  app = Flask(__name__)
42
 
43
  @app.route("/")
 
46
 
47
  @app.route("/v1/chat/completions", methods=["POST"])
48
  def chat():
 
49
  auth_header = request.headers.get("Authorization", "")
50
  if not auth_header.startswith("Bearer ") or auth_header.split(" ")[1] != API_TOKEN:
51
  return jsonify({"error": "Unauthorized"}), 401
 
56
  temperature = data.get("temperature", 0.7)
57
  stream = data.get("stream", False)
58
 
 
59
  prompt = ""
60
  for msg in messages:
61
  role = msg.get("role", "user").capitalize()
 
63
  prompt += f"{role}: {content}\n"
64
  prompt += "Assistant:"
65
 
 
66
  if stream:
67
  def generate_stream():
68
  output = generator(
 
95
 
96
  return Response(generate_stream(), content_type="text/event-stream")
97
 
 
98
  output = generator(
99
  prompt,
100
  max_new_tokens=max_tokens,
 
106
  reply = output[0]["generated_text"].replace(prompt, "").strip()
107
 
108
  return jsonify({
109
+ "choices": [{
110
+ "message": {
111
+ "role": "assistant",
112
+ "content": reply
113
+ },
114
+ "finish_reason": "stop",
115
+ "index": 0
116
+ }]
 
 
117
  })
118
 
119
+ # Optional Gradio frontend to keep Space alive
120
  with gr.Blocks() as demo:
121
  gr.Markdown("### LLM backend is running and ready for API calls.")
122
 
123
  demo.launch()
124
 
125
  if __name__ == "__main__":
 
126
  app.run(host="0.0.0.0", port=8080)