import torch import torch.nn as nn import os import sys # ------------------------------- # Model hyperparameters (must match training) VOCAB_SIZE = 256 EMBEDDING_DIM = 512 # Update as per your training HIDDEN_DIM = 2048 # Update as per your training NUM_LAYERS = 6 DROPOUT = 0 SKIP_HEADER = 0 BIDIRECTIONAL = True # ------------------------------- # Define the LSTM-based model class XorInversionLSTM(nn.Module): def __init__(self, vocab_size, embedding_dim, hidden_dim, num_layers, dropout): super(XorInversionLSTM, self).__init__() self.embedding = nn.Embedding(vocab_size, embedding_dim) self.lstm = nn.LSTM(embedding_dim, hidden_dim, num_layers=num_layers, batch_first=True, dropout=dropout, bidirectional=BIDIRECTIONAL) if BIDIRECTIONAL: self.fc = nn.Linear(hidden_dim*2, vocab_size) else: self.fc = nn.Linear(hidden_dim, vocab_size) def forward(self, x): # x: [batch_size, seq_len] emb = self.embedding(x) # [batch_size, seq_len, embedding_dim] output, _ = self.lstm(emb) # [batch_size, seq_len, hidden_dim] logits = self.fc(output) # [batch_size, seq_len, vocab_size] return logits def load_model(model_path, device): model = XorInversionLSTM(VOCAB_SIZE, EMBEDDING_DIM, HIDDEN_DIM, NUM_LAYERS, DROPOUT) model.load_state_dict(torch.load(model_path, map_location=device)) model.to(device) model.eval() return model def beam_search_window(logits, beam_width=5, bonus=0.1): """ Perform beam search for a single window. Args: logits: Tensor of shape [seq_len, vocab_size] for one window. beam_width: Number of candidates to keep. bonus: Bonus (in log-probability) added if the current token is the same as the previous token. Returns: A list of integers (predicted token IDs) of length seq_len. """ seq_len = logits.shape[0] # Convert logits to log probabilities. probs = torch.softmax(logits, dim=-1) log_probs = torch.log(probs) # shape: [seq_len, vocab_size] # Beam: each element is (sequence, cumulative_log_prob) beam = [([], 0.0)] for t in range(seq_len): new_beam = [] for seq, score in beam: for token in range(VOCAB_SIZE): token_score = log_probs[t, token].item() # If token is same as last token, add bonus. if seq and token == seq[-1]: token_score += bonus new_seq = seq + [token] new_score = score + token_score new_beam.append((new_seq, new_score)) # Keep top beam_width sequences. new_beam.sort(key=lambda x: x[1], reverse=True) beam = new_beam[:beam_width] # Return the best sequence (highest score) best_seq, best_score = beam[0] return best_seq def predict_file_in_windows(file_path, model, device, window_size=16, stride=16, skip_header=64, beam_width=5, bonus=0.1): # Read the file as binary. with open(file_path, 'rb') as f: data = f.read() if len(data) == 0: print(f"Warning: {file_path} is empty!") return b"" # Check for a header: if first skip_header bytes are all zeros, skip them. header = data[:skip_header] if all(b == 0 for b in header): print(f"Detected header of {skip_header} zero bytes in {file_path}. Skipping header for inference.") data = data[skip_header:] else: print(f"No header to skip in {file_path}.") data_list = list(data) full_length = len(data_list) predicted_bytes = [] # Process the file in sliding windows. for start in range(0, full_length, stride): if start % 1024 == 0: print("Progress:", round((start/full_length)*100,2)) end = min(start + window_size, full_length) window = data_list[start:end] # Pad window if it's shorter than window_size. if len(window) < window_size: window = window + [0] * (window_size - len(window)) input_tensor = torch.tensor(window, dtype=torch.long, device=device).unsqueeze(0) with torch.no_grad(): logits = model(input_tensor) # shape: [1, window_size, vocab_size] # Remove batch dimension. window_logits = logits.squeeze(0) # shape: [window_size, vocab_size] # Use beam search to decode this window. best_seq = beam_search_window(window_logits, beam_width=beam_width, bonus=bonus) # Only add the unpadded portion. predicted_bytes.extend(best_seq[:end - start]) # Optionally, prepend the header. output_bytes = header + bytes(predicted_bytes) return output_bytes def main(model_path, input_path): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") #device = torch.device("mps" if torch.backends.mps.is_available() else "cpu") model = load_model(model_path, device) if os.path.isfile(input_path): print(f"Processing file: {input_path}") output_bytes = predict_file_in_windows(input_path, model, device, window_size=16, stride=16, skip_header=SKIP_HEADER, beam_width=5, bonus=0.1) output_file = input_path + ".predicted" with open(output_file, 'wb') as f: f.write(output_bytes) print(f"Predicted output saved to {output_file}") elif os.path.isdir(input_path): for fname in os.listdir(input_path): if fname.endswith(".bin"): file_path = os.path.join(input_path, fname) print(f"Processing file: {file_path}") output_bytes = predict_file_in_windows(file_path, model, device, window_size=16, stride=16, skip_header=SKIP_HEADER) output_file = file_path + ".predicted" with open(output_file, 'wb') as f: f.write(output_bytes) print(f"Predicted output saved to {output_file}") else: print("Error: Input path is neither a file nor a directory.") if __name__ == "__main__": if len(sys.argv) < 3: print("Usage: python xor-inverse-infer.py ") else: model_path = sys.argv[1] input_path = sys.argv[2] main(model_path, input_path)