File size: 10,753 Bytes
f38ab88
c850ce2
 
 
 
 
 
 
 
 
 
 
 
 
f38ab88
 
 
cffaee2
f38ab88
cffaee2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c850ce2
cffaee2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c850ce2
 
f38ab88
c850ce2
 
 
f38ab88
c850ce2
 
f38ab88
 
c850ce2
f38ab88
c850ce2
 
 
 
 
 
f38ab88
c850ce2
 
f38ab88
 
c850ce2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cffaee2
c850ce2
 
cffaee2
c850ce2
cffaee2
c850ce2
 
 
 
 
cffaee2
c850ce2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f38ab88
cffaee2
 
 
 
 
c850ce2
 
 
 
cffaee2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c850ce2
cffaee2
 
 
 
f38ab88
 
 
cffaee2
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
import gradio as gr
from huggingface_hub import hf_hub_download # Still useful if model is private and needs custom token
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, GenerationConfig
from transformers.pipelines import pipeline
import re
import os
import torch # Required for transformers models
import threading
import time # For short sleeps in streamer

# --- Model Configuration ---
# Your SmilyAI model ID on Hugging Face Hub
MODEL_REPO_ID = "Smilyai-labs/Sam-reason-S3"
N_CTX = 2048 # Context window for the model (applies more to LLMs)
MAX_TOKENS = 500
TEMPERATURE = 0.7
TOP_P = 0.9
STOP_SEQUENCES = ["USER:", "\n\n"] # Model will stop generating when it encounters these

# --- Safety Configuration ---
print("Loading safety model (unitary/toxic-bert)...")
try:
    safety_classifier = pipeline(
        "text-classification",
        model="unitary/toxic-bert",
        framework="pt" # Use PyTorch backend
    )
    print("Safety model loaded successfully.")
except Exception as e:
    print(f"Error loading safety model: {e}")
    exit(1)

TOXICITY_THRESHOLD = 0.9

def is_text_safe(text: str) -> tuple[bool, str | None]:
    """
    Checks if the given text contains unsafe content using the safety classifier.
    Returns (True, None) if safe, or (False, detected_label) if unsafe.
    """
    if not text.strip():
        return True, None

    try:
        results = safety_classifier(text)
        if results and results[0]['label'] == 'toxic' and results[0]['score'] > TOXICITY_THRESHOLD:
            print(f"Detected unsafe content: '{text.strip()}' (Score: {results[0]['score']:.4f})")
            return False, results[0]['label']
        
        return True, None

    except Exception as e:
        print(f"Error during safety check: {e}")
        # If the safety check fails, consider it unsafe by default or log and let it pass.
        return False, "safety_check_failed"


# --- Main Model Loading (using Transformers) ---
print(f"Loading tokenizer for {MODEL_REPO_ID}...")
try:
    # AutoTokenizer fetches the correct tokenizer for the model
    tokenizer = AutoTokenizer.from_pretrained(MODEL_REPO_ID)
    print("Tokenizer loaded.")
except Exception as e:
    print(f"Error loading tokenizer: {e}")
    print("Make sure the model ID is correct and, if it's a private repo, you've set the HF_TOKEN secret in your Space.")
    exit(1)

print(f"Loading model {MODEL_REPO_ID} (this will be VERY slow on CPU and might take a long time)...")
try:
    # AutoModelForCausalLM loads the language model.
    # device_map="cpu" ensures all model layers are loaded onto the CPU.
    # torch_dtype=torch.float32 is standard for CPU; float16 can save memory but might not be faster on all CPUs.
    model = AutoModelForCausalLM.from_pretrained(MODEL_REPO_ID, device_map="cpu", torch_dtype=torch.float32)
    model.eval() # Set model to evaluation mode for inference
    print("Model loaded successfully.")
except Exception as e:
    print(f"Error loading model: {e}")
    print("Ensure it's a standard Transformers model and you have HF_TOKEN secret if private.")
    exit(1)

# Configure generation for streaming
# Use GenerationConfig from the model for default parameters, then override as needed.
generation_config = GenerationConfig.from_pretrained(MODEL_REPO_ID)
generation_config.max_new_tokens = MAX_TOKENS
generation_config.temperature = TEMPERATURE
generation_config.top_p = TOP_P
generation_config.do_sample = True # Enable sampling for temperature/top_p
# Set EOS and PAD token IDs for proper generation stopping and padding
generation_config.eos_token_id = tokenizer.eos_token_id if tokenizer.eos_token_id is not None else -1
generation_config.pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else tokenizer.eos_token_id if tokenizer.eos_token_id is not None else -1
# Fallback for pad_token_id if not explicitly set
if generation_config.pad_token_id == -1:
    generation_config.pad_token_id = 0 # Fallback to 0, though not ideal for all models

# --- Custom Streamer for Gradio and Safety Check ---
class GradioSafetyStreamer(TextIteratorStreamer):
    def __init__(self, tokenizer, safety_checker_fn, toxicity_threshold, skip_special_tokens=True, **kwargs):
        super().__init__(tokenizer, skip_special_tokens=skip_special_tokens, **kwargs)
        self.safety_checker_fn = safety_checker_fn
        self.toxicity_threshold = toxicity_threshold
        self.current_sentence_buffer = ""
        self.output_queue = [] # Queue to store safety-checked sentences to be yielded by Gradio
        self.sentence_regex = re.compile(r'[.!?]\s*') # Regex for sentence end, simple version
        self.text_done = threading.Event() # Event to signal when internal text processing is complete

    def on_finalized_text(self, text: str, stream_end: bool = False):
        # This method is called by the superclass when a decoded token chunk is ready.
        self.current_sentence_buffer += text
        
        # Split buffer into sentences. Keep the last part in buffer if it's incomplete.
        sentences = self.sentence_regex.split(self.current_sentence_buffer)
        
        sentences_to_process = []
        if not stream_end and sentences and self.sentence_regex.search(sentences[-1]) is None:
            # If not end of stream and last part is not a complete sentence, buffer it for next time
            sentences_to_process = sentences[:-1]
            self.current_sentence_buffer = sentences[-1]
        else:
            # Otherwise, process all segments and clear buffer
            sentences_to_process = sentences
            self.current_sentence_buffer = ""

        for sentence in sentences_to_process:
            if not sentence.strip(): continue # Skip empty strings from splitting

            is_safe, detected_label = self.safety_checker_fn(sentence)
            if not is_safe:
                print(f"Safety check failed for: '{sentence.strip()}' (Detected: {detected_label})")
                self.output_queue.append("[Content removed due to safety guidelines]")
                self.output_queue.append("__STOP_GENERATION__") # Special signal to stop LLM generation
                return # Stop processing further sentences from this chunk if unsafe

            else:
                self.output_queue.append(sentence)

        if stream_end:
            # If stream ends and there's leftover text in buffer, process it
            if self.current_sentence_buffer.strip():
                is_safe, detected_label = self.safety_checker_fn(self.current_sentence_buffer)
                if not is_safe:
                    self.output_queue.append("[Content removed due to safety guidelines]")
                else:
                    self.output_queue.append(self.current_sentence_buffer)
                self.current_sentence_buffer = "" # Clear after final check
            self.text_done.set() # Signal that all text processing is complete

    def __iter__(self):
        # This method allows Gradio to iterate over the safety-checked output.
        while True:
            if self.output_queue:
                item = self.output_queue.pop(0)
                if item == "__STOP_GENERATION__":
                    # Signal to the outer Gradio loop to stop yielding.
                    raise StopIteration
                yield item
            elif self.text_done.is_set(): # Check if internal generation and safety processing is truly finished
                raise StopIteration # End of generation and safety check
            else:
                time.sleep(0.01) # Small sleep to prevent busy-waiting while waiting for new tokens


# --- Inference Function with Safety and Streaming ---
def generate_word_by_word_with_safety(prompt_text: str):
    formatted_prompt = f"USER: {prompt_text}\nASSISTANT:"
    # Encode input on the model's device (CPU)
    input_ids = tokenizer(formatted_prompt, return_tensors="pt").input_ids.to(model.device)

    # Initialize the custom streamer
    streamer = GradioSafetyStreamer(tokenizer, is_text_safe, TOXICITY_THRESHOLD)

    # Use a separate thread for model generation because model.generate is a blocking call.
    # This allows the streamer to continuously fill its queue while Gradio yields.
    generate_kwargs = {
        "input_ids": input_ids,
        "streamer": streamer,
        "generation_config": generation_config,
        # Explicitly pass these for clarity, even if in generation_config
        "do_sample": True,
        "temperature": TEMPERATURE,
        "top_p": TOP_P,
        "max_new_tokens": MAX_TOKENS,
        "eos_token_id": generation_config.eos_token_id,
        "pad_token_id": generation_config.pad_token_id,
    }
    
    # Start generation in a separate thread
    thread = threading.Thread(target=model.generate, kwargs=generate_kwargs)
    thread.start()

    # Yield tokens from the streamer's output queue for Gradio to display progressively
    full_generated_text = ""
    try:
        for new_sentence_or_chunk in streamer:
            full_generated_text += new_sentence_or_chunk
            yield full_generated_text # Gradio expects accumulated string for streaming display
    except StopIteration:
        pass # Streamer signaled end
    except Exception as e:
        print(f"Error during streaming: {e}")
        yield full_generated_text + f"\n\n[Error during streaming: {e}]" # Show error in output
    finally:
        thread.join() # Ensure the generation thread finishes gracefully


# --- Gradio Blocks Interface ---
with gr.Blocks(theme=gr.themes.Soft()) as demo:
    gr.Markdown(
        """
        # SmilyAI: Sam-reason-S3 Inference (Transformers on CPU with Safety Filter)
        Enter a prompt and get a word-by-word response from the **Smilyai-labs/Sam-reason-S3** model.
        **⚠️ WARNING: This model is running on a free CPU tier via the `transformers` library. Inference will be VERY slow.**
        All generated sentences are checked for safety using an AI filter; unsafe content will be replaced.
        """
    )

    with gr.Row():
        user_prompt = gr.Textbox(
            lines=5,
            label="Enter your prompt here:",
            placeholder="e.g., Explain the concept of quantum entanglement in simple terms.",
            scale=4
        )
        generated_text = gr.Textbox(label="Generated Text", show_copy_button=True, scale=6)

    send_button = gr.Button("Send", variant="primary")

    send_button.click(
        fn=generate_word_by_word_with_safety,
        inputs=user_prompt,
        outputs=generated_text,
        api_name="predict",
    )

if __name__ == "__main__":
    print("Launching Gradio app...")
    demo.launch(server_name="0.0.0.0", server_port=7860)