from commons.aws_helper import AWSHelper
from ec2.ec2_helper import EC2Helper
from botocore.exceptions import ClientError
import time

class LambdaHelper(AWSHelper):
    def __init__(self, deployed):
        super(LambdaHelper, self).__init__(deployed)
        self.lambda_client = self.get_aws_client(region=deployed.region, resource_name='lambda')
        self.ec2_helper = EC2Helper(deployed)

    def get_lambda_params(self):
        function_name = self.deployed.functionName if self.deployed.functionName else self.deployed.name
        param_dict = {'FunctionName': function_name,
                      'Runtime': self.deployed.runtime,
                      'Handler': self.deployed.handler,
                      'Role': self.deployed.role}
        if self.deployed.description:
            param_dict['Description'] = self.deployed.description
        if self.deployed.timeout:
            param_dict['Timeout'] = self.deployed.timeout
        if self.deployed.environment:
            param_dict['Environment'] = {'Variables': self.deployed.environment}
        if self.deployed.kmsKeyArn:
            param_dict['KMSKeyArn'] = self.deployed.kmsKeyArn
        if self.deployed.memorySize:
            param_dict['MemorySize'] = self.deployed.memorySize
        if self.deployed.subnets:
            subnet_ids = self.ec2_helper.get_subnet_id_list(self.deployed.subnets)
            param_dict['VpcConfig'] = {
                'SubnetIds': subnet_ids
            }
        if self.deployed.securityGroups:
            security_group_ids = self.ec2_helper.get_security_group_id_list(self.deployed.securityGroups)
            if 'VpcConfig' in param_dict:
                vpc_dict = param_dict['VpcConfig']
                vpc_dict['SecurityGroupIds'] = security_group_ids
            else:
                param_dict['VpcConfig'] = {
                    'SecurityGroupIds': security_group_ids
                }
        if self.deployed.deadLetterConfig:
            param_dict['DeadLetterConfig'] = {
                'TargetArn': self.deployed.deadLetterConfig
            }
        if self.deployed.tracingConfig:
            param_dict['TracingConfig'] = {
                'Mode': self.deployed.tracingConfig
            }
        return param_dict

    def add_create_only_params(self, param_dict):
        if self.deployed.publish:
            param_dict['Publish'] = self.deployed.publish
        if self.deployed.lambdaTags:
            param_dict['Tags'] = self.deployed.lambdaTags
        return param_dict

    def create_function_with_zip(self):
        print "Creating lambda function..."
        lambda_params = self.get_lambda_params()
        lambda_params = self.add_create_only_params(lambda_params)
        lambda_params['Code'] = {'ZipFile': open(self.deployed.file.path, 'rb').read()}
        create_response = self.lambda_client.create_function(**lambda_params)
        return create_response

    def create_function_with_s3_bucket(self):
        print "Creating lambda function..."
        lambda_params = self.get_lambda_params()
        lambda_params = self.add_create_only_params(lambda_params)
        lambda_params['Code'] = {'S3Bucket': self.deployed.bucketName,
                                 'S3Key': self.deployed.s3Key}
        if self.deployed.s3ObjectVersion:
            bucket_dict = lambda_params['Code']
            bucket_dict['S3ObjectVersion'] = self.deployed.s3ObjectVersion

        create_response = self.lambda_client.create_function(**lambda_params)
        return create_response

    def update_function_code_zip(self):
        function_name = self.deployed.functionName if self.deployed.functionName else self.deployed.name
        update_res = self.lambda_client.update_function_code(
            FunctionName=function_name,
            ZipFile=open(self.deployed.file.path, 'rb').read(),
            Publish=self.deployed.publish
        )
        return update_res

    def update_function_code_s3(self):
        function_name = self.deployed.functionName if self.deployed.functionName else self.deployed.name
        lambda_params = {'FunctionName': function_name,
                         'Publish': self.deployed.publish,
                         'S3Bucket': self.deployed.bucketName,
                         'S3Key': self.deployed.s3Key}
        if self.deployed.s3ObjectVersion:
            lambda_params['S3ObjectVersion'] = self.deployed.s3ObjectVersion
        update_res = self.lambda_client.update_function_code(**lambda_params)
        return update_res

    def add_tags_to_function(self):
        response = self.lambda_client.tag_resource(
            Resource=self.deployed.functionARN,
            Tags=self.deployed.lambdaTags
        )
        return response

    def update_function_configurations(self):
        function_name = self.deployed.functionName if self.deployed.functionName else self.deployed.name
        print "Updating lambda function configurations..."
        lambda_params = self.get_lambda_params()
        while True:
            try:
                update_res = self.lambda_client.update_function_configuration(**lambda_params)
            except ClientError as ce:
                if ce.response['Error']['Code'] == 'ResourceConflictException':
                    print("Waiting for lambda function [{0}] to be updated".format(function_name))
                    time.sleep(5)
                    continue
                else:
                    raise ce
            else:
                print("Lambda function [{0}] updated successfully".format(function_name))
                break
        return update_res

    def delete_function(self):
        print "Deleting lambda function..."
        function_name = self.deployed.functionName if self.deployed.functionName else self.deployed.name
        delete_response = self.lambda_client.delete_function(FunctionName=function_name)
        return delete_response

    def get_lambda_function(self, function_name):
         return self.lambda_client.get_function(FunctionName=function_name) ['Configuration']
