File size: 5,889 Bytes
9d43dda
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
#!/usr/bin/env python3
"""
BINARY TRANSFORMER - Raw network bytes → neural network
No tokenizer. No preprocessing. Just bytes.

Vocab = 256 (one token per byte value 0x00-0xFF)
Input: Raw bytes from network stream via stdin
"""

import sys
import math
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
from collections import deque

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.backends.cuda.matmul.allow_tf32 = True

# Binary model config - TINY for speed
CONFIG = {
    "d": 128,       # smaller embedding
    "layers": 3,    # fewer layers
    "heads": 4,     
    "vocab": 256,   # ONE TOKEN PER BYTE
    "ctx": 1024,    # longer context (bytes are fine-grained)
}

LR = 3e-4
UPDATE_EVERY = 64   # bytes between updates
PRINT_EVERY = 50000 # bytes between stats

class ByteAttention(nn.Module):
    def __init__(self, d, h):
        super().__init__()
        self.h, self.dk = h, d // h
        self.qkv = nn.Linear(d, 3 * d, bias=False)
        self.proj = nn.Linear(d, d, bias=False)
        
    def forward(self, x, mask=None):
        B, N, D = x.shape
        qkv = self.qkv(x).view(B, N, 3, self.h, self.dk).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]
        att = (q @ k.transpose(-1, -2)) / math.sqrt(self.dk)
        if mask is not None:
            att = att + mask
        return self.proj((F.softmax(att, -1) @ v).transpose(1, 2).reshape(B, N, D))

class ByteBlock(nn.Module):
    def __init__(self, d, h):
        super().__init__()
        self.ln1, self.ln2 = nn.LayerNorm(d), nn.LayerNorm(d)
        self.attn = ByteAttention(d, h)
        self.ff = nn.Sequential(nn.Linear(d, 4*d), nn.GELU(), nn.Linear(4*d, d))
        
    def forward(self, x, mask):
        x = x + self.attn(self.ln1(x), mask)
        return x + self.ff(self.ln2(x))

class BinaryTransformer(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        d, L, h, V = cfg["d"], cfg["layers"], cfg["heads"], cfg["vocab"]
        self.emb = nn.Embedding(V, d)  # 256 embeddings, one per byte
        self.blocks = nn.ModuleList([ByteBlock(d, h) for _ in range(L)])
        self.ln = nn.LayerNorm(d)
        self.head = nn.Linear(d, V, bias=False)
        self.head.weight = self.emb.weight  # tie weights
        
    def forward(self, x):
        B, N = x.shape
        mask = torch.triu(torch.ones(N, N, device=x.device), 1) * -1e9
        h = self.emb(x)
        for block in self.blocks:
            h = block(h, mask)
        return self.head(self.ln(h))
    
    def count_params(self):
        return sum(p.numel() for p in self.parameters())

class BinaryTrainer:
    def __init__(self, model, lr=LR):
        self.model = model.to(DEVICE)
        self.opt = torch.optim.AdamW(model.parameters(), lr=lr)
        self.ctx_size = CONFIG["ctx"]
        self.buffer = deque(maxlen=self.ctx_size + 1)
        
        self.bytes_seen = 0
        self.total_loss = 0.0
        self.updates = 0
        self.start_time = time.time()
        
    def ingest_byte(self, byte_val):
        """Absorb a single byte (0-255)"""
        self.buffer.append(byte_val)
        self.bytes_seen += 1
        
        if len(self.buffer) >= UPDATE_EVERY + 1 and self.bytes_seen % UPDATE_EVERY == 0:
            self._update()
            
        if self.bytes_seen % PRINT_EVERY == 0:
            self._print_stats()
            
        # Save checkpoint every 500k bytes
        if self.bytes_seen % 500000 == 0 and self.bytes_seen > 0:
            self._save()
    
    def _update(self):
        tokens = list(self.buffer)
        x = torch.tensor(tokens[:-1], device=DEVICE, dtype=torch.long).unsqueeze(0)
        y = torch.tensor(tokens[1:], device=DEVICE, dtype=torch.long).unsqueeze(0)
        
        self.model.train()
        logits = self.model(x)
        loss = F.cross_entropy(
            logits[:, -UPDATE_EVERY:].reshape(-1, 256),
            y[:, -UPDATE_EVERY:].reshape(-1)
        )
        
        self.opt.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
        self.opt.step()
        
        self.total_loss += loss.item()
        self.updates += 1
    
    def _print_stats(self):
        elapsed = time.time() - self.start_time
        rate = self.bytes_seen / elapsed if elapsed > 0 else 0
        avg_loss = self.total_loss / max(1, self.updates)
        mb = self.bytes_seen / 1_000_000
        
        # Bits per byte (compression metric) - log2(256)=8 is random, lower is learning
        bpb = avg_loss / math.log(2)
        
        print(f"[{elapsed:.0f}s] {mb:.2f}MB | {rate/1000:.1f} KB/s | "
              f"loss={avg_loss:.3f} | bpb={bpb:.2f} | updates={self.updates}", flush=True)
    
    def _save(self):
        avg_loss = self.total_loss / max(1, self.updates)
        mb = self.bytes_seen // 1_000_000
        ckpt = {
            "model": self.model.state_dict(),
            "bytes": self.bytes_seen,
            "loss": avg_loss,
        }
        torch.save(ckpt, f"byte_ckpt_{mb}mb.pt")
        print(f"[SAVED] {mb}MB checkpoint", flush=True)

def main():
    print(f"BINARY TRANSFORMER - Raw bytes learning", flush=True)
    print(f"Config: {CONFIG}", flush=True)
    print(f"Device: {DEVICE}", flush=True)
    
    model = BinaryTransformer(CONFIG)
    params = model.count_params()
    print(f"Parameters: {params:,} ({params/1e6:.1f}M)", flush=True)
    print(f"Vocab: 256 (one per byte)", flush=True)
    
    trainer = BinaryTrainer(model)
    
    print(f"Listening for raw bytes on stdin...", flush=True)
    
    # Read raw bytes from stdin
    while True:
        byte = sys.stdin.buffer.read(1)
        if not byte:
            break
        trainer.ingest_byte(byte[0])
    
    print(f"Stream ended. Total bytes: {trainer.bytes_seen:,}", flush=True)

if __name__ == "__main__":
    main()