#
# Copyright (c) 2018. All rights reserved.
#
# This software and all trademarks, trade names, and logos included herein are the property of XebiaLabs, Inc. and its affiliates, subsidiaries, and licensors.
#

from kubernetes import client
from kubernetes.client.api_client import ApiClient
from xld.kubernetes.commons.common_utils import CommonUtils
import cPickle as pk


class PodHelper(object):
    def read_pod(self, deployed_pod, pod=client.V1Pod()):
        self.__apply_properties(pod, deployed_pod)
        pod_dict = ApiClient().sanitize_for_serialization(pod)
        volumes = PodHelper.__get_volume_dict(deployed_pod)
        pod_dict['spec']['volumes'] = volumes
        return pod_dict

    @staticmethod
    def enrich_app_selectors(deployed_pod):
        if not deployed_pod.labels or 'app' not in deployed_pod.labels:
            deployed_pod.labels=dict(deployed_pod.labels, app = PodHelper.get_pod_name(deployed_pod))

    @staticmethod
    def get_pod_name(deployed):
        return deployed.podName if deployed.podName else deployed.name

    @staticmethod
    def get_container_name(deployed):
        return deployed.containerName if deployed.containerName else deployed.name

    @staticmethod
    def __read_container(deployed_container):
        container = client.V1Container(name=PodHelper.get_container_name(deployed_container),
                                       image=deployed_container.image,
                                       command=deployed_container.command,
                                       args=deployed_container.args)
        if deployed_container.probes:
            for container_probe in deployed_container.probes:
                if container_probe.probeType == 'Readiness':
                    container.readiness_probe = PodHelper.__read_container_probe(container_probe)
                else:
                    container.liveness_probe = PodHelper.__read_container_probe(container_probe)
        return container

    @staticmethod
    def __read_container_probe(deployed_probe):
        probe = client.V1Probe()
        PodHelper.__apply_common_probe_properties(deployed_probe, probe)
        if deployed_probe.probeActionType == 'Exec':
            probe._exec = client.V1ExecAction(command=deployed_probe.command)
        elif deployed_probe.probeActionType == 'HTTPGet':
            probe.http_get = PodHelper.__read_http_action_probe(deployed_probe)
        elif deployed_probe.probeActionType == 'TCPSocket':
            probe.tcp_socket = client.V1TCPSocketAction(
                port=CommonUtils.convert_to_int_if_possible(deployed_probe.tcpPort))
        else:
            raise Exception("Probe type {} not supported".format(deployed_probe.probeActionType))
        return probe

    @staticmethod
    def __read_http_action_probe(deployed_probe):
        http_get_action = client.V1HTTPGetAction(port=CommonUtils.convert_to_int_if_possible(deployed_probe.port),
                                                 host=deployed_probe.host,
                                                 path=deployed_probe.path,
                                                 scheme=deployed_probe.scheme)
        if deployed_probe.httpHeaders:
            http_get_action.http_headers = []
            for k, v in deployed_probe.httpHeaders.iteritems():
                http_get_action.http_headers.append(client.V1HTTPHeader(name=k, value=v))
        return http_get_action

    @staticmethod
    def __apply_common_probe_properties(deployed_probe, probe):
        probe.initial_delay_seconds = deployed_probe.initialDelaySeconds
        probe.timeout_seconds = deployed_probe.timeoutSeconds
        probe.failure_threshold = deployed_probe.failureThreshold
        probe.success_threshold = deployed_probe.successThreshold
        probe.period_seconds = deployed_probe.periodSeconds

    @staticmethod
    def __read_container_port(deployed_port):
        return client.V1ContainerPort(container_port= CommonUtils.convert_to_int_if_possible(deployed_port.containerPort),
                                      host_port=deployed_port.hostPort,
                                      protocol=deployed_port.protocol,
                                      host_ip=deployed_port.hostIP)

    @staticmethod
    def __read_volume_mount(deployed_volume_mount):
        name = deployed_volume_mount.volumeName if bool(
            deployed_volume_mount.volumeName) else deployed_volume_mount.name
        volume_mount = client.V1VolumeMount(mount_path=deployed_volume_mount.mountPath,
                                            name=name,
                                            read_only=deployed_volume_mount.readOnly,
                                            sub_path=deployed_volume_mount.subPath)
        return volume_mount

    def __apply_properties(self, pod, deployed_pod):
        containers = []
        for deployed_container in deployed_pod.containers:
            container = PodHelper.__read_container(deployed_container)
            if deployed_container.ports:
                container.ports = []
                for deployed_port in deployed_container.ports:
                    port = PodHelper.__read_container_port(deployed_port)
                    container.ports.append(port)

            if deployed_container.envVars:
                container.env = PodHelper.__read_env_var(deployed_container.envVars)

            if deployed_container.volumeBindings:
                container.volume_mounts = []
                for volume_binding in deployed_container.volumeBindings:
                    volume_mount = self.__read_volume_mount(volume_binding)
                    container.volume_mounts.append(volume_mount)
            containers.append(container)

        image_pull_secrets = []

        for secret_name in deployed_pod.imagePullSecrets:
            image_pull_secrets.append(client.V1LocalObjectReference(name=secret_name))

        spec = client.V1PodSpec(containers=containers,
                                host_network=deployed_pod.hostNetwork,
                                restart_policy=deployed_pod.restartPolicy,
                                image_pull_secrets=image_pull_secrets)
        pod.metadata = client.V1ObjectMeta(name=self.get_pod_name(deployed_pod))

        if deployed_pod.labels:
            pod.metadata.labels = deployed_pod.labels
        pod.spec = spec

    @staticmethod
    def __get_volume_dict(deployed_pod):
        volumes = []
        for deployed_volume in deployed_pod.volumes:
            volume = {'name': deployed_volume.volumeName if bool(deployed_volume.volumeName) else deployed_volume.name}
            vol_prop = {}
            for key, val in deployed_volume.properties.items():
                vol_prop[key] = True if val.lower() == 'true' else pk.loads(pk.dumps(val))

            volume[deployed_volume.volumeType] = vol_prop
            volumes.append(volume)
        return volumes

    @staticmethod
    def validate_pod(deployed_pod):
        if deployed_pod.containers:
            for deployed_container in deployed_pod.containers:
                PodHelper.__validate_container(deployed_container)

    @staticmethod
    def __validate_container(deployed_container):
        if deployed_container.probes:
            PodHelper.__validate_probes(deployed_container.probes, 'Readiness',
                                        PodHelper.get_container_name(deployed_container))
            PodHelper.__validate_probes(deployed_container.probes, 'Liveness',
                                        PodHelper.get_container_name(deployed_container))

    @staticmethod
    def __validate_probes(all_probes, probe_type, container_name):
        probes = [probe for probe in all_probes if probe.probeType == probe_type]
        if len(probes) > 1:
            raise Exception(
                "Maximum of 2 probes, one each for readiness and liveness, can be added for each container, but found {} of type {} for container {}"
                    .format(len(probes), probe_type, container_name))

    @staticmethod
    def __read_config_map_key_env_var(config_map_env_var, env):
        name = config_map_env_var.envVarName if config_map_env_var.envVarName else config_map_env_var.keyName
        env_var = client.V1EnvVar(name=name)
        env_var.value_from = client.V1EnvVarSource()
        env_var.value_from.config_map_key_ref = client.V1ConfigMapKeySelector(key=config_map_env_var.keyName,
                                                                              name=config_map_env_var.configMapName)
        env.append(env_var)

    @staticmethod
    def __read_secret_key_env_var(secret_env_var, env):
        name = secret_env_var.envVarName if secret_env_var.envVarName else secret_env_var.keyName
        env_var = client.V1EnvVar(name=name)
        env_var.value_from = client.V1EnvVarSource()
        env_var.value_from.secret_key_ref = client.V1SecretKeySelector()
        env_var.value_from.secret_key_ref.key = secret_env_var.keyName
        env_var.value_from.secret_key_ref.name = secret_env_var.secretName
        env.append(env_var)

    @staticmethod
    def __read_field_env_var(field_env_var, env):
        for key, value in field_env_var.properties.iteritems():
            env_var = client.V1EnvVar(name=key,
                                      value_from=client.V1EnvVarSource(
                                          field_ref=client.V1ObjectFieldSelector(field_path=value)))
            env.append(env_var)

    @staticmethod
    def __read_resource_field_env_var(resource_field_env_var, env):
        for key, value in resource_field_env_var.properties.iteritems():
            env_var = client.V1EnvVar(name=key,
                                      value_from=client.V1EnvVarSource(
                                          resource_field_ref=client.V1ResourceFieldSelector(resource=value)))
            env.append(env_var)

    @staticmethod
    def __read_key_value_env_var(key_val_env_var, env):
        for key, value in key_val_env_var.properties.iteritems():
            env_var = client.V1EnvVar(name=key, value=value)
            env.append(env_var)

    @staticmethod
    def __read_env_var(deployed_env_vars):
        env = []
        env_var_functions = {'k8s.envVar.KeyValue': PodHelper.__read_key_value_env_var,
                             'k8s.envVar.ResourceField': PodHelper.__read_resource_field_env_var,
                             'k8s.envVar.Field': PodHelper.__read_field_env_var,
                             'k8s.envVar.ConfigMap': PodHelper.__read_config_map_key_env_var,
                             'k8s.envVar.Secret': PodHelper.__read_secret_key_env_var}

        for deployed_env_var in deployed_env_vars:
            function_to_call = env_var_functions[deployed_env_var.type.toString()]
            function_to_call(deployed_env_var, env)
        return env
