import torch import torch.nn as nn from torch.utils.data import Dataset, DataLoader import torch.nn.functional as F import os import datetime # Hyperparameters VOCAB_SIZE = 256 # Byte values 0-255 EMBEDDING_DIM = 512 HIDDEN_DIM = 2048 BATCH_SIZE = 128 NUM_EPOCHS = 50 LEARNING_RATE = 1e-3 DROPOUT = 0 NUM_LAYERS = 6 BIDIRECTIONAL = True RL = True # Define hyperparameters for Reinforced learning (RL) component: lambda_RL_start = 0.5 # initial value lambda_RL_end = 0.7 # final value print("Start time:", datetime.datetime.now()) print("Embedding dim:", EMBEDDING_DIM) print("Bidirectional:", BIDIRECTIONAL) print("Hidden dim:", HIDDEN_DIM) print("Batch size:", BATCH_SIZE) print("Learning rate:", LEARNING_RATE) print("Number of Layers:", NUM_LAYERS) # A simple dataset: each sample is (xor_window, original_window) class FirmwareXORDataset(Dataset): def __init__(self, data_folder): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") #device = torch.device("mps" if torch.backends.mps.is_available() else "cpu") print("[+] Metal settings", device) self.samples = [] # Each sample is a tuple: (xor_bytes, orig_bytes) for fname in os.listdir(data_folder): if fname.endswith('.bin'): # Assume each file is a synthetic XOR sample with a corresponding target file. # Here, we expect the file to have been preprocessed so that we can load both XOR and original. # For example, you might store two files per sample or encode both in one file. # This is a simplified example. path = os.path.join(data_folder, fname) with open(path, 'rb') as f: xor_data = f.read() # For illustration, assume we have a corresponding target file with suffix '.target' target_path1 = path + '.orig1' if os.path.exists(target_path1): with open(target_path1, 'rb') as f1: orig_data1 = f1.read() WINDOW_SIZE = len(xor_data) self.samples.append(( torch.tensor(list(xor_data[:WINDOW_SIZE]), dtype=torch.long, device=device), torch.tensor(list(orig_data1[:WINDOW_SIZE]), dtype=torch.long, device=device) )) print(f"[*] Found {len(self.samples)} samples in {data_folder}.") def __len__(self): return len(self.samples) def __getitem__(self, idx): return self.samples[idx] # A simple LSTM-based model for sequence prediction class XorInversionModel(nn.Module): def __init__(self, vocab_size, embedding_dim, hidden_dim): super(XorInversionModel, self).__init__() self.embedding = nn.Embedding(vocab_size, embedding_dim) self.lstm = nn.LSTM(embedding_dim, hidden_dim, batch_first=True, num_layers=NUM_LAYERS, dropout=DROPOUT, bidirectional=BIDIRECTIONAL) if BIDIRECTIONAL: self.fc = nn.Linear(hidden_dim*2, vocab_size) else: self.fc = nn.Linear(hidden_dim, vocab_size) def forward(self, x): # x: [batch_size, seq_len] emb = self.embedding(x) # [batch_size, seq_len, embedding_dim] output, _ = self.lstm(emb) # [batch_size, seq_len, hidden_dim] logits = self.fc(output) # [batch_size, seq_len, vocab_size] return logits def get_lambda_rl(epoch): # Linear schedule: increases gradually from lambda_RL_start to lambda_RL_end return lambda_RL_start + (lambda_RL_end - lambda_RL_start) * (epoch / NUM_EPOCHS) def compute_reward(pred_seq, tgt_seq, alpha=1.1): """ Compute a reward for a predicted sequence that strongly rewards consecutive correct predictions. Here, we use an exponential function: for each run of consecutive correct tokens, reward += (alpha^(run_length) - alpha). This gives zero reward if no consecutive correct tokens occur, and higher reward for longer runs. """ reward = 0.0 run_length = 0 seq_len = len(pred_seq) for i in range(seq_len): if pred_seq[i] == tgt_seq[i]: run_length += 1 else: if run_length > 0: reward += (alpha ** run_length) - alpha run_length = 0 # For any run at the end if run_length > 0: reward += (alpha ** run_length) - alpha # Optionally, normalize by sequence length return reward / (seq_len) print(f"[*] {datetime.datetime.now()} Loading dataset") # Create dataset and data loader data_folder = "training-data" # Folder with your synthetic samples dataset = FirmwareXORDataset(data_folder) data_loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True) valid_data_folder = "valid-data" # Folder with your synthetic samples valid_dataset = FirmwareXORDataset(valid_data_folder) valid_data_loader = DataLoader(valid_dataset, batch_size=BATCH_SIZE, shuffle=True) print(f"[*] {datetime.datetime.now()} Building datamodel") # Initialize model, loss, optimizer model = XorInversionModel(VOCAB_SIZE, EMBEDDING_DIM, HIDDEN_DIM) criterion = nn.CrossEntropyLoss() optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=1e-4) print(f"[*] {datetime.datetime.now()} Starting training loop") # Training loop device = torch.device("cuda" if torch.cuda.is_available() else "cpu") #device = torch.device("mps" if torch.backends.mps.is_available() else "cpu") model.to(device) baseline = 0.0 # initialize baseline for self-critical RL baseline_decay = 0.85 # exponential moving average decay for baseline training_loss = 1.0 epoch = 0 accuracy = 0 while epoch < NUM_EPOCHS: running_loss = 0.0 epoch_start = datetime.datetime.now() epoch += 1 current_lambda_rl = get_lambda_rl(epoch) # dynamically compute the weight model.train() for xor_seq, target_seq in data_loader: optimizer.zero_grad() logits = model(xor_seq) # [batch_size, seq_len, vocab_size] # Standard CE loss: ce_loss = criterion(logits.view(-1, VOCAB_SIZE), target_seq.view(-1)) # RL Component: probs = F.softmax(logits, dim=-1) m = torch.distributions.Categorical(probs) sampled_seq = m.sample() # [batch_size, seq_len] log_probs = m.log_prob(sampled_seq) # [batch_size, seq_len] # Compute rewards for each sequence in the batch. batch_rewards = [] for i in range(sampled_seq.size(0)): pred_seq = sampled_seq[i].cpu().tolist() tgt_seq = target_seq[i].cpu().tolist() r = compute_reward(pred_seq, tgt_seq, alpha=1.1) batch_rewards.append(r) batch_rewards = torch.tensor(batch_rewards, dtype=torch.float, device=device) # Compute greedy baseline for self-critical training: greedy_seq = torch.argmax(logits, dim=-1) greedy_rewards = [] for i in range(greedy_seq.size(0)): g_seq = greedy_seq[i].cpu().tolist() tgt_seq = target_seq[i].cpu().tolist() r_g = compute_reward(g_seq, tgt_seq, alpha=1.1) greedy_rewards.append(r_g) greedy_rewards = torch.tensor(greedy_rewards, dtype=torch.float, device=device) # Use self-critical advantage: advantages = batch_rewards - greedy_rewards log_probs_sum = log_probs.sum(dim=1) # [batch_size] rl_loss = - (log_probs_sum * advantages).mean() # Optionally update baseline using an exponential moving average: baseline = baseline_decay * baseline + (1 - baseline_decay) * batch_rewards.mean().item() # Combine losses: loss = ce_loss + current_lambda_rl * rl_loss loss.backward() optimizer.step() running_loss += loss.item() training_loss = running_loss / len(data_loader) # Validation phase (using greedy decoding): model.eval() val_loss = 0.0 correct = 0 total = 0 with torch.no_grad(): for xor_seq, target_seq in valid_data_loader: v_logits = model(xor_seq) v_loss = criterion(v_logits.view(-1, VOCAB_SIZE), target_seq.view(-1)) val_loss += v_loss.item() preds = torch.argmax(v_logits, dim=-1) correct += (preds == target_seq).sum().item() total += target_seq.numel() validation_loss = val_loss / len(valid_data_loader) accuracy_new = correct / total * 100 if BIDIRECTIONAL: HD = HIDDEN_DIM*2 else: HD = HIDDEN_DIM if accuracy_new > accuracy: MODEL_SAVE_PATH = "xor_inversion_model_ED"+str(EMBEDDING_DIM)+"_HD"+str(HD)+"_BS"+str(BATCH_SIZE)+"_NL"+str(NUM_LAYERS)+".pth" torch.save(model.state_dict(), MODEL_SAVE_PATH) accuracy = accuracy_new print("Saved model, current best accuracy:", accuracy) print("Epoch duration (min):", ((datetime.datetime.now()-epoch_start).total_seconds())/60) print(f"[*] {datetime.datetime.now()} Epoch {epoch} | Train Loss: {training_loss:.4f} | Validation Loss: {validation_loss:.4f} | Accuracy: {accuracy_new:.2f}%") print("[*] Training complete.")