import sentencepiece as spm import os, numpy as np, tensorflow as tf from tensorflow.keras import layers import gradio as gr # --- 1. 환경 설정 및 모델 구조 정의 --- # 파일 이름만 사용 (현재 작업 디렉토리에 파일이 있어야 함) TOKENIZER_PATH = "tokenizer.model" sp = spm.SentencePieceProcessor(TOKENIZER_PATH) pad_id = sp.piece_to_id("") if sp.piece_to_id("") != -1 else 0 end_id = sp.piece_to_id("") vocab_size = sp.get_piece_size() class TimeMix(layers.Layer): def __init__(self, d_model, layer_id, n_layers): super().__init__() self.d_model = d_model ratio = (layer_id / (n_layers - 1)) if n_layers > 1 else 0.5 decay_speed = np.arange(d_model) self.time_decay = tf.Variable(-5 + 8 * (decay_speed / (d_model - 1)) ** (0.7 + 1.3 * ratio), dtype=tf.float32) self.time_first = tf.Variable(np.ones(d_model) * np.log(0.3), dtype=tf.float32) self.w_proj = layers.Dense(d_model, kernel_initializer='zeros', use_bias=False) self.r_proj = layers.Dense(d_model, kernel_initializer='zeros', use_bias=False) self.k_proj = layers.Dense(d_model, kernel_initializer='zeros', use_bias=False) self.v_proj = layers.Dense(d_model, kernel_initializer='zeros', use_bias=False) self.key = layers.Dense(d_model, use_bias=False) self.value = layers.Dense(d_model, use_bias=False) self.receptance = layers.Dense(d_model, use_bias=False) self.output_projection = layers.Dense(d_model, use_bias=False) self.tm_w = tf.Variable(1 - (ratio ** 0.5), dtype=tf.float32) self.tm_k = tf.Variable(1 - (ratio ** 0.5), dtype=tf.float32) self.tm_v = tf.Variable(1 - (ratio ** 0.5), dtype=tf.float32) self.tm_r = tf.Variable(1 - (ratio ** 0.2), dtype=tf.float32) def call(self, x, state): last_x, aa, bb, pp = state t_type = x.dtype tm_w, tm_k, tm_v, tm_r = tf.cast(self.tm_w, t_type), tf.cast(self.tm_k, t_type), tf.cast(self.tm_v, t_type), tf.cast(self.tm_r, t_type) dx = x * tm_w + last_x * (1 - tm_w) w = tf.cast(self.time_decay, t_type) + tf.cast(self.w_proj(dx), t_type) w = -tf.exp(tf.cast(w, tf.float32)) r = self.receptance(x * tm_r + last_x * (1 - tm_r)) + self.r_proj(dx) k = self.key(x * tm_k + last_x * (1 - tm_k)) + self.k_proj(dx) v = self.value(x * tm_v + last_x * (1 - tm_v)) + self.v_proj(dx) u = tf.cast(self.time_first, tf.float32) kv, vv = tf.cast(k, tf.float32), tf.cast(v, tf.float32) ww = u + kv p = tf.maximum(pp, ww) e1, e2 = tf.exp(pp - p), tf.exp(ww - p) wkv = (e1 * aa + e2 * vv) / (e1 * bb + e2 + 1e-12) ww_next = w + pp p_next = tf.maximum(ww_next, kv) e1_next, e2_next = tf.exp(ww_next - p_next), tf.exp(kv - p_next) new_state = [x, e1_next * aa + e2_next * vv, e1_next * bb + e2_next, p_next] return self.output_projection(tf.nn.sigmoid(r) * tf.cast(wkv, t_type)), new_state class ChannelMix(layers.Layer): def __init__(self, d_model, layer_id, n_layers): super().__init__() ratio = (layer_id / (n_layers - 1)) if n_layers > 1 else 0.5 self.time_mix_k = tf.Variable(1 - (ratio ** 0.5), dtype=tf.float32) self.time_mix_r = tf.Variable(1 - (ratio ** 0.5), dtype=tf.float32) self.key = layers.Dense(int(d_model * 4.25), use_bias=False) self.receptance = layers.Dense(d_model, use_bias=False) self.value = layers.Dense(d_model, use_bias=False) def call(self, x, last_x): t_type = x.dtype tm_k, tm_r = tf.cast(self.time_mix_k, t_type), tf.cast(self.time_mix_r, t_type) k = self.key(x * tm_k + last_x * (1 - tm_k)) r = self.receptance(x * tm_r + last_x * (1 - tm_r)) kv = self.value(tf.square(tf.nn.relu(k))) return tf.nn.sigmoid(r) * kv, x class Block(layers.Layer): def __init__(self, d_model, layer_id, n_layers): super().__init__() self.ln = layers.LayerNormalization(epsilon=1e-5) self.time_mix = TimeMix(d_model, layer_id, n_layers) self.channel_mix = ChannelMix(d_model, layer_id, n_layers) def call(self, x, state): ln_x = self.ln(x) tm_out, tm_state = self.time_mix(ln_x, state[:4]) x = x + tm_out cm_out, cm_last_x = self.channel_mix(ln_x, state[4]) x = x + cm_out return x, tm_state + [cm_last_x] class Head(tf.keras.Model): def __init__(self, vocab_size): super().__init__() self.lm_head = layers.Dense(vocab_size, use_bias=False, name="output_head") def call(self, x): return tf.cast(self.lm_head(x), tf.float32) class LM(tf.keras.Model): def __init__(self, d_model, n_layers): super().__init__() self.token_embedding = layers.Embedding(vocab_size, d_model) self.blocks = [Block(d_model, i, n_layers) for i in range(n_layers)] self.ln_f = layers.LayerNormalization(epsilon=1e-5, dtype=tf.float32) def call(self, x, states): x = self.token_embedding(x) new_states = [] for i, block in enumerate(self.blocks): x, b_state = block(x, states[i*5 : (i+1)*5]) new_states.extend(b_state) return self.ln_f(x), new_states # --- 2. 초기화 및 가중치 로드 --- d_model, n_layers = 512, 10 blocklm = LM(d_model, n_layers) head = Head(vocab_size) def get_init_state(): return [tf.zeros((1, 1, d_model)) if i%5!=3 else tf.ones((1, 1, d_model))*-1e30 for i in range(n_layers*5)] # Dummy call _o, _s = blocklm(tf.constant([[0]]), get_init_state()) _ = head(_o) blocklm.load_weights("blocklm.weights.h5") head.load_weights("head.weights.h5") # --- 3. 추론 엔진 --- class InferenceEngine: def __init__(self, model, head, sp): self.model = model self.head = head self.sp = sp self.pad_id = sp.piece_to_id("") if sp.piece_to_id("") != -1 else 0 self.eos_id = sp.piece_to_id("") if sp.piece_to_id("") != -1 else sp.piece_to_id("[EOS]") def apply_repetition_penalty(self, logits, generated_ids, penalty, window=64): if not generated_ids: return logits recent_ids = set(generated_ids[-window:]) for token_id in recent_ids: if logits[token_id] > 0: logits[token_id] /= penalty else: logits[token_id] *= penalty return logits def sample(self, logits, temp, top_k, top_p): if temp <= 0: return np.argmax(logits) logits = logits / temp if top_k > 0: indices_to_remove = logits < np.sort(logits)[-min(top_k, logits.shape[-1])] logits[indices_to_remove] = -float('inf') probs = tf.nn.softmax(logits).numpy() sorted_indices = np.argsort(probs)[::-1] sorted_probs = probs[sorted_indices] cumulative_probs = np.cumsum(sorted_probs) idx_to_remove = cumulative_probs > top_p if np.any(idx_to_remove): cutoff_idx = max(1, np.where(idx_to_remove)[0][0] + 1) probs[sorted_indices[cutoff_idx:]] = 0 if np.sum(probs) > 0: probs /= np.sum(probs) else: probs[sorted_indices[0]] = 1.0 return np.random.choice(len(probs), p=probs) @tf.function(reduce_retracing=True) def model_step(self, token_id, states): out, next_states = self.model(token_id, states) logits = self.head(out) return logits, next_states def generate(self, prompt, max_new_tokens, temp, top_k, top_p, penalty): input_ids = self.sp.encode(prompt) states = get_init_state() generated = [] if len(input_ids) > 1: for i in range(len(input_ids) - 1): _, states = self.model_step(tf.constant([[input_ids[i]]]), states) curr_token_id = input_ids[-1] prev_text = "" for _ in range(max_new_tokens): logits_out, states = self.model_step(tf.constant([[curr_token_id]]), states) logits = logits_out[0, 0].numpy() logits = self.apply_repetition_penalty(logits, input_ids + generated, penalty) logits[self.pad_id] = -float('inf') next_id = int(self.sample(logits, temp, top_k, top_p)) if next_id == self.eos_id: break generated.append(next_id) full_text = self.sp.decode(generated) new_part = full_text[len(prev_text):] if new_part: yield new_part prev_text = full_text curr_token_id = next_id engine = InferenceEngine(blocklm, head, sp) # --- 4. Gradio UI (단순 텍스트 입출력 방식) --- with gr.Blocks(title="RWKV Text Generator") as demo: gr.Markdown("## 🖋️ Dynamic RWKV Text Generation") gr.Markdown("질문을 입력하고 Generate를 누르면 답변이 아래 텍스트 박스에 실시간으로 생성됩니다.") with gr.Row(): with gr.Column(): input_text = gr.Textbox(lines=5, label="Input Prompt", placeholder="여기에 질문이나 문장을 입력하세요...") with gr.Row(): temp_slider = gr.Slider(0, 2, value=0.7, label="Temperature") top_p_slider = gr.Slider(0, 1, value=0.92, label="Top-P") with gr.Row(): penalty_slider = gr.Slider(1, 2, value=1.2, label="Penalty") max_tokens = gr.Slider(1, 2048, value=512, step=1, label="Max Tokens") submit_btn = gr.Button("Generate", variant="primary") clear_btn = gr.Button("Clear") with gr.Column(): output_text = gr.Textbox(lines=15, label="Generated Output", interactive=False) def run_generation(prompt, tokens, temp, top_p, penalty): if not prompt.strip(): return "프롬프트를 입력해주세요." full_prompt = f"Question: {prompt}\nAnswer:" current_output = "" for chunk in engine.generate(full_prompt, int(tokens), temp, 40, top_p, penalty): current_output += chunk yield current_output # 버튼 클릭 및 엔터 키 입력 이벤트 submit_btn.click( fn=run_generation, inputs=[input_text, max_tokens, temp_slider, top_p_slider, penalty_slider], outputs=output_text ) clear_btn.click(lambda: ("", ""), outputs=[input_text, output_text]) if __name__ == "__main__": demo.queue().launch()