# Required env vars:
# $CLUSTER: name of the ECS cluster (i.e. `myapp-prod`)
# $TASK_DEFINITION: name and revision of the task definition (i.e. `mytask:1`)
# $SUBNETS: comma-separated list of subnets to place the new task (i.e. `sg-12345678,sg-abcdef12`)
# $SECURITY_GROUPS: comma-separated list of security groups to be used for the new task (i.e. `subnet-12345678`)

import boto3
import os
import traceback

code_pipeline = boto3.client('codepipeline')
ecs = boto3.client('ecs')


def lambda_handler(event, context):
    try:
        job_id = event['CodePipeline.job']['id']
        job_data = event['CodePipeline.job']['data']
        print('- CodePipeline JobID: %s' % job_id)

        if 'continuationToken' in job_data:
            # Not the first time we are called
            # Let's check the status of the one-off task
            print('- ContinuationToken: %s' % job_data['continuationToken'])
            response = ecs.describe_tasks(
                cluster=os.getenv('CLUSTER'),
                tasks=[
                    job_data['continuationToken']
                ]
            )
            task = response['tasks'][0]
            print('- Task: %s' % str(task))
            if task['lastStatus'] != 'STOPPED':
                # Not yet finished
                print('-- Task not yet finished, continuing at a later time')
                code_pipeline.put_job_success_result(jobId=job_id, continuationToken=job_data['continuationToken'])
            elif task['stoppedReason'] != 'Essential container in task exited':
                # Finished unsuccessfully, for some reason
                print('-- Task failed for unknown reason. Failing CodePipeline job')
                code_pipeline.put_job_failure_result(
                    jobId=job_id,
                    failureDetails={
                        'message': task['stoppedReason'],
                        'type': 'JobFailed'
                    }
                )
            elif task['containers'][0]['exitCode'] != 0:
                # Finished unsuccessfully
                print('-- Task failed because of container failed. Failing CodePipeline job')
                code_pipeline.put_job_failure_result(
                    jobId=job_id,
                    failureDetails={
                        'message': 'Task exited with exit code %s' %
                                   task['containers'][0]['exitCode'],
                        'type': 'JobFailed'
                    }
                )
            else:
                # Finished successfully
                print('-- Task succeeded. Succeeding CodePipeline job')
                code_pipeline.put_job_success_result(jobId=job_id)
        else:
            # First time we are called
            # Let's run the one-off task
            print('-- No ContinuationToken, starting task')
            response = ecs.run_task(
                cluster=os.getenv('CLUSTER'),
                launchType=os.getenv('LAUNCH_TYPE', 'FARGATE'),
                taskDefinition=os.getenv('TASK_DEFINITION'),
                count=int(os.getenv('COUNT', 1)),
                platformVersion='LATEST',
                networkConfiguration={
                    'awsvpcConfiguration': {
                        'subnets': os.getenv('SUBNETS').split(','),
                        'assignPublicIp': os.getenv('ASSIGN_PUBLIC_IP', 'ENABLED'),
                        'securityGroups': os.getenv('SECURITY_GROUPS').split(','),
                    },
                }
            )
            print('- Task: %s' % str(response))

            # Check the status of the task later
            continuation_token = response['tasks'][0]['taskArn']
            print('-- Task just triggered. Continuing at a later time')
            code_pipeline.put_job_success_result(jobId=job_id, continuationToken=continuation_token)

    except Exception as e:
        # If any other exceptions which we didn't expect are raised
        # then fail the job and log the exception message.
        print('ERROR: Function failed due to exception')
        print(e)
        traceback.print_exc()
        code_pipeline.put_job_failure_result(
            jobId=job_id,
            failureDetails={
                'message': 'Function exception: ' + str(e),
                'type': 'JobFailed'
            }
        )

    return "Complete."