Sam-S-3-api / app.py
boning123's picture
Update app.py
c850ce2 verified
raw
history blame
10.8 kB
import gradio as gr
from huggingface_hub import hf_hub_download # Still useful if model is private and needs custom token
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, GenerationConfig
from transformers.pipelines import pipeline
import re
import os
import torch # Required for transformers models
import threading
import time # For short sleeps in streamer
# --- Model Configuration ---
# Your SmilyAI model ID on Hugging Face Hub
MODEL_REPO_ID = "Smilyai-labs/Sam-reason-S3"
N_CTX = 2048 # Context window for the model (applies more to LLMs)
MAX_TOKENS = 500
TEMPERATURE = 0.7
TOP_P = 0.9
STOP_SEQUENCES = ["USER:", "\n\n"] # Model will stop generating when it encounters these
# --- Safety Configuration ---
print("Loading safety model (unitary/toxic-bert)...")
try:
safety_classifier = pipeline(
"text-classification",
model="unitary/toxic-bert",
framework="pt" # Use PyTorch backend
)
print("Safety model loaded successfully.")
except Exception as e:
print(f"Error loading safety model: {e}")
exit(1)
TOXICITY_THRESHOLD = 0.9
def is_text_safe(text: str) -> tuple[bool, str | None]:
"""
Checks if the given text contains unsafe content using the safety classifier.
Returns (True, None) if safe, or (False, detected_label) if unsafe.
"""
if not text.strip():
return True, None
try:
results = safety_classifier(text)
if results and results[0]['label'] == 'toxic' and results[0]['score'] > TOXICITY_THRESHOLD:
print(f"Detected unsafe content: '{text.strip()}' (Score: {results[0]['score']:.4f})")
return False, results[0]['label']
return True, None
except Exception as e:
print(f"Error during safety check: {e}")
# If the safety check fails, consider it unsafe by default or log and let it pass.
return False, "safety_check_failed"
# --- Main Model Loading (using Transformers) ---
print(f"Loading tokenizer for {MODEL_REPO_ID}...")
try:
# AutoTokenizer fetches the correct tokenizer for the model
tokenizer = AutoTokenizer.from_pretrained(MODEL_REPO_ID)
print("Tokenizer loaded.")
except Exception as e:
print(f"Error loading tokenizer: {e}")
print("Make sure the model ID is correct and, if it's a private repo, you've set the HF_TOKEN secret in your Space.")
exit(1)
print(f"Loading model {MODEL_REPO_ID} (this will be VERY slow on CPU and might take a long time)...")
try:
# AutoModelForCausalLM loads the language model.
# device_map="cpu" ensures all model layers are loaded onto the CPU.
# torch_dtype=torch.float32 is standard for CPU; float16 can save memory but might not be faster on all CPUs.
model = AutoModelForCausalLM.from_pretrained(MODEL_REPO_ID, device_map="cpu", torch_dtype=torch.float32)
model.eval() # Set model to evaluation mode for inference
print("Model loaded successfully.")
except Exception as e:
print(f"Error loading model: {e}")
print("Ensure it's a standard Transformers model and you have HF_TOKEN secret if private.")
exit(1)
# Configure generation for streaming
# Use GenerationConfig from the model for default parameters, then override as needed.
generation_config = GenerationConfig.from_pretrained(MODEL_REPO_ID)
generation_config.max_new_tokens = MAX_TOKENS
generation_config.temperature = TEMPERATURE
generation_config.top_p = TOP_P
generation_config.do_sample = True # Enable sampling for temperature/top_p
# Set EOS and PAD token IDs for proper generation stopping and padding
generation_config.eos_token_id = tokenizer.eos_token_id if tokenizer.eos_token_id is not None else -1
generation_config.pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else tokenizer.eos_token_id if tokenizer.eos_token_id is not None else -1
# Fallback for pad_token_id if not explicitly set
if generation_config.pad_token_id == -1:
generation_config.pad_token_id = 0 # Fallback to 0, though not ideal for all models
# --- Custom Streamer for Gradio and Safety Check ---
class GradioSafetyStreamer(TextIteratorStreamer):
def __init__(self, tokenizer, safety_checker_fn, toxicity_threshold, skip_special_tokens=True, **kwargs):
super().__init__(tokenizer, skip_special_tokens=skip_special_tokens, **kwargs)
self.safety_checker_fn = safety_checker_fn
self.toxicity_threshold = toxicity_threshold
self.current_sentence_buffer = ""
self.output_queue = [] # Queue to store safety-checked sentences to be yielded by Gradio
self.sentence_regex = re.compile(r'[.!?]\s*') # Regex for sentence end, simple version
self.text_done = threading.Event() # Event to signal when internal text processing is complete
def on_finalized_text(self, text: str, stream_end: bool = False):
# This method is called by the superclass when a decoded token chunk is ready.
self.current_sentence_buffer += text
# Split buffer into sentences. Keep the last part in buffer if it's incomplete.
sentences = self.sentence_regex.split(self.current_sentence_buffer)
sentences_to_process = []
if not stream_end and sentences and self.sentence_regex.search(sentences[-1]) is None:
# If not end of stream and last part is not a complete sentence, buffer it for next time
sentences_to_process = sentences[:-1]
self.current_sentence_buffer = sentences[-1]
else:
# Otherwise, process all segments and clear buffer
sentences_to_process = sentences
self.current_sentence_buffer = ""
for sentence in sentences_to_process:
if not sentence.strip(): continue # Skip empty strings from splitting
is_safe, detected_label = self.safety_checker_fn(sentence)
if not is_safe:
print(f"Safety check failed for: '{sentence.strip()}' (Detected: {detected_label})")
self.output_queue.append("[Content removed due to safety guidelines]")
self.output_queue.append("__STOP_GENERATION__") # Special signal to stop LLM generation
return # Stop processing further sentences from this chunk if unsafe
else:
self.output_queue.append(sentence)
if stream_end:
# If stream ends and there's leftover text in buffer, process it
if self.current_sentence_buffer.strip():
is_safe, detected_label = self.safety_checker_fn(self.current_sentence_buffer)
if not is_safe:
self.output_queue.append("[Content removed due to safety guidelines]")
else:
self.output_queue.append(self.current_sentence_buffer)
self.current_sentence_buffer = "" # Clear after final check
self.text_done.set() # Signal that all text processing is complete
def __iter__(self):
# This method allows Gradio to iterate over the safety-checked output.
while True:
if self.output_queue:
item = self.output_queue.pop(0)
if item == "__STOP_GENERATION__":
# Signal to the outer Gradio loop to stop yielding.
raise StopIteration
yield item
elif self.text_done.is_set(): # Check if internal generation and safety processing is truly finished
raise StopIteration # End of generation and safety check
else:
time.sleep(0.01) # Small sleep to prevent busy-waiting while waiting for new tokens
# --- Inference Function with Safety and Streaming ---
def generate_word_by_word_with_safety(prompt_text: str):
formatted_prompt = f"USER: {prompt_text}\nASSISTANT:"
# Encode input on the model's device (CPU)
input_ids = tokenizer(formatted_prompt, return_tensors="pt").input_ids.to(model.device)
# Initialize the custom streamer
streamer = GradioSafetyStreamer(tokenizer, is_text_safe, TOXICITY_THRESHOLD)
# Use a separate thread for model generation because model.generate is a blocking call.
# This allows the streamer to continuously fill its queue while Gradio yields.
generate_kwargs = {
"input_ids": input_ids,
"streamer": streamer,
"generation_config": generation_config,
# Explicitly pass these for clarity, even if in generation_config
"do_sample": True,
"temperature": TEMPERATURE,
"top_p": TOP_P,
"max_new_tokens": MAX_TOKENS,
"eos_token_id": generation_config.eos_token_id,
"pad_token_id": generation_config.pad_token_id,
}
# Start generation in a separate thread
thread = threading.Thread(target=model.generate, kwargs=generate_kwargs)
thread.start()
# Yield tokens from the streamer's output queue for Gradio to display progressively
full_generated_text = ""
try:
for new_sentence_or_chunk in streamer:
full_generated_text += new_sentence_or_chunk
yield full_generated_text # Gradio expects accumulated string for streaming display
except StopIteration:
pass # Streamer signaled end
except Exception as e:
print(f"Error during streaming: {e}")
yield full_generated_text + f"\n\n[Error during streaming: {e}]" # Show error in output
finally:
thread.join() # Ensure the generation thread finishes gracefully
# --- Gradio Blocks Interface ---
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown(
"""
# SmilyAI: Sam-reason-S3 Inference (Transformers on CPU with Safety Filter)
Enter a prompt and get a word-by-word response from the **Smilyai-labs/Sam-reason-S3** model.
**⚠️ WARNING: This model is running on a free CPU tier via the `transformers` library. Inference will be VERY slow.**
All generated sentences are checked for safety using an AI filter; unsafe content will be replaced.
"""
)
with gr.Row():
user_prompt = gr.Textbox(
lines=5,
label="Enter your prompt here:",
placeholder="e.g., Explain the concept of quantum entanglement in simple terms.",
scale=4
)
generated_text = gr.Textbox(label="Generated Text", show_copy_button=True, scale=6)
send_button = gr.Button("Send", variant="primary")
send_button.click(
fn=generate_word_by_word_with_safety,
inputs=user_prompt,
outputs=generated_text,
api_name="predict",
)
if __name__ == "__main__":
print("Launching Gradio app...")
demo.launch(server_name="0.0.0.0", server_port=7860)