Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	| import os, json, random | |
| import torch | |
| import gradio as gr | |
| import spaces | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer | |
| from huggingface_hub import login, hf_hub_download | |
| import pyvene as pv | |
| from threading import Thread | |
| from typing import Iterator | |
| HF_TOKEN = os.environ.get("HF_TOKEN") | |
| login(token=HF_TOKEN) | |
| MAX_MAX_NEW_TOKENS = 2048 | |
| DEFAULT_MAX_NEW_TOKENS = 512 # smaller default to save memory | |
| MAX_INPUT_TOKEN_LENGTH = 4096 | |
| def load_jsonl(jsonl_path): | |
| jsonl_data = [] | |
| with open(jsonl_path, 'r') as f: | |
| for line in f: | |
| data = json.loads(line) | |
| jsonl_data.append(data) | |
| return jsonl_data | |
| class Steer(pv.SourcelessIntervention): | |
| def __init__(self, **kwargs): | |
| super().__init__(**kwargs, keep_last_dim=True) | |
| self.proj = torch.nn.Linear(self.embed_dim, kwargs["latent_dim"], bias=False) | |
| def forward(self, base, source=None, subspaces=None): | |
| steer_vec = base | |
| if subspaces is not None: | |
| for sp in subspaces: | |
| idx = sp["idx"] | |
| mag = sp["internal_mag"] # scaled by 50 | |
| steering_vec = mag * self.proj.weight[idx].unsqueeze(dim=0) | |
| steer_vec = steer_vec + steering_vec | |
| return steer_vec | |
| # Check GPU | |
| if not torch.cuda.is_available(): | |
| print("Warning: Running on CPU, may be slow.") | |
| # Load model & dictionary | |
| model_id = "google/gemma-2-2b-it" | |
| pv_model = None | |
| tokenizer = None | |
| concept_list = [] | |
| concept_id_map = {} | |
| if torch.cuda.is_available(): | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_id, device_map="cuda", torch_dtype=torch.bfloat16 | |
| ) | |
| tokenizer = AutoTokenizer.from_pretrained(model_id) | |
| # Download dictionary | |
| weight_path = hf_hub_download(repo_id="pyvene/gemma-reft-2b-it-res", filename="l20/weight.pt") | |
| meta_path = hf_hub_download(repo_id="pyvene/gemma-reft-2b-it-res", filename="l20/metadata.jsonl") | |
| params = torch.load(weight_path).cuda() | |
| md = load_jsonl(meta_path) | |
| concept_list = [item["concept"] for item in md] | |
| concept_id_map = {item["concept"]: item["concept_id"] for item in md} | |
| steer = Steer(embed_dim=params.shape[0], latent_dim=params.shape[1]) | |
| steer.proj.weight.data = params.float() | |
| pv_model = pv.IntervenableModel( | |
| { | |
| "component": f"model.layers[20].output", | |
| "intervention": steer, | |
| }, | |
| model=model, | |
| ) | |
| terminators = [tokenizer.eos_token_id] if tokenizer else [] | |
| def generate( | |
| message: str, | |
| chat_history: list[tuple[str, str]], | |
| max_new_tokens: int, | |
| subspaces_list: list[dict], | |
| ) -> Iterator[str]: | |
| # limit to last 3 turns | |
| start_idx = max(0, len(chat_history) - 3) | |
| recent_history = chat_history[start_idx:] | |
| # build list of messages | |
| messages = [] | |
| for user_msg, model_msg in recent_history: | |
| messages.append({"role": "user", "content": user_msg}) | |
| messages.append({"role": "model", "content": model_msg}) | |
| messages.append({"role": "user", "content": message}) | |
| input_ids = torch.tensor([tokenizer.apply_chat_template( | |
| messages, tokenize=True, add_generation_prompt=True)]).cuda() | |
| # trim if needed | |
| if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH: | |
| input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:] | |
| yield "[Truncated prior text]\n" | |
| streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True) | |
| generate_kwargs = { | |
| "base": {"input_ids": input_ids}, | |
| "unit_locations": None, | |
| "max_new_tokens": max_new_tokens, | |
| "intervene_on_prompt": True, | |
| "subspaces": subspaces_list, | |
| "streamer": streamer, | |
| "eos_token_id": terminators, | |
| "early_stopping": True, | |
| "do_sample": True | |
| } | |
| t = Thread(target=pv_model.generate, kwargs=generate_kwargs) | |
| t.start() | |
| partial_text = [] | |
| for token_str in streamer: | |
| partial_text.append(token_str) | |
| yield "".join(partial_text) | |
| def filter_concepts(search_text: str): | |
| if not search_text.strip(): | |
| return concept_list[:500] | |
| filtered = [c for c in concept_list if search_text.lower() in c.lower()] | |
| return filtered[:500] | |
| def add_concept_to_list(selected_concept, user_slider_val, current_list): | |
| if not selected_concept: | |
| return current_list, _build_table_data(current_list), gr.update(choices=_build_remove_choices(current_list)) | |
| idx = concept_id_map[selected_concept] | |
| internal_mag = user_slider_val * 50 | |
| new_entry = { | |
| "text": selected_concept, | |
| "idx": idx, | |
| "display_mag": user_slider_val, | |
| "internal_mag": internal_mag, | |
| } | |
| updated_list = current_list + [new_entry] | |
| return ( | |
| updated_list, | |
| _build_table_data(updated_list), | |
| gr.update(choices=_build_remove_choices(updated_list)) | |
| ) | |
| def remove_concept_from_list(selected_text, current_list): | |
| if not selected_text: | |
| return current_list, _build_table_data(current_list), gr.update(choices=_build_remove_choices(current_list)) | |
| updated_list = [x for x in current_list if x["text"] != selected_text] | |
| return ( | |
| updated_list, | |
| _build_table_data(updated_list), | |
| gr.update(choices=_build_remove_choices(updated_list)) | |
| ) | |
| def _build_table_data(subspaces): | |
| return [[x["text"], x["display_mag"]] for x in subspaces] | |
| def _build_remove_choices(subspaces): | |
| return [x["text"] for x in subspaces] | |
| def update_dropdown_choices(search_text): | |
| filtered = filter_concepts(search_text) | |
| return gr.update(choices=filtered) | |
| with gr.Blocks(css="style.css") as demo: | |
| # A short title only | |
| gr.Markdown("## Model Steering with ReFT-r1 (16K concepts)") | |
| # Pre-populate with a random concept if available | |
| default_subspaces = [] | |
| if pv_model and concept_list: | |
| default_concept = random.choice(concept_list) | |
| default_subspaces = [{ | |
| "text": default_concept, | |
| "idx": concept_id_map[default_concept], | |
| "display_mag": 3, | |
| "internal_mag": 150.0, | |
| }] | |
| selected_subspaces = gr.State(default_subspaces) | |
| with gr.Row(): | |
| # Left side: bigger chat area | |
| with gr.Column(scale=7): | |
| chat_interface = gr.ChatInterface( | |
| fn=generate, | |
| title="", | |
| type="messages", | |
| additional_inputs=[selected_subspaces], | |
| ) | |
| # Right side: concept management | |
| with gr.Column(scale=3): | |
| gr.Markdown("### Steering Concepts") | |
| search_box = gr.Textbox( | |
| label="Search concepts", | |
| placeholder="e.g. 'time travel'" | |
| ) | |
| concept_dropdown = gr.Dropdown( | |
| label="Filtered Concepts", | |
| choices=[] | |
| ) | |
| concept_magnitude = gr.Slider( | |
| label="Steering Factor", | |
| minimum=-5, | |
| maximum=5, | |
| step=1, | |
| value=3 | |
| ) | |
| add_button = gr.Button("Add Concept") | |
| active_subspaces_table = gr.Dataframe( | |
| headers=["Concept", "Mag (scaled)"], | |
| datatype=["str", "number"], | |
| value=_build_table_data(default_subspaces), | |
| interactive=False, | |
| label="Active Concept Subspaces", | |
| ) | |
| # Row with the remove dropdown + button | |
| with gr.Row(): | |
| remove_dropdown = gr.Dropdown( | |
| label="Remove concept", | |
| choices=_build_remove_choices(default_subspaces), | |
| multiselect=False | |
| ) | |
| remove_button = gr.Button("Remove", variant="secondary") | |
| # Place the max tokens slider at bottom, smaller | |
| with gr.Row(): | |
| gr.Markdown("**Max New Tokens**", elem_classes=["small-label"]) | |
| max_token_slider = gr.Slider( | |
| minimum=1, maximum=MAX_MAX_NEW_TOKENS, step=1, | |
| value=DEFAULT_MAX_NEW_TOKENS, | |
| label="", # hide the big label | |
| container=False, | |
| ) | |
| # Wire up events | |
| search_box.change(update_dropdown_choices, [search_box], [concept_dropdown]) | |
| add_button.click( | |
| add_concept_to_list, | |
| [concept_dropdown, concept_magnitude, selected_subspaces], | |
| [selected_subspaces, active_subspaces_table, remove_dropdown] | |
| ) | |
| remove_button.click( | |
| remove_concept_from_list, | |
| [remove_dropdown, selected_subspaces], | |
| [selected_subspaces, active_subspaces_table, remove_dropdown] | |
| ) | |
| demo.launch() | |
