diff --git a/app.py b/app.py index 6948331..9188b37 100644 --- a/app.py +++ b/app.py @@ -9,48 +9,51 @@ from aws_cdk import core from once.once_stack import OnceStack, CustomDomainStack -ONCE_CONFIG_FILE = os.getenv('ONCE_CONFIG_FILE', os.path.expanduser('~/.once')) +ONCE_CONFIG_FILE = os.getenv("ONCE_CONFIG_FILE", os.path.expanduser("~/.once")) -SECRET_KEY = os.getenv('SECRET_KEY') -CUSTOM_DOMAIN = os.getenv('CUSTOM_DOMAIN') -HOSTED_ZONE_NAME = os.getenv('HOSTED_ZONE_NAME') -HOSTED_ZONE_ID = os.getenv('HOSTED_ZONE_ID') +SECRET_KEY = os.getenv("SECRET_KEY") +CUSTOM_DOMAIN = os.getenv("CUSTOM_DOMAIN") +HOSTED_ZONE_NAME = os.getenv("HOSTED_ZONE_NAME") +HOSTED_ZONE_ID = os.getenv("HOSTED_ZONE_ID") def generate_random_key() -> str: - return base64.b64encode(os.urandom(128)).decode('utf-8') + return base64.b64encode(os.urandom(128)).decode("utf-8") -def generate_config(secret_key: Optional[str] = None, +def generate_config( + secret_key: Optional[str] = None, custom_domain: str = None, hosted_zone_name: str = None, - hosted_zone_id: str = None) -> configparser.ConfigParser: + hosted_zone_id: str = None, +) -> configparser.ConfigParser: config = configparser.ConfigParser() - config['once'] = { - 'secret_key': secret_key or generate_random_key(), + config["once"] = { + "secret_key": secret_key or generate_random_key(), } - config['deployment'] = {} + config["deployment"] = {} if all([custom_domain, hosted_zone_name, hosted_zone_id]): - config['once']['base_url'] = f'https://{custom_domain}' - config['deployment'] = { - 'custom_domain': custom_domain, - 'hosted_zone_name': hosted_zone_name, - 'hosted_zone_id': hosted_zone_id + config["once"]["base_url"] = f"https://{custom_domain}" + config["deployment"] = { + "custom_domain": custom_domain, + "hosted_zone_name": hosted_zone_name, + "hosted_zone_id": hosted_zone_id, } return config def get_config(config_gile: str = ONCE_CONFIG_FILE) -> configparser.ConfigParser: if not os.path.exists(ONCE_CONFIG_FILE): - print(f'Generating configuration file at {ONCE_CONFIG_FILE}') - with open(ONCE_CONFIG_FILE, 'w') as config_file: + print(f"Generating configuration file at {ONCE_CONFIG_FILE}") + with open(ONCE_CONFIG_FILE, "w") as config_file: config = generate_config( secret_key=SECRET_KEY, custom_domain=CUSTOM_DOMAIN, hosted_zone_name=HOSTED_ZONE_NAME, - hosted_zone_id=HOSTED_ZONE_ID) + hosted_zone_id=HOSTED_ZONE_ID, + ) config.write(config_file) else: config = configparser.ConfigParser() @@ -61,14 +64,14 @@ def get_config(config_gile: str = ONCE_CONFIG_FILE) -> configparser.ConfigParser def main(): config = get_config() - kwargs = {'secret_key': config['once']['secret_key']} - if config.has_section('deployment'): - kwargs.update(config['deployment']) + kwargs = {"secret_key": config["once"]["secret_key"]} + if config.has_section("deployment"): + kwargs.update(config["deployment"]) app = core.App() - once = OnceStack(app, 'once', **kwargs) + once = OnceStack(app, "once", **kwargs) app.synth() -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/client/__init__.py b/client/__init__.py index ed2c7c4..7479eeb 100644 --- a/client/__init__.py +++ b/client/__init__.py @@ -1,6 +1,6 @@ -''' +""" Simple command to share one-time files -''' +""" import os import base64 @@ -17,9 +17,9 @@ import requests from pygments import highlight, lexers, formatters -ONCE_CONFIG_FILE = os.getenv('ONCE_CONFIG_FILE', os.path.expanduser('~/.once')) -ONCE_SIGNATURE_HEADER = 'x-once-signature' -ONCE_TIMESTAMP_FORMAT = '%Y%m%d%H%M%S%f' +ONCE_CONFIG_FILE = os.getenv("ONCE_CONFIG_FILE", os.path.expanduser("~/.once")) +ONCE_SIGNATURE_HEADER = "x-once-signature" +ONCE_TIMESTAMP_FORMAT = "%Y%m%d%H%M%S%f" def highlight_json(obj): @@ -33,7 +33,7 @@ def echo_obj(obj): def get_config(config_file: str = ONCE_CONFIG_FILE) -> configparser.ConfigParser: if not os.path.exists(config_file): - raise ValueError(f'Config file not found at {config_file}') + raise ValueError(f"Config file not found at {config_file}") config = configparser.ConfigParser() config.read(ONCE_CONFIG_FILE) return config @@ -41,59 +41,57 @@ def get_config(config_file: str = ONCE_CONFIG_FILE) -> configparser.ConfigParser def api_req(method: str, url: str, verbose: bool = False, **kwargs): config = get_config() - if not config.has_option('once', 'base_url'): - raise ValueError(f'Configuration file at {ONCE_CONFIG_FILE} misses `base_url` option') + if not config.has_option("once", "base_url"): + raise ValueError(f"Configuration file at {ONCE_CONFIG_FILE} misses `base_url` option") - base_url = os.getenv('ONCE_API_URL', config['once']['base_url']) - secret_key = base64.b64decode(os.getenv('ONCE_SECRET_KEY', config['once']['secret_key'])) + base_url = os.getenv("ONCE_API_URL", config["once"]["base_url"]) + secret_key = base64.b64decode(os.getenv("ONCE_SECRET_KEY", config["once"]["secret_key"])) method = method.lower() - if method not in ['get', 'post']: + if method not in ["get", "post"]: raise ValueError(f'Unsupported HTTP method "{method}"') actual_url = urljoin(base_url, url) if verbose: - print(f'{method.upper()} {actual_url}') + print(f"{method.upper()} {actual_url}") req = requests.Request(method=method, url=actual_url, **kwargs).prepare() - plain_text = req.path_url.encode('utf-8') + plain_text = req.path_url.encode("utf-8") hmac_obj = hmac.new(secret_key, msg=plain_text, digestmod=hashlib.sha256) req.headers[ONCE_SIGNATURE_HEADER] = base64.b64encode(hmac_obj.digest()) response = requests.Session().send(req) if verbose: - print(f'Server response status: {response.status_code}') + print(f"Server response status: {response.status_code}") echo_obj(response.json()) return response -@click.command('share') -@click.argument('file', type=click.File(mode='rb'), required=True) -@click.option('--verbose', '-v', is_flag=True, default=False, help='Enables verbose output.') +@click.command("share") +@click.argument("file", type=click.File(mode="rb"), required=True) +@click.option("--verbose", "-v", is_flag=True, default=False, help="Enables verbose output.") def share(file: click.File, verbose: bool): - entry = api_req('GET', '/', - params={ - 'f': quote_plus(os.path.basename(file.name)), - 't': datetime.utcnow().strftime(ONCE_TIMESTAMP_FORMAT) - }, - verbose=verbose).json() + entry = api_req( + "GET", + "/", + params={"f": quote_plus(os.path.basename(file.name)), "t": datetime.utcnow().strftime(ONCE_TIMESTAMP_FORMAT)}, + verbose=verbose, + ).json() - once_url = entry['once_url'] - upload_data = entry['presigned_post'] - files = {'file': file} + once_url = entry["once_url"] + upload_data = entry["presigned_post"] + files = {"file": file} upload_started = time.time() - response = requests.post(upload_data['url'], - data=upload_data['fields'], - files=files) + response = requests.post(upload_data["url"], data=upload_data["fields"], files=files) upload_time = time.time() - upload_started - print(f"File uploaded in {upload_time}s") + print(f"File uploaded in {upload_time}s") print(f"File can be downloaded once at: {once_url}") -if __name__ == '__main__': +if __name__ == "__main__": share() diff --git a/once/delete-served-files/handler.py b/once/delete-served-files/handler.py index f5ab457..18db1d7 100644 --- a/once/delete-served-files/handler.py +++ b/once/delete-served-files/handler.py @@ -9,16 +9,16 @@ from boto3.dynamodb.conditions import Key def is_debug_enabled() -> bool: - value = os.getenv('DEBUG', 'false').lower() - if value in ['false', '0']: + value = os.getenv("DEBUG", "false").lower() + if value in ["false", "0"]: return False else: return bool(value) DEBUG = is_debug_enabled() -FILES_BUCKET = os.getenv('FILES_BUCKET') -FILES_TABLE_NAME = os.getenv('FILES_TABLE_NAME') +FILES_BUCKET = os.getenv("FILES_BUCKET") +FILES_TABLE_NAME = os.getenv("FILES_TABLE_NAME") log = logging.getLogger() @@ -29,28 +29,27 @@ else: def on_event(event, context): - log.debug(f'Event received: {event}') - log.debug(f'Context is: {context}') - log.debug(f'Debug mode is {DEBUG}') + log.debug(f"Event received: {event}") + log.debug(f"Context is: {context}") + log.debug(f"Debug mode is {DEBUG}") log.debug(f'Files bucket is "{FILES_BUCKET}"') - dynamodb = boto3.client('dynamodb') + dynamodb = boto3.client("dynamodb") response = dynamodb.scan( TableName=FILES_TABLE_NAME, - Select='ALL_ATTRIBUTES', - FilterExpression='deleted = :deleted', - ExpressionAttributeValues={ - ':deleted': {'BOOL': True} - }) + Select="ALL_ATTRIBUTES", + FilterExpression="deleted = :deleted", + ExpressionAttributeValues={":deleted": {"BOOL": True}}, + ) - s3 = boto3.client('s3') - for item in response['Items']: - object_name = item['object_name']['S'] - log.info(f'Deleting file {object_name}') + s3 = boto3.client("s3") + for item in response["Items"]: + object_name = item["object_name"]["S"] + log.info(f"Deleting file {object_name}") try: s3.delete_object(Bucket=FILES_BUCKET, Key=object_name) except: - log.exception('Could not delete file {object_name}') + log.exception("Could not delete file {object_name}") - response = dynamodb.delete_item(TableName=FILES_TABLE_NAME, Key={'id': item['id']}) - log.debug(f'dynamodb delete item: {response}') + response = dynamodb.delete_item(TableName=FILES_TABLE_NAME, Key={"id": item["id"]}) + log.debug(f"dynamodb delete item: {response}") diff --git a/once/download-and-delete/handler.py b/once/download-and-delete/handler.py index 072e9dd..c618b9e 100644 --- a/once/download-and-delete/handler.py +++ b/once/download-and-delete/handler.py @@ -9,31 +9,36 @@ import boto3 def is_debug_enabled() -> bool: - value = os.getenv('DEBUG', 'false').lower() - if value in ['false', '0']: + value = os.getenv("DEBUG", "false").lower() + if value in ["false", "0"]: return False else: return bool(value) DEBUG = is_debug_enabled() -FILES_BUCKET = os.getenv('FILES_BUCKET') -FILES_TABLE_NAME = os.getenv('FILES_TABLE_NAME') -PRESIGNED_URL_EXPIRES_IN = int(os.getenv('PRESIGNED_URL_EXPIRES_IN', 20)) -MASKED_USER_AGENTS = os.getenv('MASKED_USER_AGENTS', ','.join([ - '^Facebook.*', - '^Google.*', - '^Instagram.*', - '^LinkedIn.*', - '^Outlook.*', - '^Reddit.*', - '^Slack.*', - '^Skype.*', - '^SnapChat.*', - '^Telegram.*', - '^Twitter.*', - '^WhatsApp.*' -])).split(',') +FILES_BUCKET = os.getenv("FILES_BUCKET") +FILES_TABLE_NAME = os.getenv("FILES_TABLE_NAME") +PRESIGNED_URL_EXPIRES_IN = int(os.getenv("PRESIGNED_URL_EXPIRES_IN", 20)) +MASKED_USER_AGENTS = os.getenv( + "MASKED_USER_AGENTS", + ",".join( + [ + "^Facebook.*", + "^Google.*", + "^Instagram.*", + "^LinkedIn.*", + "^Outlook.*", + "^Reddit.*", + "^Slack.*", + "^Skype.*", + "^SnapChat.*", + "^Telegram.*", + "^Twitter.*", + "^WhatsApp.*", + ] + ), +).split(",") log = logging.getLogger() @@ -44,55 +49,45 @@ else: def on_event(event, context): - log.debug(f'Event received: {event}') - log.debug(f'Context is: {context}') - log.debug(f'Debug mode is {DEBUG}') + log.debug(f"Event received: {event}") + log.debug(f"Context is: {context}") + log.debug(f"Debug mode is {DEBUG}") log.debug(f'Files bucket is "{FILES_BUCKET}"') - entry_id = event['pathParameters']['entry_id'] - filename = urllib.parse.unquote_plus(event['pathParameters']['filename']) - object_name = f'{entry_id}/{filename}' + entry_id = event["pathParameters"]["entry_id"] + filename = urllib.parse.unquote_plus(event["pathParameters"]["filename"]) + object_name = f"{entry_id}/{filename}" - dynamodb = boto3.client('dynamodb') - entry = dynamodb.get_item( - TableName=FILES_TABLE_NAME, - Key={'id': {'S': entry_id}}) + dynamodb = boto3.client("dynamodb") + entry = dynamodb.get_item(TableName=FILES_TABLE_NAME, Key={"id": {"S": entry_id}}) - log.debug(f'Matched Dynamodb entry: {entry}') + log.debug(f"Matched Dynamodb entry: {entry}") - if 'Item' not in entry or 'deleted' in entry['Item']: - error_message = f'Entry not found: {object_name}' + if "Item" not in entry or "deleted" in entry["Item"]: + error_message = f"Entry not found: {object_name}" log.info(error_message) - return {'statusCode': 404, 'body': error_message} + return {"statusCode": 404, "body": error_message} # Some rich clients try to get a preview of any link pasted # into text controls. - user_agent = event['headers'].get('user-agent', '') + user_agent = event["headers"].get("user-agent", "") is_masked_agent = any([re.match(agent, user_agent) for agent in MASKED_USER_AGENTS]) if is_masked_agent: - log.info('Serving possible link preview. Download prevented.') - return { - 'statusCode': 200, - 'headers': {} - } + log.info("Serving possible link preview. Download prevented.") + return {"statusCode": 200, "headers": {}} - s3 = boto3.client('s3') + s3 = boto3.client("s3") download_url = s3.generate_presigned_url( - 'get_object', - Params={'Bucket': FILES_BUCKET, 'Key': object_name}, - ExpiresIn=PRESIGNED_URL_EXPIRES_IN) + "get_object", Params={"Bucket": FILES_BUCKET, "Key": object_name}, ExpiresIn=PRESIGNED_URL_EXPIRES_IN + ) dynamodb.update_item( TableName=FILES_TABLE_NAME, - Key={'id': {'S': entry_id}}, - UpdateExpression='SET deleted = :deleted', - ExpressionAttributeValues={':deleted': {'BOOL': True}}) + Key={"id": {"S": entry_id}}, + UpdateExpression="SET deleted = :deleted", + ExpressionAttributeValues={":deleted": {"BOOL": True}}, + ) - log.info(f'Entry {object_name} marked as deleted') + log.info(f"Entry {object_name} marked as deleted") - return { - 'statusCode': 301, - 'headers': { - 'Location': download_url - } - } + return {"statusCode": 301, "headers": {"Location": download_url}} diff --git a/once/get-upload-ticket/handler.py b/once/get-upload-ticket/handler.py index c849174..64e80ac 100644 --- a/once/get-upload-ticket/handler.py +++ b/once/get-upload-ticket/handler.py @@ -17,25 +17,25 @@ from botocore.exceptions import ClientError def is_debug_enabled() -> bool: - value = os.getenv('DEBUG', 'false').lower() - if value in ['false', '0']: + value = os.getenv("DEBUG", "false").lower() + if value in ["false", "0"]: return False else: return bool(value) DEBUG = is_debug_enabled() -APP_URL = os.getenv('APP_URL') -EXPIRATION_TIMEOUT = int(os.getenv('EXPIRATION_TIMEOUT', 60*5)) -FILES_BUCKET = os.getenv('FILES_BUCKET') -FILES_TABLE_NAME = os.getenv('FILES_TABLE_NAME') -S3_REGION_NAME = os.getenv('S3_REGION_NAME', 'eu-west-1') -S3_SIGNATURE_VERSION = os.getenv('S3_SIGNATURE_VERSION', 's3v4') -SECRET_KEY = base64.b64decode(os.getenv('SECRET_KEY')) -SIGNATURE_HEADER = os.getenv('SIGNATURE_HEADER', 'x-once-signature') -SIGNATURE_TIME_TOLERANCE = int(os.getenv('SIGNATURE_TIME_TOLERANCE', 5)) -TIMESTAMP_FORMAT_STRING = os.getenv('TIMESTAMP_FORMAT_STRING', '%d%m%Y%H%M%S') -TIMESTAMP_PARAMETER_FORMAT = '%Y%m%d%H%M%S%f' +APP_URL = os.getenv("APP_URL") +EXPIRATION_TIMEOUT = int(os.getenv("EXPIRATION_TIMEOUT", 60 * 5)) +FILES_BUCKET = os.getenv("FILES_BUCKET") +FILES_TABLE_NAME = os.getenv("FILES_TABLE_NAME") +S3_REGION_NAME = os.getenv("S3_REGION_NAME", "eu-west-1") +S3_SIGNATURE_VERSION = os.getenv("S3_SIGNATURE_VERSION", "s3v4") +SECRET_KEY = base64.b64decode(os.getenv("SECRET_KEY")) +SIGNATURE_HEADER = os.getenv("SIGNATURE_HEADER", "x-once-signature") +SIGNATURE_TIME_TOLERANCE = int(os.getenv("SIGNATURE_TIME_TOLERANCE", 5)) +TIMESTAMP_FORMAT_STRING = os.getenv("TIMESTAMP_FORMAT_STRING", "%d%m%Y%H%M%S") +TIMESTAMP_PARAMETER_FORMAT = "%Y%m%d%H%M%S%f" log = logging.getLogger() @@ -53,45 +53,38 @@ class UnauthorizedError(Exception): pass -def create_presigned_post(bucket_name: str, object_name: str, - fields=None, conditions=None, expiration=3600) -> Dict: - ''' +def create_presigned_post(bucket_name: str, object_name: str, fields=None, conditions=None, expiration=3600) -> Dict: + """ Generate a presigned URL S3 POST request to upload a file - ''' - s3_client = boto3.client('s3', - region_name=S3_REGION_NAME, - config=Config(signature_version=S3_SIGNATURE_VERSION)) + """ + s3_client = boto3.client("s3", region_name=S3_REGION_NAME, config=Config(signature_version=S3_SIGNATURE_VERSION)) return s3_client.generate_presigned_post( - bucket_name, object_name, - Fields=fields, - Conditions=conditions, - ExpiresIn=expiration) + bucket_name, object_name, Fields=fields, Conditions=conditions, ExpiresIn=expiration + ) def validate_signature(event: Dict, secret_key: bytes) -> bool: - canonicalized_url = event['rawPath'] - if 'queryStringParameters' in event: - qs = urlencode(event['queryStringParameters'], quote_via=quote_plus) - canonicalized_url = f'{canonicalized_url}?{qs}' + canonicalized_url = event["rawPath"] + if "queryStringParameters" in event: + qs = urlencode(event["queryStringParameters"], quote_via=quote_plus) + canonicalized_url = f"{canonicalized_url}?{qs}" - plain_text = canonicalized_url.encode('utf-8') - log.debug(f'Plain text: {plain_text}') + plain_text = canonicalized_url.encode("utf-8") + log.debug(f"Plain text: {plain_text}") - encoded_signature = event['headers'][SIGNATURE_HEADER] - log.debug(f'Received signature: {encoded_signature}') + encoded_signature = event["headers"][SIGNATURE_HEADER] + log.debug(f"Received signature: {encoded_signature}") signature_value = base64.b64decode(encoded_signature) - hmac_obj = hmac.new(secret_key, - msg=plain_text, - digestmod=hashlib.sha256) + hmac_obj = hmac.new(secret_key, msg=plain_text, digestmod=hashlib.sha256) calculated_signature = hmac_obj.digest() return calculated_signature == signature_value -def validate_timestamp(timestamp: str, current_time: datetime=None) -> bool: +def validate_timestamp(timestamp: str, current_time: datetime = None) -> bool: if current_time is None: current_time = datetime.utcnow() @@ -99,68 +92,63 @@ def validate_timestamp(timestamp: str, current_time: datetime=None) -> bool: file_loading_time = datetime.strptime(timestamp, TIMESTAMP_PARAMETER_FORMAT) return current_time - file_loading_time <= timedelta(seconds=SIGNATURE_TIME_TOLERANCE) except: - log.error(f'Could not validate timestamp {timestamp} according to the format: {TIMESTAMP_PARAMETER_FORMAT}') + log.error(f"Could not validate timestamp {timestamp} according to the format: {TIMESTAMP_PARAMETER_FORMAT}") return False def on_event(event, context): - log.debug(f'Event received: {event}') - log.debug(f'Context is: {context}') - log.debug(f'Requests library version: {requests.__version__}') + log.debug(f"Event received: {event}") + log.debug(f"Context is: {context}") + log.debug(f"Requests library version: {requests.__version__}") - log.debug(f'Debug mode is {DEBUG}') + log.debug(f"Debug mode is {DEBUG}") log.debug(f'App URL is "{APP_URL}"') log.debug(f'Files bucket is "{FILES_BUCKET}"') log.debug(f'Files Dynamodb table name is "{FILES_TABLE_NAME}"') log.debug(f'S3 region name is: "{S3_REGION_NAME}"') log.debug(f'S3 signature algorithm version is "{S3_SIGNATURE_VERSION}"') - log.debug(f'Pre-signed urls will expire after {EXPIRATION_TIMEOUT} seconds') + log.debug(f"Pre-signed urls will expire after {EXPIRATION_TIMEOUT} seconds") - q = event.get('queryStringParameters', {}) - filename = unquote_plus(q.get('f')) - timestamp = unquote_plus(q.get('t')) + q = event.get("queryStringParameters", {}) + filename = unquote_plus(q.get("f")) + timestamp = unquote_plus(q.get("t")) response_code = 200 response = {} try: if filename is None: - raise BadRequestError('Provide a valid value for the `f` query parameter') + raise BadRequestError("Provide a valid value for the `f` query parameter") if timestamp is None: - raise BadRequestError('Please provide a valid value for the `t` query parameter') + raise BadRequestError("Please provide a valid value for the `t` query parameter") if not validate_timestamp(timestamp): - log.error('Request timestamp is not valid') - raise UnauthorizedError('Your request cannot be authorized') + log.error("Request timestamp is not valid") + raise UnauthorizedError("Your request cannot be authorized") if not validate_signature(event, SECRET_KEY): - log.error('Request signature is not valid') - raise UnauthorizedError('Your request cannot be authorized') + log.error("Request signature is not valid") + raise UnauthorizedError("Your request cannot be authorized") domain = string.ascii_uppercase + string.ascii_lowercase + string.digits - entry_id = ''.join(random.choice(domain) for _ in range(6)) - object_name = f'{entry_id}/{filename}' - response['once_url'] = f'{APP_URL}{entry_id}/{quote(filename)}' + entry_id = "".join(random.choice(domain) for _ in range(6)) + object_name = f"{entry_id}/{filename}" + response["once_url"] = f"{APP_URL}{entry_id}/{quote(filename)}" - dynamodb = boto3.client('dynamodb') - dynamodb.put_item( - TableName=FILES_TABLE_NAME, - Item={ - 'id': {'S': entry_id}, - 'object_name': {'S': object_name} - }) + dynamodb = boto3.client("dynamodb") + dynamodb.put_item(TableName=FILES_TABLE_NAME, Item={"id": {"S": entry_id}, "object_name": {"S": object_name}}) - log.debug(f'Creating pre-signed post for {object_name} on ' - f'{FILES_BUCKET} (expiration={EXPIRATION_TIMEOUT})') + log.debug( + f"Creating pre-signed post for {object_name} on " f"{FILES_BUCKET} (expiration={EXPIRATION_TIMEOUT})" + ) presigned_post = create_presigned_post( - bucket_name=FILES_BUCKET, - object_name=object_name, - expiration=EXPIRATION_TIMEOUT) + bucket_name=FILES_BUCKET, object_name=object_name, expiration=EXPIRATION_TIMEOUT + ) - log.info(f'Authorized upload request for {object_name}') - log.debug(f'Presigned-Post response: {presigned_post}') - response['presigned_post'] = presigned_post + log.info(f"Authorized upload request for {object_name}") + log.debug(f"Presigned-Post response: {presigned_post}") + response["presigned_post"] = presigned_post except BadRequestError as e: response_code = 400 response = dict(message=str(e)) @@ -172,7 +160,7 @@ def on_event(event, context): response = dict(message=str(e)) finally: return { - 'statusCode': response_code, - 'headers': {'Content-Type': 'application/json'}, - 'body': json.dumps(response) + "statusCode": response_code, + "headers": {"Content-Type": "application/json"}, + "body": json.dumps(response), } diff --git a/once/once_stack.py b/once/once_stack.py index 49c87c9..9fa783d 100644 --- a/once/once_stack.py +++ b/once/once_stack.py @@ -2,9 +2,10 @@ import os from typing import Optional import jsii -from aws_cdk import( +from aws_cdk import ( core, aws_apigatewayv2 as apigw, + aws_apigatewayv2_integrations as integrations, aws_certificatemanager as certmgr, aws_cloudformation as cfn, aws_dynamodb as dynamodb, @@ -14,13 +15,14 @@ from aws_cdk import( aws_logs as logs, aws_route53 as route53, aws_route53_targets as route53_targets, - aws_s3 as s3) + aws_s3 as s3, +) from .utils import make_python_zip_bundle BASE_PATH = os.path.dirname(os.path.abspath(__file__)) -LOG_RETENTION = getattr(logs.RetentionDays, os.getenv('LOG_RETENTION', 'TWO_WEEKS')) +LOG_RETENTION = getattr(logs.RetentionDays, os.getenv("LOG_RETENTION", "TWO_WEEKS")) @jsii.implements(route53.IAliasRecordTarget) @@ -28,149 +30,184 @@ class ApiGatewayV2Domain(object): def __init__(self, domain_name: apigw.CfnDomainName): self.domain_name = domain_name - @jsii.member(jsii_name='bind') + @jsii.member(jsii_name="bind") def bind(self, _record: route53.IRecordSet) -> route53.AliasRecordTargetConfig: return { - 'dnsName': self.domain_name.get_att('RegionalDomainName').to_string(), - 'hostedZoneId': self.domain_name.get_att('RegionalHostedZoneId').to_string() + "dnsName": self.domain_name.get_att("RegionalDomainName").to_string(), + "hostedZoneId": self.domain_name.get_att("RegionalHostedZoneId").to_string(), } class CustomDomainStack(cfn.NestedStack): - def __init__(self, scope: core.Construct, id: str, + def __init__( + self, + scope: core.Construct, + id: str, hosted_zone_id: str, hosted_zone_name: str, domain_name: str, - api: apigw.HttpApi): + api: apigw.HttpApi, + ): super().__init__(scope, id) - hosted_zone = route53.HostedZone.from_hosted_zone_attributes(self, id='dns-hosted-zone', - hosted_zone_id=hosted_zone_id, - zone_name=hosted_zone_name) + hosted_zone = route53.HostedZone.from_hosted_zone_attributes( + self, id="dns-hosted-zone", hosted_zone_id=hosted_zone_id, zone_name=hosted_zone_name + ) - certificate = certmgr.DnsValidatedCertificate(self, 'tls-certificate', + certificate = certmgr.DnsValidatedCertificate( + self, + "tls-certificate", domain_name=domain_name, hosted_zone=hosted_zone, - validation_method=certmgr.ValidationMethod.DNS) + validation_method=certmgr.ValidationMethod.DNS, + ) - custom_domain = apigw.CfnDomainName(self, 'custom-domain', + custom_domain = apigw.CfnDomainName( + self, + "custom-domain", domain_name=domain_name, domain_name_configurations=[ - apigw.CfnDomainName.DomainNameConfigurationProperty( - certificate_arn=certificate.certificate_arn)]) + apigw.CfnDomainName.DomainNameConfigurationProperty(certificate_arn=certificate.certificate_arn) + ], + ) custom_domain.node.add_dependency(api) custom_domain.node.add_dependency(certificate) - api_mapping = apigw.CfnApiMapping(self, 'custom-domain-mapping', - api_id=api.http_api_id, - domain_name=domain_name, - stage='$default') + api_mapping = apigw.CfnApiMapping( + self, "custom-domain-mapping", api_id=api.http_api_id, domain_name=domain_name, stage="$default" + ) api_mapping.node.add_dependency(custom_domain) - route53.ARecord(self, 'custom-domain-record', + route53.ARecord( + self, + "custom-domain-record", target=route53.RecordTarget.from_alias(ApiGatewayV2Domain(custom_domain)), zone=hosted_zone, - record_name=domain_name) + record_name=domain_name, + ) class OnceStack(core.Stack): - def __init__(self, scope: core.Construct, id: str, - secret_key: str, - custom_domain: Optional[str] = None, - hosted_zone_id: Optional[str] = None, - hosted_zone_name: Optional[str] = None, - **kwargs) -> None: + def __init__( + self, + scope: core.Construct, + id: str, + secret_key: str, + custom_domain: Optional[str] = None, + hosted_zone_id: Optional[str] = None, + hosted_zone_name: Optional[str] = None, + **kwargs, + ) -> None: super().__init__(scope, id, **kwargs) - self.files_bucket = s3.Bucket(self, 'files-bucket', - bucket_name='once-shared-files', + self.files_bucket = s3.Bucket( + self, + "files-bucket", + bucket_name="once-shared-files", block_public_access=s3.BlockPublicAccess.BLOCK_ALL, encryption=s3.BucketEncryption.S3_MANAGED, - removal_policy=core.RemovalPolicy.DESTROY) + removal_policy=core.RemovalPolicy.DESTROY, + ) - self.files_table = dynamodb.Table(self, 'once-files-table', - table_name='once-files', - partition_key=dynamodb.Attribute(name='id', type=dynamodb.AttributeType.STRING), + self.files_table = dynamodb.Table( + self, + "once-files-table", + table_name="once-files", + partition_key=dynamodb.Attribute(name="id", type=dynamodb.AttributeType.STRING), billing_mode=dynamodb.BillingMode.PAY_PER_REQUEST, - removal_policy=core.RemovalPolicy.DESTROY) + removal_policy=core.RemovalPolicy.DESTROY, + ) - self.api = apigw.HttpApi(self, 'once-api', api_name='once-api') + self.api = apigw.HttpApi(self, "once-api", api_name="once-api") api_url = self.api.url if custom_domain is not None: - api_url = f'https://{custom_domain}/' + api_url = f"https://{custom_domain}/" - core.CfnOutput(self, 'base-url', value=api_url) + core.CfnOutput(self, "base-url", value=api_url) - self.get_upload_ticket_function = lambda_.Function(self, 'get-upload-ticket-function', - function_name='once-get-upload-ticket', - description='Returns a pre-signed request to share a file', + self.get_upload_ticket_function = lambda_.Function( + self, + "get-upload-ticket-function", + function_name="once-get-upload-ticket", + description="Returns a pre-signed request to share a file", runtime=lambda_.Runtime.PYTHON_3_7, - code=make_python_zip_bundle(os.path.join(BASE_PATH, 'get-upload-ticket')), - handler='handler.on_event', + code=make_python_zip_bundle(os.path.join(BASE_PATH, "get-upload-ticket")), + handler="handler.on_event", log_retention=LOG_RETENTION, environment={ - 'APP_URL': api_url, - 'FILES_TABLE_NAME': self.files_table.table_name, - 'FILES_BUCKET': self.files_bucket.bucket_name, - 'SECRET_KEY': secret_key - }) + "APP_URL": api_url, + "FILES_TABLE_NAME": self.files_table.table_name, + "FILES_BUCKET": self.files_bucket.bucket_name, + "SECRET_KEY": secret_key, + }, + ) self.files_bucket.grant_put(self.get_upload_ticket_function) self.files_table.grant_read_write_data(self.get_upload_ticket_function) - self.download_and_delete_function = lambda_.Function(self, 'download-and-delete-function', - function_name='once-download-and-delete', - description='Serves a file from S3 and deletes it as soon as it has been successfully transferred', + self.download_and_delete_function = lambda_.Function( + self, + "download-and-delete-function", + function_name="once-download-and-delete", + description="Serves a file from S3 and deletes it as soon as it has been successfully transferred", runtime=lambda_.Runtime.PYTHON_3_7, - code=lambda_.Code.from_asset(os.path.join(BASE_PATH, 'download-and-delete')), - handler='handler.on_event', + code=lambda_.Code.from_asset(os.path.join(BASE_PATH, "download-and-delete")), + handler="handler.on_event", log_retention=LOG_RETENTION, environment={ - 'FILES_BUCKET': self.files_bucket.bucket_name, - 'FILES_TABLE_NAME': self.files_table.table_name - }) + "FILES_BUCKET": self.files_bucket.bucket_name, + "FILES_TABLE_NAME": self.files_table.table_name, + }, + ) self.files_bucket.grant_read(self.download_and_delete_function) self.files_bucket.grant_delete(self.download_and_delete_function) self.files_table.grant_read_write_data(self.download_and_delete_function) - get_upload_ticket_integration = apigw.LambdaProxyIntegration(handler=self.get_upload_ticket_function) - self.api.add_routes( - path='/', - methods=[apigw.HttpMethod.GET], - integration=get_upload_ticket_integration) + get_upload_ticket_integration = integrations.LambdaProxyIntegration(handler=self.get_upload_ticket_function) + self.api.add_routes(path="/", methods=[apigw.HttpMethod.GET], integration=get_upload_ticket_integration) - download_and_delete_integration = apigw.LambdaProxyIntegration(handler=self.download_and_delete_function) + download_and_delete_integration = integrations.LambdaProxyIntegration( + handler=self.download_and_delete_function + ) self.api.add_routes( - path='/{entry_id}/{filename}', - methods=[apigw.HttpMethod.GET], - integration=download_and_delete_integration) + path="/{entry_id}/{filename}", methods=[apigw.HttpMethod.GET], integration=download_and_delete_integration + ) - self.cleanup_function = lambda_.Function(self, 'delete-served-files-function', - function_name='once-delete-served-files', - description='Deletes files from S3 once they have been marked as deleted in DynamoDB', + self.cleanup_function = lambda_.Function( + self, + "delete-served-files-function", + function_name="once-delete-served-files", + description="Deletes files from S3 once they have been marked as deleted in DynamoDB", runtime=lambda_.Runtime.PYTHON_3_7, - code=lambda_.Code.from_asset(os.path.join(BASE_PATH, 'delete-served-files')), - handler='handler.on_event', + code=lambda_.Code.from_asset(os.path.join(BASE_PATH, "delete-served-files")), + handler="handler.on_event", log_retention=LOG_RETENTION, environment={ - 'FILES_BUCKET': self.files_bucket.bucket_name, - 'FILES_TABLE_NAME': self.files_table.table_name - }) + "FILES_BUCKET": self.files_bucket.bucket_name, + "FILES_TABLE_NAME": self.files_table.table_name, + }, + ) self.files_bucket.grant_delete(self.cleanup_function) self.files_table.grant_read_write_data(self.cleanup_function) - events.Rule(self, 'once-delete-served-files-rule', + events.Rule( + self, + "once-delete-served-files-rule", schedule=events.Schedule.rate(core.Duration.hours(24)), - targets=[targets.LambdaFunction(self.cleanup_function)]) + targets=[targets.LambdaFunction(self.cleanup_function)], + ) if custom_domain is not None: - self.custom_domain_stack = CustomDomainStack(self, 'custom-domain', + self.custom_domain_stack = CustomDomainStack( + self, + "custom-domain", api=self.api, domain_name=custom_domain, hosted_zone_id=hosted_zone_id, - hosted_zone_name=hosted_zone_name) + hosted_zone_name=hosted_zone_name, + ) diff --git a/once/utils.py b/once/utils.py index 9db1c9c..1703bf8 100644 --- a/once/utils.py +++ b/once/utils.py @@ -12,22 +12,24 @@ from aws_cdk import aws_lambda as _lambda class MissingPrerequisiteCommand(Exception): - '''A required system command is missing''' + """A required system command is missing""" -def add_folder_to_zip(zip_obj: zipfile.ZipFile, folder: str, ignore_names: List[str] = [], ignore_dotfiles: bool = True): +def add_folder_to_zip( + zip_obj: zipfile.ZipFile, folder: str, ignore_names: List[str] = [], ignore_dotfiles: bool = True +): for root, dirs, files in os.walk(folder): if ignore_dotfiles: - dirs[:] = [d for d in dirs if not d.startswith('.')] - files[:] = [f for f in files if not f.startswith('.')] + dirs[:] = [d for d in dirs if not d.startswith(".")] + files[:] = [f for f in files if not f.startswith(".")] dirs[:] = [d for d in dirs if d not in ignore_names] files[:] = [f for f in files if f not in ignore_names] - logging.debug(f'FILES: {files}, DIRS: {dirs}') + logging.debug(f"FILES: {files}, DIRS: {dirs}") if root == folder: - archive_folder_name = '' + archive_folder_name = "" else: archive_folder_name = os.path.relpath(root, folder) zip_obj.write(root, arcname=archive_folder_name) @@ -38,86 +40,96 @@ def add_folder_to_zip(zip_obj: zipfile.ZipFile, folder: str, ignore_names: List[ zip_obj.write(f, arcname=d) -def execute_shell_command(command: Union[str, List[str]], - env: Union[Dict, None] = None) -> str: +def execute_shell_command(command: Union[str, List[str]], env: Union[Dict, None] = None) -> str: if isinstance(command, list): - command = ' '.join(command) + command = " ".join(command) - logging.debug(f'Executing command: {command}') + logging.debug(f"Executing command: {command}") - completed_process = subprocess.run(command, - env=env, - shell=True, - check=True, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE) + completed_process = subprocess.run( + command, env=env, shell=True, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE + ) logging.debug(completed_process) - return completed_process.stdout.strip().decode('utf-8') + return completed_process.stdout.strip().decode("utf-8") def locate_command(command: str) -> str: - path = execute_shell_command(['which', command]) + path = execute_shell_command(["which", command]) if path is None: raise MissingPrerequisiteCommand(f'Unable to find "{command}"') return path -def make_python_zip_bundle(input_path: str, - python_version: str = '3.7', - build_folder: str = '.build', - requirements_file: str = 'requirements.txt', - output_bundle_name: str = 'bundle.zip') -> _lambda.AssetCode: - ''' +def make_python_zip_bundle( + input_path: str, + python_version: str = "3.7", + build_folder: str = ".build", + requirements_file: str = "requirements.txt", + output_bundle_name: str = "bundle.zip", +) -> _lambda.AssetCode: + """ Builds an lambda AssetCode bundling python dependencies along with the code. The bundle is built using docker and the target lambda runtime image. - ''' + """ build_path = os.path.abspath(os.path.join(input_path, build_folder)) asset_path = os.path.join(build_path, output_bundle_name) # checks if it's required to build a new zip file if not os.path.exists(asset_path) or os.path.getmtime(asset_path) < get_folder_latest_mtime(input_path): - docker = locate_command('docker') - lambda_runtime_docker_image = f'lambci/lambda:build-python{python_version}' + docker = locate_command("docker") + lambda_runtime_docker_image = f"lambci/lambda:build-python{python_version}" # cleans the target folder - logging.debug(f'Cleaning folder: {build_path}') + logging.debug(f"Cleaning folder: {build_path}") shutil.rmtree(build_path, ignore_errors=True) # builds requirements using target runtime - build_log = execute_shell_command(command=[ - 'docker', 'run', '--rm', - '-v', f'{input_path}:/app', - '-w', '/app', - lambda_runtime_docker_image, - 'pip', 'install', - '-r', requirements_file, - '-t', build_folder]) + build_log = execute_shell_command( + command=[ + "docker", + "run", + "--rm", + "-v", + f"{input_path}:/app", + "-w", + "/app", + lambda_runtime_docker_image, + "pip", + "install", + "-r", + requirements_file, + "-t", + build_folder, + ] + ) logging.info(build_log) # creates the zip archive - logging.debug(f'Deleting file: {asset_path}') + logging.debug(f"Deleting file: {asset_path}") shutil.rmtree(asset_path, ignore_errors=True) - logging.debug(f'Creating bundle: {asset_path}') - with zipfile.ZipFile(asset_path, 'w', zipfile.ZIP_DEFLATED) as zip_obj: - add_folder_to_zip(zip_obj, input_path, ignore_names=[output_bundle_name, '__pycache__']) - add_folder_to_zip(zip_obj, build_path, ignore_names=[output_bundle_name, '__pycache__'], ignore_dotfiles=False) + logging.debug(f"Creating bundle: {asset_path}") + with zipfile.ZipFile(asset_path, "w", zipfile.ZIP_DEFLATED) as zip_obj: + add_folder_to_zip(zip_obj, input_path, ignore_names=[output_bundle_name, "__pycache__"]) + add_folder_to_zip( + zip_obj, build_path, ignore_names=[output_bundle_name, "__pycache__"], ignore_dotfiles=False + ) - logging.info(f'Lambda bundle created at {asset_path}') + logging.info(f"Lambda bundle created at {asset_path}") source_hash = get_folder_checksum(input_path) - logging.debug(f'Source folder hash {input_path} -> {source_hash}') + logging.debug(f"Source folder hash {input_path} -> {source_hash}") return _lambda.AssetCode.from_asset(asset_path, source_hash=source_hash) -def get_folder_checksum(path: str, ignore_dotfiles: bool = True, - chunk_size: int = 4096, - digest_method: hashlib._hashlib.HASH = hashlib.md5) -> str: +def get_folder_checksum( + path: str, ignore_dotfiles: bool = True, chunk_size: int = 4096, digest_method: hashlib._hashlib.HASH = hashlib.md5 +) -> str: def _hash_file(filename: str) -> bytes: - with open(filename, mode='rb', buffering=0) as fp: + with open(filename, mode="rb", buffering=0) as fp: hash_func = digest_method() buffer = fp.read(chunk_size) while len(buffer) > 0: @@ -125,10 +137,10 @@ def get_folder_checksum(path: str, ignore_dotfiles: bool = True, buffer = fp.read(chunk_size) return hash_func.digest() - folder_hash = b'' + folder_hash = b"" for root, dirs, files in os.walk(path): - files = [f for f in files if not f.startswith('.')] - dirs[:] = [d for d in dirs if not d.startswith('.')] + files = [f for f in files if not f.startswith(".")] + dirs[:] = [d for d in dirs if not d.startswith(".")] for file_name in sorted(files): file_path = os.path.join(root, file_name) @@ -142,8 +154,8 @@ def get_folder_latest_mtime(path: str, ignore_dotfiles: bool = True) -> float: latest_mtime = None for root, dirs, files in os.walk(path): if ignore_dotfiles: - files = [f for f in files if not f.startswith('.')] - dirs[:] = [d for d in dirs if not d.startswith('.')] + files = [f for f in files if not f.startswith(".")] + dirs[:] = [d for d in dirs if not d.startswith(".")] for file_name in files: file_path = os.path.join(root, file_name)