from alb.tg.tg_helper import TargetGroupHelper


class ALBHelper(TargetGroupHelper):
    def __init__(self, deployed, previous_deployed=None):
        super(ALBHelper, self).__init__(deployed, previous_deployed)

    def create_alb(self):
        load_balancer_name = self.deployed.loadBalancerName if self.deployed.loadBalancerName else self.deployed.name
        response = self.alb_client.create_load_balancer(
            Name=load_balancer_name,
            SecurityGroups=self.ec2_helper.get_security_group_id_list(self.deployed.securityGroups),
            Subnets=self.get_subnet_ids(self.deployed.subnets),
            Scheme=self.deployed.scheme,
            IpAddressType=self.deployed.ipAddressType
        )
        return response

    def modify_load_balancer_attributes(self):
        if self.deployed.enableAccessLog:
            attributes = [
                {'Key': 'access_logs.s3.enabled', 'Value': 'true'},
                {'Key': 'access_logs.s3.bucket', 'Value': self.deployed.accessLogS3BucketName},
                {'Key': 'access_logs.s3.prefix', 'Value': self.deployed.accessLogS3BucketPrefix}
            ]
        else:
            attributes = [
                {'Key': 'access_logs.s3.enabled', 'Value': 'false'}
            ]
        attributes.append(
            {'Key': 'deletion_protection.enabled', 'Value': str(self.deployed.deletionProtection).lower()})
        attributes.append({'Key': 'idle_timeout.timeout_seconds', 'Value': str(self.deployed.idleTimeout)})

        self.alb_client.modify_load_balancer_attributes(
            LoadBalancerArn=self.deployed.arn,
            Attributes=attributes
        )

    def set_ip_address_type(self):
        self.alb_client.set_ip_address_type(
            LoadBalancerArn=self.deployed.arn,
            IpAddressType=self.deployed.ipAddressType
        )

    def create_listener_rules(self):
        for listener_rule in self.deployed.listenerRules:
            ALBHelper.__print_listener_rule_log("Creating", listener_rule)
            response = self.alb_client.create_rule(**self.__get_listener_rule_params(listener_rule))
            listener_rule.arn = response['Rules'][0]['RuleArn']

    def create_listeners(self):
        for listener in self.deployed.listeners:
            ALBHelper.__print_listener_log("Creating", listener)
            response = self.alb_client.create_listener(**self.__get_listener_params(listener))
            listener.arn = response['Listeners'][0]['ListenerArn']

    def set_security_groups(self, associate):
        if associate:
            difference = self.deployed.securityGroups - self.previous_deployed.securityGroups
        else:
            difference = self.previous_deployed.securityGroups - self.deployed.securityGroups
        if difference:
            load_balancer_name = self.deployed.loadBalancerName if self.deployed.loadBalancerName else self.deployed.name
            ALBHelper.__print_association_logs(associate, difference, 'security groups', load_balancer_name,
                                               'application ELB')
            self.alb_client.set_security_groups(
                LoadBalancerArn=self.deployed.arn,
                SecurityGroups=self.ec2_helper.get_security_group_id_list(
                    self.previous_deployed.securityGroups | self.deployed.securityGroups if associate else self.previous_deployed.securityGroups & self.deployed.securityGroups)
            )

    def set_subnets(self, associate):
        if associate:
            difference = self.deployed.subnets - self.previous_deployed.subnets
        else:
            difference = self.previous_deployed.subnets - self.deployed.subnets
        if difference:
            load_balancer_name = self.deployed.loadBalancerName if self.deployed.loadBalancerName else self.deployed.name
            ALBHelper.__print_association_logs(associate, difference, 'subnets', load_balancer_name,
                                               'application ELB')
            self.alb_client.set_subnets(
                LoadBalancerArn=self.deployed.arn,
                Subnets=self.get_subnet_ids(
                    self.previous_deployed.subnets | self.deployed.subnets if associate else self.previous_deployed.subnets & self.deployed.subnets)
            )

            if not associate:
                associated_network_interfaces = self.fetch_network_interfaces_for_subnets(subnet_ids=difference,
                                                                                          load_balancer_name=self.__get_alb_ni_desription())
                if bool(associated_network_interfaces):
                    self.remove_associated_network_interfaces(network_interface_ids=associated_network_interfaces,
                                                              max_retries=self.deployed.maxDetachmentRetries)

    def destroy_listeners(self):
        for listener in self.get_created_arn_resources("listeners"):
            ALBHelper.__print_listener_log("Destroying", listener)
            self.alb_client.delete_listener(
                ListenerArn=listener.arn
            )

    def destroy_listener_rules(self):
        for listener_rule in self.get_created_arn_resources("listenerRules"):
            ALBHelper.__print_listener_rule_log("Destroying", listener_rule)
            self.alb_client.delete_rule(
                RuleArn=listener_rule.arn
            )

    def destroy_alb(self):
        return self.alb_client.delete_load_balancer(
            LoadBalancerArn=self.deployed.arn
        )

    def is_alb_destroyed(self):
        try:
            response = self.alb_client.describe_load_balancers(LoadBalancerArns=[self.deployed.arn])
            destroyed = not response['LoadBalancers']
        except:
            destroyed = True
        return destroyed and self.__is_listeners_destroyed() and self.__is_rules_destroyed() and self.__is_alb_ni_destroyed()

    def is_resources_dissociated(self):
        subnets_dissociated = self.previous_deployed.subnets - self.deployed.subnets
        if subnets_dissociated:
            nis = self.fetch_network_interfaces_for_subnets(subnet_ids=subnets_dissociated, load_balancer_name=self.__get_alb_ni_desription())
            if bool(nis):
                return False
        return True

    def set_resource_name(self, resource_id, name):
        print "Setting name for resource {0} to {1}".format(resource_id, name)
        self.create_resource_tag(resource_id, 'Name', name)

    def create_resource_tag(self, resource_id, key, value):
        self.alb_client.add_tags(ResourceArns=[resource_id],
                                 Tags=[{'Key': key, 'Value': value}])
        print "Added tag {}, {} to resource".format(key, value)

    def delete_resource_tag(self, resource_id, key, value):
        self.alb_client.remove_tags(ResourceArns=[resource_id],
                                    TagKeys=[key])
        print "Deleted tag {}, {} from resource".format(key, value)

    def should_modify_alb(self):
        return TargetGroupHelper.is_modified(self.deployed, self.previous_deployed,
                                             ['enableAccessLog', 'accessLogS3BucketName', 'accessLogS3BucketPrefix',
                                              'deletionProtection', 'idleTimeout', 'ipAddressType'])

    def get_created_arn_resources(self, property_name):
        return filter(lambda resource: resource.arn, self.previous_deployed[property_name])

    def should_associate_resources(self):
        return self.deployed.subnets - self.previous_deployed.subnets or self.deployed.securityGroups - self.previous_deployed.securityGroups

    def should_dissociate_resources(self):
        return self.previous_deployed.subnets - self.deployed.subnets or self.previous_deployed.securityGroups - self.deployed.securityGroups

    def validate_alb(self):
        if not self.deployed.ipAddressType == 'ipv4' and self.deployed.scheme == 'internal':
            raise RuntimeError("Internal application ELB must use ipv4 as IP address type.")
        if len(self.deployed.subnets) < 2:
            raise RuntimeError("application ELB must have subnets from at least two Availability Zones.")

    def validate_listeners(self):
        for listener in self.deployed.listeners:
            if listener.protocol.upper() == 'HTTPS' and not listener.sslCertificate:
                raise RuntimeError("Listener with HTTPS protocol must have SSL Certificate.")

    def validate_listener_rules(self):
        priorities = set()
        for listener_rule in self.deployed.listenerRules:
            self.__get_listener_arn(listener_rule)
            fully_qualified_priority = "{}:{}".format(listener_rule.listenerPort, listener_rule.priority)
            if fully_qualified_priority in priorities:
                raise RuntimeError("Multiple listener rules must not use same priority.")
            else:
                priorities.add(fully_qualified_priority)

            if (listener_rule.hostNamePattern and listener_rule.pathPattern) or (
                    not listener_rule.hostNamePattern and not listener_rule.pathPattern):
                raise RuntimeError(
                    "Listener rule must use either host name pattern or path pattern.")

    def validate_alb_modification(self):
        load_balancer_name = self.deployed.loadBalancerName if self.deployed.loadBalancerName else self.deployed.name
        previous_load_balancer_name = self.previous_deployed.loadBalancerName if self.previous_deployed.loadBalancerName else self.previous_deployed.name
        if load_balancer_name != previous_load_balancer_name:
            raise RuntimeError("Application ELB name can not be updated once created")
        if self.deployed.region != self.previous_deployed.region:
            raise RuntimeError("Application ELB region can not be updated once created")
        if self.deployed.scheme != self.previous_deployed.scheme:
            raise RuntimeError("Application ELB scheme not be updated once created")

    def __get_alb_ni_desription(self):
        index = self.deployed.arn.rfind('/')
        load_balancer_name = self.deployed.loadBalancerName if self.deployed.loadBalancerName else self.deployed.name
        return 'app/{}{}'.format(load_balancer_name, self.deployed.arn[index:])

    def __get_listener_arn(self, listener_rule):
        for listener in self.deployed.listeners:
            if listener.port == listener_rule.listenerPort:
                return listener.arn
        raise RuntimeError('Listener rule has incorrect listener port {}'.format(listener_rule.listenerPort))

    def __get_listener_rule_params(self, listener_rule):
        params = {
            'ListenerArn': self.__get_listener_arn(listener_rule),
            'Priority': listener_rule.priority,
            'Conditions': [{
                'Field': 'host-header' if listener_rule.hostNamePattern else 'path-pattern',
                'Values': [
                    listener_rule.hostNamePattern if listener_rule.hostNamePattern else listener_rule.pathPattern]}],
            'Actions': [{'Type': 'forward', 'TargetGroupArn': self.get_target_group_arn(listener_rule.targetGroup)}]
        }

        return params

    def __get_listener_params(self, listener):
        params = {
            'LoadBalancerArn': self.deployed.arn,
            'Protocol': listener.protocol.upper(),
            'Port': listener.port,
            'DefaultActions': [
                {'Type': 'forward', 'TargetGroupArn': self.get_target_group_arn(listener.targetGroup)}]
        }

        if listener.sslPolicy:
            params['SslPolicy'] = listener.sslPolicy
        if listener.sslCertificate:
            params['Certificates'] = [{'CertificateArn': listener.sslCertificate}]
        return params

    def __is_alb_ni_destroyed(self):
        nis = self.ni_helper.fetch_ni_by_criteria([
            {"Name": "description", "Values": ["ELB {}".format(self.__get_alb_ni_desription())]}
        ])
        return not bool(nis)

    def __is_listeners_destroyed(self):
        arns = map(lambda listener: listener.arn, filter(lambda listener: listener.arn, self.deployed.listeners))
        if arns:
            try:
                response = self.alb_client.describe_listeners(
                    ListenerArns=arns
                )
                return not response['Listeners']
            except:
                return True
        return True

    def __is_rules_destroyed(self):
        arns = map(lambda rule: rule.arn, filter(lambda rule: rule.arn, self.deployed.listenerRules))
        if arns:
            try:
                response = self.alb_client.describe_rules(
                    RuleArns=arns
                )
                return not response['Rules']
            except:
                return True
        return True

    @staticmethod
    def __print_association_logs(associate, children, child_type, target, target_type):
        if associate:
            verb = 'Associating'
            relation = 'to'
        else:
            verb = 'Dissociating'
            relation = 'from'

        print "{} {} {} {} {} {} ".format(verb, child_type, ', '.join(children), relation, target_type, target)

    @staticmethod
    def __print_listener_log(verb, listener):
        print '{} listener - protocol: {}, port: {}, targetGroup: {}'.format(verb, listener.protocol,
                                                                             listener.port,
                                                                             listener.targetGroup)

    @staticmethod
    def __print_listener_rule_log(verb, listener_rule):
        pattern_value = listener_rule.hostNamePattern if listener_rule.hostNamePattern else listener_rule.pathPattern
        print '{} listener rule - listener port: {}, pattern value: {}, targetGroup: {}'.format(verb,
                                                                                                listener_rule.listenerPort,
                                                                                                pattern_value,
                                                                                                listener_rule.targetGroup)
