Replacing the placeholder worker with a real one

This commit is contained in:
2020-04-16 14:24:17 +02:00
parent 50a4c09c3c
commit ca0c55f052
9 changed files with 181 additions and 4 deletions

7
app.py
View File

@@ -79,16 +79,15 @@ class SeparatorWorker(core.Stack):
**kwargs):
super().__init__(scope, id, **kwargs)
worker_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), 'worker-placeholder'))
worker_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), 'worker'))
self.service = ecs_patterns.QueueProcessingFargateService(self, 'separator-service',
cluster=cluster,
cpu=256,
memory_limit_mib=512,
cpu=2048,
memory_limit_mib=8192,
image=ecs.ContainerImage.from_asset(directory=worker_dir),
environment={
'TRACKS_TABLE_NAME': tracks_table.table_name,
'INPUT_BUCKET_NAME': input_bucket.bucket_name,
'OUTPUT_BUCKET_NAME': output_bucket.bucket_name
})

9
worker/Dockerfile Normal file
View File

@@ -0,0 +1,9 @@
FROM researchdeezer/spleeter
ENV PYTHONDONTWRITEBYTECODE=1
WORKDIR /app
COPY . /app
RUN pip install --no-cache-dir -r requirements.txt
RUN python -m prefetch_models
ENTRYPOINT ["python", "separator.py"]

0
worker/conftest.py Normal file
View File

14
worker/prefetch_models.py Normal file
View File

@@ -0,0 +1,14 @@
'''
Pre-fetches the required model from github.
'''
import os
from spleeter.model.provider import get_default_model_provider
PREFETCH_MODELS = ['2stems']
for model_name in PREFETCH_MODELS:
get_default_model_provider().download(
model_name,
os.path.join(os.getenv('MODEL_PATH'), model_name))

4
worker/pytest.ini Normal file
View File

@@ -0,0 +1,4 @@
[pytest]
filterwarnings =
ignore::FutureWarning
ignore::DeprecationWarning

3
worker/requirements.txt Normal file
View File

@@ -0,0 +1,3 @@
boto3
pytest
moto

109
worker/separator.py Normal file
View File

@@ -0,0 +1,109 @@
import os
import logging
import json
import sys
from collections import namedtuple
from tempfile import NamedTemporaryFile, TemporaryDirectory
from threading import Timer
from urllib.parse import urlparse
import boto3
from spleeter.separator import Separator
from spleeter.audio.adapter import get_audio_adapter
SPLEETER_CONFIGURATION = os.getenv('SPLEETER_CONFIGURATION', 'spleeter:2stems')
OUTPUT_CODEC = os.getenv('OUTPUT_CODEC', 'mp3')
OUTPUT_BITRATE = os.getenv('OUTPUT_BITRATE', '128k')
MAX_AUDIO_DURATION = float(os.getenv('MAX_AUDIO_DURATION', 600.))
AUDIO_START_OFFSET = float(os.getenv('AUDIO_START_OFFSET', 0.))
OUTPUT_FILENAME_FORMAT = os.getenv('OUTPUT_FILENAME_FORMAT', '{instrument}.{codec}')
QUEUE_NAME = os.getenv('QUEUE_NAME')
POLLING_INTERVAL = int(os.getenv('POLLING_INTERVAL', 5))
OUTPUT_BUCKET_NAME = os.getenv('OUTPUT_BUCKET_NAME')
TRACKS_TABLE_NAME = os.getenv('TRACKS_TABLE_NAME')
USE_MULTICHANNEL_WIENER_FILTERING = False
S3Entry = namedtuple('S3Entry', ['bucket_name', 'key'])
def parse_s3_url(s3_url: str) -> S3Entry:
o = urlparse(s3_url)
return S3Entry(o.netloc, o.path.lstrip('/'))
def fetch_separate_and_upload(input_s3_url, output_s3_url, s3=None):
if s3 is None:
s3 = boto3.resource("s3")
with NamedTemporaryFile() as input_file, TemporaryDirectory() as output_path:
input_object = s3.Object(**parse_s3_url(input_s3_url)._asdict())
input_object.download_file(input_file.name)
audio_adapter = get_audio_adapter(None)
separator = Separator(
SPLEETER_CONFIGURATION,
MWF=USE_MULTICHANNEL_WIENER_FILTERING)
separator.separate_to_file(
input_file.name,
output_path,
audio_adapter=audio_adapter,
offset=AUDIO_START_OFFSET,
duration=MAX_AUDIO_DURATION,
codec=OUTPUT_CODEC,
bitrate=OUTPUT_BITRATE,
filename_format=OUTPUT_FILENAME_FORMAT,
synchronous=True)
logging.info(f'Uploading output to: {output_s3_url}')
output_object = s3.Object(**parse_s3_url(output_s3_url)._asdict())
output_filename = os.path.join(output_path, f'accompaniment.{OUTPUT_CODEC}')
output_object.upload_file(output_filename)
def poll_for_sqs_message(queue_name: str):
sqs = boto3.client('sqs')
queue_url = sqs.get_queue_url(QueueName=queue_name)['QueueUrl']
response = sqs.receive_message(QueueUrl=queue_url)
try:
messages = response['Messages']
except KeyError:
logging.info('No messages in the queue')
messages = []
for message in messages:
body = json.loads(message['Body'])
job_id = body['job_id']
output_s3_url = f"s3://{OUTPUT_BUCKET_NAME}/track_{job_id}.mp3"
logging.info(f'Start separating {job_id} -> {output_s3_url}')
try:
fetch_separate_and_upload(
input_s3_url=body['input_s3_url'],
output_s3_url=output_s3_url)
logging.info(f'Processing successful: {output_s3_url}')
dynamodb = boto3.client('dynamodb')
dynamodb.update_item(
TableName=TRACKS_TABLE_NAME,
Key={'id': {'S': job_id}},
AttributeUpdates={
'status': {'Value': {'S': 'successful'}},
'output_url': {'Value': {'S': output_s3_url}}
})
except:
logging.exception('Processing failed for some reason')
finally:
sqs.delete_message(QueueUrl=queue_url, ReceiptHandle=message['ReceiptHandle'])
Timer(5, poll_for_sqs_message, args=[queue_name]).start()
if __name__ == '__main__':
logging.basicConfig(level=logging.INFO)
poll_for_sqs_message(QUEUE_NAME)

Binary file not shown.

View File

@@ -0,0 +1,39 @@
import os
import boto3
import pytest
from moto import mock_s3
from separator import fetch_separate_and_upload, parse_s3_url
TEST_FILENAME = 'audio_example.mp3'
@pytest.fixture
def mocked_cloud():
with mock_s3():
s3 = boto3.resource('s3', region_name='eu-east-1')
input_bucket = s3.create_bucket(Bucket='input')
test_file_local_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), TEST_FILENAME)
input_bucket.upload_file(Filename=test_file_local_path, Key=TEST_FILENAME)
yield {
's3': s3,
'input_bucket': input_bucket,
'output_bucket': s3.create_bucket(Bucket='output')
}
def test_fetch_separate_and_upload(mocked_cloud):
fetch_separate_and_upload('s3://input/audio_example.mp3', 's3://output/audio_example.mp3', s3=mocked_cloud['s3'])
assert [e.key for e in mocked_cloud['output_bucket'].objects.all()] == ['audio_example.mp3']
@pytest.mark.parametrize('url,expected_result', [
('s3://bucket/key', ('bucket', 'key')),
('s3://bucket/with/deep/nested/path.ext', ('bucket', 'with/deep/nested/path.ext')),
])
def test_parse_s3_url(url, expected_result):
assert parse_s3_url(url) == expected_result