Text Generation
Transformers
PyTorch
English
experimental
research
bit-level
transformer
reversible
safety
telemetry
language-modeling
Instructions to use WCNegentropy/BitTransformerLM with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use WCNegentropy/BitTransformerLM with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("text-generation", model="WCNegentropy/BitTransformerLM")# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("WCNegentropy/BitTransformerLM", dtype="auto") - Notebooks
- Google Colab
- Kaggle
- Local Apps Settings
- vLLM
How to use WCNegentropy/BitTransformerLM with vLLM:
Install from pip and serve model
# Install vLLM from pip: pip install vllm # Start the vLLM server: vllm serve "WCNegentropy/BitTransformerLM" # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:8000/v1/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "WCNegentropy/BitTransformerLM", "prompt": "Once upon a time,", "max_tokens": 512, "temperature": 0.5 }'Use Docker
docker model run hf.co/WCNegentropy/BitTransformerLM
- SGLang
How to use WCNegentropy/BitTransformerLM with SGLang:
Install from pip and serve model
# Install SGLang from pip: pip install sglang # Start the SGLang server: python3 -m sglang.launch_server \ --model-path "WCNegentropy/BitTransformerLM" \ --host 0.0.0.0 \ --port 30000 # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:30000/v1/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "WCNegentropy/BitTransformerLM", "prompt": "Once upon a time,", "max_tokens": 512, "temperature": 0.5 }'Use Docker images
docker run --gpus all \ --shm-size 32g \ -p 30000:30000 \ -v ~/.cache/huggingface:/root/.cache/huggingface \ --env "HF_TOKEN=<secret>" \ --ipc=host \ lmsysorg/sglang:latest \ python3 -m sglang.launch_server \ --model-path "WCNegentropy/BitTransformerLM" \ --host 0.0.0.0 \ --port 30000 # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:30000/v1/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "WCNegentropy/BitTransformerLM", "prompt": "Once upon a time,", "max_tokens": 512, "temperature": 0.5 }' - Docker Model Runner
How to use WCNegentropy/BitTransformerLM with Docker Model Runner:
docker model run hf.co/WCNegentropy/BitTransformerLM
| import math | |
| import contextlib | |
| import logging | |
| from typing import Dict, List, Tuple, Optional | |
| import torch | |
| import torch.distributed as dist | |
| import sys | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import torch.utils.checkpoint as checkpoint | |
| from .torch_utils import cpu_autocast | |
| from .optimization import configure_optimizer | |
| from .compression import decompress_bits | |
| from .parity import enforce_parity | |
| _mask_cache: Dict[Tuple[int, torch.device], torch.Tensor] = {} | |
| _attention_cache: Dict[str, torch.Tensor] = {} # For caching attention patterns | |
| _MAX_CACHE_SIZE = 50 # Limit cache growth | |
| def clear_cache(): | |
| """Clear memory caches to prevent OOM in long sequences.""" | |
| global _mask_cache, _attention_cache | |
| _mask_cache.clear() | |
| _attention_cache.clear() | |
| def get_tri_mask(seq_len: int, device: torch.device) -> torch.Tensor: | |
| """Return or create a cached upper-triangular mask with memory management.""" | |
| key = (seq_len, device) | |
| # Clear cache if it gets too large | |
| if len(_mask_cache) > _MAX_CACHE_SIZE: | |
| clear_cache() | |
| if key not in _mask_cache: | |
| _mask_cache[key] = torch.triu( | |
| torch.ones(seq_len, seq_len, device=device, dtype=torch.bool), 1 | |
| ) | |
| return _mask_cache[key] | |
| try: # torch.compile may not work on all Python versions | |
| if torch.__version__ and tuple(map(int, torch.__version__.split(".")[:2])) >= (2, 0) and sys.version_info < (3, 11): | |
| compile_fn = torch.compile | |
| else: | |
| raise RuntimeError | |
| except Exception: # pragma: no cover - handle missing torch or unsupported version | |
| def compile_fn(fn=None, **kwargs): | |
| if fn is None: | |
| return lambda f: f | |
| return fn | |
| class PositionalEncoding(nn.Module): | |
| """Sinusoidal positional encoding.""" | |
| def __init__(self, d_model: int, max_len: int = 1024) -> None: | |
| super().__init__() | |
| pe = torch.zeros(max_len, d_model) | |
| pos = torch.arange(0, max_len, dtype=torch.float32).unsqueeze(1) | |
| inv = torch.exp( | |
| torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model) | |
| ) | |
| pe[:, 0::2] = torch.sin(pos * inv) | |
| pe[:, 1::2] = torch.cos(pos * inv) | |
| self.register_buffer("pe", pe.unsqueeze(1)) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| """Add positional encoding to input tensor.""" | |
| return x + self.pe[: x.size(0)] | |
| class LoggingTransformerEncoderLayer(nn.Module): | |
| """Transformer encoder layer that exposes attention weights. | |
| It optionally performs chunked attention with a fixed window size. | |
| """ | |
| def __init__( | |
| self, | |
| d_model: int, | |
| nhead: int, | |
| dim_feedforward: int = 512, | |
| dropout: float = 0.1, | |
| chunk_size: Optional[int] = None, | |
| overlap: int = 0, | |
| full_attn_logging: Optional[bool] = None, | |
| ) -> None: | |
| super().__init__() | |
| self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=True) | |
| self.chunk_size = chunk_size | |
| self.overlap = overlap | |
| if full_attn_logging is None: | |
| full_attn_logging = False if chunk_size is not None else True | |
| self.full_attn_logging = full_attn_logging | |
| self.linear1 = nn.Linear(d_model, dim_feedforward) | |
| self.dropout = nn.Dropout(dropout) | |
| self.linear2 = nn.Linear(dim_feedforward, d_model) | |
| self.norm1 = nn.LayerNorm(d_model) | |
| self.norm2 = nn.LayerNorm(d_model) | |
| self.dropout1 = nn.Dropout(dropout) | |
| self.dropout2 = nn.Dropout(dropout) | |
| self.activation = F.relu | |
| def _chunked_attn( | |
| self, src: torch.Tensor, attn_mask: Optional[torch.Tensor] = None | |
| ) -> Tuple[torch.Tensor, torch.Tensor]: | |
| """Perform memory-efficient chunked self attention with overlap.""" | |
| T, B, D = src.shape | |
| # Early return for small sequences | |
| if T <= 128 or self.chunk_size is None or self.chunk_size >= T: | |
| return self._full_attn(src, attn_mask) | |
| src_b = src.transpose(0, 1) # [B, T, D] | |
| C = self.chunk_size | |
| O = self.overlap | |
| n_chunks = (T + C - 1) // C | |
| pad_len = n_chunks * C - T | |
| # Process chunks with gradient checkpointing for memory efficiency | |
| outputs = [] | |
| weights_list = [] | |
| # Use memory-efficient processing | |
| with torch.cuda.amp.autocast(enabled=torch.cuda.is_available()): | |
| for chunk_idx in range(n_chunks): | |
| start_idx = chunk_idx * C | |
| end_idx = min(start_idx + C + 2 * O, T + O) | |
| # Extract chunk with overlap | |
| chunk_start = max(0, start_idx - O) | |
| chunk_end = min(T, end_idx) | |
| chunk = src_b[:, chunk_start:chunk_end] | |
| # Pad if necessary | |
| if chunk.size(1) < C + 2 * O: | |
| pad_size = C + 2 * O - chunk.size(1) | |
| chunk = F.pad(chunk, (0, 0, 0, pad_size)) | |
| chunk_len = chunk.size(1) | |
| mask = get_tri_mask(chunk_len, src.device) if attn_mask is not None else None | |
| # Apply attention to chunk | |
| out, weights = self.self_attn( | |
| chunk, chunk, chunk, | |
| attn_mask=mask, | |
| need_weights=self.full_attn_logging, | |
| average_attn_weights=False, | |
| ) | |
| # Extract the core part (remove overlap) | |
| core_start = O if chunk_idx > 0 else 0 | |
| core_end = core_start + min(C, T - start_idx) | |
| outputs.append(out[:, core_start:core_end]) | |
| if self.full_attn_logging and weights is not None: | |
| weights_list.append(weights[:, :, core_start:core_end]) | |
| # Clear intermediate tensors to save memory | |
| del out, weights, chunk | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| # Concatenate outputs | |
| seq = torch.cat(outputs, dim=1) | |
| # Handle attention weights | |
| if self.full_attn_logging and weights_list: | |
| # Use sparse representation for large sequences | |
| if T > 1024: | |
| attn_out = torch.empty(0, device=src.device) # Skip full attention for very long sequences | |
| else: | |
| attn_out = torch.cat(weights_list, dim=2) | |
| else: | |
| attn_out = torch.empty(0, device=src.device) | |
| return seq.transpose(0, 1), attn_out | |
| def _full_attn(self, src: torch.Tensor, attn_mask: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor]: | |
| """Standard full attention for smaller sequences.""" | |
| qkv = src.transpose(0, 1) | |
| attn_output, attn_weights = self.self_attn( | |
| qkv, qkv, qkv, | |
| attn_mask=attn_mask, | |
| need_weights=True, | |
| average_attn_weights=False, | |
| ) | |
| return attn_output.transpose(0, 1), attn_weights | |
| def forward( | |
| self, src: torch.Tensor, attn_mask: Optional[torch.Tensor] = None | |
| ) -> Tuple[torch.Tensor, torch.Tensor]: | |
| """Return output and attention map.""" | |
| if self.chunk_size is not None: | |
| attn_output, attn_weights = self._chunked_attn(src, attn_mask) | |
| else: | |
| qkv = src.transpose(0, 1) | |
| attn_output, attn_weights = self.self_attn( | |
| qkv, | |
| qkv, | |
| qkv, | |
| attn_mask=attn_mask, | |
| need_weights=True, | |
| average_attn_weights=False, | |
| ) | |
| attn_output = attn_output.transpose(0, 1) | |
| src = src + self.dropout1(attn_output) | |
| src = self.norm1(src) | |
| out = self.linear2(self.dropout(self.activation(self.linear1(src)))) | |
| src = src + self.dropout2(out) | |
| src = self.norm2(src) | |
| return src, attn_weights.detach() | |
| class ReversibleLoggingTransformerEncoderLayer(nn.Module): | |
| """Reversible transformer encoder layer with checkpointing.""" | |
| def __init__( | |
| self, | |
| d_model: int, | |
| nhead: int, | |
| dim_feedforward: int = 512, | |
| dropout: float = 0.1, | |
| chunk_size: Optional[int] = None, | |
| overlap: int = 0, | |
| full_attn_logging: Optional[bool] = None, | |
| ) -> None: | |
| super().__init__() | |
| self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=True) | |
| self.chunk_size = chunk_size | |
| self.overlap = overlap | |
| if full_attn_logging is None: | |
| full_attn_logging = False if chunk_size is not None else True | |
| self.full_attn_logging = full_attn_logging | |
| self.linear1 = nn.Linear(d_model, dim_feedforward) | |
| self.dropout = nn.Dropout(dropout) | |
| self.linear2 = nn.Linear(dim_feedforward, d_model) | |
| self.norm1 = nn.LayerNorm(d_model) | |
| self.norm2 = nn.LayerNorm(d_model) | |
| self.dropout1 = nn.Dropout(dropout) | |
| self.dropout2 = nn.Dropout(dropout) | |
| self.activation = F.relu | |
| def _sa_block( | |
| self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None | |
| ) -> Tuple[torch.Tensor, torch.Tensor]: | |
| if self.chunk_size is not None: | |
| T, B, D = x.shape | |
| x_b = x.transpose(0, 1) | |
| C = self.chunk_size or T | |
| O = self.overlap | |
| n_chunks = (T + C - 1) // C | |
| pad_len = n_chunks * C - T | |
| src_pad = F.pad(x_b, (0, 0, O, pad_len + O)) | |
| chunk_len = C + 2 * O | |
| chunks = src_pad.unfold(1, chunk_len, C) | |
| mask = get_tri_mask(chunk_len, x.device) if attn_mask is not None else None | |
| out, weights = self.self_attn( | |
| chunks.reshape(B * n_chunks, chunk_len, D), | |
| chunks.reshape(B * n_chunks, chunk_len, D), | |
| chunks.reshape(B * n_chunks, chunk_len, D), | |
| attn_mask=mask, | |
| need_weights=True, | |
| average_attn_weights=False, | |
| ) | |
| out = out.view(B, n_chunks, chunk_len, D)[:, :, O : O + C] | |
| weights = weights.view(B, n_chunks, self.self_attn.num_heads, chunk_len, chunk_len)[ | |
| :, :, :, O : O + C | |
| ] | |
| seq = out.reshape(B, n_chunks * C, D)[:, :T] | |
| if self.full_attn_logging and C < T: | |
| full_attn = torch.zeros( | |
| B, self.self_attn.num_heads, n_chunks * C, n_chunks * C, device=x.device | |
| ) | |
| for idx in range(n_chunks): | |
| s = idx * C | |
| start = max(s - O, 0) | |
| end = min(s + C, n_chunks * C) | |
| src_start = O - (s - start) | |
| src_end = src_start + (end - start) | |
| full_attn[:, :, s : s + C, start:end] = weights[ | |
| :, idx, :, src_start:src_end | |
| ] | |
| full_attn = full_attn[:, :, :T, :T] | |
| weights = full_attn.detach() | |
| else: | |
| weights = torch.empty(0, device=x.device) | |
| attn_out = seq.transpose(0, 1) | |
| else: | |
| qkv = x.transpose(0, 1) | |
| attn_out, weights = self.self_attn( | |
| qkv, | |
| qkv, | |
| qkv, | |
| attn_mask=attn_mask, | |
| need_weights=True, | |
| average_attn_weights=False, | |
| ) | |
| attn_out = attn_out.transpose(0, 1) | |
| x = self.norm1(x + self.dropout1(attn_out)) | |
| return x, weights.detach() | |
| def _ff_block(self, x: torch.Tensor) -> torch.Tensor: | |
| out = self.linear2(self.dropout(self.activation(self.linear1(x)))) | |
| x = self.norm2(x + self.dropout2(out)) | |
| return x | |
| def forward( | |
| self, | |
| x1: torch.Tensor, | |
| x2: torch.Tensor, | |
| attn_mask: Optional[torch.Tensor] = None, | |
| ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | |
| y1, weights = self._sa_block(x2, attn_mask) | |
| y1 = x1 + y1 | |
| y2 = x2 + self._ff_block(y1) | |
| return y1, y2, weights | |
| class BitTransformerLM(nn.Module): | |
| """Transformer language model that operates on raw bits (0/1) with telemetry.""" | |
| def __init__( | |
| self, | |
| d_model: int = 128, | |
| nhead: int = 8, | |
| num_layers: int = 4, | |
| dim_feedforward: int = 512, | |
| max_seq_len: int = 1024, | |
| lambda_K: float = 1.0, | |
| lambda_C: float = 1.0, | |
| lambda_S: float = 1.0, | |
| reversible: bool = False, | |
| use_checkpoint: bool = True, | |
| use_autocast: bool = False, | |
| use_act: bool = False, | |
| act_threshold: float = 0.9, | |
| chunk_size: Optional[int] = None, | |
| overlap: int = 0, | |
| full_attn_logging: Optional[bool] = None, | |
| ) -> None: | |
| """Create a BitTransformer language model. | |
| Args: | |
| full_attn_logging: When ``False`` and ``chunk_size`` is | |
| smaller than the sequence length, the model skips | |
| reconstructing the full ``T×T`` attention matrices for | |
| telemetry to reduce memory use. | |
| """ | |
| super().__init__() | |
| self.d_model = d_model | |
| self.num_layers = num_layers | |
| self.lambda_K = lambda_K | |
| self.lambda_C = lambda_C | |
| self.lambda_S = lambda_S | |
| self.reversible = reversible | |
| self.use_checkpoint = use_checkpoint | |
| self.use_autocast = use_autocast | |
| self.use_act = use_act | |
| self.act_threshold = act_threshold | |
| self.chunk_size = chunk_size | |
| self.overlap = overlap | |
| if full_attn_logging is None: | |
| full_attn_logging = False if chunk_size is not None else True | |
| self.full_attn_logging = full_attn_logging | |
| # Bit embedding: two possible input values | |
| self.embedding = nn.Embedding(2, d_model) | |
| self.pos_enc = PositionalEncoding(d_model, max_len=max_seq_len) | |
| layer_cls = ( | |
| ReversibleLoggingTransformerEncoderLayer | |
| if reversible | |
| else LoggingTransformerEncoderLayer | |
| ) | |
| self.layers = nn.ModuleList( | |
| [ | |
| layer_cls( | |
| d_model=d_model, | |
| nhead=nhead, | |
| dim_feedforward=dim_feedforward, | |
| chunk_size=chunk_size, | |
| overlap=overlap, | |
| full_attn_logging=full_attn_logging, | |
| ) | |
| for _ in range(num_layers) | |
| ] | |
| ) | |
| if self.use_act: | |
| self.halt_projs = nn.ModuleList( | |
| [nn.Linear(d_model, 1) for _ in range(num_layers)] | |
| ) | |
| self.out_head = nn.Linear(d_model, 2) # output logits for bit=0 or bit=1 | |
| def expand_positional_encoding(self, new_len: int) -> None: | |
| """Expand positional encoding to at least ``new_len``.""" | |
| cur_len = self.pos_enc.pe.size(0) | |
| if new_len <= cur_len: | |
| return | |
| device = self.pos_enc.pe.device | |
| d_model = self.d_model | |
| pe = torch.zeros(new_len, d_model, device=device) | |
| pe[:cur_len] = self.pos_enc.pe.squeeze(1) | |
| pos = torch.arange(cur_len, new_len, dtype=torch.float32, device=device).unsqueeze(1) | |
| inv = torch.exp(torch.arange(0, d_model, 2, device=device).float() * -(math.log(10000.0) / d_model)) | |
| pe[cur_len:, 0::2] = torch.sin(pos * inv) | |
| pe[cur_len:, 1::2] = torch.cos(pos * inv) | |
| self.pos_enc.pe = pe.unsqueeze(1) | |
| def set_lambdas(self, lambda_K: float, lambda_C: float, lambda_S: float) -> None: | |
| """Update weighting coefficients for telemetry metrics.""" | |
| self.lambda_K = lambda_K | |
| self.lambda_C = lambda_C | |
| self.lambda_S = lambda_S | |
| def _maybe_decompress(self, codes: torch.Tensor) -> torch.Tensor: | |
| """Return raw bit sequences, decompressing if input appears run-length encoded.""" | |
| if codes.dim() <= 1: | |
| return codes | |
| needs_decompress = codes.max().item() > 1 | |
| if not needs_decompress and codes.size(1) % 2 == 0: | |
| vals = codes[:, 0::2] | |
| if torch.all(vals[:, 1:] != vals[:, :-1]): | |
| needs_decompress = True | |
| if not needs_decompress: | |
| return codes | |
| seqs = [decompress_bits(row.to(torch.uint8)) for row in codes] | |
| max_len = max(seq.numel() for seq in seqs) | |
| padded = [F.pad(seq, (0, max_len - seq.numel())) for seq in seqs] | |
| return torch.stack(padded) | |
| def negentropy_kpi(self, codes: torch.Tensor) -> torch.Tensor: | |
| """Approximate negentropy of bit sequences. | |
| Returns a value in ``[0, 1]`` where ``1`` denotes a perfectly ordered | |
| sequence (all zeros or ones) and ``0`` reflects maximal entropy. | |
| """ | |
| codes = self._maybe_decompress(codes) | |
| p = codes.float().mean(dim=1) | |
| entropy = -(p * torch.log(p + 1e-9) + (1 - p) * torch.log(1 - p + 1e-9)) | |
| max_e = math.log(2.0) | |
| return 1 - entropy / max_e | |
| def lz_complexity(self, codes: torch.Tensor) -> torch.Tensor: | |
| """Differentiable proxy for Lempel–Ziv complexity. | |
| Values near ``0`` indicate highly compressible sequences while values | |
| approaching ``1`` correspond to rapid bit alternation. | |
| """ | |
| codes = self._maybe_decompress(codes) | |
| diffs = torch.abs(codes[:, 1:] - codes[:, :-1]) | |
| return diffs.float().mean(dim=1) | |
| def negentropy_logits(self, logits: torch.Tensor, detach: bool = True) -> torch.Tensor: | |
| """Negentropy computed from model logits. | |
| Parameters | |
| ---------- | |
| logits: ``torch.Tensor`` | |
| Logit tensor of shape ``(B, T, 2)``. | |
| detach: bool, default ``True`` | |
| When ``True`` the computation is detached from the autograd graph. | |
| """ | |
| assert logits.dim() == 3 and logits.size(-1) == 2, "logits must be [B,T,2]" | |
| prob = logits.softmax(-1) | |
| if detach: | |
| prob = prob.detach() | |
| p = prob[..., 1].mean(dim=1) | |
| entropy = -(p * torch.log(p + 1e-9) + (1 - p) * torch.log(1 - p + 1e-9)) | |
| max_e = math.log(2.0) | |
| return 1 - entropy / max_e | |
| def lz_complexity_logits(self, logits: torch.Tensor, detach: bool = True) -> torch.Tensor: | |
| """LZ complexity proxy computed from logits. | |
| Parameters | |
| ---------- | |
| logits: ``torch.Tensor`` | |
| Logit tensor of shape ``(B, T, 2)``. | |
| detach: bool, default ``True`` | |
| When ``True`` the computation is detached from the autograd graph. | |
| """ | |
| assert logits.dim() == 3 and logits.size(-1) == 2, "logits must be [B,T,2]" | |
| prob = logits.softmax(-1) | |
| if detach: | |
| prob = prob.detach() | |
| prob1 = prob[..., 1] | |
| diffs = torch.abs(prob1[:, 1:] - prob1[:, :-1]) | |
| return diffs.mean(dim=1) | |
| def symbiosis_kl_logits( | |
| self, logits: torch.Tensor, ref_prob: float = 0.5, detach: bool = True | |
| ) -> torch.Tensor: | |
| """Symbiosis score from KL divergence to a reference distribution. | |
| Returns a value in ``[0, 1]`` with ``1`` meaning perfect agreement with | |
| the reference distribution and ``0`` indicating maximal divergence. | |
| """ | |
| assert logits.dim() == 3 and logits.size(-1) == 2, "logits must be [B,T,2]" | |
| probs = logits.softmax(-1) | |
| if detach: | |
| probs = probs.detach() | |
| ref = torch.tensor([1 - ref_prob, ref_prob], device=logits.device) | |
| kl = (probs * (probs.clamp_min(1e-9).log() - ref.log())).sum(-1).mean(dim=1) | |
| max_kl = math.log(2.0) | |
| return 1 - kl / max_kl | |
| def _act_step( | |
| self, | |
| hidden: torch.Tensor, | |
| idx: int, | |
| halt_prob: torch.Tensor, | |
| act_state: torch.Tensor, | |
| halt_history: List[torch.Tensor], | |
| ) -> Tuple[torch.Tensor, torch.Tensor, bool]: | |
| """Apply one step of ACT halting logic.""" | |
| p = torch.sigmoid(self.halt_projs[idx](hidden)) | |
| delta = (1 - halt_prob) * p | |
| halt_prob = halt_prob + delta | |
| act_state = act_state + hidden * delta | |
| halt_history.append(halt_prob.detach()) | |
| min_prob = halt_prob.detach().min() | |
| if dist.is_initialized(): | |
| dist.all_reduce(min_prob, op=dist.ReduceOp.MIN) | |
| return halt_prob, act_state, min_prob.item() >= self.act_threshold | |
| def forward( | |
| self, bit_seq: torch.Tensor, causal: bool = True | |
| ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: | |
| """Forward pass returning logits and telemetry from the same graph. | |
| By default the model uses causal masking and (optional) chunked | |
| attention. When ``causal`` is ``False`` the model operates in | |
| "Diffusion LM" mode. In this mode chunked attention is temporarily | |
| disabled so that every token can attend to the full sequence | |
| bidirectionally. The original chunking configuration is restored after | |
| the forward pass. | |
| """ | |
| # Disable chunking when running in bidirectional (non-causal) mode | |
| orig_chunks = None | |
| orig_model_chunk = None | |
| if not causal and self.chunk_size is not None: | |
| orig_model_chunk = self.chunk_size | |
| orig_chunks = [layer.chunk_size for layer in self.layers] | |
| self.chunk_size = None | |
| for layer in self.layers: | |
| layer.chunk_size = None | |
| try: | |
| ctx = cpu_autocast() if self.use_autocast else contextlib.nullcontext() | |
| with ctx: | |
| x = self.embedding(bit_seq).transpose(0, 1) * math.sqrt(self.d_model) | |
| x = self.pos_enc(x) | |
| attn_mask = get_tri_mask(x.size(0), x.device) if causal else None | |
| activations: List[torch.Tensor] = [] | |
| attn_maps: List[torch.Tensor] = [] | |
| halt_history: List[torch.Tensor] = [] | |
| if self.use_act: | |
| halt_prob = torch.zeros(x.size(0), x.size(1), 1, device=x.device) | |
| act_state = torch.zeros_like(x) | |
| if self.reversible: | |
| x1, x2 = x, x | |
| for idx, layer in enumerate(self.layers): | |
| if self.use_checkpoint: | |
| x1, x2, attn = checkpoint.checkpoint( | |
| layer, x1, x2, attn_mask | |
| ) | |
| else: | |
| x1, x2, attn = layer(x1, x2, attn_mask) | |
| combined = (x1 + x2) / 2 | |
| activations.append(combined) | |
| if attn.numel() > 0: | |
| attn_maps.append(attn) | |
| if self.use_act: | |
| halt_prob, act_state, should_break = self._act_step( | |
| combined, idx, halt_prob, act_state, halt_history | |
| ) | |
| if should_break: | |
| break | |
| x = (x1 + x2) / 2 | |
| else: | |
| for idx, layer in enumerate(self.layers): | |
| if self.use_checkpoint: | |
| x, attn = checkpoint.checkpoint(layer, x, attn_mask) | |
| else: | |
| x, attn = layer(x, attn_mask) | |
| activations.append(x) | |
| if attn.numel() > 0: | |
| attn_maps.append(attn) | |
| if self.use_act: | |
| halt_prob, act_state, should_break = self._act_step( | |
| x, idx, halt_prob, act_state, halt_history | |
| ) | |
| if should_break: | |
| break | |
| if self.use_act: | |
| act_state = act_state + x * (1 - halt_prob) | |
| x = act_state | |
| logits = self.out_head(x) | |
| # Per-layer entropy of activations | |
| entropies = [] | |
| for act in activations: | |
| prob = act.softmax(-1) | |
| ent = -(prob * prob.clamp_min(1e-9).log()).sum(-1).mean() | |
| entropies.append(ent) | |
| attn_entropies = [] | |
| for attn in attn_maps: | |
| prob = attn # weights are already softmaxed | |
| ent = -(prob * prob.clamp_min(1e-9).log()).sum(-1) | |
| ent = ent.mean(1) | |
| attn_entropies.append(ent) | |
| if attn_entropies: | |
| attn_entropy_map = torch.stack(attn_entropies).mean(0) | |
| else: | |
| attn_entropy_map = torch.zeros( | |
| bit_seq.size(0), bit_seq.size(1), device=bit_seq.device | |
| ) | |
| max_ent = math.log(attn_entropy_map.size(-1)) | |
| attn_entropy_map = attn_entropy_map / max_ent | |
| attn_entropy = attn_entropy_map.mean(1) | |
| logits_bt = logits.transpose(0, 1) | |
| negentropy_in = self.negentropy_kpi(bit_seq) | |
| lz_in = self.lz_complexity(bit_seq.float()) | |
| negentropy_logits_b = self.negentropy_logits(logits_bt, detach=False) | |
| lz_logits_b = self.lz_complexity_logits(logits_bt, detach=False) | |
| kl_div_b = self.symbiosis_kl_logits(logits_bt, detach=False) | |
| raw_sym = ( | |
| (self.lambda_K * negentropy_logits_b + self.lambda_C * lz_logits_b) / 2 | |
| + negentropy_logits_b * lz_logits_b | |
| - self.lambda_S * kl_div_b | |
| - 0.1 * attn_entropy | |
| ) | |
| weight_norm = torch.stack([p.norm() for p in self.parameters()]).mean().detach() | |
| raw_sym = raw_sym - 0.01 * weight_norm | |
| sym_score = torch.sigmoid(raw_sym) | |
| B, T = bit_seq.shape | |
| assert logits_bt.shape[:2] == (B, T) | |
| assert attn_entropy_map.shape == (B, T) | |
| telemetry = { | |
| "activations": activations, | |
| "attention_maps": attn_maps, | |
| "attention_entropy": attn_entropy_map, | |
| "entropy": entropies, | |
| "attention_entropy_mean": attn_entropy, | |
| "negentropy_input": negentropy_in.detach(), | |
| "lz_complexity_input": lz_in.detach(), | |
| "negentropy_logits": negentropy_logits_b.detach(), | |
| "lz_complexity_logits": lz_logits_b.detach(), | |
| "symbiosis_kl": kl_div_b.detach(), | |
| "symbiosis_score": sym_score.detach(), | |
| } | |
| if self.use_act: | |
| telemetry["halt_probs"] = halt_history | |
| return logits_bt, telemetry | |
| finally: | |
| if orig_chunks is not None: | |
| self.chunk_size = orig_model_chunk | |
| for layer, chunk in zip(self.layers, orig_chunks): | |
| layer.chunk_size = chunk | |
| def forward_compressed( | |
| self, compressed_bits, causal: bool = True | |
| ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: | |
| """Decompress bit sequences then run the normal forward pass.""" | |
| if isinstance(compressed_bits, torch.Tensor) and compressed_bits.dim() == 1: | |
| sequences = [decompress_bits(compressed_bits).to(torch.long)] | |
| else: | |
| sequences = [decompress_bits(c).to(torch.long) for c in compressed_bits] | |
| lengths = [seq.numel() for seq in sequences] | |
| if len(set(lengths)) != 1: | |
| raise ValueError("Sequences decompress to different lengths") | |
| bits = torch.stack(sequences) | |
| return self.forward(bits, causal=causal) | |
| def _current_params(self) -> Dict: | |
| """Return a dictionary with the current model hyperparameters.""" | |
| return { | |
| "d_model": self.d_model, | |
| "nhead": self.layers[0].self_attn.num_heads, | |
| "num_layers": self.num_layers, | |
| "dim_feedforward": self.layers[0].linear1.out_features, | |
| "max_seq_len": self.pos_enc.pe.size(0), | |
| "lambda_K": self.lambda_K, | |
| "lambda_C": self.lambda_C, | |
| "lambda_S": self.lambda_S, | |
| "reversible": self.reversible, | |
| "use_checkpoint": self.use_checkpoint, | |
| "use_autocast": self.use_autocast, | |
| "use_act": self.use_act, | |
| "act_threshold": self.act_threshold, | |
| "chunk_size": self.chunk_size, | |
| "overlap": self.overlap, | |
| } | |
| def double_width(self) -> "BitTransformerLM": | |
| """Return a copy of the model with doubled hidden size.""" | |
| from .scale import expand_model | |
| params = self._current_params() | |
| params["d_model"] *= 2 | |
| params["dim_feedforward"] *= 2 | |
| return expand_model(self, params) | |
| def double_layers(self) -> "BitTransformerLM": | |
| """Return a copy of the model with twice as many layers.""" | |
| from .scale import expand_model | |
| params = self._current_params() | |
| params["num_layers"] *= 2 | |
| return expand_model(self, params) | |
| def double_length(self) -> "BitTransformerLM": | |
| """Return a copy of the model with doubled maximum sequence length.""" | |
| from .scale import expand_model | |
| params = self._current_params() | |
| params["max_seq_len"] *= 2 | |
| params["chunk_size"] = params["max_seq_len"] | |
| return expand_model(self, params) | |
| def train_full_sequence( | |
| self, | |
| bits: torch.Tensor, | |
| *, | |
| ctx_bits: int = 4096, | |
| detach_every_n: int = 1_048_576, | |
| ) -> float: | |
| """Train on a long bit tensor using sliding windows. | |
| Parameters | |
| ---------- | |
| bits: ``torch.Tensor`` | |
| 1D tensor containing the full bit sequence. | |
| ctx_bits: int | |
| Size of the training context window. | |
| detach_every_n: int | |
| Interval in bits for optimizer updates and graph detachment. | |
| Returns | |
| ------- | |
| float | |
| Mean loss over all windows. | |
| """ | |
| self.train() | |
| optimizer, scheduler = configure_optimizer( | |
| self, lr=1e-3, total_steps=max(1, bits.numel() // ctx_bits) | |
| ) | |
| accum = 0 | |
| total_loss = 0.0 | |
| count = 0 | |
| for start in range(0, bits.numel() - ctx_bits - 1, ctx_bits): | |
| segment = bits[start : start + ctx_bits + 1].unsqueeze(0) | |
| logits, _ = self(segment) | |
| pred = logits[:, :-1, :].reshape(-1, 2) | |
| target = segment[:, 1:].reshape(-1) | |
| loss = F.cross_entropy(pred, target) | |
| loss.backward() | |
| accum += ctx_bits | |
| total_loss += loss.item() | |
| count += 1 | |
| if accum >= detach_every_n: | |
| torch.nn.utils.clip_grad_norm_(self.parameters(), 1.0) | |
| optimizer.step() | |
| scheduler.step() | |
| optimizer.zero_grad() | |
| accum = 0 | |
| if accum > 0: | |
| torch.nn.utils.clip_grad_norm_(self.parameters(), 1.0) | |
| optimizer.step() | |
| scheduler.step() | |
| optimizer.zero_grad() | |
| return total_loss / max(1, count) | |
| def infer_long_sequence( | |
| model: BitTransformerLM, | |
| bits: torch.Tensor, | |
| *, | |
| ctx_bits: int = 4096, | |
| overlap: int = 256, | |
| ) -> Tuple[torch.Tensor, List[Dict[str, torch.Tensor]]]: | |
| """Infer a long bit sequence using sliding windows with overlap.""" | |
| model.eval() | |
| device = next(model.parameters()).device | |
| bits = bits.to(device) | |
| step = ctx_bits - overlap | |
| outputs: List[torch.Tensor] = [] | |
| logs: List[Dict[str, torch.Tensor]] = [] | |
| for start in range(0, bits.numel(), step): | |
| window = bits[start : start + ctx_bits].unsqueeze(0) | |
| logits, tele = model(window, causal=True) | |
| pred = logits.argmax(-1).squeeze(0) | |
| outputs.append(pred) | |
| logs.append(tele) | |
| out = torch.cat(outputs)[: bits.numel()] | |
| return out, logs | |
| def diffusion_inference( | |
| model: BitTransformerLM, | |
| *, | |
| length: int, | |
| steps: int = 8, | |
| batch_size: int = 1, | |
| init_bits: Optional[torch.Tensor] = None, | |
| schedule: str = "linear", | |
| ) -> torch.Tensor: | |
| """Generate bit sequences using iterative denoising diffusion. | |
| Parameters | |
| ---------- | |
| model: ``BitTransformerLM`` | |
| The model used for denoising. It is run in non-causal mode with | |
| chunked attention disabled, enabling full-context bidirectional | |
| attention. | |
| length: int | |
| Length of the bit sequences to generate. | |
| steps: int, default ``8`` | |
| Number of denoising iterations. More steps generally yield sharper | |
| samples at the cost of compute. | |
| batch_size: int, default ``1`` | |
| Number of sequences to generate in parallel. | |
| init_bits: ``torch.Tensor`` | ``None`` | |
| Optional initial noisy bits of shape ``(batch_size, length)``. When | |
| ``None`` random noise is used. | |
| schedule: str, default ``"linear"`` | |
| Noise schedule for the denoising mask probability. Options are | |
| ``"linear"``, ``"cosine"``, and ``"exp"``. | |
| Returns | |
| ------- | |
| ``torch.Tensor`` | |
| A tensor of shape ``(batch_size, length)`` containing generated bits. | |
| """ | |
| model.eval() | |
| device = next(model.parameters()).device | |
| if init_bits is None: | |
| bits = torch.randint(0, 2, (batch_size, length), device=device) | |
| else: | |
| bits = init_bits.to(device) | |
| if bits.shape != (batch_size, length): | |
| raise ValueError("init_bits must have shape (batch_size, length)") | |
| for step in range(steps): | |
| logits, _ = model(bits, causal=False) | |
| prob = logits.softmax(-1)[..., 1] | |
| t = (step + 1) / steps | |
| if schedule == "linear": | |
| mask_prob = 1.0 - t | |
| elif schedule == "cosine": | |
| mask_prob = math.cos(math.pi * t / 2) | |
| elif schedule == "exp": | |
| mask_prob = math.exp(-5 * t) | |
| else: | |
| raise ValueError(f"unknown schedule: {schedule}") | |
| mask = (torch.rand_like(bits.float()) < mask_prob).long() | |
| sampled = torch.bernoulli(prob).long() | |
| bits = torch.where(mask.bool(), sampled, bits) | |
| if bits.shape[-1] % 9 == 0: | |
| bits, corrections = enforce_parity(bits) | |
| if corrections: | |
| logging.info("Parity corrections applied: %d", corrections) | |
| try: | |
| from .safety import hil_safe_inference | |
| hil_safe_inference(model, bits, causal=False, strict=False) | |
| except RuntimeError as exc: | |
| logging.warning("Safety gate warning: %s", exc) | |
| return bits | |
| def example_usage() -> float: | |
| """Run the example from the README and return the loss.""" | |
| B, L = 4, 16 | |
| model = BitTransformerLM( | |
| d_model=64, nhead=4, num_layers=2, dim_feedforward=256, max_seq_len=L | |
| ) | |
| bits = torch.randint(0, 2, (B, L), dtype=torch.long) | |
| logits, _ = model(bits) | |
| pred = logits[:, :-1, :].reshape(-1, 2) | |
| target = bits[:, 1:].reshape(-1) | |
| loss = F.cross_entropy(pred, target) | |
| return loss.item() | |
| def example_training_step() -> Tuple[float, Dict[str, torch.Tensor]]: | |
| """Demonstrate a training step where metrics do not affect gradients.""" | |
| B, L = 4, 16 | |
| model = BitTransformerLM( | |
| d_model=32, nhead=4, num_layers=1, dim_feedforward=64, max_seq_len=L | |
| ) | |
| optimizer, scheduler = configure_optimizer(model, lr=1e-3, total_steps=1) | |
| bits = torch.randint(0, 2, (B, L), dtype=torch.long) | |
| logits, telemetry = model(bits) | |
| pred = logits[:, :-1, :].reshape(-1, 2) | |
| target = bits[:, 1:].reshape(-1) | |
| loss = F.cross_entropy(pred, target) | |
| loss.backward() | |
| torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) | |
| optimizer.step() | |
| scheduler.step() | |
| optimizer.zero_grad() | |
| return loss.item(), telemetry | |
| if __name__ == "__main__": | |
| loss, telemetry = example_training_step() | |
| print("Composite loss:", loss) | |
| print("Telemetry keys:", list(telemetry.keys())) | |