Spaces:
Sleeping
Sleeping
import gradio as gr | |
from huggingface_hub import hf_hub_download # Still useful if model is private and needs custom token | |
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, GenerationConfig | |
from transformers.pipelines import pipeline | |
import re | |
import os | |
import torch # Required for transformers models | |
import threading | |
import time # For short sleeps in streamer | |
# --- Model Configuration --- | |
# Your SmilyAI model ID on Hugging Face Hub | |
MODEL_REPO_ID = "Smilyai-labs/Sam-reason-S3" | |
N_CTX = 2048 # Context window for the model (applies more to LLMs) | |
MAX_TOKENS = 500 | |
TEMPERATURE = 0.7 | |
TOP_P = 0.9 | |
STOP_SEQUENCES = ["USER:", "\n\n"] # Model will stop generating when it encounters these | |
# --- Safety Configuration --- | |
print("Loading safety model (unitary/toxic-bert)...") | |
try: | |
safety_classifier = pipeline( | |
"text-classification", | |
model="unitary/toxic-bert", | |
framework="pt" # Use PyTorch backend | |
) | |
print("Safety model loaded successfully.") | |
except Exception as e: | |
print(f"Error loading safety model: {e}") | |
exit(1) | |
TOXICITY_THRESHOLD = 0.9 | |
def is_text_safe(text: str) -> tuple[bool, str | None]: | |
""" | |
Checks if the given text contains unsafe content using the safety classifier. | |
Returns (True, None) if safe, or (False, detected_label) if unsafe. | |
""" | |
if not text.strip(): | |
return True, None | |
try: | |
results = safety_classifier(text) | |
if results and results[0]['label'] == 'toxic' and results[0]['score'] > TOXICITY_THRESHOLD: | |
print(f"Detected unsafe content: '{text.strip()}' (Score: {results[0]['score']:.4f})") | |
return False, results[0]['label'] | |
return True, None | |
except Exception as e: | |
print(f"Error during safety check: {e}") | |
# If the safety check fails, consider it unsafe by default or log and let it pass. | |
return False, "safety_check_failed" | |
# --- Main Model Loading (using Transformers) --- | |
print(f"Loading tokenizer for {MODEL_REPO_ID}...") | |
try: | |
# AutoTokenizer fetches the correct tokenizer for the model | |
tokenizer = AutoTokenizer.from_pretrained(MODEL_REPO_ID) | |
print("Tokenizer loaded.") | |
except Exception as e: | |
print(f"Error loading tokenizer: {e}") | |
print("Make sure the model ID is correct and, if it's a private repo, you've set the HF_TOKEN secret in your Space.") | |
exit(1) | |
print(f"Loading model {MODEL_REPO_ID} (this will be VERY slow on CPU and might take a long time)...") | |
try: | |
# AutoModelForCausalLM loads the language model. | |
# device_map="cpu" ensures all model layers are loaded onto the CPU. | |
# torch_dtype=torch.float32 is standard for CPU; float16 can save memory but might not be faster on all CPUs. | |
model = AutoModelForCausalLM.from_pretrained(MODEL_REPO_ID, device_map="cpu", torch_dtype=torch.float32) | |
model.eval() # Set model to evaluation mode for inference | |
print("Model loaded successfully.") | |
except Exception as e: | |
print(f"Error loading model: {e}") | |
print("Ensure it's a standard Transformers model and you have HF_TOKEN secret if private.") | |
exit(1) | |
# Configure generation for streaming | |
# Use GenerationConfig from the model for default parameters, then override as needed. | |
generation_config = GenerationConfig.from_pretrained(MODEL_REPO_ID) | |
generation_config.max_new_tokens = MAX_TOKENS | |
generation_config.temperature = TEMPERATURE | |
generation_config.top_p = TOP_P | |
generation_config.do_sample = True # Enable sampling for temperature/top_p | |
# Set EOS and PAD token IDs for proper generation stopping and padding | |
generation_config.eos_token_id = tokenizer.eos_token_id if tokenizer.eos_token_id is not None else -1 | |
generation_config.pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else tokenizer.eos_token_id if tokenizer.eos_token_id is not None else -1 | |
# Fallback for pad_token_id if not explicitly set | |
if generation_config.pad_token_id == -1: | |
generation_config.pad_token_id = 0 # Fallback to 0, though not ideal for all models | |
# --- Custom Streamer for Gradio and Safety Check --- | |
class GradioSafetyStreamer(TextIteratorStreamer): | |
def __init__(self, tokenizer, safety_checker_fn, toxicity_threshold, skip_special_tokens=True, **kwargs): | |
super().__init__(tokenizer, skip_special_tokens=skip_special_tokens, **kwargs) | |
self.safety_checker_fn = safety_checker_fn | |
self.toxicity_threshold = toxicity_threshold | |
self.current_sentence_buffer = "" | |
self.output_queue = [] # Queue to store safety-checked sentences to be yielded by Gradio | |
self.sentence_regex = re.compile(r'[.!?]\s*') # Regex for sentence end, simple version | |
self.text_done = threading.Event() # Event to signal when internal text processing is complete | |
def on_finalized_text(self, text: str, stream_end: bool = False): | |
# This method is called by the superclass when a decoded token chunk is ready. | |
self.current_sentence_buffer += text | |
# Split buffer into sentences. Keep the last part in buffer if it's incomplete. | |
sentences = self.sentence_regex.split(self.current_sentence_buffer) | |
sentences_to_process = [] | |
if not stream_end and sentences and self.sentence_regex.search(sentences[-1]) is None: | |
# If not end of stream and last part is not a complete sentence, buffer it for next time | |
sentences_to_process = sentences[:-1] | |
self.current_sentence_buffer = sentences[-1] | |
else: | |
# Otherwise, process all segments and clear buffer | |
sentences_to_process = sentences | |
self.current_sentence_buffer = "" | |
for sentence in sentences_to_process: | |
if not sentence.strip(): continue # Skip empty strings from splitting | |
is_safe, detected_label = self.safety_checker_fn(sentence) | |
if not is_safe: | |
print(f"Safety check failed for: '{sentence.strip()}' (Detected: {detected_label})") | |
self.output_queue.append("[Content removed due to safety guidelines]") | |
self.output_queue.append("__STOP_GENERATION__") # Special signal to stop LLM generation | |
return # Stop processing further sentences from this chunk if unsafe | |
else: | |
self.output_queue.append(sentence) | |
if stream_end: | |
# If stream ends and there's leftover text in buffer, process it | |
if self.current_sentence_buffer.strip(): | |
is_safe, detected_label = self.safety_checker_fn(self.current_sentence_buffer) | |
if not is_safe: | |
self.output_queue.append("[Content removed due to safety guidelines]") | |
else: | |
self.output_queue.append(self.current_sentence_buffer) | |
self.current_sentence_buffer = "" # Clear after final check | |
self.text_done.set() # Signal that all text processing is complete | |
def __iter__(self): | |
# This method allows Gradio to iterate over the safety-checked output. | |
while True: | |
if self.output_queue: | |
item = self.output_queue.pop(0) | |
if item == "__STOP_GENERATION__": | |
# Signal to the outer Gradio loop to stop yielding. | |
raise StopIteration | |
yield item | |
elif self.text_done.is_set(): # Check if internal generation and safety processing is truly finished | |
raise StopIteration # End of generation and safety check | |
else: | |
time.sleep(0.01) # Small sleep to prevent busy-waiting while waiting for new tokens | |
# --- Inference Function with Safety and Streaming --- | |
def generate_word_by_word_with_safety(prompt_text: str): | |
formatted_prompt = f"USER: {prompt_text}\nASSISTANT:" | |
# Encode input on the model's device (CPU) | |
input_ids = tokenizer(formatted_prompt, return_tensors="pt").input_ids.to(model.device) | |
# Initialize the custom streamer | |
streamer = GradioSafetyStreamer(tokenizer, is_text_safe, TOXICITY_THRESHOLD) | |
# Use a separate thread for model generation because model.generate is a blocking call. | |
# This allows the streamer to continuously fill its queue while Gradio yields. | |
generate_kwargs = { | |
"input_ids": input_ids, | |
"streamer": streamer, | |
"generation_config": generation_config, | |
# Explicitly pass these for clarity, even if in generation_config | |
"do_sample": True, | |
"temperature": TEMPERATURE, | |
"top_p": TOP_P, | |
"max_new_tokens": MAX_TOKENS, | |
"eos_token_id": generation_config.eos_token_id, | |
"pad_token_id": generation_config.pad_token_id, | |
} | |
# Start generation in a separate thread | |
thread = threading.Thread(target=model.generate, kwargs=generate_kwargs) | |
thread.start() | |
# Yield tokens from the streamer's output queue for Gradio to display progressively | |
full_generated_text = "" | |
try: | |
for new_sentence_or_chunk in streamer: | |
full_generated_text += new_sentence_or_chunk | |
yield full_generated_text # Gradio expects accumulated string for streaming display | |
except StopIteration: | |
pass # Streamer signaled end | |
except Exception as e: | |
print(f"Error during streaming: {e}") | |
yield full_generated_text + f"\n\n[Error during streaming: {e}]" # Show error in output | |
finally: | |
thread.join() # Ensure the generation thread finishes gracefully | |
# --- Gradio Blocks Interface --- | |
with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
gr.Markdown( | |
""" | |
# SmilyAI: Sam-reason-S3 Inference (Transformers on CPU with Safety Filter) | |
Enter a prompt and get a word-by-word response from the **Smilyai-labs/Sam-reason-S3** model. | |
**⚠️ WARNING: This model is running on a free CPU tier via the `transformers` library. Inference will be VERY slow.** | |
All generated sentences are checked for safety using an AI filter; unsafe content will be replaced. | |
""" | |
) | |
with gr.Row(): | |
user_prompt = gr.Textbox( | |
lines=5, | |
label="Enter your prompt here:", | |
placeholder="e.g., Explain the concept of quantum entanglement in simple terms.", | |
scale=4 | |
) | |
generated_text = gr.Textbox(label="Generated Text", show_copy_button=True, scale=6) | |
send_button = gr.Button("Send", variant="primary") | |
send_button.click( | |
fn=generate_word_by_word_with_safety, | |
inputs=user_prompt, | |
outputs=generated_text, | |
api_name="predict", | |
) | |
if __name__ == "__main__": | |
print("Launching Gradio app...") | |
demo.launch(server_name="0.0.0.0", server_port=7860) | |