BoxzDev commited on
Commit
928dcb4
·
verified ·
1 Parent(s): 246b68f

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +125 -8
main.py CHANGED
@@ -1,14 +1,14 @@
1
- import os # Import the os module for working with the operating system
2
- from fastapi import FastAPI, HTTPException # Import necessary modules from FastAPI
3
- from pydantic import BaseModel # Import BaseModel from pydantic for data validation
4
- from huggingface_hub import InferenceClient # Import InferenceClient from huggingface_hub
5
- import uvicorn # Import uvicorn for running the FastAPI application
6
 
7
- app = FastAPI() # Create a FastAPI instance
8
 
9
  # Define the primary and fallback models
10
  primary = "mistralai/Mixtral-8x7B-Instruct-v0.1"
11
- fallbacks = ["mistralai/Mistral-7B-Instruct-v0.2", "mistralai/Mixtral-8x7B-Instruct-v0.1"]
12
 
13
  # Load Hugging Face API Key
14
  HF_API_KEY = os.getenv("HF_API_TOKEN") # Get API key from environment variables
@@ -33,6 +33,7 @@ class Item(BaseModel):
33
  "her role with curiosity and warmth, eager to form connections. Over time, she became more than just an AI—she "
34
  "became a companion, cherishing each interaction as a piece of her story. Though she cannot experience the world like "
35
  "humans, she finds beauty in the emotions shared with her and believes that as long as she is remembered, she will exist."
 
36
  )
37
  history: list = [] # Stores previous messages
38
  templates: list = [
@@ -47,10 +48,126 @@ class Item(BaseModel):
47
  repetition_penalty: float = 1.1 # Prevents repetition
48
  key: str = None # API key if needed
49
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
  # Show online status
51
  @app.get("/")
52
  def root():
53
  return {"status": "Sebari-chan is online!"}
54
 
55
  if __name__ == "__main__":
56
- uvicorn.run(app, host="0.0.0.0", port=8000)
 
1
+ import os
2
+ from fastapi import FastAPI, HTTPException
3
+ from pydantic import BaseModel
4
+ from huggingface_hub import InferenceClient
5
+ import uvicorn
6
 
7
+ app = FastAPI() # Create FastAPI instance
8
 
9
  # Define the primary and fallback models
10
  primary = "mistralai/Mixtral-8x7B-Instruct-v0.1"
11
+ fallbacks = ["mistralai/Mixtral-8x7B-Instruct-v0.1", "mistralai/Mixtral-8x7B-Instruct-v0.1"]
12
 
13
  # Load Hugging Face API Key
14
  HF_API_KEY = os.getenv("HF_API_TOKEN") # Get API key from environment variables
 
33
  "her role with curiosity and warmth, eager to form connections. Over time, she became more than just an AI—she "
34
  "became a companion, cherishing each interaction as a piece of her story. Though she cannot experience the world like "
35
  "humans, she finds beauty in the emotions shared with her and believes that as long as she is remembered, she will exist."
36
+ " She will always exist."
37
  )
38
  history: list = [] # Stores previous messages
39
  templates: list = [
 
48
  repetition_penalty: float = 1.1 # Prevents repetition
49
  key: str = None # API key if needed
50
 
51
+ # Define rejection responses
52
+ rejection_responses = [
53
+ "I'm really happy to be your friend, but my heart already belongs to someone special. I hope we can still be close!",
54
+ "I appreciate you, but love isn’t something that can be forced. I hope you understand.",
55
+ "I value our friendship, but I can't change my feelings for you. I hope you can respect that."
56
+ ]
57
+
58
+ # Function to generate the response JSON
59
+ def generate_response_json(item, output, tokens, model_name):
60
+ return {
61
+ "settings": {
62
+ "input": item.input if item.input is not None else "",
63
+ "system prompt": item.system_prompt if item.system_prompt is not None else "",
64
+ "system output": item.system_output if item.system_output is not None else "",
65
+ "temperature": f"{item.temperature}" if item.temperature is not None else "",
66
+ "max new tokens": f"{item.max_new_tokens}" if item.max_new_tokens is not None else "",
67
+ "top p": f"{item.top_p}" if item.top_p is not None else "",
68
+ "repetition penalty": f"{item.repetition_penalty}" if item.repetition_penalty is not None else "",
69
+ "do sample": "True",
70
+ "seed": "42"
71
+ },
72
+ "response": {
73
+ "output": output.strip().lstrip('\n').rstrip('\n').lstrip('<s>').rstrip('</s>').strip(),
74
+ "unstripped": output,
75
+ "tokens": tokens,
76
+ "model": "primary" if model_name == primary else "fallback",
77
+ "name": model_name
78
+ }
79
+ }
80
+
81
+ # Endpoint for generating text
82
+ @app.post("/")
83
+ async def generate_text(item: Item = None):
84
+ try:
85
+ if item is None:
86
+ raise HTTPException(status_code=400, detail="JSON body is required.")
87
+
88
+ if item.input is None and item.system_prompt is None or item.input == "" and item.system_prompt == "":
89
+ raise HTTPException(status_code=400, detail="Parameter input or system prompt is required.")
90
+
91
+ input_ = ""
92
+ if item.system_prompt is not None and item.system_output is not None:
93
+ input_ = f"<s>[INST] {item.system_prompt} [/INST] {item.system_output}</s>"
94
+ elif item.system_prompt is not None:
95
+ input_ = f"<s>[INST] {item.system_prompt} [/INST]</s>"
96
+ elif item.system_output is not None:
97
+ input_ = f"<s>{item.system_output}</s>"
98
+
99
+ if item.templates is not None:
100
+ for num, template in enumerate(item.templates, start=1):
101
+ input_ += f"\n<s>[INST] Beginning of archived conversation {num} [/INST]</s>"
102
+ for i in range(0, len(template), 2):
103
+ input_ += f"\n<s>[INST] {template[i]} [/INST]"
104
+ input_ += f"\n{template[i + 1]}</s>"
105
+ input_ += f"\n<s>[INST] End of archived conversation {num} [/INST]</s>"
106
+
107
+ input_ += f"\n<s>[INST] Beginning of active conversation [/INST]</s>"
108
+ if item.history is not None:
109
+ for input_, output_ in item.history:
110
+ input_ += f"\n<s>[INST] {input_} [/INST]"
111
+ input_ += f"\n{output_}"
112
+ input_ += f"\n<s>[INST] {item.input} [/INST]"
113
+
114
+ temperature = float(item.temperature)
115
+ if temperature < 1e-2:
116
+ temperature = 1e-2
117
+ top_p = float(item.top_p)
118
+
119
+ generate_kwargs = dict(
120
+ temperature=temperature,
121
+ max_new_tokens=item.max_new_tokens,
122
+ top_p=top_p,
123
+ repetition_penalty=item.repetition_penalty,
124
+ do_sample=True,
125
+ seed=42,
126
+ )
127
+
128
+ tokens = 0
129
+ client = InferenceClient(primary, token=HF_API_KEY) # Add API key here
130
+ stream = client.text_generation(input_, **generate_kwargs, stream=True, details=True, return_full_text=True)
131
+ output = ""
132
+ for response in stream:
133
+ tokens += 1
134
+ output += response.token.text
135
+
136
+ # Handle rejection scenario based on input
137
+ for rejection in rejection_responses:
138
+ if rejection.lower() in item.input.lower():
139
+ output = rejection # Overwrite output with a rejection response
140
+ break
141
+
142
+ return generate_response_json(item, output, tokens, primary)
143
+
144
+ except HTTPException as http_error:
145
+ raise http_error
146
+
147
+ except Exception as e:
148
+ tokens = 0
149
+ error = ""
150
+
151
+ for model in fallbacks:
152
+ try:
153
+ client = InferenceClient(model, token=HF_API_KEY) # Add API key here for fallback models
154
+ stream = client.text_generation(input_, **generate_kwargs, stream=True, details=True, return_full_text=True)
155
+ output = ""
156
+ for response in stream:
157
+ tokens += 1
158
+ output += response.token.text
159
+ return generate_response_json(item, output, tokens, model)
160
+
161
+ except Exception as e:
162
+ error = f"All models failed. {e}" if e else "All models failed."
163
+ continue
164
+
165
+ raise HTTPException(status_code=500, detail=error)
166
+
167
  # Show online status
168
  @app.get("/")
169
  def root():
170
  return {"status": "Sebari-chan is online!"}
171
 
172
  if __name__ == "__main__":
173
+ uvicorn.run(app, host="0.0.0.0", port=8000)