Spaces:
Sleeping
Sleeping
File size: 10,753 Bytes
f38ab88 c850ce2 f38ab88 cffaee2 f38ab88 cffaee2 c850ce2 cffaee2 c850ce2 f38ab88 c850ce2 f38ab88 c850ce2 f38ab88 c850ce2 f38ab88 c850ce2 f38ab88 c850ce2 f38ab88 c850ce2 cffaee2 c850ce2 cffaee2 c850ce2 cffaee2 c850ce2 cffaee2 c850ce2 f38ab88 cffaee2 c850ce2 cffaee2 c850ce2 cffaee2 f38ab88 cffaee2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 |
import gradio as gr
from huggingface_hub import hf_hub_download # Still useful if model is private and needs custom token
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, GenerationConfig
from transformers.pipelines import pipeline
import re
import os
import torch # Required for transformers models
import threading
import time # For short sleeps in streamer
# --- Model Configuration ---
# Your SmilyAI model ID on Hugging Face Hub
MODEL_REPO_ID = "Smilyai-labs/Sam-reason-S3"
N_CTX = 2048 # Context window for the model (applies more to LLMs)
MAX_TOKENS = 500
TEMPERATURE = 0.7
TOP_P = 0.9
STOP_SEQUENCES = ["USER:", "\n\n"] # Model will stop generating when it encounters these
# --- Safety Configuration ---
print("Loading safety model (unitary/toxic-bert)...")
try:
safety_classifier = pipeline(
"text-classification",
model="unitary/toxic-bert",
framework="pt" # Use PyTorch backend
)
print("Safety model loaded successfully.")
except Exception as e:
print(f"Error loading safety model: {e}")
exit(1)
TOXICITY_THRESHOLD = 0.9
def is_text_safe(text: str) -> tuple[bool, str | None]:
"""
Checks if the given text contains unsafe content using the safety classifier.
Returns (True, None) if safe, or (False, detected_label) if unsafe.
"""
if not text.strip():
return True, None
try:
results = safety_classifier(text)
if results and results[0]['label'] == 'toxic' and results[0]['score'] > TOXICITY_THRESHOLD:
print(f"Detected unsafe content: '{text.strip()}' (Score: {results[0]['score']:.4f})")
return False, results[0]['label']
return True, None
except Exception as e:
print(f"Error during safety check: {e}")
# If the safety check fails, consider it unsafe by default or log and let it pass.
return False, "safety_check_failed"
# --- Main Model Loading (using Transformers) ---
print(f"Loading tokenizer for {MODEL_REPO_ID}...")
try:
# AutoTokenizer fetches the correct tokenizer for the model
tokenizer = AutoTokenizer.from_pretrained(MODEL_REPO_ID)
print("Tokenizer loaded.")
except Exception as e:
print(f"Error loading tokenizer: {e}")
print("Make sure the model ID is correct and, if it's a private repo, you've set the HF_TOKEN secret in your Space.")
exit(1)
print(f"Loading model {MODEL_REPO_ID} (this will be VERY slow on CPU and might take a long time)...")
try:
# AutoModelForCausalLM loads the language model.
# device_map="cpu" ensures all model layers are loaded onto the CPU.
# torch_dtype=torch.float32 is standard for CPU; float16 can save memory but might not be faster on all CPUs.
model = AutoModelForCausalLM.from_pretrained(MODEL_REPO_ID, device_map="cpu", torch_dtype=torch.float32)
model.eval() # Set model to evaluation mode for inference
print("Model loaded successfully.")
except Exception as e:
print(f"Error loading model: {e}")
print("Ensure it's a standard Transformers model and you have HF_TOKEN secret if private.")
exit(1)
# Configure generation for streaming
# Use GenerationConfig from the model for default parameters, then override as needed.
generation_config = GenerationConfig.from_pretrained(MODEL_REPO_ID)
generation_config.max_new_tokens = MAX_TOKENS
generation_config.temperature = TEMPERATURE
generation_config.top_p = TOP_P
generation_config.do_sample = True # Enable sampling for temperature/top_p
# Set EOS and PAD token IDs for proper generation stopping and padding
generation_config.eos_token_id = tokenizer.eos_token_id if tokenizer.eos_token_id is not None else -1
generation_config.pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else tokenizer.eos_token_id if tokenizer.eos_token_id is not None else -1
# Fallback for pad_token_id if not explicitly set
if generation_config.pad_token_id == -1:
generation_config.pad_token_id = 0 # Fallback to 0, though not ideal for all models
# --- Custom Streamer for Gradio and Safety Check ---
class GradioSafetyStreamer(TextIteratorStreamer):
def __init__(self, tokenizer, safety_checker_fn, toxicity_threshold, skip_special_tokens=True, **kwargs):
super().__init__(tokenizer, skip_special_tokens=skip_special_tokens, **kwargs)
self.safety_checker_fn = safety_checker_fn
self.toxicity_threshold = toxicity_threshold
self.current_sentence_buffer = ""
self.output_queue = [] # Queue to store safety-checked sentences to be yielded by Gradio
self.sentence_regex = re.compile(r'[.!?]\s*') # Regex for sentence end, simple version
self.text_done = threading.Event() # Event to signal when internal text processing is complete
def on_finalized_text(self, text: str, stream_end: bool = False):
# This method is called by the superclass when a decoded token chunk is ready.
self.current_sentence_buffer += text
# Split buffer into sentences. Keep the last part in buffer if it's incomplete.
sentences = self.sentence_regex.split(self.current_sentence_buffer)
sentences_to_process = []
if not stream_end and sentences and self.sentence_regex.search(sentences[-1]) is None:
# If not end of stream and last part is not a complete sentence, buffer it for next time
sentences_to_process = sentences[:-1]
self.current_sentence_buffer = sentences[-1]
else:
# Otherwise, process all segments and clear buffer
sentences_to_process = sentences
self.current_sentence_buffer = ""
for sentence in sentences_to_process:
if not sentence.strip(): continue # Skip empty strings from splitting
is_safe, detected_label = self.safety_checker_fn(sentence)
if not is_safe:
print(f"Safety check failed for: '{sentence.strip()}' (Detected: {detected_label})")
self.output_queue.append("[Content removed due to safety guidelines]")
self.output_queue.append("__STOP_GENERATION__") # Special signal to stop LLM generation
return # Stop processing further sentences from this chunk if unsafe
else:
self.output_queue.append(sentence)
if stream_end:
# If stream ends and there's leftover text in buffer, process it
if self.current_sentence_buffer.strip():
is_safe, detected_label = self.safety_checker_fn(self.current_sentence_buffer)
if not is_safe:
self.output_queue.append("[Content removed due to safety guidelines]")
else:
self.output_queue.append(self.current_sentence_buffer)
self.current_sentence_buffer = "" # Clear after final check
self.text_done.set() # Signal that all text processing is complete
def __iter__(self):
# This method allows Gradio to iterate over the safety-checked output.
while True:
if self.output_queue:
item = self.output_queue.pop(0)
if item == "__STOP_GENERATION__":
# Signal to the outer Gradio loop to stop yielding.
raise StopIteration
yield item
elif self.text_done.is_set(): # Check if internal generation and safety processing is truly finished
raise StopIteration # End of generation and safety check
else:
time.sleep(0.01) # Small sleep to prevent busy-waiting while waiting for new tokens
# --- Inference Function with Safety and Streaming ---
def generate_word_by_word_with_safety(prompt_text: str):
formatted_prompt = f"USER: {prompt_text}\nASSISTANT:"
# Encode input on the model's device (CPU)
input_ids = tokenizer(formatted_prompt, return_tensors="pt").input_ids.to(model.device)
# Initialize the custom streamer
streamer = GradioSafetyStreamer(tokenizer, is_text_safe, TOXICITY_THRESHOLD)
# Use a separate thread for model generation because model.generate is a blocking call.
# This allows the streamer to continuously fill its queue while Gradio yields.
generate_kwargs = {
"input_ids": input_ids,
"streamer": streamer,
"generation_config": generation_config,
# Explicitly pass these for clarity, even if in generation_config
"do_sample": True,
"temperature": TEMPERATURE,
"top_p": TOP_P,
"max_new_tokens": MAX_TOKENS,
"eos_token_id": generation_config.eos_token_id,
"pad_token_id": generation_config.pad_token_id,
}
# Start generation in a separate thread
thread = threading.Thread(target=model.generate, kwargs=generate_kwargs)
thread.start()
# Yield tokens from the streamer's output queue for Gradio to display progressively
full_generated_text = ""
try:
for new_sentence_or_chunk in streamer:
full_generated_text += new_sentence_or_chunk
yield full_generated_text # Gradio expects accumulated string for streaming display
except StopIteration:
pass # Streamer signaled end
except Exception as e:
print(f"Error during streaming: {e}")
yield full_generated_text + f"\n\n[Error during streaming: {e}]" # Show error in output
finally:
thread.join() # Ensure the generation thread finishes gracefully
# --- Gradio Blocks Interface ---
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown(
"""
# SmilyAI: Sam-reason-S3 Inference (Transformers on CPU with Safety Filter)
Enter a prompt and get a word-by-word response from the **Smilyai-labs/Sam-reason-S3** model.
**⚠️ WARNING: This model is running on a free CPU tier via the `transformers` library. Inference will be VERY slow.**
All generated sentences are checked for safety using an AI filter; unsafe content will be replaced.
"""
)
with gr.Row():
user_prompt = gr.Textbox(
lines=5,
label="Enter your prompt here:",
placeholder="e.g., Explain the concept of quantum entanglement in simple terms.",
scale=4
)
generated_text = gr.Textbox(label="Generated Text", show_copy_button=True, scale=6)
send_button = gr.Button("Send", variant="primary")
send_button.click(
fn=generate_word_by_word_with_safety,
inputs=user_prompt,
outputs=generated_text,
api_name="predict",
)
if __name__ == "__main__":
print("Launching Gradio app...")
demo.launch(server_name="0.0.0.0", server_port=7860)
|