Reformatting sources with black

This commit is contained in:
2020-11-22 10:52:48 +01:00
parent 84b4ad5305
commit 8b5fbf3ddf
7 changed files with 361 additions and 329 deletions

51
app.py
View File

@@ -9,48 +9,51 @@ from aws_cdk import core
from once.once_stack import OnceStack, CustomDomainStack
ONCE_CONFIG_FILE = os.getenv('ONCE_CONFIG_FILE', os.path.expanduser('~/.once'))
ONCE_CONFIG_FILE = os.getenv("ONCE_CONFIG_FILE", os.path.expanduser("~/.once"))
SECRET_KEY = os.getenv('SECRET_KEY')
CUSTOM_DOMAIN = os.getenv('CUSTOM_DOMAIN')
HOSTED_ZONE_NAME = os.getenv('HOSTED_ZONE_NAME')
HOSTED_ZONE_ID = os.getenv('HOSTED_ZONE_ID')
SECRET_KEY = os.getenv("SECRET_KEY")
CUSTOM_DOMAIN = os.getenv("CUSTOM_DOMAIN")
HOSTED_ZONE_NAME = os.getenv("HOSTED_ZONE_NAME")
HOSTED_ZONE_ID = os.getenv("HOSTED_ZONE_ID")
def generate_random_key() -> str:
return base64.b64encode(os.urandom(128)).decode('utf-8')
return base64.b64encode(os.urandom(128)).decode("utf-8")
def generate_config(secret_key: Optional[str] = None,
def generate_config(
secret_key: Optional[str] = None,
custom_domain: str = None,
hosted_zone_name: str = None,
hosted_zone_id: str = None) -> configparser.ConfigParser:
hosted_zone_id: str = None,
) -> configparser.ConfigParser:
config = configparser.ConfigParser()
config['once'] = {
'secret_key': secret_key or generate_random_key(),
config["once"] = {
"secret_key": secret_key or generate_random_key(),
}
config['deployment'] = {}
config["deployment"] = {}
if all([custom_domain, hosted_zone_name, hosted_zone_id]):
config['once']['base_url'] = f'https://{custom_domain}'
config['deployment'] = {
'custom_domain': custom_domain,
'hosted_zone_name': hosted_zone_name,
'hosted_zone_id': hosted_zone_id
config["once"]["base_url"] = f"https://{custom_domain}"
config["deployment"] = {
"custom_domain": custom_domain,
"hosted_zone_name": hosted_zone_name,
"hosted_zone_id": hosted_zone_id,
}
return config
def get_config(config_gile: str = ONCE_CONFIG_FILE) -> configparser.ConfigParser:
if not os.path.exists(ONCE_CONFIG_FILE):
print(f'Generating configuration file at {ONCE_CONFIG_FILE}')
with open(ONCE_CONFIG_FILE, 'w') as config_file:
print(f"Generating configuration file at {ONCE_CONFIG_FILE}")
with open(ONCE_CONFIG_FILE, "w") as config_file:
config = generate_config(
secret_key=SECRET_KEY,
custom_domain=CUSTOM_DOMAIN,
hosted_zone_name=HOSTED_ZONE_NAME,
hosted_zone_id=HOSTED_ZONE_ID)
hosted_zone_id=HOSTED_ZONE_ID,
)
config.write(config_file)
else:
config = configparser.ConfigParser()
@@ -61,14 +64,14 @@ def get_config(config_gile: str = ONCE_CONFIG_FILE) -> configparser.ConfigParser
def main():
config = get_config()
kwargs = {'secret_key': config['once']['secret_key']}
if config.has_section('deployment'):
kwargs.update(config['deployment'])
kwargs = {"secret_key": config["once"]["secret_key"]}
if config.has_section("deployment"):
kwargs.update(config["deployment"])
app = core.App()
once = OnceStack(app, 'once', **kwargs)
once = OnceStack(app, "once", **kwargs)
app.synth()
if __name__ == '__main__':
if __name__ == "__main__":
main()

View File

@@ -1,6 +1,6 @@
'''
"""
Simple command to share one-time files
'''
"""
import os
import base64
@@ -17,9 +17,9 @@ import requests
from pygments import highlight, lexers, formatters
ONCE_CONFIG_FILE = os.getenv('ONCE_CONFIG_FILE', os.path.expanduser('~/.once'))
ONCE_SIGNATURE_HEADER = 'x-once-signature'
ONCE_TIMESTAMP_FORMAT = '%Y%m%d%H%M%S%f'
ONCE_CONFIG_FILE = os.getenv("ONCE_CONFIG_FILE", os.path.expanduser("~/.once"))
ONCE_SIGNATURE_HEADER = "x-once-signature"
ONCE_TIMESTAMP_FORMAT = "%Y%m%d%H%M%S%f"
def highlight_json(obj):
@@ -33,7 +33,7 @@ def echo_obj(obj):
def get_config(config_file: str = ONCE_CONFIG_FILE) -> configparser.ConfigParser:
if not os.path.exists(config_file):
raise ValueError(f'Config file not found at {config_file}')
raise ValueError(f"Config file not found at {config_file}")
config = configparser.ConfigParser()
config.read(ONCE_CONFIG_FILE)
return config
@@ -41,59 +41,57 @@ def get_config(config_file: str = ONCE_CONFIG_FILE) -> configparser.ConfigParser
def api_req(method: str, url: str, verbose: bool = False, **kwargs):
config = get_config()
if not config.has_option('once', 'base_url'):
raise ValueError(f'Configuration file at {ONCE_CONFIG_FILE} misses `base_url` option')
if not config.has_option("once", "base_url"):
raise ValueError(f"Configuration file at {ONCE_CONFIG_FILE} misses `base_url` option")
base_url = os.getenv('ONCE_API_URL', config['once']['base_url'])
secret_key = base64.b64decode(os.getenv('ONCE_SECRET_KEY', config['once']['secret_key']))
base_url = os.getenv("ONCE_API_URL", config["once"]["base_url"])
secret_key = base64.b64decode(os.getenv("ONCE_SECRET_KEY", config["once"]["secret_key"]))
method = method.lower()
if method not in ['get', 'post']:
if method not in ["get", "post"]:
raise ValueError(f'Unsupported HTTP method "{method}"')
actual_url = urljoin(base_url, url)
if verbose:
print(f'{method.upper()} {actual_url}')
print(f"{method.upper()} {actual_url}")
req = requests.Request(method=method, url=actual_url, **kwargs).prepare()
plain_text = req.path_url.encode('utf-8')
plain_text = req.path_url.encode("utf-8")
hmac_obj = hmac.new(secret_key, msg=plain_text, digestmod=hashlib.sha256)
req.headers[ONCE_SIGNATURE_HEADER] = base64.b64encode(hmac_obj.digest())
response = requests.Session().send(req)
if verbose:
print(f'Server response status: {response.status_code}')
print(f"Server response status: {response.status_code}")
echo_obj(response.json())
return response
@click.command('share')
@click.argument('file', type=click.File(mode='rb'), required=True)
@click.option('--verbose', '-v', is_flag=True, default=False, help='Enables verbose output.')
@click.command("share")
@click.argument("file", type=click.File(mode="rb"), required=True)
@click.option("--verbose", "-v", is_flag=True, default=False, help="Enables verbose output.")
def share(file: click.File, verbose: bool):
entry = api_req('GET', '/',
params={
'f': quote_plus(os.path.basename(file.name)),
't': datetime.utcnow().strftime(ONCE_TIMESTAMP_FORMAT)
},
verbose=verbose).json()
entry = api_req(
"GET",
"/",
params={"f": quote_plus(os.path.basename(file.name)), "t": datetime.utcnow().strftime(ONCE_TIMESTAMP_FORMAT)},
verbose=verbose,
).json()
once_url = entry['once_url']
upload_data = entry['presigned_post']
files = {'file': file}
once_url = entry["once_url"]
upload_data = entry["presigned_post"]
files = {"file": file}
upload_started = time.time()
response = requests.post(upload_data['url'],
data=upload_data['fields'],
files=files)
response = requests.post(upload_data["url"], data=upload_data["fields"], files=files)
upload_time = time.time() - upload_started
print(f"File uploaded in {upload_time}s")
print(f"File can be downloaded once at: {once_url}")
if __name__ == '__main__':
if __name__ == "__main__":
share()

View File

@@ -9,16 +9,16 @@ from boto3.dynamodb.conditions import Key
def is_debug_enabled() -> bool:
value = os.getenv('DEBUG', 'false').lower()
if value in ['false', '0']:
value = os.getenv("DEBUG", "false").lower()
if value in ["false", "0"]:
return False
else:
return bool(value)
DEBUG = is_debug_enabled()
FILES_BUCKET = os.getenv('FILES_BUCKET')
FILES_TABLE_NAME = os.getenv('FILES_TABLE_NAME')
FILES_BUCKET = os.getenv("FILES_BUCKET")
FILES_TABLE_NAME = os.getenv("FILES_TABLE_NAME")
log = logging.getLogger()
@@ -29,28 +29,27 @@ else:
def on_event(event, context):
log.debug(f'Event received: {event}')
log.debug(f'Context is: {context}')
log.debug(f'Debug mode is {DEBUG}')
log.debug(f"Event received: {event}")
log.debug(f"Context is: {context}")
log.debug(f"Debug mode is {DEBUG}")
log.debug(f'Files bucket is "{FILES_BUCKET}"')
dynamodb = boto3.client('dynamodb')
dynamodb = boto3.client("dynamodb")
response = dynamodb.scan(
TableName=FILES_TABLE_NAME,
Select='ALL_ATTRIBUTES',
FilterExpression='deleted = :deleted',
ExpressionAttributeValues={
':deleted': {'BOOL': True}
})
Select="ALL_ATTRIBUTES",
FilterExpression="deleted = :deleted",
ExpressionAttributeValues={":deleted": {"BOOL": True}},
)
s3 = boto3.client('s3')
for item in response['Items']:
object_name = item['object_name']['S']
log.info(f'Deleting file {object_name}')
s3 = boto3.client("s3")
for item in response["Items"]:
object_name = item["object_name"]["S"]
log.info(f"Deleting file {object_name}")
try:
s3.delete_object(Bucket=FILES_BUCKET, Key=object_name)
except:
log.exception('Could not delete file {object_name}')
log.exception("Could not delete file {object_name}")
response = dynamodb.delete_item(TableName=FILES_TABLE_NAME, Key={'id': item['id']})
log.debug(f'dynamodb delete item: {response}')
response = dynamodb.delete_item(TableName=FILES_TABLE_NAME, Key={"id": item["id"]})
log.debug(f"dynamodb delete item: {response}")

View File

@@ -9,31 +9,36 @@ import boto3
def is_debug_enabled() -> bool:
value = os.getenv('DEBUG', 'false').lower()
if value in ['false', '0']:
value = os.getenv("DEBUG", "false").lower()
if value in ["false", "0"]:
return False
else:
return bool(value)
DEBUG = is_debug_enabled()
FILES_BUCKET = os.getenv('FILES_BUCKET')
FILES_TABLE_NAME = os.getenv('FILES_TABLE_NAME')
PRESIGNED_URL_EXPIRES_IN = int(os.getenv('PRESIGNED_URL_EXPIRES_IN', 20))
MASKED_USER_AGENTS = os.getenv('MASKED_USER_AGENTS', ','.join([
'^Facebook.*',
'^Google.*',
'^Instagram.*',
'^LinkedIn.*',
'^Outlook.*',
'^Reddit.*',
'^Slack.*',
'^Skype.*',
'^SnapChat.*',
'^Telegram.*',
'^Twitter.*',
'^WhatsApp.*'
])).split(',')
FILES_BUCKET = os.getenv("FILES_BUCKET")
FILES_TABLE_NAME = os.getenv("FILES_TABLE_NAME")
PRESIGNED_URL_EXPIRES_IN = int(os.getenv("PRESIGNED_URL_EXPIRES_IN", 20))
MASKED_USER_AGENTS = os.getenv(
"MASKED_USER_AGENTS",
",".join(
[
"^Facebook.*",
"^Google.*",
"^Instagram.*",
"^LinkedIn.*",
"^Outlook.*",
"^Reddit.*",
"^Slack.*",
"^Skype.*",
"^SnapChat.*",
"^Telegram.*",
"^Twitter.*",
"^WhatsApp.*",
]
),
).split(",")
log = logging.getLogger()
@@ -44,55 +49,45 @@ else:
def on_event(event, context):
log.debug(f'Event received: {event}')
log.debug(f'Context is: {context}')
log.debug(f'Debug mode is {DEBUG}')
log.debug(f"Event received: {event}")
log.debug(f"Context is: {context}")
log.debug(f"Debug mode is {DEBUG}")
log.debug(f'Files bucket is "{FILES_BUCKET}"')
entry_id = event['pathParameters']['entry_id']
filename = urllib.parse.unquote_plus(event['pathParameters']['filename'])
object_name = f'{entry_id}/{filename}'
entry_id = event["pathParameters"]["entry_id"]
filename = urllib.parse.unquote_plus(event["pathParameters"]["filename"])
object_name = f"{entry_id}/{filename}"
dynamodb = boto3.client('dynamodb')
entry = dynamodb.get_item(
TableName=FILES_TABLE_NAME,
Key={'id': {'S': entry_id}})
dynamodb = boto3.client("dynamodb")
entry = dynamodb.get_item(TableName=FILES_TABLE_NAME, Key={"id": {"S": entry_id}})
log.debug(f'Matched Dynamodb entry: {entry}')
log.debug(f"Matched Dynamodb entry: {entry}")
if 'Item' not in entry or 'deleted' in entry['Item']:
error_message = f'Entry not found: {object_name}'
if "Item" not in entry or "deleted" in entry["Item"]:
error_message = f"Entry not found: {object_name}"
log.info(error_message)
return {'statusCode': 404, 'body': error_message}
return {"statusCode": 404, "body": error_message}
# Some rich clients try to get a preview of any link pasted
# into text controls.
user_agent = event['headers'].get('user-agent', '')
user_agent = event["headers"].get("user-agent", "")
is_masked_agent = any([re.match(agent, user_agent) for agent in MASKED_USER_AGENTS])
if is_masked_agent:
log.info('Serving possible link preview. Download prevented.')
return {
'statusCode': 200,
'headers': {}
}
log.info("Serving possible link preview. Download prevented.")
return {"statusCode": 200, "headers": {}}
s3 = boto3.client('s3')
s3 = boto3.client("s3")
download_url = s3.generate_presigned_url(
'get_object',
Params={'Bucket': FILES_BUCKET, 'Key': object_name},
ExpiresIn=PRESIGNED_URL_EXPIRES_IN)
"get_object", Params={"Bucket": FILES_BUCKET, "Key": object_name}, ExpiresIn=PRESIGNED_URL_EXPIRES_IN
)
dynamodb.update_item(
TableName=FILES_TABLE_NAME,
Key={'id': {'S': entry_id}},
UpdateExpression='SET deleted = :deleted',
ExpressionAttributeValues={':deleted': {'BOOL': True}})
Key={"id": {"S": entry_id}},
UpdateExpression="SET deleted = :deleted",
ExpressionAttributeValues={":deleted": {"BOOL": True}},
)
log.info(f'Entry {object_name} marked as deleted')
log.info(f"Entry {object_name} marked as deleted")
return {
'statusCode': 301,
'headers': {
'Location': download_url
}
}
return {"statusCode": 301, "headers": {"Location": download_url}}

View File

@@ -17,25 +17,25 @@ from botocore.exceptions import ClientError
def is_debug_enabled() -> bool:
value = os.getenv('DEBUG', 'false').lower()
if value in ['false', '0']:
value = os.getenv("DEBUG", "false").lower()
if value in ["false", "0"]:
return False
else:
return bool(value)
DEBUG = is_debug_enabled()
APP_URL = os.getenv('APP_URL')
EXPIRATION_TIMEOUT = int(os.getenv('EXPIRATION_TIMEOUT', 60*5))
FILES_BUCKET = os.getenv('FILES_BUCKET')
FILES_TABLE_NAME = os.getenv('FILES_TABLE_NAME')
S3_REGION_NAME = os.getenv('S3_REGION_NAME', 'eu-west-1')
S3_SIGNATURE_VERSION = os.getenv('S3_SIGNATURE_VERSION', 's3v4')
SECRET_KEY = base64.b64decode(os.getenv('SECRET_KEY'))
SIGNATURE_HEADER = os.getenv('SIGNATURE_HEADER', 'x-once-signature')
SIGNATURE_TIME_TOLERANCE = int(os.getenv('SIGNATURE_TIME_TOLERANCE', 5))
TIMESTAMP_FORMAT_STRING = os.getenv('TIMESTAMP_FORMAT_STRING', '%d%m%Y%H%M%S')
TIMESTAMP_PARAMETER_FORMAT = '%Y%m%d%H%M%S%f'
APP_URL = os.getenv("APP_URL")
EXPIRATION_TIMEOUT = int(os.getenv("EXPIRATION_TIMEOUT", 60 * 5))
FILES_BUCKET = os.getenv("FILES_BUCKET")
FILES_TABLE_NAME = os.getenv("FILES_TABLE_NAME")
S3_REGION_NAME = os.getenv("S3_REGION_NAME", "eu-west-1")
S3_SIGNATURE_VERSION = os.getenv("S3_SIGNATURE_VERSION", "s3v4")
SECRET_KEY = base64.b64decode(os.getenv("SECRET_KEY"))
SIGNATURE_HEADER = os.getenv("SIGNATURE_HEADER", "x-once-signature")
SIGNATURE_TIME_TOLERANCE = int(os.getenv("SIGNATURE_TIME_TOLERANCE", 5))
TIMESTAMP_FORMAT_STRING = os.getenv("TIMESTAMP_FORMAT_STRING", "%d%m%Y%H%M%S")
TIMESTAMP_PARAMETER_FORMAT = "%Y%m%d%H%M%S%f"
log = logging.getLogger()
@@ -53,39 +53,32 @@ class UnauthorizedError(Exception):
pass
def create_presigned_post(bucket_name: str, object_name: str,
fields=None, conditions=None, expiration=3600) -> Dict:
'''
def create_presigned_post(bucket_name: str, object_name: str, fields=None, conditions=None, expiration=3600) -> Dict:
"""
Generate a presigned URL S3 POST request to upload a file
'''
s3_client = boto3.client('s3',
region_name=S3_REGION_NAME,
config=Config(signature_version=S3_SIGNATURE_VERSION))
"""
s3_client = boto3.client("s3", region_name=S3_REGION_NAME, config=Config(signature_version=S3_SIGNATURE_VERSION))
return s3_client.generate_presigned_post(
bucket_name, object_name,
Fields=fields,
Conditions=conditions,
ExpiresIn=expiration)
bucket_name, object_name, Fields=fields, Conditions=conditions, ExpiresIn=expiration
)
def validate_signature(event: Dict, secret_key: bytes) -> bool:
canonicalized_url = event['rawPath']
if 'queryStringParameters' in event:
qs = urlencode(event['queryStringParameters'], quote_via=quote_plus)
canonicalized_url = f'{canonicalized_url}?{qs}'
canonicalized_url = event["rawPath"]
if "queryStringParameters" in event:
qs = urlencode(event["queryStringParameters"], quote_via=quote_plus)
canonicalized_url = f"{canonicalized_url}?{qs}"
plain_text = canonicalized_url.encode('utf-8')
log.debug(f'Plain text: {plain_text}')
plain_text = canonicalized_url.encode("utf-8")
log.debug(f"Plain text: {plain_text}")
encoded_signature = event['headers'][SIGNATURE_HEADER]
log.debug(f'Received signature: {encoded_signature}')
encoded_signature = event["headers"][SIGNATURE_HEADER]
log.debug(f"Received signature: {encoded_signature}")
signature_value = base64.b64decode(encoded_signature)
hmac_obj = hmac.new(secret_key,
msg=plain_text,
digestmod=hashlib.sha256)
hmac_obj = hmac.new(secret_key, msg=plain_text, digestmod=hashlib.sha256)
calculated_signature = hmac_obj.digest()
return calculated_signature == signature_value
@@ -99,68 +92,63 @@ def validate_timestamp(timestamp: str, current_time: datetime=None) -> bool:
file_loading_time = datetime.strptime(timestamp, TIMESTAMP_PARAMETER_FORMAT)
return current_time - file_loading_time <= timedelta(seconds=SIGNATURE_TIME_TOLERANCE)
except:
log.error(f'Could not validate timestamp {timestamp} according to the format: {TIMESTAMP_PARAMETER_FORMAT}')
log.error(f"Could not validate timestamp {timestamp} according to the format: {TIMESTAMP_PARAMETER_FORMAT}")
return False
def on_event(event, context):
log.debug(f'Event received: {event}')
log.debug(f'Context is: {context}')
log.debug(f'Requests library version: {requests.__version__}')
log.debug(f"Event received: {event}")
log.debug(f"Context is: {context}")
log.debug(f"Requests library version: {requests.__version__}")
log.debug(f'Debug mode is {DEBUG}')
log.debug(f"Debug mode is {DEBUG}")
log.debug(f'App URL is "{APP_URL}"')
log.debug(f'Files bucket is "{FILES_BUCKET}"')
log.debug(f'Files Dynamodb table name is "{FILES_TABLE_NAME}"')
log.debug(f'S3 region name is: "{S3_REGION_NAME}"')
log.debug(f'S3 signature algorithm version is "{S3_SIGNATURE_VERSION}"')
log.debug(f'Pre-signed urls will expire after {EXPIRATION_TIMEOUT} seconds')
log.debug(f"Pre-signed urls will expire after {EXPIRATION_TIMEOUT} seconds")
q = event.get('queryStringParameters', {})
filename = unquote_plus(q.get('f'))
timestamp = unquote_plus(q.get('t'))
q = event.get("queryStringParameters", {})
filename = unquote_plus(q.get("f"))
timestamp = unquote_plus(q.get("t"))
response_code = 200
response = {}
try:
if filename is None:
raise BadRequestError('Provide a valid value for the `f` query parameter')
raise BadRequestError("Provide a valid value for the `f` query parameter")
if timestamp is None:
raise BadRequestError('Please provide a valid value for the `t` query parameter')
raise BadRequestError("Please provide a valid value for the `t` query parameter")
if not validate_timestamp(timestamp):
log.error('Request timestamp is not valid')
raise UnauthorizedError('Your request cannot be authorized')
log.error("Request timestamp is not valid")
raise UnauthorizedError("Your request cannot be authorized")
if not validate_signature(event, SECRET_KEY):
log.error('Request signature is not valid')
raise UnauthorizedError('Your request cannot be authorized')
log.error("Request signature is not valid")
raise UnauthorizedError("Your request cannot be authorized")
domain = string.ascii_uppercase + string.ascii_lowercase + string.digits
entry_id = ''.join(random.choice(domain) for _ in range(6))
object_name = f'{entry_id}/{filename}'
response['once_url'] = f'{APP_URL}{entry_id}/{quote(filename)}'
entry_id = "".join(random.choice(domain) for _ in range(6))
object_name = f"{entry_id}/{filename}"
response["once_url"] = f"{APP_URL}{entry_id}/{quote(filename)}"
dynamodb = boto3.client('dynamodb')
dynamodb.put_item(
TableName=FILES_TABLE_NAME,
Item={
'id': {'S': entry_id},
'object_name': {'S': object_name}
})
dynamodb = boto3.client("dynamodb")
dynamodb.put_item(TableName=FILES_TABLE_NAME, Item={"id": {"S": entry_id}, "object_name": {"S": object_name}})
log.debug(f'Creating pre-signed post for {object_name} on '
f'{FILES_BUCKET} (expiration={EXPIRATION_TIMEOUT})')
log.debug(
f"Creating pre-signed post for {object_name} on " f"{FILES_BUCKET} (expiration={EXPIRATION_TIMEOUT})"
)
presigned_post = create_presigned_post(
bucket_name=FILES_BUCKET,
object_name=object_name,
expiration=EXPIRATION_TIMEOUT)
bucket_name=FILES_BUCKET, object_name=object_name, expiration=EXPIRATION_TIMEOUT
)
log.info(f'Authorized upload request for {object_name}')
log.debug(f'Presigned-Post response: {presigned_post}')
response['presigned_post'] = presigned_post
log.info(f"Authorized upload request for {object_name}")
log.debug(f"Presigned-Post response: {presigned_post}")
response["presigned_post"] = presigned_post
except BadRequestError as e:
response_code = 400
response = dict(message=str(e))
@@ -172,7 +160,7 @@ def on_event(event, context):
response = dict(message=str(e))
finally:
return {
'statusCode': response_code,
'headers': {'Content-Type': 'application/json'},
'body': json.dumps(response)
"statusCode": response_code,
"headers": {"Content-Type": "application/json"},
"body": json.dumps(response),
}

View File

@@ -5,6 +5,7 @@ import jsii
from aws_cdk import (
core,
aws_apigatewayv2 as apigw,
aws_apigatewayv2_integrations as integrations,
aws_certificatemanager as certmgr,
aws_cloudformation as cfn,
aws_dynamodb as dynamodb,
@@ -14,13 +15,14 @@ from aws_cdk import(
aws_logs as logs,
aws_route53 as route53,
aws_route53_targets as route53_targets,
aws_s3 as s3)
aws_s3 as s3,
)
from .utils import make_python_zip_bundle
BASE_PATH = os.path.dirname(os.path.abspath(__file__))
LOG_RETENTION = getattr(logs.RetentionDays, os.getenv('LOG_RETENTION', 'TWO_WEEKS'))
LOG_RETENTION = getattr(logs.RetentionDays, os.getenv("LOG_RETENTION", "TWO_WEEKS"))
@jsii.implements(route53.IAliasRecordTarget)
@@ -28,149 +30,184 @@ class ApiGatewayV2Domain(object):
def __init__(self, domain_name: apigw.CfnDomainName):
self.domain_name = domain_name
@jsii.member(jsii_name='bind')
@jsii.member(jsii_name="bind")
def bind(self, _record: route53.IRecordSet) -> route53.AliasRecordTargetConfig:
return {
'dnsName': self.domain_name.get_att('RegionalDomainName').to_string(),
'hostedZoneId': self.domain_name.get_att('RegionalHostedZoneId').to_string()
"dnsName": self.domain_name.get_att("RegionalDomainName").to_string(),
"hostedZoneId": self.domain_name.get_att("RegionalHostedZoneId").to_string(),
}
class CustomDomainStack(cfn.NestedStack):
def __init__(self, scope: core.Construct, id: str,
def __init__(
self,
scope: core.Construct,
id: str,
hosted_zone_id: str,
hosted_zone_name: str,
domain_name: str,
api: apigw.HttpApi):
api: apigw.HttpApi,
):
super().__init__(scope, id)
hosted_zone = route53.HostedZone.from_hosted_zone_attributes(self, id='dns-hosted-zone',
hosted_zone_id=hosted_zone_id,
zone_name=hosted_zone_name)
hosted_zone = route53.HostedZone.from_hosted_zone_attributes(
self, id="dns-hosted-zone", hosted_zone_id=hosted_zone_id, zone_name=hosted_zone_name
)
certificate = certmgr.DnsValidatedCertificate(self, 'tls-certificate',
certificate = certmgr.DnsValidatedCertificate(
self,
"tls-certificate",
domain_name=domain_name,
hosted_zone=hosted_zone,
validation_method=certmgr.ValidationMethod.DNS)
validation_method=certmgr.ValidationMethod.DNS,
)
custom_domain = apigw.CfnDomainName(self, 'custom-domain',
custom_domain = apigw.CfnDomainName(
self,
"custom-domain",
domain_name=domain_name,
domain_name_configurations=[
apigw.CfnDomainName.DomainNameConfigurationProperty(
certificate_arn=certificate.certificate_arn)])
apigw.CfnDomainName.DomainNameConfigurationProperty(certificate_arn=certificate.certificate_arn)
],
)
custom_domain.node.add_dependency(api)
custom_domain.node.add_dependency(certificate)
api_mapping = apigw.CfnApiMapping(self, 'custom-domain-mapping',
api_id=api.http_api_id,
domain_name=domain_name,
stage='$default')
api_mapping = apigw.CfnApiMapping(
self, "custom-domain-mapping", api_id=api.http_api_id, domain_name=domain_name, stage="$default"
)
api_mapping.node.add_dependency(custom_domain)
route53.ARecord(self, 'custom-domain-record',
route53.ARecord(
self,
"custom-domain-record",
target=route53.RecordTarget.from_alias(ApiGatewayV2Domain(custom_domain)),
zone=hosted_zone,
record_name=domain_name)
record_name=domain_name,
)
class OnceStack(core.Stack):
def __init__(self, scope: core.Construct, id: str,
def __init__(
self,
scope: core.Construct,
id: str,
secret_key: str,
custom_domain: Optional[str] = None,
hosted_zone_id: Optional[str] = None,
hosted_zone_name: Optional[str] = None,
**kwargs) -> None:
**kwargs,
) -> None:
super().__init__(scope, id, **kwargs)
self.files_bucket = s3.Bucket(self, 'files-bucket',
bucket_name='once-shared-files',
self.files_bucket = s3.Bucket(
self,
"files-bucket",
bucket_name="once-shared-files",
block_public_access=s3.BlockPublicAccess.BLOCK_ALL,
encryption=s3.BucketEncryption.S3_MANAGED,
removal_policy=core.RemovalPolicy.DESTROY)
removal_policy=core.RemovalPolicy.DESTROY,
)
self.files_table = dynamodb.Table(self, 'once-files-table',
table_name='once-files',
partition_key=dynamodb.Attribute(name='id', type=dynamodb.AttributeType.STRING),
self.files_table = dynamodb.Table(
self,
"once-files-table",
table_name="once-files",
partition_key=dynamodb.Attribute(name="id", type=dynamodb.AttributeType.STRING),
billing_mode=dynamodb.BillingMode.PAY_PER_REQUEST,
removal_policy=core.RemovalPolicy.DESTROY)
removal_policy=core.RemovalPolicy.DESTROY,
)
self.api = apigw.HttpApi(self, 'once-api', api_name='once-api')
self.api = apigw.HttpApi(self, "once-api", api_name="once-api")
api_url = self.api.url
if custom_domain is not None:
api_url = f'https://{custom_domain}/'
api_url = f"https://{custom_domain}/"
core.CfnOutput(self, 'base-url', value=api_url)
core.CfnOutput(self, "base-url", value=api_url)
self.get_upload_ticket_function = lambda_.Function(self, 'get-upload-ticket-function',
function_name='once-get-upload-ticket',
description='Returns a pre-signed request to share a file',
self.get_upload_ticket_function = lambda_.Function(
self,
"get-upload-ticket-function",
function_name="once-get-upload-ticket",
description="Returns a pre-signed request to share a file",
runtime=lambda_.Runtime.PYTHON_3_7,
code=make_python_zip_bundle(os.path.join(BASE_PATH, 'get-upload-ticket')),
handler='handler.on_event',
code=make_python_zip_bundle(os.path.join(BASE_PATH, "get-upload-ticket")),
handler="handler.on_event",
log_retention=LOG_RETENTION,
environment={
'APP_URL': api_url,
'FILES_TABLE_NAME': self.files_table.table_name,
'FILES_BUCKET': self.files_bucket.bucket_name,
'SECRET_KEY': secret_key
})
"APP_URL": api_url,
"FILES_TABLE_NAME": self.files_table.table_name,
"FILES_BUCKET": self.files_bucket.bucket_name,
"SECRET_KEY": secret_key,
},
)
self.files_bucket.grant_put(self.get_upload_ticket_function)
self.files_table.grant_read_write_data(self.get_upload_ticket_function)
self.download_and_delete_function = lambda_.Function(self, 'download-and-delete-function',
function_name='once-download-and-delete',
description='Serves a file from S3 and deletes it as soon as it has been successfully transferred',
self.download_and_delete_function = lambda_.Function(
self,
"download-and-delete-function",
function_name="once-download-and-delete",
description="Serves a file from S3 and deletes it as soon as it has been successfully transferred",
runtime=lambda_.Runtime.PYTHON_3_7,
code=lambda_.Code.from_asset(os.path.join(BASE_PATH, 'download-and-delete')),
handler='handler.on_event',
code=lambda_.Code.from_asset(os.path.join(BASE_PATH, "download-and-delete")),
handler="handler.on_event",
log_retention=LOG_RETENTION,
environment={
'FILES_BUCKET': self.files_bucket.bucket_name,
'FILES_TABLE_NAME': self.files_table.table_name
})
"FILES_BUCKET": self.files_bucket.bucket_name,
"FILES_TABLE_NAME": self.files_table.table_name,
},
)
self.files_bucket.grant_read(self.download_and_delete_function)
self.files_bucket.grant_delete(self.download_and_delete_function)
self.files_table.grant_read_write_data(self.download_and_delete_function)
get_upload_ticket_integration = apigw.LambdaProxyIntegration(handler=self.get_upload_ticket_function)
self.api.add_routes(
path='/',
methods=[apigw.HttpMethod.GET],
integration=get_upload_ticket_integration)
get_upload_ticket_integration = integrations.LambdaProxyIntegration(handler=self.get_upload_ticket_function)
self.api.add_routes(path="/", methods=[apigw.HttpMethod.GET], integration=get_upload_ticket_integration)
download_and_delete_integration = apigw.LambdaProxyIntegration(handler=self.download_and_delete_function)
download_and_delete_integration = integrations.LambdaProxyIntegration(
handler=self.download_and_delete_function
)
self.api.add_routes(
path='/{entry_id}/{filename}',
methods=[apigw.HttpMethod.GET],
integration=download_and_delete_integration)
path="/{entry_id}/{filename}", methods=[apigw.HttpMethod.GET], integration=download_and_delete_integration
)
self.cleanup_function = lambda_.Function(self, 'delete-served-files-function',
function_name='once-delete-served-files',
description='Deletes files from S3 once they have been marked as deleted in DynamoDB',
self.cleanup_function = lambda_.Function(
self,
"delete-served-files-function",
function_name="once-delete-served-files",
description="Deletes files from S3 once they have been marked as deleted in DynamoDB",
runtime=lambda_.Runtime.PYTHON_3_7,
code=lambda_.Code.from_asset(os.path.join(BASE_PATH, 'delete-served-files')),
handler='handler.on_event',
code=lambda_.Code.from_asset(os.path.join(BASE_PATH, "delete-served-files")),
handler="handler.on_event",
log_retention=LOG_RETENTION,
environment={
'FILES_BUCKET': self.files_bucket.bucket_name,
'FILES_TABLE_NAME': self.files_table.table_name
})
"FILES_BUCKET": self.files_bucket.bucket_name,
"FILES_TABLE_NAME": self.files_table.table_name,
},
)
self.files_bucket.grant_delete(self.cleanup_function)
self.files_table.grant_read_write_data(self.cleanup_function)
events.Rule(self, 'once-delete-served-files-rule',
events.Rule(
self,
"once-delete-served-files-rule",
schedule=events.Schedule.rate(core.Duration.hours(24)),
targets=[targets.LambdaFunction(self.cleanup_function)])
targets=[targets.LambdaFunction(self.cleanup_function)],
)
if custom_domain is not None:
self.custom_domain_stack = CustomDomainStack(self, 'custom-domain',
self.custom_domain_stack = CustomDomainStack(
self,
"custom-domain",
api=self.api,
domain_name=custom_domain,
hosted_zone_id=hosted_zone_id,
hosted_zone_name=hosted_zone_name)
hosted_zone_name=hosted_zone_name,
)

View File

@@ -12,22 +12,24 @@ from aws_cdk import aws_lambda as _lambda
class MissingPrerequisiteCommand(Exception):
'''A required system command is missing'''
"""A required system command is missing"""
def add_folder_to_zip(zip_obj: zipfile.ZipFile, folder: str, ignore_names: List[str] = [], ignore_dotfiles: bool = True):
def add_folder_to_zip(
zip_obj: zipfile.ZipFile, folder: str, ignore_names: List[str] = [], ignore_dotfiles: bool = True
):
for root, dirs, files in os.walk(folder):
if ignore_dotfiles:
dirs[:] = [d for d in dirs if not d.startswith('.')]
files[:] = [f for f in files if not f.startswith('.')]
dirs[:] = [d for d in dirs if not d.startswith(".")]
files[:] = [f for f in files if not f.startswith(".")]
dirs[:] = [d for d in dirs if d not in ignore_names]
files[:] = [f for f in files if f not in ignore_names]
logging.debug(f'FILES: {files}, DIRS: {dirs}')
logging.debug(f"FILES: {files}, DIRS: {dirs}")
if root == folder:
archive_folder_name = ''
archive_folder_name = ""
else:
archive_folder_name = os.path.relpath(root, folder)
zip_obj.write(root, arcname=archive_folder_name)
@@ -38,86 +40,96 @@ def add_folder_to_zip(zip_obj: zipfile.ZipFile, folder: str, ignore_names: List[
zip_obj.write(f, arcname=d)
def execute_shell_command(command: Union[str, List[str]],
env: Union[Dict, None] = None) -> str:
def execute_shell_command(command: Union[str, List[str]], env: Union[Dict, None] = None) -> str:
if isinstance(command, list):
command = ' '.join(command)
command = " ".join(command)
logging.debug(f'Executing command: {command}')
logging.debug(f"Executing command: {command}")
completed_process = subprocess.run(command,
env=env,
shell=True,
check=True,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE)
completed_process = subprocess.run(
command, env=env, shell=True, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE
)
logging.debug(completed_process)
return completed_process.stdout.strip().decode('utf-8')
return completed_process.stdout.strip().decode("utf-8")
def locate_command(command: str) -> str:
path = execute_shell_command(['which', command])
path = execute_shell_command(["which", command])
if path is None:
raise MissingPrerequisiteCommand(f'Unable to find "{command}"')
return path
def make_python_zip_bundle(input_path: str,
python_version: str = '3.7',
build_folder: str = '.build',
requirements_file: str = 'requirements.txt',
output_bundle_name: str = 'bundle.zip') -> _lambda.AssetCode:
'''
def make_python_zip_bundle(
input_path: str,
python_version: str = "3.7",
build_folder: str = ".build",
requirements_file: str = "requirements.txt",
output_bundle_name: str = "bundle.zip",
) -> _lambda.AssetCode:
"""
Builds an lambda AssetCode bundling python dependencies along with the code.
The bundle is built using docker and the target lambda runtime image.
'''
"""
build_path = os.path.abspath(os.path.join(input_path, build_folder))
asset_path = os.path.join(build_path, output_bundle_name)
# checks if it's required to build a new zip file
if not os.path.exists(asset_path) or os.path.getmtime(asset_path) < get_folder_latest_mtime(input_path):
docker = locate_command('docker')
lambda_runtime_docker_image = f'lambci/lambda:build-python{python_version}'
docker = locate_command("docker")
lambda_runtime_docker_image = f"lambci/lambda:build-python{python_version}"
# cleans the target folder
logging.debug(f'Cleaning folder: {build_path}')
logging.debug(f"Cleaning folder: {build_path}")
shutil.rmtree(build_path, ignore_errors=True)
# builds requirements using target runtime
build_log = execute_shell_command(command=[
'docker', 'run', '--rm',
'-v', f'{input_path}:/app',
'-w', '/app',
build_log = execute_shell_command(
command=[
"docker",
"run",
"--rm",
"-v",
f"{input_path}:/app",
"-w",
"/app",
lambda_runtime_docker_image,
'pip', 'install',
'-r', requirements_file,
'-t', build_folder])
"pip",
"install",
"-r",
requirements_file,
"-t",
build_folder,
]
)
logging.info(build_log)
# creates the zip archive
logging.debug(f'Deleting file: {asset_path}')
logging.debug(f"Deleting file: {asset_path}")
shutil.rmtree(asset_path, ignore_errors=True)
logging.debug(f'Creating bundle: {asset_path}')
with zipfile.ZipFile(asset_path, 'w', zipfile.ZIP_DEFLATED) as zip_obj:
add_folder_to_zip(zip_obj, input_path, ignore_names=[output_bundle_name, '__pycache__'])
add_folder_to_zip(zip_obj, build_path, ignore_names=[output_bundle_name, '__pycache__'], ignore_dotfiles=False)
logging.debug(f"Creating bundle: {asset_path}")
with zipfile.ZipFile(asset_path, "w", zipfile.ZIP_DEFLATED) as zip_obj:
add_folder_to_zip(zip_obj, input_path, ignore_names=[output_bundle_name, "__pycache__"])
add_folder_to_zip(
zip_obj, build_path, ignore_names=[output_bundle_name, "__pycache__"], ignore_dotfiles=False
)
logging.info(f'Lambda bundle created at {asset_path}')
logging.info(f"Lambda bundle created at {asset_path}")
source_hash = get_folder_checksum(input_path)
logging.debug(f'Source folder hash {input_path} -> {source_hash}')
logging.debug(f"Source folder hash {input_path} -> {source_hash}")
return _lambda.AssetCode.from_asset(asset_path, source_hash=source_hash)
def get_folder_checksum(path: str, ignore_dotfiles: bool = True,
chunk_size: int = 4096,
digest_method: hashlib._hashlib.HASH = hashlib.md5) -> str:
def get_folder_checksum(
path: str, ignore_dotfiles: bool = True, chunk_size: int = 4096, digest_method: hashlib._hashlib.HASH = hashlib.md5
) -> str:
def _hash_file(filename: str) -> bytes:
with open(filename, mode='rb', buffering=0) as fp:
with open(filename, mode="rb", buffering=0) as fp:
hash_func = digest_method()
buffer = fp.read(chunk_size)
while len(buffer) > 0:
@@ -125,10 +137,10 @@ def get_folder_checksum(path: str, ignore_dotfiles: bool = True,
buffer = fp.read(chunk_size)
return hash_func.digest()
folder_hash = b''
folder_hash = b""
for root, dirs, files in os.walk(path):
files = [f for f in files if not f.startswith('.')]
dirs[:] = [d for d in dirs if not d.startswith('.')]
files = [f for f in files if not f.startswith(".")]
dirs[:] = [d for d in dirs if not d.startswith(".")]
for file_name in sorted(files):
file_path = os.path.join(root, file_name)
@@ -142,8 +154,8 @@ def get_folder_latest_mtime(path: str, ignore_dotfiles: bool = True) -> float:
latest_mtime = None
for root, dirs, files in os.walk(path):
if ignore_dotfiles:
files = [f for f in files if not f.startswith('.')]
dirs[:] = [d for d in dirs if not d.startswith('.')]
files = [f for f in files if not f.startswith(".")]
dirs[:] = [d for d in dirs if not d.startswith(".")]
for file_name in files:
file_path = os.path.join(root, file_name)