from elb.elb_helper import ElbHelper
from ec2.vpc.vpc_helper import VPCHelper


class TargetGroupHelper(ElbHelper):
    def __init__(self, deployed, previous_deployed=None):
        super(TargetGroupHelper, self).__init__(deployed)
        self.alb_client = self.get_aws_client(region=deployed.region, resource_name='elbv2')
        self.previous_deployed = previous_deployed
        self.vpc_helper = VPCHelper(deployed)

    def create_target_group(self):
        response = self.alb_client.create_target_group(**self.__get_target_group_params())
        self.deployed.arn = response['TargetGroups'][0]['TargetGroupArn']

    def destroy_target_group(self):
        self.alb_client.delete_target_group(
            TargetGroupArn=self.previous_deployed.arn
        )

    def modify_target_group_attributes(self):
        attributes = [
            {'Key': 'deregistration_delay.timeout_seconds', 'Value': str(self.deployed.deRegistrationDelay)}
        ]
        if self.deployed.stickinessDuration > 0:
            attributes.append({'Key': 'stickiness.enabled', 'Value': 'true'})
            attributes.append({'Key': 'stickiness.type', 'Value': 'lb_cookie'})
            attributes.append(
                {'Key': 'stickiness.lb_cookie.duration_seconds', 'Value': str(self.deployed.stickinessDuration)})
        else:
            attributes.append({'Key': 'stickiness.enabled', 'Value': 'false'})

        self.alb_client.modify_target_group_attributes(
            TargetGroupArn=self.deployed.arn,
            Attributes=attributes
        )

    def modify_target_group(self):
        self.alb_client.modify_target_group(**self.__get_target_group_params())

    def associate_instances(self):
        targets = self.__get_targets(self.deployed)
        if targets:
            target_name = self.deployed.targetName if self.deployed.targetName else self.deployed.name
            TargetGroupHelper.__print_association_logs(True, targets, 'instances', target_name,
                                                       'target group')
            self.alb_client.register_targets(
                TargetGroupArn=self.deployed.arn,
                Targets=targets
            )

    def dissociate_instances(self):
        targets = self.__get_targets(self.previous_deployed)
        if targets:
            target_name = self.deployed.targetName if self.deployed.targetName else self.deployed.name
            TargetGroupHelper.__print_association_logs(False, targets, 'instances', target_name,
                                                       'target group')
            self.alb_client.deregister_targets(
                TargetGroupArn=self.deployed.arn,
                Targets=targets
            )

    def get_target_group_arn(self, target_group_name):
        target_group = self.find_target_group(target_group_name)
        return target_group['TargetGroupArn']

    def find_target_group(self, target_group_name):
        response = self.alb_client.describe_target_groups(
            Names=[target_group_name]
        )
        if not response['TargetGroups']:
            raise RuntimeError('No target group found for name {}'.format(target_group_name))
        return response['TargetGroups'][0]

    def should_modify_target_group(self):
        return TargetGroupHelper.is_modified(self.deployed, self.previous_deployed,
                                             ['healthCheckProtocol', 'healthCheckPort', 'healthyThresholdCount',
                                              'unhealthyThresholdCount', 'healthCheckTimeout', 'healthCheckInterval',
                                              'healthCheckSuccessCode'])

    def should_modify_target_group_attributes(self):
        return TargetGroupHelper.is_modified(self.deployed, self.previous_deployed,
                                             ['deRegistrationDelay', 'stickinessDuration'])

    def validate_target_group(self):
        if not self.deployed.healthyThresholdCount in range(2, 11):
            raise RuntimeError("Target group healthy threshold count range should be 2 to 10.")
        if not self.deployed.unhealthyThresholdCount in range(2, 11):
            raise RuntimeError("Target group unhealthy threshold count range should be 2 to 10.")
        if not self.deployed.healthCheckTimeout in range(2, 61):
            raise RuntimeError("Target group threshold check timeout range should be 2 to 60 seconds.")
        if not self.deployed.healthCheckInterval in range(5, 301):
            raise RuntimeError("Target group threshold check interval range should be 5 to 300 seconds.")

    def validate_target_group_modification(self):
        target_name = self.deployed.targetName if self.deployed.targetName else self.deployed.name
        previous_target_name = self.previous_deployed.targetName if self.previous_deployed.targetName else self.previous_deployed.name
        if target_name != previous_target_name:
            raise RuntimeError("Target group name can not be updated once created.")
        if self.deployed.region != self.previous_deployed.region:
            raise RuntimeError("Target group region can not be updated once created.")
        if self.deployed.vpc != self.previous_deployed.vpc:
            raise RuntimeError("Target group vpc can not be updated once created.")
        if self.deployed.protocol != self.previous_deployed.protocol:
            raise RuntimeError("Target group protocol can not be updated once created.")
        if self.deployed.port != self.previous_deployed.port:
            raise RuntimeError("Target group port can not be updated once created.")

    def __get_targets(self, deployed):
        targets = []
        targets_with_name = []

        for id, port in deployed.instances.items():
            target = {'Id': id}
            if port and not (port == 'null' or port.strip() == ''):
                target['Port'] = int(port)
            targets.append(target)

            if self.is_starts_with_name(id):
                target['Id'] = self.get_property_name(id)
                targets_with_name.append(target)

        if targets_with_name:
            instances = self.ec2_helper.get_instances(map(lambda target: target['Id'], targets_with_name))
            for target in targets_with_name:
                found = False
                for instance in instances:
                    tags = filter(lambda tag: tag['Key'] == 'Name', instance['Tags'])
                    if tags and tags[0]['Value'] == target['Id']:
                        found = True
                        target['Id'] = instance['InstanceId']
                if not found:
                    raise RuntimeError("Instances with name {} not found in AWS.".format(target['Id']))

        return targets

    def __get_target_group_params(self):
        params = {
            'HealthCheckProtocol': self.deployed.healthCheckProtocol.upper(),
            'HealthCheckPort': self.deployed.healthCheckPort,
            'HealthCheckPath': self.deployed.healthCheckPath,
            'HealthCheckIntervalSeconds': self.deployed.healthCheckInterval,
            'HealthCheckTimeoutSeconds': self.deployed.healthCheckTimeout,
            'HealthyThresholdCount': self.deployed.healthyThresholdCount,
            'UnhealthyThresholdCount': self.deployed.unhealthyThresholdCount,
            'Matcher': {'HttpCode': self.deployed.healthCheckSuccessCode}
        }

        if self.deployed.arn:
            params['TargetGroupArn'] = self.deployed.arn
        else:
            if self.is_starts_with_name(self.deployed.vpc):
                vpc_id = self.vpc_helper.get_vpc_id_by_name(self.get_property_name(self.deployed.vpc))
            else:
                vpc_id = self.deployed.vpc
            target_name = self.deployed.targetName if self.deployed.targetName else self.deployed.name
            params['VpcId'] = vpc_id
            params['Name'] = target_name
            params['Protocol'] = self.deployed.protocol.upper()
            params['Port'] = self.deployed.port

        return params

    @staticmethod
    def is_modified(deployed, previous_deployed, property_names):
        if deployed and previous_deployed:
            return next((True for property_name in property_names if
                         deployed[property_name] != previous_deployed[property_name]), False)
        return False

    @staticmethod
    def __print_association_logs(associate, children, child_type, target, target_type):
        if associate:
            verb = 'Associating'
            relation = 'to'
        else:
            verb = 'Dissociating'
            relation = 'from'

        print "{} {} {} {} {} {}.".format(verb, child_type, ', '.join(map(lambda child: str(child), children)), relation, target_type, target)
