diff --git a/app.py b/app.py index 86cb289..7935f30 100644 --- a/app.py +++ b/app.py @@ -79,16 +79,15 @@ class SeparatorWorker(core.Stack): **kwargs): super().__init__(scope, id, **kwargs) - worker_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), 'worker-placeholder')) + worker_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), 'worker')) self.service = ecs_patterns.QueueProcessingFargateService(self, 'separator-service', cluster=cluster, - cpu=256, - memory_limit_mib=512, + cpu=2048, + memory_limit_mib=8192, image=ecs.ContainerImage.from_asset(directory=worker_dir), environment={ 'TRACKS_TABLE_NAME': tracks_table.table_name, - 'INPUT_BUCKET_NAME': input_bucket.bucket_name, 'OUTPUT_BUCKET_NAME': output_bucket.bucket_name }) diff --git a/worker/Dockerfile b/worker/Dockerfile new file mode 100644 index 0000000..3c9c8ca --- /dev/null +++ b/worker/Dockerfile @@ -0,0 +1,9 @@ +FROM researchdeezer/spleeter +ENV PYTHONDONTWRITEBYTECODE=1 + +WORKDIR /app +COPY . /app +RUN pip install --no-cache-dir -r requirements.txt +RUN python -m prefetch_models + +ENTRYPOINT ["python", "separator.py"] diff --git a/worker/conftest.py b/worker/conftest.py new file mode 100644 index 0000000..e69de29 diff --git a/worker/prefetch_models.py b/worker/prefetch_models.py new file mode 100644 index 0000000..7539c73 --- /dev/null +++ b/worker/prefetch_models.py @@ -0,0 +1,14 @@ +''' +Pre-fetches the required model from github. +''' + +import os +from spleeter.model.provider import get_default_model_provider + + +PREFETCH_MODELS = ['2stems'] + +for model_name in PREFETCH_MODELS: + get_default_model_provider().download( + model_name, + os.path.join(os.getenv('MODEL_PATH'), model_name)) diff --git a/worker/pytest.ini b/worker/pytest.ini new file mode 100644 index 0000000..9cc4abd --- /dev/null +++ b/worker/pytest.ini @@ -0,0 +1,4 @@ +[pytest] +filterwarnings = + ignore::FutureWarning + ignore::DeprecationWarning diff --git a/worker/requirements.txt b/worker/requirements.txt new file mode 100644 index 0000000..3f3536b --- /dev/null +++ b/worker/requirements.txt @@ -0,0 +1,3 @@ +boto3 +pytest +moto diff --git a/worker/separator.py b/worker/separator.py new file mode 100644 index 0000000..c2a565c --- /dev/null +++ b/worker/separator.py @@ -0,0 +1,109 @@ +import os +import logging +import json +import sys +from collections import namedtuple +from tempfile import NamedTemporaryFile, TemporaryDirectory +from threading import Timer +from urllib.parse import urlparse + +import boto3 +from spleeter.separator import Separator +from spleeter.audio.adapter import get_audio_adapter + + +SPLEETER_CONFIGURATION = os.getenv('SPLEETER_CONFIGURATION', 'spleeter:2stems') +OUTPUT_CODEC = os.getenv('OUTPUT_CODEC', 'mp3') +OUTPUT_BITRATE = os.getenv('OUTPUT_BITRATE', '128k') +MAX_AUDIO_DURATION = float(os.getenv('MAX_AUDIO_DURATION', 600.)) +AUDIO_START_OFFSET = float(os.getenv('AUDIO_START_OFFSET', 0.)) +OUTPUT_FILENAME_FORMAT = os.getenv('OUTPUT_FILENAME_FORMAT', '{instrument}.{codec}') +QUEUE_NAME = os.getenv('QUEUE_NAME') +POLLING_INTERVAL = int(os.getenv('POLLING_INTERVAL', 5)) +OUTPUT_BUCKET_NAME = os.getenv('OUTPUT_BUCKET_NAME') +TRACKS_TABLE_NAME = os.getenv('TRACKS_TABLE_NAME') +USE_MULTICHANNEL_WIENER_FILTERING = False + + +S3Entry = namedtuple('S3Entry', ['bucket_name', 'key']) + + +def parse_s3_url(s3_url: str) -> S3Entry: + o = urlparse(s3_url) + return S3Entry(o.netloc, o.path.lstrip('/')) + + +def fetch_separate_and_upload(input_s3_url, output_s3_url, s3=None): + if s3 is None: + s3 = boto3.resource("s3") + + with NamedTemporaryFile() as input_file, TemporaryDirectory() as output_path: + input_object = s3.Object(**parse_s3_url(input_s3_url)._asdict()) + input_object.download_file(input_file.name) + + audio_adapter = get_audio_adapter(None) + separator = Separator( + SPLEETER_CONFIGURATION, + MWF=USE_MULTICHANNEL_WIENER_FILTERING) + + separator.separate_to_file( + input_file.name, + output_path, + audio_adapter=audio_adapter, + offset=AUDIO_START_OFFSET, + duration=MAX_AUDIO_DURATION, + codec=OUTPUT_CODEC, + bitrate=OUTPUT_BITRATE, + filename_format=OUTPUT_FILENAME_FORMAT, + synchronous=True) + + logging.info(f'Uploading output to: {output_s3_url}') + output_object = s3.Object(**parse_s3_url(output_s3_url)._asdict()) + + output_filename = os.path.join(output_path, f'accompaniment.{OUTPUT_CODEC}') + output_object.upload_file(output_filename) + + +def poll_for_sqs_message(queue_name: str): + sqs = boto3.client('sqs') + queue_url = sqs.get_queue_url(QueueName=queue_name)['QueueUrl'] + response = sqs.receive_message(QueueUrl=queue_url) + + try: + messages = response['Messages'] + except KeyError: + logging.info('No messages in the queue') + messages = [] + + for message in messages: + body = json.loads(message['Body']) + job_id = body['job_id'] + output_s3_url = f"s3://{OUTPUT_BUCKET_NAME}/track_{job_id}.mp3" + logging.info(f'Start separating {job_id} -> {output_s3_url}') + + try: + fetch_separate_and_upload( + input_s3_url=body['input_s3_url'], + output_s3_url=output_s3_url) + + logging.info(f'Processing successful: {output_s3_url}') + + dynamodb = boto3.client('dynamodb') + dynamodb.update_item( + TableName=TRACKS_TABLE_NAME, + Key={'id': {'S': job_id}}, + AttributeUpdates={ + 'status': {'Value': {'S': 'successful'}}, + 'output_url': {'Value': {'S': output_s3_url}} + }) + except: + logging.exception('Processing failed for some reason') + finally: + sqs.delete_message(QueueUrl=queue_url, ReceiptHandle=message['ReceiptHandle']) + + Timer(5, poll_for_sqs_message, args=[queue_name]).start() + + +if __name__ == '__main__': + logging.basicConfig(level=logging.INFO) + poll_for_sqs_message(QUEUE_NAME) diff --git a/worker/tests/audio_example.mp3 b/worker/tests/audio_example.mp3 new file mode 100644 index 0000000..cc91788 Binary files /dev/null and b/worker/tests/audio_example.mp3 differ diff --git a/worker/tests/test_separator.py b/worker/tests/test_separator.py new file mode 100644 index 0000000..5b359a7 --- /dev/null +++ b/worker/tests/test_separator.py @@ -0,0 +1,39 @@ +import os + +import boto3 +import pytest +from moto import mock_s3 + +from separator import fetch_separate_and_upload, parse_s3_url + + +TEST_FILENAME = 'audio_example.mp3' + + +@pytest.fixture +def mocked_cloud(): + with mock_s3(): + s3 = boto3.resource('s3', region_name='eu-east-1') + input_bucket = s3.create_bucket(Bucket='input') + + test_file_local_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), TEST_FILENAME) + input_bucket.upload_file(Filename=test_file_local_path, Key=TEST_FILENAME) + + yield { + 's3': s3, + 'input_bucket': input_bucket, + 'output_bucket': s3.create_bucket(Bucket='output') + } + + +def test_fetch_separate_and_upload(mocked_cloud): + fetch_separate_and_upload('s3://input/audio_example.mp3', 's3://output/audio_example.mp3', s3=mocked_cloud['s3']) + assert [e.key for e in mocked_cloud['output_bucket'].objects.all()] == ['audio_example.mp3'] + + +@pytest.mark.parametrize('url,expected_result', [ + ('s3://bucket/key', ('bucket', 'key')), + ('s3://bucket/with/deep/nested/path.ext', ('bucket', 'with/deep/nested/path.ext')), +]) +def test_parse_s3_url(url, expected_result): + assert parse_s3_url(url) == expected_result