package org.apache.hadoop.yarn.server.nodemanager.containermanager.resourceplugin.gpu;

import java.io.IOException;
import java.io.Serializable;
import java.io.StringWriter;
import java.net.URL;
import java.net.URLConnection;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Map;
import java.util.Set;
import org.apache.hadoop.classification.VisibleForTesting;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.shaded.org.apache.commons.io.IOUtils;
import org.apache.hadoop.yarn.server.nodemanager.containermanager.container.Container;
import org.apache.hadoop.yarn.server.nodemanager.containermanager.container.ResourceMappings;
import org.apache.hadoop.yarn.server.nodemanager.containermanager.linux.resources.gpu.GpuResourceAllocator;
import org.apache.hadoop.yarn.server.nodemanager.containermanager.linux.runtime.docker.DockerRunCommand;
import org.apache.hadoop.yarn.server.nodemanager.containermanager.linux.runtime.docker.DockerVolumeCommand;
import org.apache.hadoop.yarn.server.nodemanager.containermanager.resourceplugin.DockerCommandPlugin;
import org.apache.hadoop.yarn.server.nodemanager.containermanager.runtime.ContainerExecutionException;
import org.apache.hadoop.yarn.server.nodemanager.webapp.ContainerLogsPage;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/gpu/NvidiaDockerV1CommandPlugin.class */
public class NvidiaDockerV1CommandPlugin implements DockerCommandPlugin {
    static final Logger LOG = LoggerFactory.getLogger(NvidiaDockerV1CommandPlugin.class);
    private Configuration conf;
    private Map<String, Set<String>> additionalCommands = null;
    private String volumeDriver = ContainerLogsPage.LOG_AGGREGATION_LOCAL_TYPE;
    private String DEVICE_OPTION = "--device";
    private String VOLUME_DRIVER_OPTION = "--volume-driver";
    private String MOUNT_RO_OPTION = "--volume";

    public NvidiaDockerV1CommandPlugin(Configuration configuration) {
        this.conf = configuration;
    }

    private String getValue(String str) throws IllegalArgumentException {
        int indexOf = str.indexOf(61);
        if (indexOf < 0) {
            throw new IllegalArgumentException("Failed to locate '=' from input=" + str);
        }
        return str.substring(indexOf + 1);
    }

    private void addToCommand(String str, String str2) {
        if (this.additionalCommands == null) {
            this.additionalCommands = new HashMap();
        }
        if (!this.additionalCommands.containsKey(str)) {
            this.additionalCommands.put(str, new HashSet());
        }
        this.additionalCommands.get(str).add(str2);
    }

    private void init() throws ContainerExecutionException {
        String str = this.conf.get("yarn.nodemanager.resource-plugins.gpu.docker-plugin.nvidia-docker-v1.endpoint", "http://localhost:3476/v1.0/docker/cli");
        if (null == str || str.isEmpty()) {
            LOG.info("yarn.nodemanager.resource-plugins.gpu.docker-plugin.nvidia-docker-v1.endpoint set to empty, skip init ..");
            return;
        }
        try {
            URLConnection openConnection = new URL(str).openConnection();
            openConnection.setRequestProperty("X-Requested-With", "Curl");
            StringWriter stringWriter = new StringWriter();
            IOUtils.copy(openConnection.getInputStream(), stringWriter, "utf-8");
            String stringWriter2 = stringWriter.toString();
            LOG.info("Additional docker CLI options from plugin to run GPU containers:" + stringWriter2);
            for (String str2 : stringWriter2.split(" ")) {
                String trim = str2.trim();
                if (trim.startsWith(this.DEVICE_OPTION)) {
                    addToCommand(this.DEVICE_OPTION, getValue(trim));
                } else if (trim.startsWith(this.VOLUME_DRIVER_OPTION)) {
                    this.volumeDriver = getValue(trim);
                    LOG.debug("Found volume-driver:{}", this.volumeDriver);
                } else {
                    if (!trim.startsWith(this.MOUNT_RO_OPTION)) {
                        throw new IllegalArgumentException("Unsupported option:" + trim);
                    }
                    String value = getValue(trim);
                    if (!value.endsWith(":ro")) {
                        throw new IllegalArgumentException("Should not have mount other than ro, command=" + trim);
                    }
                    addToCommand(this.MOUNT_RO_OPTION, value.substring(0, value.lastIndexOf(58)));
                }
            }
        } catch (IOException e) {
            LOG.warn("IOException of " + getClass().getSimpleName() + " init:", e);
            throw new ContainerExecutionException(e);
        } catch (RuntimeException e2) {
            LOG.warn("RuntimeException of " + getClass().getSimpleName() + " init:", e2);
            throw new ContainerExecutionException(e2);
        }
    }

    private int getGpuIndexFromDeviceName(String str) {
        int lastIndexOf = str.lastIndexOf("nvidia");
        if (lastIndexOf < 0) {
            return -1;
        }
        String substring = str.substring(lastIndexOf + "nvidia".length());
        for (int i = 0; i < substring.length(); i++) {
            if (!Character.isDigit(substring.charAt(i))) {
                return -1;
            }
        }
        return Integer.parseInt(substring);
    }

    private Set<GpuDevice> getAssignedGpus(Container container) {
        ResourceMappings resourceMappings = container.getResourceMappings();
        HashSet hashSet = null;
        if (resourceMappings != null) {
            hashSet = new HashSet();
            Iterator<Serializable> it = resourceMappings.getAssignedResources("yarn.io/gpu").iterator();
            while (it.hasNext()) {
                hashSet.add((GpuDevice) it.next());
            }
        }
        return (hashSet == null || hashSet.isEmpty()) ? Collections.emptySet() : hashSet;
    }

    @VisibleForTesting
    protected boolean requestsGpu(Container container) {
        return GpuResourceAllocator.getRequestedGpus(container.getResource()) > 0;
    }

    private boolean initializeWhenGpuRequested(Container container) throws ContainerExecutionException {
        if (!requestsGpu(container)) {
            return false;
        }
        if (this.additionalCommands != null) {
            return true;
        }
        init();
        return true;
    }

    @Override // org.apache.hadoop.yarn.server.nodemanager.containermanager.resourceplugin.DockerCommandPlugin
    public synchronized void updateDockerRunCommand(DockerRunCommand dockerRunCommand, Container container) throws ContainerExecutionException {
        Set<GpuDevice> assignedGpus;
        if (!initializeWhenGpuRequested(container) || (assignedGpus = getAssignedGpus(container)) == null || assignedGpus.isEmpty()) {
            return;
        }
        for (Map.Entry<String, Set<String>> entry : this.additionalCommands.entrySet()) {
            String key = entry.getKey();
            Set<String> value = entry.getValue();
            if (key.equals(this.DEVICE_OPTION)) {
                int i = 0;
                for (String str : value) {
                    Integer valueOf = Integer.valueOf(getGpuIndexFromDeviceName(str));
                    if (valueOf.intValue() >= 0) {
                        Iterator<GpuDevice> it = assignedGpus.iterator();
                        while (it.hasNext()) {
                            if (it.next().getIndex() == valueOf.intValue()) {
                                i++;
                                dockerRunCommand.addDevice(str, str);
                            }
                        }
                    } else {
                        dockerRunCommand.addDevice(str, str);
                    }
                }
                if (i < assignedGpus.size()) {
                    throw new ContainerExecutionException("Cannot get all assigned Gpu devices from docker plugin output");
                }
            } else {
                if (!key.equals(this.MOUNT_RO_OPTION)) {
                    throw new ContainerExecutionException("Unsupported option:" + key);
                }
                for (String str2 : value) {
                    int indexOf = str2.indexOf(58);
                    dockerRunCommand.addReadOnlyMountLocation(str2.substring(0, indexOf), str2.substring(indexOf + 1), true);
                }
            }
        }
    }

    @Override // org.apache.hadoop.yarn.server.nodemanager.containermanager.resourceplugin.DockerCommandPlugin
    public DockerVolumeCommand getCreateDockerVolumeCommand(Container container) throws ContainerExecutionException {
        if (!initializeWhenGpuRequested(container)) {
            return null;
        }
        String str = null;
        Iterator<String> it = this.additionalCommands.get(this.MOUNT_RO_OPTION).iterator();
        while (true) {
            if (!it.hasNext()) {
                break;
            }
            String next = it.next();
            int indexOf = next.indexOf(58);
            if (indexOf >= 0) {
                String substring = next.substring(0, indexOf);
                if (DockerVolumeCommand.VOLUME_NAME_PATTERN.matcher(substring).matches()) {
                    str = substring;
                    LOG.debug("Found volume name for GPU:{}", str);
                    break;
                }
                LOG.debug("Failed to match {} to named-volume regex pattern", substring);
            }
        }
        if (str == null) {
            return null;
        }
        DockerVolumeCommand dockerVolumeCommand = new DockerVolumeCommand("create");
        dockerVolumeCommand.setDriverName(this.volumeDriver);
        dockerVolumeCommand.setVolumeName(str);
        return dockerVolumeCommand;
    }

    @Override // org.apache.hadoop.yarn.server.nodemanager.containermanager.resourceplugin.DockerCommandPlugin
    public DockerVolumeCommand getCleanupDockerVolumesCommand(Container container) throws ContainerExecutionException {
        return null;
    }
}
