Загрузка данных


pip install torch datasets tqdm


import torch
import torch.nn as nn
import copy
import time
import os
from datasets import load_dataset
from tqdm import tqdm

# --- 1. Model Architecture (~10k parameters) ---
class TinyTransformer(nn.Module):
    def __init__(self, vocab_size=128, d_model=16, n_heads=4, n_layers=1, seq_len=128, dim_feedforward=64):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.pos_encoding = nn.Parameter(torch.zeros(1, seq_len, d_model))
        
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model, 
            nhead=n_heads, 
            dim_feedforward=dim_feedforward, 
            batch_first=True,
            norm_first=True
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=n_layers)
        self.head = nn.Linear(d_model, vocab_size)

    def forward(self, x):
        x = self.embedding(x) + self.pos_encoding[:, :x.size(1), :]
        x = self.transformer(x)
        return self.head(x)

# --- 2. EML-V4 Optimizer Logic (Compiled) ---
@torch.compile
def eml_v4_step(p, g, m_s, v_s, g_m, sr_val, lr_base=1e-3):
    def clamp(x): return torch.clamp(x, min=-10.0, max=5.0)
    exp, ln, eps = torch.exp, lambda x: torch.log1p(torch.abs(x)), 1e-6
    gn, mn = g/(torch.sqrt(v_s)+eps), m_s/(torch.sqrt(v_s)+eps)
    lv, ws, dg = torch.log(v_s+eps), p*1e-3, g-g_m
    sr = torch.full_like(p, sr_val)
    
    # Symbolic DAG Nodes (Simplified V4)
    h1 = (-0.058077*((0.326463*gn) + (-0.088942*mn) + (0.116864*lv) + (0.027878*ws) + (0.447840*sr) + (0.128187*dg))) + (0.073431*(exp(clamp((0.326463*gn) + (-0.088942*mn) + (0.116864*lv) + (0.027878*ws) + (0.447840*sr) + (0.128187*dg))) - ln((-0.167744*gn) + (-0.801673*mn) + (0.080196*lv) + (0.333559*ws) + (-0.235951*sr) + (-0.554490*dg))))
    h2 = (-0.300506*((0.175481*gn) + (0.014998*mn) + (0.101129*lv) + (-0.175984*ws) + (0.373183*sr) + (0.006823*dg) + (-0.595734*h1))) + (0.579206*(exp(clamp((0.175481*gn) + (0.014998*mn) + (0.101129*lv) + (-0.175984*ws) + (0.373183*sr) + (0.006823*dg) + (-0.595734*h1))) - ln((0.203413*gn) + (-0.292612*mn) + (0.030940*lv) + (-0.109206*ws) + (-0.424965*sr) + (-0.042272*dg) + (0.271790*h1))))
    h3 = (-0.400698*((-0.194373*gn) + (0.120059*mn) + (0.184991*lv) + (-0.036116*ws) + (-0.035269*sr) + (0.294774*dg) + (-0.099381*h1) + (0.130899*h2))) + (0.149856*(exp(clamp((-0.194373*gn) + (0.120059*mn) + (0.184991*lv) + (-0.036116*ws) + (-0.035269*sr) + (0.294774*dg) + (-0.099381*h1) + (0.130899*h2))) - ln((0.083350*gn) + (-1.189922*mn) + (-0.518277*lv) + (-0.394501*ws) + (0.037102*sr) + (0.157617*dg) + (0.435979*h1) + (-0.108964*h2))))
    h4 = (0.145793*((-0.070677*gn) + (0.054053*mn) + (0.032187*lv) + (-0.118977*ws) + (-0.531387*sr) + (0.105140*dg) + (-0.154515*h1) + (-0.009049*h2) + (0.213377*h3))) + (0.039548*(exp(clamp((-0.070677*gn) + (0.054053*mn) + (0.032187*lv) + (-0.118977*ws) + (-0.531387*sr) + (0.105140*dg) + (-0.154515*h1) + (-0.009049*h2) + (0.213377*h3))) - ln((0.076588*gn) + (0.079540*mn) + (-0.209503*lv) + (-0.047369*ws) + (0.010977*sr) + (0.519068*dg) + (-0.154961*h1) + (0.441311*h2) + (0.055281*h3))))
    h5 = (-0.594586*((-0.067614*gn) + (0.015879*mn) + (0.119583*lv) + (0.411063*ws) + (-0.168859*sr) + (-0.175732*dg) + (-0.036869*h1) + (0.304058*h2) + (-0.021353*h3) + (0.152075*h4))) + (0.748769*(exp(clamp((-0.067614*gn) + (0.015879*mn) + (0.119583*lv) + (0.411063*ws) + (-0.168859*sr) + (-0.175732*dg) + (-0.036869*h1) + (0.304058*h2) + (-0.021353*h3) + (0.152075*h4))) - ln((0.078501*gn) + (-0.072800*mn) + (0.420199*lv) + (-0.100246*ws) + (-0.107114*sr) + (0.156440*dg) + (0.244622*h1) + (-0.045126*h2) + (0.380606*h3) + (-0.624639*h4))))
    h6 = (-0.179390*((0.928254*gn) + (0.650038*mn) + (-0.047455*lv) + (-0.226875*ws) + (-0.614850*sr) + (-0.433946*dg) + (-0.540925*h1) + (0.146426*h2) + (-0.029819*h3) + (0.149721*h4) + (0.208455*h5))) + (-0.493183*(exp(clamp((0.928254*gn) + (0.650038*mn) + (-0.047455*lv) + (-0.226875*ws) + (-0.614850*sr) + (-0.433946*dg) + (-0.540925*h1) + (0.146426*h2) + (-0.029819*h3) + (0.149721*h4) + (0.208455*h5))) - ln((-0.073715*gn) + (-0.092505*mn) + (0.824915*lv) + (-0.008520*ws) + (0.124421*sr) + (-0.077104*dg) + (0.190160*h1) + (-0.302955*h2) + (-0.169784*h3) + (0.209646*h4) + (0.078545*h5))))
    h7 = (0.370934*((0.069138*gn) + (-0.380477*mn) + (0.412478*lv) + (0.192075*ws) + (0.181433*sr) + (-0.396660*dg) + (-0.050689*h1) + (-0.262612*h2) + (-0.216929*h3) + (-0.005291*h4) + (-0.439095*h5) + (0.608022*h6))) + (0.145111*(exp(clamp((0.069138*gn) + (-0.380477*mn) + (0.412478*lv) + (0.192075*ws) + (0.181433*sr) + (-0.396660*dg) + (-0.050689*h1) + (-0.262612*h2) + (-0.216929*h3) + (-0.005291*h4) + (-0.439095*h5) + (0.608022*h6))) - ln((0.346874*gn) + (-0.025481*mn) + (0.472702*lv) + (0.261349*ws) + (-0.459945*sr) + (-0.597384*dg) + (-0.071627*h1) + (-0.037873*h2) + (0.749209*h3) + (0.190380*h4) + (-0.040818*h5) + (-0.019762*h6))))
    h8 = (-0.175042*((0.018960*gn) + (-0.230718*mn) + (0.351541*lv) + (0.336566*ws) + (0.147692*sr) + (-0.165015*dg) + (0.225870*h1) + (-0.344514*h2) + (0.218345*h3) + (-0.056044*h4) + (0.369750*h5) + (-0.117038*h6) + (0.438704*h7))) + (0.111757*(exp(clamp((0.018960*gn) + (-0.230718*mn) + (0.351541*lv) + (0.336566*ws) + (0.147692*sr) + (-0.165015*dg) + (0.225870*h1) + (-0.344514*h2) + (0.218345*h3) + (-0.056044*h4) + (0.369750*h5) + (-0.117038*h6) + (0.438704*h7))) - ln((-0.104335*gn) + (0.335018*mn) + (-0.385627*lv) + (0.072936*ws) + (-1.268249*sr) + (0.118827*dg) + (0.097950*h1) + (-0.286033*h2) + (0.579968*h3) + (-0.448825*h4) + (-0.102034*h5) + (-0.106708*h6) + (-0.310861*h7))))
    
    Delta_W = (-1.262488*((0.045908*gn) + (-1.763110*mn) + (-0.029681*lv) + (0.179422*ws) + (0.090199*sr) + (0.006834*dg) + (0.358762*h1) + (0.083950*h2) + (-0.189795*h3) + (0.401436*h4) + (0.060015*h5) + (0.147876*h6) + (0.151525*h7) + (0.084969*h8))) + (-0.115983*(exp(clamp((0.045908*gn) + (-1.763110*mn) + (-0.029681*lv) + (0.179422*ws) + (0.090199*sr) + (0.006834*dg) + (0.358762*h1) + (0.083950*h2) + (-0.189795*h3) + (0.401436*h4) + (0.060015*h5) + (0.147876*h6) + (0.151525*h7) + (0.084969*h8))) - ln((0.412255*gn) + (-0.114782*mn) + (-0.403829*lv) + (0.142990*ws) + (-0.021270*sr) + (0.504233*dg) + (0.343557*h1) + (-0.125353*h2) + (0.009652*h3) + (0.771302*h4) + (0.808182*h5) + (0.005922*h6) + (-0.352103*h7) + (-0.263151*h8))))
    return Delta_W * lr_base

# --- 3. Logging Utility ---
def log_to_file(msg, filename="tinystories_bench.log"):
    with open(filename, "a") as f:
        f.write(msg + "\n")
    print(msg)

# --- 4. Data Loading (TinyStories Char-Level) ---
def get_dataloader(batch_size=32, seq_len=128):
    ds = load_dataset("karpathy/tinystories-gpt4-clean", split="train", streaming=True)
    def gen():
        buffer = []
        for example in ds:
            text = example['text']
            encoded = [ord(c) if ord(c) < 128 else 32 for c in text]
            buffer.extend(encoded)
            while len(buffer) >= seq_len + 1:
                chunk = buffer[:seq_len + 1]
                buffer = buffer[seq_len:]
                yield torch.tensor(chunk[:-1]), torch.tensor(chunk[1:])
    
    loader = torch.utils.data.DataLoader(
        torch.utils.data.IterableDataset.from_generator(gen),
        batch_size=batch_size
    )
    return loader

# --- 5. Training Routine ---
def train(name, steps=5000, lr=1e-3):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = TinyTransformer().to(device)
    loader = get_dataloader(batch_size=32, seq_len=128)
    criterion = nn.CrossEntropyLoss()
    
    log_to_file(f"\n--- Training {name} on TinyStories ---")
    log_to_file(f"Device: {device} | Params: 9536 | Steps: {steps} | LR: {lr}")

    if name == "AdamW":
        opt = torch.optim.AdamW(model.parameters(), lr=lr)
    elif name == "Muon":
        # Divide parameters as per standard practice
        muon_params = [p for n, p in model.named_parameters() if not any(x in n for x in ["embedding", "head", "pos_encoding"]) and len(p.shape) == 2]
        adam_params = [p for n, p in model.named_parameters() if p not in muon_params]
        
        # Using built-in Muon if available (PyTorch 2.9+), otherwise fallback
        try:
            from torch.optim import Muon
            log_to_file("Using built-in torch.optim.Muon")
            opt_muon = Muon(muon_params, lr=lr*2) # Muon usually likes higher LR
        except ImportError:
            log_to_file("torch.optim.Muon NOT found. Falling back to custom implementation.")
            # Simple custom fallback for one-file portability
            class CustomMuon(torch.optim.Optimizer):
                def __init__(self, params, lr=0.02): super().__init__(params, dict(lr=lr))
                @torch.no_grad()
                def step(self):
                    for group in self.param_groups:
                        for p in group['params']:
                            if p.grad is None: continue
                            # Simplified orthogonalization for fallback
                            p.data.add_(p.grad, alpha=-group['lr']) # Placeholder for actual NS in fallback
            opt_muon = CustomMuon(muon_params, lr=lr*2)
            
        opt_adam = torch.optim.AdamW(adam_params, lr=lr)
        
    elif name == "EML-V4":
        params = {k: v for k, v in model.named_parameters()}
        m_s = {k: torch.zeros_like(v) for k, v in params.items()}
        v_s = {k: torch.zeros_like(v) for k, v in params.items()}
        g_m = {k: torch.zeros_like(v) for k, v in params.items()}

    it = iter(loader)
    for i in range(steps):
        try:
            x, y = next(it)
        except StopIteration:
            break
        x, y = x.to(device), y.to(device)
        
        if name in ["AdamW", "Muon"]:
            if name == "Muon": 
                opt_muon.zero_grad(); opt_adam.zero_grad()
            else: 
                opt.zero_grad()
            
            out = model(x)
            loss = criterion(out.reshape(-1, out.size(-1)), y.reshape(-1))
            loss.backward()
            
            if name == "Muon": 
                opt_muon.step(); opt_adam.step()
            else: 
                opt.step()
            
        elif name == "EML-V4":
            out = model(x)
            loss = criterion(out.reshape(-1, out.size(-1)), y.reshape(-1))
            model.zero_grad()
            loss.backward()
            sr = min(i / steps, 1.0)
            with torch.no_grad():
                for n, p in model.named_parameters():
                    g = p.grad
                    m_s[n] = 0.9 * m_s[n] + 0.1 * g
                    v_s[n] = 0.999 * v_s[n] + 0.001 * (g**2)
                    dw = eml_v4_step(p, g, m_s[n], v_s[n], g_m[n], sr, lr_base=lr)
                    g_m[n] = g.detach()
                    p.sub_(dw)
        
        if (i+1) % 100 == 0:
            log_to_file(f"Step {i+1:<5} | Loss: {loss.item():.4f}")

if __name__ == "__main__":
    if os.path.exists("tinystories_bench.log"):
        os.remove("tinystories_bench.log")
        
    for opt in ["AdamW", "Muon", "EML-V4"]:
        train(opt, steps=10000)