"""Experiment 1b: Deeper training for 768-dim configs. Observation from exp01: 768-dim configs converge slower but MSE is actually lower. Let's train longer (200 epochs) to see if they surpass 256-dim configs in CosSim. Also test: does the encoder need a wider bottleneck (more neurons)? """ import sys import time import json from pathlib import Path import torch import torch.nn as nn import torch.optim as optim sys.path.insert(0, str(Path(__file__).parent.parent / "src")) from nuonuo.encoder import SpikeAutoencoder DEVICE = "cuda" RESULTS_DIR = Path(__file__).parent.parent / "doc" def cosine_sim(a, b): return nn.functional.cosine_similarity(a, b, dim=-1).mean().item() def run(embed_dim, num_neurons, num_steps, epochs=200, lr=3e-4): model = SpikeAutoencoder(embed_dim, num_neurons, num_steps).to(DEVICE) optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-4) scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, epochs) mse_fn = nn.MSELoss() batch_size = 64 num_batches = 30 target = torch.ones(batch_size, device=DEVICE) best_cos = 0 milestones = [] for epoch in range(epochs): model.train() epoch_mse = 0 epoch_cos = 0 for _ in range(num_batches): emb = torch.randn(batch_size, embed_dim, device=DEVICE) emb = nn.functional.normalize(emb, dim=-1) recon, spikes, _ = model(emb) loss_mse = mse_fn(recon, emb) loss_cos = nn.functional.cosine_embedding_loss( recon, emb, target) firing_rate = spikes.mean() loss_sparse = (firing_rate - 0.1).pow(2) loss = loss_mse + 0.5 * loss_cos + 0.1 * loss_sparse optimizer.zero_grad() loss.backward() nn.utils.clip_grad_norm_(model.parameters(), 1.0) optimizer.step() epoch_mse += loss_mse.item() epoch_cos += cosine_sim(recon.detach(), emb) scheduler.step() epoch_mse /= num_batches epoch_cos /= num_batches if epoch_cos > best_cos: best_cos = epoch_cos if (epoch + 1) % 20 == 0: print(f" Epoch {epoch+1:3d}: MSE={epoch_mse:.6f}, CosSim={epoch_cos:.4f}, " f"FR={spikes.mean().item():.3f}, LR={scheduler.get_last_lr()[0]:.6f}") milestones.append({"epoch": epoch+1, "mse": epoch_mse, "cos": epoch_cos}) # Final eval model.eval() with torch.no_grad(): test_emb = torch.randn(512, embed_dim, device=DEVICE) test_emb = nn.functional.normalize(test_emb, dim=-1) recon, spikes, _ = model(test_emb) final_mse = mse_fn(recon, test_emb).item() final_cos = cosine_sim(recon, test_emb) print(f" ** Final: MSE={final_mse:.6f}, CosSim={final_cos:.4f}") return {"dim": embed_dim, "neurons": num_neurons, "steps": num_steps, "final_mse": final_mse, "final_cos": final_cos, "milestones": milestones} def main(): print("Experiment 1b: Deeper training (200 epochs)") print("=" * 60) configs = [ (768, 2048, 64), (768, 4096, 64), (768, 4096, 128), (768, 8192, 64), # wider ] results = [] for dim, neurons, steps in configs: print(f"\n--- dim={dim}, neurons={neurons}, steps={steps} ---") r = run(dim, neurons, steps, epochs=200, lr=3e-4) results.append(r) torch.cuda.empty_cache() # Summary print("\n" + "=" * 60) print("SUMMARY (200 epochs)") print(f"{'Dim':>5} {'Neurons':>8} {'Steps':>6} {'MSE':>10} {'CosSim':>8}") print("-" * 40) for r in results: print(f"{r['dim']:>5} {r['neurons']:>8} {r['steps']:>6} " f"{r['final_mse']:>10.6f} {r['final_cos']:>8.4f}") with open(RESULTS_DIR / "exp01b_results.json", "w") as f: json.dump(results, f, indent=2, default=float) if __name__ == "__main__": main()