Spaces:
Sleeping
Sleeping
| import sentencepiece as spm | |
| import os, numpy as np, tensorflow as tf | |
| from tensorflow.keras import layers | |
| import gradio as gr | |
| # --- 1. ํ๊ฒฝ ์ค์ ๋ฐ ๋ชจ๋ธ ๊ตฌ์กฐ ์ ์ --- | |
| # ํ์ผ ์ด๋ฆ๋ง ์ฌ์ฉ (ํ์ฌ ์์ ๋๋ ํ ๋ฆฌ์ ํ์ผ์ด ์์ด์ผ ํจ) | |
| TOKENIZER_PATH = "tokenizer.model" | |
| sp = spm.SentencePieceProcessor(TOKENIZER_PATH) | |
| pad_id = sp.piece_to_id("<pad>") if sp.piece_to_id("<pad>") != -1 else 0 | |
| end_id = sp.piece_to_id("</s>") | |
| vocab_size = sp.get_piece_size() | |
| class TimeMix(layers.Layer): | |
| def __init__(self, d_model, layer_id, n_layers): | |
| super().__init__() | |
| self.d_model = d_model | |
| ratio = (layer_id / (n_layers - 1)) if n_layers > 1 else 0.5 | |
| decay_speed = np.arange(d_model) | |
| self.time_decay = tf.Variable(-5 + 8 * (decay_speed / (d_model - 1)) ** (0.7 + 1.3 * ratio), dtype=tf.float32) | |
| self.time_first = tf.Variable(np.ones(d_model) * np.log(0.3), dtype=tf.float32) | |
| self.w_proj = layers.Dense(d_model, kernel_initializer='zeros', use_bias=False) | |
| self.r_proj = layers.Dense(d_model, kernel_initializer='zeros', use_bias=False) | |
| self.k_proj = layers.Dense(d_model, kernel_initializer='zeros', use_bias=False) | |
| self.v_proj = layers.Dense(d_model, kernel_initializer='zeros', use_bias=False) | |
| self.key = layers.Dense(d_model, use_bias=False) | |
| self.value = layers.Dense(d_model, use_bias=False) | |
| self.receptance = layers.Dense(d_model, use_bias=False) | |
| self.output_projection = layers.Dense(d_model, use_bias=False) | |
| self.tm_w = tf.Variable(1 - (ratio ** 0.5), dtype=tf.float32) | |
| self.tm_k = tf.Variable(1 - (ratio ** 0.5), dtype=tf.float32) | |
| self.tm_v = tf.Variable(1 - (ratio ** 0.5), dtype=tf.float32) | |
| self.tm_r = tf.Variable(1 - (ratio ** 0.2), dtype=tf.float32) | |
| def call(self, x, state): | |
| last_x, aa, bb, pp = state | |
| t_type = x.dtype | |
| tm_w, tm_k, tm_v, tm_r = tf.cast(self.tm_w, t_type), tf.cast(self.tm_k, t_type), tf.cast(self.tm_v, t_type), tf.cast(self.tm_r, t_type) | |
| dx = x * tm_w + last_x * (1 - tm_w) | |
| w = tf.cast(self.time_decay, t_type) + tf.cast(self.w_proj(dx), t_type) | |
| w = -tf.exp(tf.cast(w, tf.float32)) | |
| r = self.receptance(x * tm_r + last_x * (1 - tm_r)) + self.r_proj(dx) | |
| k = self.key(x * tm_k + last_x * (1 - tm_k)) + self.k_proj(dx) | |
| v = self.value(x * tm_v + last_x * (1 - tm_v)) + self.v_proj(dx) | |
| u = tf.cast(self.time_first, tf.float32) | |
| kv, vv = tf.cast(k, tf.float32), tf.cast(v, tf.float32) | |
| ww = u + kv | |
| p = tf.maximum(pp, ww) | |
| e1, e2 = tf.exp(pp - p), tf.exp(ww - p) | |
| wkv = (e1 * aa + e2 * vv) / (e1 * bb + e2 + 1e-12) | |
| ww_next = w + pp | |
| p_next = tf.maximum(ww_next, kv) | |
| e1_next, e2_next = tf.exp(ww_next - p_next), tf.exp(kv - p_next) | |
| new_state = [x, e1_next * aa + e2_next * vv, e1_next * bb + e2_next, p_next] | |
| return self.output_projection(tf.nn.sigmoid(r) * tf.cast(wkv, t_type)), new_state | |
| class ChannelMix(layers.Layer): | |
| def __init__(self, d_model, layer_id, n_layers): | |
| super().__init__() | |
| ratio = (layer_id / (n_layers - 1)) if n_layers > 1 else 0.5 | |
| self.time_mix_k = tf.Variable(1 - (ratio ** 0.5), dtype=tf.float32) | |
| self.time_mix_r = tf.Variable(1 - (ratio ** 0.5), dtype=tf.float32) | |
| self.key = layers.Dense(int(d_model * 4.25), use_bias=False) | |
| self.receptance = layers.Dense(d_model, use_bias=False) | |
| self.value = layers.Dense(d_model, use_bias=False) | |
| def call(self, x, last_x): | |
| t_type = x.dtype | |
| tm_k, tm_r = tf.cast(self.time_mix_k, t_type), tf.cast(self.time_mix_r, t_type) | |
| k = self.key(x * tm_k + last_x * (1 - tm_k)) | |
| r = self.receptance(x * tm_r + last_x * (1 - tm_r)) | |
| kv = self.value(tf.square(tf.nn.relu(k))) | |
| return tf.nn.sigmoid(r) * kv, x | |
| class Block(layers.Layer): | |
| def __init__(self, d_model, layer_id, n_layers): | |
| super().__init__() | |
| self.ln = layers.LayerNormalization(epsilon=1e-5) | |
| self.time_mix = TimeMix(d_model, layer_id, n_layers) | |
| self.channel_mix = ChannelMix(d_model, layer_id, n_layers) | |
| def call(self, x, state): | |
| ln_x = self.ln(x) | |
| tm_out, tm_state = self.time_mix(ln_x, state[:4]) | |
| x = x + tm_out | |
| cm_out, cm_last_x = self.channel_mix(ln_x, state[4]) | |
| x = x + cm_out | |
| return x, tm_state + [cm_last_x] | |
| class Head(tf.keras.Model): | |
| def __init__(self, vocab_size): | |
| super().__init__() | |
| self.lm_head = layers.Dense(vocab_size, use_bias=False, name="output_head") | |
| def call(self, x): | |
| return tf.cast(self.lm_head(x), tf.float32) | |
| class LM(tf.keras.Model): | |
| def __init__(self, d_model, n_layers): | |
| super().__init__() | |
| self.token_embedding = layers.Embedding(vocab_size, d_model) | |
| self.blocks = [Block(d_model, i, n_layers) for i in range(n_layers)] | |
| self.ln_f = layers.LayerNormalization(epsilon=1e-5, dtype=tf.float32) | |
| def call(self, x, states): | |
| x = self.token_embedding(x) | |
| new_states = [] | |
| for i, block in enumerate(self.blocks): | |
| x, b_state = block(x, states[i*5 : (i+1)*5]) | |
| new_states.extend(b_state) | |
| return self.ln_f(x), new_states | |
| # --- 2. ์ด๊ธฐํ ๋ฐ ๊ฐ์ค์น ๋ก๋ --- | |
| d_model, n_layers = 512, 10 | |
| blocklm = LM(d_model, n_layers) | |
| head = Head(vocab_size) | |
| def get_init_state(): | |
| return [tf.zeros((1, 1, d_model)) if i%5!=3 else tf.ones((1, 1, d_model))*-1e30 for i in range(n_layers*5)] | |
| # Dummy call | |
| _o, _s = blocklm(tf.constant([[0]]), get_init_state()) | |
| _ = head(_o) | |
| blocklm.load_weights("blocklm.weights.h5") | |
| head.load_weights("head.weights.h5") | |
| # --- 3. ์ถ๋ก ์์ง --- | |
| class InferenceEngine: | |
| def __init__(self, model, head, sp): | |
| self.model = model | |
| self.head = head | |
| self.sp = sp | |
| self.pad_id = sp.piece_to_id("<pad>") if sp.piece_to_id("<pad>") != -1 else 0 | |
| self.eos_id = sp.piece_to_id("</s>") if sp.piece_to_id("</s>") != -1 else sp.piece_to_id("[EOS]") | |
| def apply_repetition_penalty(self, logits, generated_ids, penalty, window=64): | |
| if not generated_ids: return logits | |
| recent_ids = set(generated_ids[-window:]) | |
| for token_id in recent_ids: | |
| if logits[token_id] > 0: logits[token_id] /= penalty | |
| else: logits[token_id] *= penalty | |
| return logits | |
| def sample(self, logits, temp, top_k, top_p): | |
| if temp <= 0: return np.argmax(logits) | |
| logits = logits / temp | |
| if top_k > 0: | |
| indices_to_remove = logits < np.sort(logits)[-min(top_k, logits.shape[-1])] | |
| logits[indices_to_remove] = -float('inf') | |
| probs = tf.nn.softmax(logits).numpy() | |
| sorted_indices = np.argsort(probs)[::-1] | |
| sorted_probs = probs[sorted_indices] | |
| cumulative_probs = np.cumsum(sorted_probs) | |
| idx_to_remove = cumulative_probs > top_p | |
| if np.any(idx_to_remove): | |
| cutoff_idx = max(1, np.where(idx_to_remove)[0][0] + 1) | |
| probs[sorted_indices[cutoff_idx:]] = 0 | |
| if np.sum(probs) > 0: probs /= np.sum(probs) | |
| else: probs[sorted_indices[0]] = 1.0 | |
| return np.random.choice(len(probs), p=probs) | |
| def model_step(self, token_id, states): | |
| out, next_states = self.model(token_id, states) | |
| logits = self.head(out) | |
| return logits, next_states | |
| def generate(self, prompt, max_new_tokens, temp, top_k, top_p, penalty): | |
| input_ids = self.sp.encode(prompt) | |
| states = get_init_state() | |
| generated = [] | |
| if len(input_ids) > 1: | |
| for i in range(len(input_ids) - 1): | |
| _, states = self.model_step(tf.constant([[input_ids[i]]]), states) | |
| curr_token_id = input_ids[-1] | |
| prev_text = "" | |
| for _ in range(max_new_tokens): | |
| logits_out, states = self.model_step(tf.constant([[curr_token_id]]), states) | |
| logits = logits_out[0, 0].numpy() | |
| logits = self.apply_repetition_penalty(logits, input_ids + generated, penalty) | |
| logits[self.pad_id] = -float('inf') | |
| next_id = int(self.sample(logits, temp, top_k, top_p)) | |
| if next_id == self.eos_id: break | |
| generated.append(next_id) | |
| full_text = self.sp.decode(generated) | |
| new_part = full_text[len(prev_text):] | |
| if new_part: | |
| yield new_part | |
| prev_text = full_text | |
| curr_token_id = next_id | |
| engine = InferenceEngine(blocklm, head, sp) | |
| # --- 4. Gradio UI (๋จ์ ํ ์คํธ ์ ์ถ๋ ฅ ๋ฐฉ์) --- | |
| with gr.Blocks(title="RWKV Text Generator") as demo: | |
| gr.Markdown("## ๐๏ธ Dynamic RWKV Text Generation") | |
| gr.Markdown("์ง๋ฌธ์ ์ ๋ ฅํ๊ณ Generate๋ฅผ ๋๋ฅด๋ฉด ๋ต๋ณ์ด ์๋ ํ ์คํธ ๋ฐ์ค์ ์ค์๊ฐ์ผ๋ก ์์ฑ๋ฉ๋๋ค.") | |
| with gr.Row(): | |
| with gr.Column(): | |
| input_text = gr.Textbox(lines=5, label="Input Prompt", placeholder="์ฌ๊ธฐ์ ์ง๋ฌธ์ด๋ ๋ฌธ์ฅ์ ์ ๋ ฅํ์ธ์...") | |
| with gr.Row(): | |
| temp_slider = gr.Slider(0, 2, value=0.7, label="Temperature") | |
| top_p_slider = gr.Slider(0, 1, value=0.92, label="Top-P") | |
| with gr.Row(): | |
| penalty_slider = gr.Slider(1, 2, value=1.2, label="Penalty") | |
| max_tokens = gr.Slider(1, 2048, value=512, step=1, label="Max Tokens") | |
| submit_btn = gr.Button("Generate", variant="primary") | |
| clear_btn = gr.Button("Clear") | |
| with gr.Column(): | |
| output_text = gr.Textbox(lines=15, label="Generated Output", interactive=False) | |
| def run_generation(prompt, tokens, temp, top_p, penalty): | |
| if not prompt.strip(): | |
| return "ํ๋กฌํํธ๋ฅผ ์ ๋ ฅํด์ฃผ์ธ์." | |
| full_prompt = f"Question: {prompt}\nAnswer:" | |
| current_output = "" | |
| for chunk in engine.generate(full_prompt, int(tokens), temp, 40, top_p, penalty): | |
| current_output += chunk | |
| yield current_output | |
| # ๋ฒํผ ํด๋ฆญ ๋ฐ ์ํฐ ํค ์ ๋ ฅ ์ด๋ฒคํธ | |
| submit_btn.click( | |
| fn=run_generation, | |
| inputs=[input_text, max_tokens, temp_slider, top_p_slider, penalty_slider], | |
| outputs=output_text | |
| ) | |
| clear_btn.click(lambda: ("", ""), outputs=[input_text, output_text]) | |
| if __name__ == "__main__": | |
| demo.queue().launch() |