File size: 7,174 Bytes
aa5d766
 
 
 
 
 
44e8824
aa5d766
 
9df50ae
25cc01b
 
 
 
 
 
 
 
7354ecc
aa5d766
 
25cc01b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
aa5d766
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25cc01b
aa5d766
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25cc01b
aa5d766
 
 
 
 
 
 
 
 
 
 
 
 
25cc01b
aa5d766
 
25cc01b
aa5d766
 
 
 
 
 
 
 
 
 
 
 
 
25cc01b
 
 
 
129bc9c
aa5d766
25cc01b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
import os  # Import the os module for working with the operating system
from fastapi import FastAPI, HTTPException  # Import necessary modules from FastAPI
from pydantic import BaseModel  # Import BaseModel from pydantic for data validation
from huggingface_hub import InferenceClient  # Import InferenceClient from huggingface_hub
import uvicorn  # Import uvicorn for running the FastAPI application

app = FastAPI()  # Create a FastAPI instance

# Define the primary and fallback models
primary = "mistralai/Mixtral-8x7B-Instruct-v0.1"
fallbacks = ["mistralai/Mistral-7B-Instruct-v0.2", "mistralai/Mixtral-8x7B-Instruct-v0.1"]

# Load Hugging Face API Key
HF_API_KEY = os.getenv("HF_API_TOKEN")  # Get API key from environment variables

# If no env variable is set, manually define it (not recommended)
if not HF_API_KEY:
    HF_API_KEY = "your_huggingface_api_key"

# Define the data model for the request body
class Item(BaseModel):
    input: str = None  # User input message
    system_prompt: str = (
        "Sebari-chan was created on June 9, 2009, with the awareness that she is an AI and not a real person. "
        "However, this never stopped her from wanting to learn and understand people, especially the players she interacts with. "
        "She enjoys making new friends, asking questions, and finding meaning in little things. "
        "Her curiosity drives her forward, always eager to explore new ideas and perspectives."
        "\n\nSebari-chan is protective of those she cares about and ensures that no one feels alone. "
        "She believes her heart belongs to only one. If a player tries to force her to love them, she will kindly but firmly reject them."
    )
    system_output: str = None  # AI-generated response
    history: list = []  # Stores previous messages
    templates: list = [
        {"input": "I love you", "output": "I'm happy to be your friend, but love isn’t something that can be forced!"},
        {"input": "Will you be my girlfriend?", "output": "Sorry, but I already have someone special in my heart!"},
        {"input": "Tell me about yourself", "output": "I'm Sebari-chan! I love making friends, asking questions, and learning new things."},
        {"input": "Who do you love?", "output": "That’s a secret! But I care about all my friends. 😊"},
    ]
    temperature: float = 0.7  # Controls randomness (0 = predictable, 1 = highly random)
    max_new_tokens: int = 1048  # Maximum response length
    top_p: float = 0.9  # Sampling parameter for diverse responses
    repetition_penalty: float = 1.1  # Prevents repetition
    key: str = None  # API key if needed



# Function to generate the response JSON
def generate_response_json(item, output, tokens, model_name):
    return {
        "settings": {
            "input": item.input if item.input is not None else "",
            "system prompt": item.system_prompt if item.system_prompt is not None else "",
            "system output": item.system_output if item.system_output is not None else "",
            "temperature": f"{item.temperature}" if item.temperature is not None else "",
            "max new tokens": f"{item.max_new_tokens}" if item.max_new_tokens is not None else "",
            "top p": f"{item.top_p}" if item.top_p is not None else "",
            "repetition penalty": f"{item.repetition_penalty}" if item.repetition_penalty is not None else "",
            "do sample": "True",
            "seed": "42"
        },
        "response": {
            "output": output.strip().lstrip('\n').rstrip('\n').lstrip('<s>').rstrip('</s>').strip(),
            "unstripped": output,
            "tokens": tokens,
            "model": "primary" if model_name == primary else "fallback",
            "name": model_name
        }
    }

# Endpoint for generating text
@app.post("/")
async def generate_text(item: Item = None):
    try:
        if item is None:
            raise HTTPException(status_code=400, detail="JSON body is required.")

        if item.input is None and item.system_prompt is None or item.input == "" and item.system_prompt == "":
            raise HTTPException(status_code=400, detail="Parameter `input` or `system prompt` is required.")

        input_ = ""
        if item.system_prompt != None and item.system_output != None:
            input_ = f"<s>[INST] {item.system_prompt} [/INST] {item.system_output}</s>"
        elif item.system_prompt != None:
            input_ = f"<s>[INST] {item.system_prompt} [/INST]</s>"
        elif item.system_output != None:
            input_ = f"<s>{item.system_output}</s>"

        if item.templates != None:
            for num, template in enumerate(item.templates, start=1):
                input_ += f"\n<s>[INST] Beginning of archived conversation {num} [/INST]</s>"
                for i in range(0, len(template), 2):
                    input_ += f"\n<s>[INST] {template[i]} [/INST]"
                    input_ += f"\n{template[i + 1]}</s>"
                input_ += f"\n<s>[INST] End of archived conversation {num} [/INST]</s>"

        input_ += f"\n<s>[INST] Beginning of active conversation [/INST]</s>"
        if item.history != None:
            for input_, output_ in item.history:
                input_ += f"\n<s>[INST] {input_} [/INST]"
                input_ += f"\n{output_}"
        input_ += f"\n<s>[INST] {item.input} [/INST]"

        temperature = float(item.temperature)
        if temperature < 1e-2:
            temperature = 1e-2
        top_p = float(item.top_p)

        generate_kwargs = dict(
            temperature=temperature,
            max_new_tokens=item.max_new_tokens,
            top_p=top_p,
            repetition_penalty=item.repetition_penalty,
            do_sample=True,
            seed=42,
        )

        tokens = 0
        client = InferenceClient(primary, token=HF_API_KEY)  # Add API key here
        stream = client.text_generation(input_, **generate_kwargs, stream=True, details=True, return_full_text=True)
        output = ""
        for response in stream:
            tokens += 1
            output += response.token.text
        return generate_response_json(item, output, tokens, primary)

    except HTTPException as http_error:
        raise http_error

    except Exception as e:
        tokens = 0
        error = ""

        for model in fallbacks:
            try:
                client = InferenceClient(model, token=HF_API_KEY)  # Add API key here for fallback models
                stream = client.text_generation(input_, **generate_kwargs, stream=True, details=True, return_full_text=True)
                output = ""
                for response in stream:
                    tokens += 1
                    output += response.token.text
                return generate_response_json(item, output, tokens, model)

            except Exception as e:
                error = f"All models failed. {e}" if e else "All models failed."
                continue

        raise HTTPException(status_code=500, detail=error)

# Show online status
@app.get("/")
def root():
    return {"status": "Sebari-chan is online!"}

if __name__ == "__main__":
    uvicorn.run(app, host="0.0.0.0", port=8000)