import json
from commons.aws_helper import AWSHelper
from alb.alb_helper import TargetGroupHelper
from ec2.ec2_helper import EC2Helper


class ECSHelper(AWSHelper):
    def __init__(self, deployed):
        super(ECSHelper, self).__init__(deployed)
        self.ecs_client = self.get_aws_client(resource_name='ecs', region=deployed.region)


class TaskServiceHelper(ECSHelper):

    def __init__(self, deployed):
        super(ECSHelper, self).__init__(deployed.container)
        self.ecs_client = self.get_aws_client(resource_name='ecs', region=deployed.container.region)
        self.deployed = deployed
        self.taskPlacementTemplates = self.initialize_task_placement_templates()
        self.tg_helper = TargetGroupHelper(deployed.container)
        self.ec2_helper = EC2Helper(deployed.container)

    def create_task_definition(self):
        container_definitions = self.read_container_definitions()
        params = {'family': self.deployed.name,
                  'taskRoleArn': self.deployed.taskRole,
                  'executionRoleArn': self.deployed.executionRole if self.deployed.executionRole else self.deployed.taskRole,
                  'networkMode': self.deployed.networkMode,
                  'containerDefinitions': container_definitions, 'volumes' : self.get_volumes(),
                  'requiresCompatibilities': self.get_requires_compatibilities(), 'cpu': self.deployed.cpu,
                  'memory': self.deployed.memory}
        params = self.remove_none_keys(params)
        print "Creating task definition {0} {1}".format(self.deployed.name, params)
        task_definition_response = self.ecs_client.register_task_definition(**params)
        if self.is_success(task_definition_response):
            task_def_arn = task_definition_response['taskDefinition']['taskDefinitionArn']
            print "Created task definition {0} with arn {1}." \
                .format(task_definition_response['taskDefinition']['family'], task_def_arn)
            return task_def_arn

    def get_requires_compatibilities(self):
        requires_compatibilities_list = []
        requires_compatibilities_list.append(self.deployed.launchType)
        return requires_compatibilities_list

    def get_network_configurations(self):
        aws_vpc_configurations = {'subnets': self.get_subnet_ids(self.deployed.subnets),
                                  'securityGroups': self.get_security_group_ids(self.deployed.securityGroups),
                                  'assignPublicIp': self.get_assign_public_ip()}
        aws_vpc_configurations = self.remove_none_keys(aws_vpc_configurations)
        network_configurations = {'awsvpcConfiguration': aws_vpc_configurations}
        return network_configurations

    def get_volumes(self):
        return map(lambda key: {'name': key, 'host': {'sourcePath' : self.deployed.volumes[key]}}, self.deployed.volumes.keys()) if bool(self.deployed.volumes) else []

    def read_container_definitions(self):
        container_definitions = []
        for definition in self.deployed.containerDefinitions:
            if not definition.hardMemoryLimit and not definition.softMemoryLimit:
                raise RuntimeError("One of hard memory limit or soft memory limit is required.")
            port_mappings = self.read_port_mappings(definition)
            log_configuration = self.read_log_configuration(definition)
            repository_credentials = self.read_repository_credentials(definition)
            mount_points = self.get_mount_points(definition)
            environment_variables = self.get_environment_variables(definition)
            container_definition = {'name': definition.containerName,
                                    'image': definition.image,
                                    'memory': definition.hardMemoryLimit,
                                    'memoryReservation': definition.softMemoryLimit,
                                    'portMappings': port_mappings,
                                    'logConfiguration': log_configuration,
                                    'mountPoints': mount_points,
                                    'links': list(definition.links),
                                    'cpu': definition.cpu,
                                    'environment': environment_variables,
                                    'command': definition.command,
                                    'repositoryCredentials': repository_credentials
                                    }
            container_definition = self.remove_none_keys(container_definition)
            container_definitions.append(container_definition)
        return container_definitions

    def __convert_mapping_to_params(self, mapping):
        port_mapping = {'containerPort': mapping.containerPort, 'hostPort': mapping.hostPort, 'protocol': mapping.protocol}
        port_mapping = self.remove_none_keys(port_mapping)
        return port_mapping

    @staticmethod
    def __get_mount_point_params(mount_point):
        return {'sourceVolume': mount_point.sourceVolume, 'containerPath': mount_point.containerPath, 'readOnly': mount_point.readOnly}

    def read_port_mappings(self, container_definition):
        return map(self.__convert_mapping_to_params, container_definition.portMappings)

    def read_log_configuration(self, container_definition):
        if container_definition.logConfiguration and container_definition.logConfiguration[0]:
            mapping = container_definition.logConfiguration[0]
            return {'logDriver': mapping.logDriver, 'options': mapping.options}
        else:
            return None

    def read_repository_credentials(self, container_definition):
        if container_definition.repositoryCredentials :
            return {'credentialsParameter': container_definition.repositoryCredentials}
        else:
            return None

    def get_mount_points(self, container_definition):
        return map(self.__get_mount_point_params, container_definition.mountPoints)

    def get_environment_variables(self, container_definition):
        return [{'name':key,'value':value} for key, value in container_definition.environment.iteritems()]

    def remove_task_definition(self):
        if self.deployed.taskDefinitionArn:
            print "Deleting task definition {0}".format(self.deployed.taskDefinitionArn)
            self.ecs_client.deregister_task_definition(taskDefinition=self.deployed.taskDefinitionArn)

    def initialize_task_placement_templates(self):
        spread_zone = {'type': 'spread', 'field': 'attribute:ecs.availability-zone'}
        spread_instance_id = {'type': 'spread', 'field': 'instanceId'}
        binpack_memory = {'type': 'binpack', 'field': 'memory'}
        distinct_instance_constraint = {'type': 'distinctInstance'}
        task_placement_templates = {'AZ Balanced Spread': {'placementStrategy': [spread_zone, spread_instance_id]},
                                    'AZ Balanced BinPack': {'placementStrategy': [spread_zone, binpack_memory]},
                                    'BinPack': {'placementStrategy': [binpack_memory]},
                                    'One Task Per Host': {'placementConstraints': [distinct_instance_constraint]}}
        return task_placement_templates

    def run_task(self):
        print "Running {0} tasks of type {1} on cluster {2}"\
            .format(self.deployed.nrOfTasks, self.deployed.name, self.deployed.container.clusterName if self.deployed.container.clusterName else self.deployed.container.name)
        response = self.ecs_client.run_task(**self.get_run_task_params())
        if self.is_success(response):
            task_arns = [task['taskArn'] for task in response['tasks']]
            if not response['failures']:
                return task_arns
            else:
                print "{0} out of {1} tasks failed to start".format(self.deployed.nrOfTasks - len(task_arns), self.deployed.nrOfTasks)
            if task_arns:
                print "Cleaning {0} tasks {1}.".format(len(task_arns), task_arns)
                self.stop_tasks([task['taskArn'] for task in response['tasks']])
            raise RuntimeError("Could not start tasks of type {0} because of failures: {1}"
                               .format(self.deployed.taskDefinitionArn, response['failures']))

    def stop_tasks(self, tasks=None):
        tasks = self.deployed.taskArns if not tasks else tasks
        for taskArn in tasks:
            print "Stopping task {0} on cluster {1}.".format(taskArn, self.deployed.container.clusterName if self.deployed.container.clusterName else self.deployed.container.name)
            self.ecs_client.stop_task(cluster=self.deployed.container.clusterArn,task=taskArn)

    def get_run_task_params(self):
        params_dict = {'taskDefinition': self.deployed.taskDefinitionArn,
                       'cluster': self.deployed.container.clusterArn,
                       'count': self.deployed.nrOfTasks,
                       'group': self.deployed.taskGroupName if self.deployed.taskGroupName else self.get_deployable_name(),
                       'launchType': self.deployed.launchType
                       }
        if self.deployed.launchType=='FARGATE':
            params_dict['networkConfiguration']=self.get_network_configurations()
        if self.deployed.taskPlacementTemplateName:
            params_dict.update(self.taskPlacementTemplates[self.deployed.taskPlacementTemplateName])
        not_null_params_dict = self.remove_none_keys(params_dict)
        return not_null_params_dict

    def get_task_statuses(self, tasks=None):
        tasks = self.deployed.taskArns if not tasks else tasks
        descriptions = self.get_task_descriptions(tasks)
        count=0
        for task in descriptions['tasks']:
            task_statuses = [{'taskArn': task['taskArn'], 'lastStatus': task['lastStatus'], 'desiredStatus': task['desiredStatus']}]
            for container in task['containers']:
                if 'reason' in container:
                    task_statuses[count]['reason'] = container['reason']
            count + 1
        return task_statuses

    def get_task_descriptions(self, tasks):
        descriptions = self.ecs_client.describe_tasks(cluster=self.deployed.container.clusterArn, tasks=tasks)
        if not self.is_success(descriptions) or descriptions['failures']:
            reason = descriptions['failures'] if descriptions['failures'] else descriptions['ResponseMetadata'][
                'HTTPStatusCode']
            raise RuntimeError("Could not describe tasks {0}".format(reason))
        return descriptions

    def is_status_not_success(self, last, desired):
        if(last, desired) == ('STOPPED', 'STOPPED'):
            return True;
        else:
            return (last, desired) != ('PENDING', 'RUNNING') and (last, desired) != ('RUNNING', 'RUNNING') \
               and (last, desired) != ('PROVISIONING', 'RUNNING')

    def filter_tasks_failed_to_start(self, tasks_statuses):
        return [task['taskArn'] for task in tasks_statuses
               if self.is_status_not_success(task['lastStatus'], task['desiredStatus'])]

    def filter_tasks_reason_failed_to_start(self, tasks_statuses):
        for reason in tasks_statuses:
            if 'reason' in reason:
                return reason['reason']

    def filter_tasks_pending(self, tasks_statuses):
        return [task['taskArn'] for task in tasks_statuses
                if (task['lastStatus'], task['desiredStatus']) == ('PENDING', 'RUNNING')
                    or (task['lastStatus'], task['desiredStatus']) == ('PROVISIONING', 'RUNNING')]

    def filter_tasks_not_stopped(self, tasks_statuses):
        return [task['taskArn'] for task in tasks_statuses if task['lastStatus'] != 'STOPPED']

    def get_service_params_for_update(self):
        params = {'service': self.deployed.serviceName,
                  'taskDefinition': self.deployed.name,
                  'desiredCount': self.deployed.desiredCount,
                  'cluster': self.deployed.container.clusterArn,
                  'deploymentConfiguration': self.get_deployment_configuration_params()
                  }
        return params

    def get_service_params(self):
        params = {'serviceName': self.deployed.serviceName,
                  'taskDefinition': self.deployed.name,
                  'desiredCount': self.deployed.desiredCount,
                  'cluster': self.deployed.container.clusterArn,
                  'loadBalancers': self.get_load_balancer_params(),
                  'role': self.deployed.role,
                  'deploymentConfiguration': self.get_deployment_configuration_params(),
                  'launchType': self.deployed.launchType
                  }
        if self.deployed.launchType=='FARGATE':
            params['networkConfiguration']=self.get_network_configurations()
        if self.deployed.taskPlacementTemplateName:
            params.update(self.taskPlacementTemplates[self.deployed.taskPlacementTemplateName])
        params = self.remove_none_keys(params)
        return params

    def create_service(self):
        params = self.get_service_params()
        print "Creating service {0} with task definition {1} with desired count {2} on cluster {3}" \
            .format(params['serviceName'], self.deployed.name, self.deployed.desiredCount, self.deployed.container.clusterName)

        response = self.ecs_client.create_service(**params)
        print "Service {0} created: {1}".format(params['serviceName'], response['service']['serviceArn'])
        return response['service']

    def get_deployment_configuration_params(self):
        if not self.deployed.maximumPercent and not self.deployed.minimumHealthyPercent:
            return None

        deployment_configuration = {'maximumPercent': self.deployed.maximumPercent,
                                    'minimumHealthyPercent': self.deployed.minimumHealthyPercent}
        return self.remove_none_keys(deployment_configuration)

    def get_load_balancer_params(self):
        if self.deployed.loadBalancers:
            return [self.__convert_to_lb_params(lb) for lb in self.deployed.loadBalancers]
        else:
            return None

    def __convert_to_lb_params(self, lb):
        target_group_arn = lb.targetGroupArn
        if self.is_starts_with_name(target_group_arn):
            target_group_arn = self.tg_helper.find_target_group(self.get_property_name(target_group_arn))['TargetGroupArn']
        return self.remove_none_keys({'targetGroupArn': target_group_arn,
                                      'loadBalancerName': lb.loadBalancerName,
                                      'containerName': lb.containerName,
                                      'containerPort': lb.containerPort})

    def update_service(self, changed_params):
        params = {'service': self.deployed.serviceName, 'cluster': self.deployed.container.clusterName}
        params.update(changed_params)
        response = self.ecs_client.update_service(**params)
        print "Service {0} updated with arn: {1}".format(params['service'], response['service']['serviceArn'])
        return response['service']

    def delete_service(self):
        params = {'service': self.deployed.serviceName, 'cluster': self.deployed.container.clusterArn}
        print "Deleting service {0} on cluster {1}".format(params['service'], self.deployed.container.clusterName)
        response = self.ecs_client.delete_service(**params)
        print "Service {0} deleted: {1}".format(params['service'], response['service']['serviceArn'])
        return response['service']

    def get_service_status(self, service_name=None):
        service_name = self.deployed.serviceName if not service_name else service_name
        return self.get_services_descriptions([service_name])[0]['status']

    def get_services_descriptions(self, services):
        response = self.ecs_client.describe_services(cluster=self.deployed.container.clusterArn, services=services)
        if self.is_success(response) and not response['failures']:
            return response['services']
        else:
            raise RuntimeError("Could not fetch description of ECS service(s) {0} because of {1}"
                               .format(services, response['failures'] if self.is_success(response) else response['ResponseMetadata']))

    def get_service_running_task_count(self):
        return self.get_services_descriptions(services=[self.get_service_name()])[0]['runningCount']

    def get_service_name(self):
        return self.deployed.serviceName if self.deployed.serviceName else self.deployed.name

    def get_subnet_ids(self, subnets):
        return self.ec2_helper.get_subnet_id_list(subnets)

    def get_security_group_ids(self, security_groups):
        return self.ec2_helper.get_security_group_id_list(security_groups)

    def get_assign_public_ip(self):
        if self.deployed.assignPublicIp:
            return 'ENABLED'
        else:
            return 'DISABLED'
