from collections import defaultdict

from ec2.ec2_helper import EC2Helper
from ec2.vpc.vpc_helper import VPCHelper


class RouteHelper(EC2Helper):
    def __init__(self, deployed):
        super(RouteHelper, self).__init__(deployed)
        self.vpc_helper = VPCHelper(deployed)

    def create_route_table(self):
        params = {'VpcId': (self.get_vpc_id())}
        create_route_table_response = self.ec2_client.create_route_table(**params)
        print "Creating route table on VPC %s" % params['VpcId']
        if create_route_table_response['RouteTable']['RouteTableId']:
            print "Route table created with Id: %s" % create_route_table_response['RouteTable']['RouteTableId']
        else:
            raise RuntimeError("Not able to create route table")
        return create_route_table_response['RouteTable']

    def delete_route_table(self):
        print "Destroying route table with Id: %s" % self.deployed.routeTableId
        delete_route_table_response = self.ec2_client.delete_route_table(
            RouteTableId=self.deployed.routeTableId
        )
        return delete_route_table_response

    def associate_subnets(self, subnets):
        if subnets:
            subnet_ids = self.get_subnet_id_list(subnets)
            if len(subnet_ids) > 0:
                for subnet_id in subnet_ids:
                    self.associate_subnet(subnet_id)

    def disassociate_subnets(self, subnets):
        if subnets:
            subnet_ids = self.get_subnet_id_list(subnets)
            if len(subnet_ids) > 0:
                associations = self.find_associations_for_subnets(subnet_ids)
                for subnet_id in associations:
                    print "Disassociating subnet: %s with association id: %s" % (subnet_id, associations[subnet_id])
                    self.ec2_client.disassociate_route_table(
                        AssociationId=associations[subnet_id]
                    )
                    print "Disassociated subnet: %s with association id: %s" % (subnet_id, associations[subnet_id])

    def find_associations_for_subnets(self, subnet_ids):
        subnet_id_association_id_map=defaultdict()
        if subnet_ids and len(subnet_ids) > 0:
            route_tables_response = self.ec2_client.describe_route_tables(
                RouteTableIds=[
                    self.deployed.routeTableId
                ],
                Filters=[
                    {
                        'Name': 'association.subnet-id',
                        'Values': subnet_ids
                    }
                ]
            )
            associations = route_tables_response['RouteTables'][0]['Associations']
            for association in associations:
                subnet_id_association_id_map[association['SubnetId']] = association['RouteTableAssociationId']
        return subnet_id_association_id_map

    def associate_subnet(self, subnet_id):
        print "Associating subnet %s with route table %s" % (subnet_id, self.deployed.routeTableId)
        associate_subnet_response = self.ec2_client.associate_route_table(
            SubnetId=subnet_id,
            RouteTableId=self.deployed.routeTableId
        )
        if associate_subnet_response['AssociationId']:
            print "Subnet %s associated with route table %s with association id %s" % (subnet_id, self.deployed.routeTableId, associate_subnet_response['AssociationId'])
        else:
            raise Exception("Could not associate subnet %s with route table %s" % (subnet_id, self.deployed.routeTableId))

    def get_vpc_id(self):
        if self.is_starts_with_name(self.deployed.vpc):
            vpc_id = self.vpc_helper.get_vpc_id_by_name(self.get_property_name(self.deployed.vpc))
        else:
            vpc_id = self.deployed.vpc
        return vpc_id

    def find_main_route_table(self):
        vpc_id = self.get_vpc_id()
        route_tables_response = self.ec2_client.describe_route_tables(
            Filters=[
                {
                    'Name': 'vpc-id',
                    'Values': [vpc_id]
                },
                {
                    'Name': 'association.main',
                    'Values': ['true']
                }
            ]
        )
        return filter(lambda association : association['Main'] == True, route_tables_response['RouteTables'][0]['Associations'])[0]

    def replace_main_table(self, new_main_table):
        main_route_table_association = self.find_main_route_table()
        if new_main_table and main_route_table_association['RouteTableId'] != new_main_table:
            print "Changing main table from (id: %s) to (id: %s)" \
                  % (main_route_table_association['RouteTableId'], new_main_table)
            replace_main_response = self.ec2_client.replace_route_table_association(
                AssociationId = main_route_table_association['RouteTableAssociationId'],
                RouteTableId = new_main_table
            )
            print "Main table changed from (id: %s) to (id: %s) with Association Id: %s"\
                  % (main_route_table_association['RouteTableId'], new_main_table, replace_main_response['NewAssociationId'])
            return main_route_table_association['RouteTableId']
        else:
            print "Route table id: %s is already the main table of the VPC", new_main_table

    def create_routes(self, routes):
        for route in routes:
            self.add_route(route=route)

    def get_create_route_params(self, route):
        params_dict = {'RouteTableId': self.deployed.routeTableId,
                       'DestinationCidrBlock': route.ipv4Address,
                       'GatewayId': self.get_internet_gateway_id_by_id_or_name(route.gatewayId),
                       'DestinationIpv6CidrBlock': route.ipv6Address,
                       'EgressOnlyInternetGatewayId': route.egressOnlyInternetGatewayId,
                       'InstanceId': self.get_instance_id_by_id_or_name(route.instanceId),
                       'NetworkInterfaceId': self.get_network_interface_id_by_id_or_name(route.networkInterfaceId),
                       'VpcPeeringConnectionId': self.get_vpc_peering_connection_id_by_id_or_name(route.vpcPeeringConnectionId),
                       'NatGatewayId': route.natGatewayId}
        not_null_params_dict = {k:v for k,v in params_dict.iteritems() if v is not None}
        return not_null_params_dict

    def add_route(self, route):
        params = self.get_create_route_params(route)
        print "Adding route in route table (Id:%s)" % self.deployed.routeTableId
        self.ec2_client.create_route(**params)

    def delete_routes(self, routes):
        for route in routes:
            self.delete_route(
                routeTableId=self.deployed.routeTableId,
                ipv4_cidr=route.ipv4Address,
                ipv6_cidr=route.ipv6Address
            )

    def delete_route(self, routeTableId, ipv4_cidr, ipv6_cidr):
        if ipv4_cidr:
            print "Deleting route with destination: %s, for route table (ID: %s)" % (ipv4_cidr, routeTableId)
            self.ec2_client.delete_route(RouteTableId=routeTableId, DestinationCidrBlock=ipv4_cidr)
        else:
            print "Deleting route with destination: %s, for route table (ID: %s)" % (ipv6_cidr, routeTableId)
            self.ec2_client.delete_route(RouteTableId=routeTableId, DestinationIpv6CidrBlock=ipv6_cidr)
