Spaces:
Sleeping
Sleeping
| from transformer_lens import HookedTransformer | |
| from sae_lens import SAE | |
| import torch | |
| if torch.backends.mps.is_available(): | |
| device = "mps" | |
| else: | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| class Inference: | |
| def __init__(self, model, pretrained_sae, layer): | |
| self.layer = layer | |
| if model == "gemma-2b": | |
| self.sae_id = f"blocks.{layer}.hook_resid_post" | |
| elif model == "gpt2-small": | |
| print(f"using {model}") | |
| self.sae_id = f"blocks.{0}.hook_resid_pre" | |
| self.sampling_kwargs = dict(temperature=1.0, top_p=0.1, freq_penalty=1.0) | |
| self.set_coeff(1) | |
| self.set_model(model) | |
| self.set_SAE(pretrained_sae) | |
| def set_model(self, model): | |
| self.model = HookedTransformer.from_pretrained(model, device = device) | |
| def set_coeff(self, coeff): | |
| self.coeff = coeff | |
| def set_temperature(self, temperature): | |
| self.sampling_kwargs['temperature'] = temperature | |
| def set_steering_vector_prompt(self, prompt: str): | |
| self.steering_vector_prompt = prompt | |
| def set_SAE(self, sae_name): | |
| sae, cfg_dict, _ = SAE.from_pretrained( | |
| release = sae_name, | |
| sae_id = self.sae_id, | |
| device = device | |
| ) | |
| self.sae = sae | |
| self.cfg_dict = cfg_dict | |
| def _get_sae_out_and_feature_activations(self): | |
| # given the words in steering_vector_prompt, the SAE predicts that the neurons(aka features) in activateCache will be activated | |
| sv_logits, activationCache = self.model.run_with_cache(self.steering_vector_prompt, prepend_bos=True) | |
| sv_feature_acts = self.sae.encode(activationCache[self.sae.cfg.hook_name]) | |
| return self.sae.decode(sv_feature_acts), sv_feature_acts | |
| def _hooked_generate(self, prompt_batch, fwd_hooks, seed=None, **kwargs): | |
| if seed is not None: | |
| torch.manual_seed(seed) | |
| with self.model.hooks(fwd_hooks=fwd_hooks): | |
| tokenized = self.model.to_tokens(prompt_batch) | |
| result = self.model.generate( | |
| stop_at_eos=False, # avoids a bug on MPS | |
| input=tokenized, | |
| max_new_tokens=50, | |
| do_sample=True, | |
| **kwargs) | |
| return result | |
| def _get_features(self, sv_feature_activations): | |
| # return torch.topk(sv_feature_acts, 1).indices.tolist() | |
| features = torch.topk(sv_feature_activations, 1).indices | |
| print(f'features that align with the text prompt: {features}') | |
| print("pump the features into the tool that gives you the words associated with each feature") | |
| return features | |
| def _get_steering_hook(self, feature, sae_out): | |
| coeff = self.coeff | |
| steering_vector = self.sae.W_dec[feature] | |
| steering_vector = steering_vector[0] | |
| def steering_hook(resid_pre, hook): | |
| if resid_pre.shape[1] == 1: | |
| return | |
| position = sae_out.shape[1] | |
| # using our steering vector and applying the coefficient | |
| resid_pre[:, :position - 1, :] += coeff * steering_vector | |
| return steering_hook | |
| def _get_steering_hooks(self): | |
| # TODO: refactor this. It works because sae_out.shape[1] = sv_feature_acts.shape[1] = len(features[0]) | |
| # you can manipulate views to retrieve hooks more cleanly | |
| # and not use the seperate function _get_steering_hook() | |
| sae_out, sv_feature_acts = self._get_sae_out_and_feature_activations() | |
| features = self._get_features(sv_feature_acts) | |
| steering_hooks = [self._get_steering_hook(feature, sae_out) for feature in features[0]] | |
| return steering_hooks | |
| def _run_generate(self, example_prompt, steering_on: bool): | |
| self.model.reset_hooks() | |
| if steering_on: | |
| steer_hooks = self._get_steering_hooks() | |
| editing_hooks = [ (self.sae_id, steer_hook) for steer_hook in steer_hooks] | |
| print(f"steering by {len(editing_hooks)} hooks") | |
| res = self._hooked_generate([example_prompt] * 3, editing_hooks, seed=None, **self.sampling_kwargs) | |
| else: | |
| tokenized = self.model.to_tokens([example_prompt]) | |
| res = self.model.generate( | |
| stop_at_eos=False, # avoids a bug on MPS | |
| input=tokenized, | |
| max_new_tokens=50, | |
| do_sample=True, | |
| **self.sampling_kwargs) | |
| # Print results, removing the ugly beginning of sequence token | |
| res_str = self.model.to_string(res[:, 1:]) | |
| response = ("\n\n" + "-" * 80 + "\n\n").join(res_str) | |
| print(response) | |
| return response | |
| def generate(self, message: str, steering_on: bool): | |
| return self._run_generate(message, steering_on) | |
| # MODEL = "gemma-2b" | |
| # PRETRAINED_SAE = "gemma-2b-res-jb" | |
| MODEL = "gpt2-small" | |
| PRETRAINED_SAE = "gpt2-small-res-jb" | |
| LAYER = 10 | |
| chatbot_model = Inference(MODEL, PRETRAINED_SAE, LAYER) | |
| import time | |
| import gradio as gr | |
| default_image = "Hexter-Hackathon.png" | |
| def slow_echo(message, history): | |
| result = chatbot_model.generate(message, False) | |
| for i in range(len(result)): | |
| time.sleep(0.01) | |
| yield result[: i + 1] | |
| def slow_echo_steering(message, history): | |
| result = chatbot_model.generate(message, True) | |
| for i in range(len(result)): | |
| time.sleep(0.01) | |
| yield result[: i + 1] | |
| with gr.Blocks() as demo: | |
| with gr.Row(): | |
| gr.Markdown("*STANDARD HEXTER BOT*") | |
| with gr.Row(): | |
| chatbot = gr.ChatInterface( | |
| slow_echo, | |
| chatbot=gr.Chatbot(min_width=1000), | |
| textbox=gr.Textbox(placeholder="Ask Hexter anything!", min_width=1000), | |
| theme="soft", | |
| cache_examples=False, | |
| retry_btn=None, | |
| clear_btn=None, | |
| undo_btn=None, | |
| ) | |
| with gr.Row(): | |
| gr.Markdown("*STEERED HEXTER BOT*") | |
| with gr.Row(): | |
| chatbot_steered = gr.ChatInterface( | |
| slow_echo_steering, | |
| chatbot=gr.Chatbot(min_width=1000), | |
| textbox=gr.Textbox(placeholder="Ask Hexter anything!", min_width=1000), | |
| theme="soft", | |
| cache_examples=False, | |
| retry_btn=None, | |
| clear_btn=None, | |
| undo_btn=None, | |
| ) | |
| with gr.Row(): | |
| steering_prompt = gr.Textbox(label="Steering prompt", value="Golden Gate Bridge") | |
| with gr.Row(): | |
| coeff = gr.Slider(1, 1000, 300, label="Coefficient", info="Coefficient is..", interactive=True) | |
| with gr.Row(): | |
| temp = gr.Slider(0, 5, 1, label="Temperature", info="Temperature is..", interactive=True) | |
| temp.change(chatbot_model.set_temperature, inputs=[temp], outputs=[]) | |
| coeff.change(chatbot_model.set_coeff, inputs=[coeff], outputs=[]) | |
| chatbot_model.set_steering_vector_prompt(steering_prompt.value) | |
| steering_prompt.change(chatbot_model.set_steering_vector_prompt, inputs=[steering_prompt], outputs=[]) | |
| demo.queue() | |
| demo.launch(debug=True) | |
| if __name__ == "__main__": | |
| demo.launch(allowed_paths=["/"]) | |