/*
 * Decompiled with CFR 0.152.
 */
package org.apache.kafka.common.security.authenticator;

import java.io.IOException;
import java.net.InetSocketAddress;
import java.nio.ByteBuffer;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Random;
import org.apache.kafka.clients.NetworkClient;
import org.apache.kafka.common.KafkaException;
import org.apache.kafka.common.network.CertStores;
import org.apache.kafka.common.network.ChannelBuilder;
import org.apache.kafka.common.network.ChannelBuilders;
import org.apache.kafka.common.network.LoginType;
import org.apache.kafka.common.network.Mode;
import org.apache.kafka.common.network.NetworkReceive;
import org.apache.kafka.common.network.NetworkSend;
import org.apache.kafka.common.network.NetworkTestUtils;
import org.apache.kafka.common.network.NioEchoServer;
import org.apache.kafka.common.network.Selector;
import org.apache.kafka.common.network.Send;
import org.apache.kafka.common.protocol.ApiKeys;
import org.apache.kafka.common.protocol.Errors;
import org.apache.kafka.common.protocol.Protocol;
import org.apache.kafka.common.protocol.SecurityProtocol;
import org.apache.kafka.common.protocol.types.Struct;
import org.apache.kafka.common.requests.AbstractRequestResponse;
import org.apache.kafka.common.requests.ApiVersionsRequest;
import org.apache.kafka.common.requests.ApiVersionsResponse;
import org.apache.kafka.common.requests.MetadataRequest;
import org.apache.kafka.common.requests.RequestHeader;
import org.apache.kafka.common.requests.RequestSend;
import org.apache.kafka.common.requests.ResponseHeader;
import org.apache.kafka.common.requests.SaslHandshakeRequest;
import org.apache.kafka.common.requests.SaslHandshakeResponse;
import org.apache.kafka.common.security.authenticator.TestJaasConfig;
import org.junit.After;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;

public class SaslAuthenticatorTest {
    private static final int BUFFER_SIZE = 4096;
    private NioEchoServer server;
    private Selector selector;
    private ChannelBuilder channelBuilder;
    private CertStores serverCertStores;
    private CertStores clientCertStores;
    private Map<String, Object> saslClientConfigs;
    private Map<String, Object> saslServerConfigs;

    @Before
    public void setup() throws Exception {
        this.serverCertStores = new CertStores(true, "localhost");
        this.clientCertStores = new CertStores(false, "localhost");
        this.saslServerConfigs = this.serverCertStores.getTrustingConfig(this.clientCertStores);
        this.saslClientConfigs = this.clientCertStores.getTrustingConfig(this.serverCertStores);
    }

    @After
    public void teardown() throws Exception {
        if (this.server != null) {
            this.server.close();
        }
        if (this.selector != null) {
            this.selector.close();
        }
    }

    @Test
    public void testValidSaslPlainOverSsl() throws Exception {
        String node = "0";
        SecurityProtocol securityProtocol = SecurityProtocol.SASL_SSL;
        this.configureMechanisms("PLAIN", Arrays.asList("PLAIN"));
        this.server = NetworkTestUtils.createEchoServer(securityProtocol, this.saslServerConfigs);
        this.createAndCheckClientConnection(securityProtocol, node);
    }

    @Test
    public void testValidSaslPlainOverPlaintext() throws Exception {
        String node = "0";
        SecurityProtocol securityProtocol = SecurityProtocol.SASL_PLAINTEXT;
        this.configureMechanisms("PLAIN", Arrays.asList("PLAIN"));
        this.server = NetworkTestUtils.createEchoServer(securityProtocol, this.saslServerConfigs);
        this.createAndCheckClientConnection(securityProtocol, node);
    }

    @Test
    public void testInvalidPasswordSaslPlain() throws Exception {
        String node = "0";
        SecurityProtocol securityProtocol = SecurityProtocol.SASL_SSL;
        TestJaasConfig jaasConfig = this.configureMechanisms("PLAIN", Arrays.asList("PLAIN"));
        jaasConfig.setPlainClientOptions("myuser", "invalidpassword");
        this.server = NetworkTestUtils.createEchoServer(securityProtocol, this.saslServerConfigs);
        this.createClientConnection(securityProtocol, node);
        NetworkTestUtils.waitForChannelClose(this.selector, node);
    }

    @Test
    public void testInvalidUsernameSaslPlain() throws Exception {
        String node = "0";
        SecurityProtocol securityProtocol = SecurityProtocol.SASL_SSL;
        TestJaasConfig jaasConfig = this.configureMechanisms("PLAIN", Arrays.asList("PLAIN"));
        jaasConfig.setPlainClientOptions("invaliduser", "mypassword");
        this.server = NetworkTestUtils.createEchoServer(securityProtocol, this.saslServerConfigs);
        this.createClientConnection(securityProtocol, node);
        NetworkTestUtils.waitForChannelClose(this.selector, node);
    }

    @Test
    public void testMissingUsernameSaslPlain() throws Exception {
        String node = "0";
        TestJaasConfig jaasConfig = this.configureMechanisms("PLAIN", Arrays.asList("PLAIN"));
        jaasConfig.setPlainClientOptions(null, "mypassword");
        SecurityProtocol securityProtocol = SecurityProtocol.SASL_SSL;
        this.server = NetworkTestUtils.createEchoServer(securityProtocol, this.saslServerConfigs);
        this.createSelector(securityProtocol, this.saslClientConfigs);
        InetSocketAddress addr = new InetSocketAddress("127.0.0.1", this.server.port());
        try {
            this.selector.connect(node, addr, 4096, 4096);
            Assert.fail((String)"SASL/PLAIN channel created without username");
        }
        catch (KafkaException e) {
            // empty catch block
        }
    }

    @Test
    public void testMissingPasswordSaslPlain() throws Exception {
        String node = "0";
        TestJaasConfig jaasConfig = this.configureMechanisms("PLAIN", Arrays.asList("PLAIN"));
        jaasConfig.setPlainClientOptions("myuser", null);
        SecurityProtocol securityProtocol = SecurityProtocol.SASL_SSL;
        this.server = NetworkTestUtils.createEchoServer(securityProtocol, this.saslServerConfigs);
        this.createSelector(securityProtocol, this.saslClientConfigs);
        InetSocketAddress addr = new InetSocketAddress("127.0.0.1", this.server.port());
        try {
            this.selector.connect(node, addr, 4096, 4096);
            Assert.fail((String)"SASL/PLAIN channel created without password");
        }
        catch (KafkaException e) {
            // empty catch block
        }
    }

    @Test
    public void testMechanismPluggability() throws Exception {
        String node = "0";
        SecurityProtocol securityProtocol = SecurityProtocol.SASL_SSL;
        this.configureMechanisms("DIGEST-MD5", Arrays.asList("DIGEST-MD5"));
        this.server = NetworkTestUtils.createEchoServer(securityProtocol, this.saslServerConfigs);
        this.createAndCheckClientConnection(securityProtocol, node);
    }

    @Test
    public void testMultipleServerMechanisms() throws Exception {
        SecurityProtocol securityProtocol = SecurityProtocol.SASL_SSL;
        this.configureMechanisms("DIGEST-MD5", Arrays.asList("DIGEST-MD5", "PLAIN"));
        this.server = NetworkTestUtils.createEchoServer(securityProtocol, this.saslServerConfigs);
        String node1 = "1";
        this.saslClientConfigs.put("sasl.mechanism", "PLAIN");
        this.createAndCheckClientConnection(securityProtocol, node1);
        String node2 = "2";
        this.saslClientConfigs.put("sasl.mechanism", "DIGEST-MD5");
        this.createSelector(securityProtocol, this.saslClientConfigs);
        InetSocketAddress addr = new InetSocketAddress("127.0.0.1", this.server.port());
        this.selector.connect(node2, addr, 4096, 4096);
        NetworkTestUtils.checkClientConnection(this.selector, node2, 100, 10);
    }

    @Test
    public void testUnauthenticatedApiVersionsRequestOverPlaintext() throws Exception {
        this.testUnauthenticatedApiVersionsRequest(SecurityProtocol.SASL_PLAINTEXT);
    }

    @Test
    public void testUnauthenticatedApiVersionsRequestOverSsl() throws Exception {
        this.testUnauthenticatedApiVersionsRequest(SecurityProtocol.SASL_SSL);
    }

    @Test
    public void testApiVersionsRequestWithUnsupportedVersion() throws Exception {
        SecurityProtocol securityProtocol = SecurityProtocol.SASL_PLAINTEXT;
        this.configureMechanisms("PLAIN", Arrays.asList("PLAIN"));
        this.server = NetworkTestUtils.createEchoServer(securityProtocol, this.saslServerConfigs);
        String node = "1";
        this.createClientConnection(SecurityProtocol.PLAINTEXT, node);
        RequestHeader header = new RequestHeader(ApiKeys.API_VERSIONS.id, Short.MAX_VALUE, "someclient", 1);
        this.selector.send((Send)new NetworkSend(node, new ByteBuffer[]{RequestSend.serialize((RequestHeader)header, (Struct)new ApiVersionsRequest().toStruct())}));
        ByteBuffer responseBuffer = this.waitForResponse();
        ResponseHeader.parse((ByteBuffer)responseBuffer);
        ApiVersionsResponse response = ApiVersionsResponse.parse((ByteBuffer)responseBuffer);
        Assert.assertEquals((long)Errors.UNSUPPORTED_VERSION.code(), (long)response.errorCode());
        this.sendVersionRequestReceiveResponse(node);
        this.sendHandshakeRequestReceiveResponse(node);
        this.authenticateUsingSaslPlainAndCheckConnection(node);
    }

    @Test
    public void testSaslHandshakeRequestWithUnsupportedVersion() throws Exception {
        SecurityProtocol securityProtocol = SecurityProtocol.SASL_PLAINTEXT;
        this.configureMechanisms("PLAIN", Arrays.asList("PLAIN"));
        this.server = NetworkTestUtils.createEchoServer(securityProtocol, this.saslServerConfigs);
        String node1 = "invalid1";
        this.createClientConnection(SecurityProtocol.PLAINTEXT, node1);
        RequestHeader header = new RequestHeader(ApiKeys.SASL_HANDSHAKE.id, Short.MAX_VALUE, "someclient", 2);
        this.selector.send((Send)new NetworkSend(node1, new ByteBuffer[]{RequestSend.serialize((RequestHeader)header, (Struct)new SaslHandshakeRequest("PLAIN").toStruct())}));
        NetworkTestUtils.waitForChannelClose(this.selector, node1);
        this.selector.close();
        this.createAndCheckClientConnection(securityProtocol, "good1");
    }

    @Test
    public void testInvalidSaslPacket() throws Exception {
        SecurityProtocol securityProtocol = SecurityProtocol.SASL_PLAINTEXT;
        this.configureMechanisms("PLAIN", Arrays.asList("PLAIN"));
        this.server = NetworkTestUtils.createEchoServer(securityProtocol, this.saslServerConfigs);
        String node1 = "invalid1";
        this.createClientConnection(SecurityProtocol.PLAINTEXT, node1);
        this.sendHandshakeRequestReceiveResponse(node1);
        Random random = new Random();
        byte[] bytes = new byte[1024];
        random.nextBytes(bytes);
        this.selector.send((Send)new NetworkSend(node1, new ByteBuffer[]{ByteBuffer.wrap(bytes)}));
        NetworkTestUtils.waitForChannelClose(this.selector, node1);
        this.selector.close();
        this.createAndCheckClientConnection(securityProtocol, "good1");
        String node2 = "invalid2";
        this.createClientConnection(SecurityProtocol.PLAINTEXT, node2);
        random.nextBytes(bytes);
        this.selector.send((Send)new NetworkSend(node2, new ByteBuffer[]{ByteBuffer.wrap(bytes)}));
        NetworkTestUtils.waitForChannelClose(this.selector, node2);
        this.selector.close();
        this.createAndCheckClientConnection(securityProtocol, "good2");
    }

    @Test
    public void testInvalidApiVersionsRequestSequence() throws Exception {
        SecurityProtocol securityProtocol = SecurityProtocol.SASL_PLAINTEXT;
        this.configureMechanisms("PLAIN", Arrays.asList("PLAIN"));
        this.server = NetworkTestUtils.createEchoServer(securityProtocol, this.saslServerConfigs);
        String node1 = "invalid1";
        this.createClientConnection(SecurityProtocol.PLAINTEXT, node1);
        this.sendHandshakeRequestReceiveResponse(node1);
        RequestHeader versionsHeader = new RequestHeader(ApiKeys.API_VERSIONS.id, "someclient", 2);
        this.selector.send((Send)new NetworkSend(node1, new ByteBuffer[]{RequestSend.serialize((RequestHeader)versionsHeader, (Struct)new ApiVersionsRequest().toStruct())}));
        NetworkTestUtils.waitForChannelClose(this.selector, node1);
        this.selector.close();
        this.createAndCheckClientConnection(securityProtocol, "good1");
    }

    @Test
    public void testPacketSizeTooBig() throws Exception {
        SecurityProtocol securityProtocol = SecurityProtocol.SASL_PLAINTEXT;
        this.configureMechanisms("PLAIN", Arrays.asList("PLAIN"));
        this.server = NetworkTestUtils.createEchoServer(securityProtocol, this.saslServerConfigs);
        String node1 = "invalid1";
        this.createClientConnection(SecurityProtocol.PLAINTEXT, node1);
        this.sendHandshakeRequestReceiveResponse(node1);
        ByteBuffer buffer = ByteBuffer.allocate(1024);
        buffer.putInt(Integer.MAX_VALUE);
        buffer.put(new byte[buffer.capacity() - 4]);
        buffer.rewind();
        this.selector.send((Send)new NetworkSend(node1, new ByteBuffer[]{buffer}));
        NetworkTestUtils.waitForChannelClose(this.selector, node1);
        this.selector.close();
        this.createAndCheckClientConnection(securityProtocol, "good1");
        String node2 = "invalid2";
        this.createClientConnection(SecurityProtocol.PLAINTEXT, node2);
        buffer.clear();
        buffer.putInt(Integer.MAX_VALUE);
        buffer.put(new byte[buffer.capacity() - 4]);
        buffer.rewind();
        this.selector.send((Send)new NetworkSend(node2, new ByteBuffer[]{buffer}));
        NetworkTestUtils.waitForChannelClose(this.selector, node2);
        this.selector.close();
        this.createAndCheckClientConnection(securityProtocol, "good2");
    }

    @Test
    public void testDisallowedKafkaRequestsBeforeAuthentication() throws Exception {
        SecurityProtocol securityProtocol = SecurityProtocol.SASL_PLAINTEXT;
        this.configureMechanisms("PLAIN", Arrays.asList("PLAIN"));
        this.server = NetworkTestUtils.createEchoServer(securityProtocol, this.saslServerConfigs);
        String node1 = "invalid1";
        this.createClientConnection(SecurityProtocol.PLAINTEXT, node1);
        RequestHeader metadataRequestHeader1 = new RequestHeader(ApiKeys.METADATA.id, "someclient", 1);
        MetadataRequest metadataRequest1 = new MetadataRequest(Collections.singletonList("sometopic"));
        this.selector.send((Send)new NetworkSend(node1, new ByteBuffer[]{RequestSend.serialize((RequestHeader)metadataRequestHeader1, (Struct)metadataRequest1.toStruct())}));
        NetworkTestUtils.waitForChannelClose(this.selector, node1);
        this.selector.close();
        this.createAndCheckClientConnection(securityProtocol, "good1");
        String node2 = "invalid2";
        this.createClientConnection(SecurityProtocol.PLAINTEXT, node2);
        this.sendHandshakeRequestReceiveResponse(node2);
        RequestHeader metadataRequestHeader2 = new RequestHeader(ApiKeys.METADATA.id, "someclient", 2);
        MetadataRequest metadataRequest2 = new MetadataRequest(Collections.singletonList("sometopic"));
        this.selector.send((Send)new NetworkSend(node2, new ByteBuffer[]{RequestSend.serialize((RequestHeader)metadataRequestHeader2, (Struct)metadataRequest2.toStruct())}));
        NetworkTestUtils.waitForChannelClose(this.selector, node2);
        this.selector.close();
        this.createAndCheckClientConnection(securityProtocol, "good2");
    }

    @Test
    public void testInvalidLoginModule() throws Exception {
        TestJaasConfig jaasConfig = this.configureMechanisms("PLAIN", Arrays.asList("PLAIN"));
        jaasConfig.createOrUpdateEntry("KafkaClient", "InvalidLoginModule", TestJaasConfig.defaultClientOptions());
        SecurityProtocol securityProtocol = SecurityProtocol.SASL_SSL;
        this.server = NetworkTestUtils.createEchoServer(securityProtocol, this.saslServerConfigs);
        try {
            this.createSelector(securityProtocol, this.saslClientConfigs);
            Assert.fail((String)"SASL/PLAIN channel created without valid login module");
        }
        catch (KafkaException kafkaException) {
            // empty catch block
        }
    }

    @Test
    public void testDisabledMechanism() throws Exception {
        String node = "0";
        SecurityProtocol securityProtocol = SecurityProtocol.SASL_SSL;
        this.configureMechanisms("PLAIN", Arrays.asList("DIGEST-MD5"));
        this.server = NetworkTestUtils.createEchoServer(securityProtocol, this.saslServerConfigs);
        this.createClientConnection(securityProtocol, node);
        NetworkTestUtils.waitForChannelClose(this.selector, node);
    }

    @Test
    public void testInvalidMechanism() throws Exception {
        String node = "0";
        SecurityProtocol securityProtocol = SecurityProtocol.SASL_SSL;
        this.configureMechanisms("PLAIN", Arrays.asList("PLAIN"));
        this.saslClientConfigs.put("sasl.mechanism", "INVALID");
        this.server = NetworkTestUtils.createEchoServer(securityProtocol, this.saslServerConfigs);
        this.createClientConnection(securityProtocol, node);
        NetworkTestUtils.waitForChannelClose(this.selector, node);
    }

    private void testUnauthenticatedApiVersionsRequest(SecurityProtocol securityProtocol) throws Exception {
        SecurityProtocol clientProtocol;
        this.configureMechanisms("PLAIN", Arrays.asList("PLAIN"));
        this.server = NetworkTestUtils.createEchoServer(securityProtocol, this.saslServerConfigs);
        String node = "1";
        switch (securityProtocol) {
            case SASL_PLAINTEXT: {
                clientProtocol = SecurityProtocol.PLAINTEXT;
                break;
            }
            case SASL_SSL: {
                clientProtocol = SecurityProtocol.SSL;
                break;
            }
            default: {
                throw new IllegalArgumentException("Server protocol " + securityProtocol + " is not SASL");
            }
        }
        this.createClientConnection(clientProtocol, node);
        NetworkTestUtils.waitForChannelReady(this.selector, node);
        ApiVersionsResponse versionsResponse = this.sendVersionRequestReceiveResponse(node);
        Assert.assertEquals((long)Protocol.MIN_VERSIONS[ApiKeys.SASL_HANDSHAKE.id], (long)versionsResponse.apiVersion((short)ApiKeys.SASL_HANDSHAKE.id).minVersion);
        Assert.assertEquals((long)Protocol.CURR_VERSION[ApiKeys.SASL_HANDSHAKE.id], (long)versionsResponse.apiVersion((short)ApiKeys.SASL_HANDSHAKE.id).maxVersion);
        SaslHandshakeResponse handshakeResponse = this.sendHandshakeRequestReceiveResponse(node);
        Assert.assertEquals(Collections.singletonList("PLAIN"), (Object)handshakeResponse.enabledMechanisms());
        this.authenticateUsingSaslPlainAndCheckConnection(node);
    }

    private void authenticateUsingSaslPlainAndCheckConnection(String node) throws Exception {
        String authString = "\u0000myuser\u0000mypassword";
        this.selector.send((Send)new NetworkSend(node, new ByteBuffer[]{ByteBuffer.wrap(authString.getBytes("UTF-8"))}));
        this.waitForResponse();
        NetworkTestUtils.checkClientConnection(this.selector, node, 100, 10);
    }

    private TestJaasConfig configureMechanisms(String clientMechanism, List<String> serverMechanisms) {
        this.saslClientConfigs.put("sasl.mechanism", clientMechanism);
        this.saslServerConfigs.put("sasl.enabled.mechanisms", serverMechanisms);
        return TestJaasConfig.createConfiguration(clientMechanism, serverMechanisms);
    }

    private void createSelector(SecurityProtocol securityProtocol, Map<String, Object> clientConfigs) {
        String saslMechanism = (String)this.saslClientConfigs.get("sasl.mechanism");
        this.channelBuilder = ChannelBuilders.create((SecurityProtocol)securityProtocol, (Mode)Mode.CLIENT, (LoginType)LoginType.CLIENT, clientConfigs, (String)saslMechanism, (boolean)true);
        this.selector = NetworkTestUtils.createSelector(this.channelBuilder);
    }

    private void createClientConnection(SecurityProtocol securityProtocol, String node) throws Exception {
        this.createSelector(securityProtocol, this.saslClientConfigs);
        InetSocketAddress addr = new InetSocketAddress("127.0.0.1", this.server.port());
        this.selector.connect(node, addr, 4096, 4096);
    }

    private void createAndCheckClientConnection(SecurityProtocol securityProtocol, String node) throws Exception {
        this.createClientConnection(securityProtocol, node);
        NetworkTestUtils.checkClientConnection(this.selector, node, 100, 10);
        this.selector.close();
        this.selector = null;
    }

    private Struct sendKafkaRequestReceiveResponse(String node, ApiKeys apiKey, AbstractRequestResponse request) throws IOException {
        RequestHeader header = new RequestHeader(apiKey.id, "someclient", 1);
        this.selector.send((Send)new NetworkSend(node, new ByteBuffer[]{RequestSend.serialize((RequestHeader)header, (Struct)request.toStruct())}));
        ByteBuffer responseBuffer = this.waitForResponse();
        return NetworkClient.parseResponse((ByteBuffer)responseBuffer, (RequestHeader)header);
    }

    private SaslHandshakeResponse sendHandshakeRequestReceiveResponse(String node) throws Exception {
        SaslHandshakeRequest handshakeRequest = new SaslHandshakeRequest("PLAIN");
        Struct responseStruct = this.sendKafkaRequestReceiveResponse(node, ApiKeys.SASL_HANDSHAKE, (AbstractRequestResponse)handshakeRequest);
        SaslHandshakeResponse response = new SaslHandshakeResponse(responseStruct);
        Assert.assertEquals((long)Errors.NONE.code(), (long)response.errorCode());
        return response;
    }

    private ApiVersionsResponse sendVersionRequestReceiveResponse(String node) throws Exception {
        ApiVersionsRequest handshakeRequest = new ApiVersionsRequest();
        Struct responseStruct = this.sendKafkaRequestReceiveResponse(node, ApiKeys.API_VERSIONS, (AbstractRequestResponse)handshakeRequest);
        ApiVersionsResponse response = new ApiVersionsResponse(responseStruct);
        Assert.assertEquals((long)Errors.NONE.code(), (long)response.errorCode());
        return response;
    }

    private ByteBuffer waitForResponse() throws IOException {
        int waitSeconds = 10;
        do {
            this.selector.poll(1000L);
        } while (this.selector.completedReceives().isEmpty() && waitSeconds-- > 0);
        Assert.assertEquals((long)1L, (long)this.selector.completedReceives().size());
        return ((NetworkReceive)this.selector.completedReceives().get(0)).payload();
    }
}

