Replacing the placeholder worker with a real one
This commit is contained in:
7
app.py
7
app.py
@@ -79,16 +79,15 @@ class SeparatorWorker(core.Stack):
|
|||||||
**kwargs):
|
**kwargs):
|
||||||
super().__init__(scope, id, **kwargs)
|
super().__init__(scope, id, **kwargs)
|
||||||
|
|
||||||
worker_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), 'worker-placeholder'))
|
worker_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), 'worker'))
|
||||||
|
|
||||||
self.service = ecs_patterns.QueueProcessingFargateService(self, 'separator-service',
|
self.service = ecs_patterns.QueueProcessingFargateService(self, 'separator-service',
|
||||||
cluster=cluster,
|
cluster=cluster,
|
||||||
cpu=256,
|
cpu=2048,
|
||||||
memory_limit_mib=512,
|
memory_limit_mib=8192,
|
||||||
image=ecs.ContainerImage.from_asset(directory=worker_dir),
|
image=ecs.ContainerImage.from_asset(directory=worker_dir),
|
||||||
environment={
|
environment={
|
||||||
'TRACKS_TABLE_NAME': tracks_table.table_name,
|
'TRACKS_TABLE_NAME': tracks_table.table_name,
|
||||||
'INPUT_BUCKET_NAME': input_bucket.bucket_name,
|
|
||||||
'OUTPUT_BUCKET_NAME': output_bucket.bucket_name
|
'OUTPUT_BUCKET_NAME': output_bucket.bucket_name
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|||||||
9
worker/Dockerfile
Normal file
9
worker/Dockerfile
Normal file
@@ -0,0 +1,9 @@
|
|||||||
|
FROM researchdeezer/spleeter
|
||||||
|
ENV PYTHONDONTWRITEBYTECODE=1
|
||||||
|
|
||||||
|
WORKDIR /app
|
||||||
|
COPY . /app
|
||||||
|
RUN pip install --no-cache-dir -r requirements.txt
|
||||||
|
RUN python -m prefetch_models
|
||||||
|
|
||||||
|
ENTRYPOINT ["python", "separator.py"]
|
||||||
0
worker/conftest.py
Normal file
0
worker/conftest.py
Normal file
14
worker/prefetch_models.py
Normal file
14
worker/prefetch_models.py
Normal file
@@ -0,0 +1,14 @@
|
|||||||
|
'''
|
||||||
|
Pre-fetches the required model from github.
|
||||||
|
'''
|
||||||
|
|
||||||
|
import os
|
||||||
|
from spleeter.model.provider import get_default_model_provider
|
||||||
|
|
||||||
|
|
||||||
|
PREFETCH_MODELS = ['2stems']
|
||||||
|
|
||||||
|
for model_name in PREFETCH_MODELS:
|
||||||
|
get_default_model_provider().download(
|
||||||
|
model_name,
|
||||||
|
os.path.join(os.getenv('MODEL_PATH'), model_name))
|
||||||
4
worker/pytest.ini
Normal file
4
worker/pytest.ini
Normal file
@@ -0,0 +1,4 @@
|
|||||||
|
[pytest]
|
||||||
|
filterwarnings =
|
||||||
|
ignore::FutureWarning
|
||||||
|
ignore::DeprecationWarning
|
||||||
3
worker/requirements.txt
Normal file
3
worker/requirements.txt
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
boto3
|
||||||
|
pytest
|
||||||
|
moto
|
||||||
109
worker/separator.py
Normal file
109
worker/separator.py
Normal file
@@ -0,0 +1,109 @@
|
|||||||
|
import os
|
||||||
|
import logging
|
||||||
|
import json
|
||||||
|
import sys
|
||||||
|
from collections import namedtuple
|
||||||
|
from tempfile import NamedTemporaryFile, TemporaryDirectory
|
||||||
|
from threading import Timer
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
|
import boto3
|
||||||
|
from spleeter.separator import Separator
|
||||||
|
from spleeter.audio.adapter import get_audio_adapter
|
||||||
|
|
||||||
|
|
||||||
|
SPLEETER_CONFIGURATION = os.getenv('SPLEETER_CONFIGURATION', 'spleeter:2stems')
|
||||||
|
OUTPUT_CODEC = os.getenv('OUTPUT_CODEC', 'mp3')
|
||||||
|
OUTPUT_BITRATE = os.getenv('OUTPUT_BITRATE', '128k')
|
||||||
|
MAX_AUDIO_DURATION = float(os.getenv('MAX_AUDIO_DURATION', 600.))
|
||||||
|
AUDIO_START_OFFSET = float(os.getenv('AUDIO_START_OFFSET', 0.))
|
||||||
|
OUTPUT_FILENAME_FORMAT = os.getenv('OUTPUT_FILENAME_FORMAT', '{instrument}.{codec}')
|
||||||
|
QUEUE_NAME = os.getenv('QUEUE_NAME')
|
||||||
|
POLLING_INTERVAL = int(os.getenv('POLLING_INTERVAL', 5))
|
||||||
|
OUTPUT_BUCKET_NAME = os.getenv('OUTPUT_BUCKET_NAME')
|
||||||
|
TRACKS_TABLE_NAME = os.getenv('TRACKS_TABLE_NAME')
|
||||||
|
USE_MULTICHANNEL_WIENER_FILTERING = False
|
||||||
|
|
||||||
|
|
||||||
|
S3Entry = namedtuple('S3Entry', ['bucket_name', 'key'])
|
||||||
|
|
||||||
|
|
||||||
|
def parse_s3_url(s3_url: str) -> S3Entry:
|
||||||
|
o = urlparse(s3_url)
|
||||||
|
return S3Entry(o.netloc, o.path.lstrip('/'))
|
||||||
|
|
||||||
|
|
||||||
|
def fetch_separate_and_upload(input_s3_url, output_s3_url, s3=None):
|
||||||
|
if s3 is None:
|
||||||
|
s3 = boto3.resource("s3")
|
||||||
|
|
||||||
|
with NamedTemporaryFile() as input_file, TemporaryDirectory() as output_path:
|
||||||
|
input_object = s3.Object(**parse_s3_url(input_s3_url)._asdict())
|
||||||
|
input_object.download_file(input_file.name)
|
||||||
|
|
||||||
|
audio_adapter = get_audio_adapter(None)
|
||||||
|
separator = Separator(
|
||||||
|
SPLEETER_CONFIGURATION,
|
||||||
|
MWF=USE_MULTICHANNEL_WIENER_FILTERING)
|
||||||
|
|
||||||
|
separator.separate_to_file(
|
||||||
|
input_file.name,
|
||||||
|
output_path,
|
||||||
|
audio_adapter=audio_adapter,
|
||||||
|
offset=AUDIO_START_OFFSET,
|
||||||
|
duration=MAX_AUDIO_DURATION,
|
||||||
|
codec=OUTPUT_CODEC,
|
||||||
|
bitrate=OUTPUT_BITRATE,
|
||||||
|
filename_format=OUTPUT_FILENAME_FORMAT,
|
||||||
|
synchronous=True)
|
||||||
|
|
||||||
|
logging.info(f'Uploading output to: {output_s3_url}')
|
||||||
|
output_object = s3.Object(**parse_s3_url(output_s3_url)._asdict())
|
||||||
|
|
||||||
|
output_filename = os.path.join(output_path, f'accompaniment.{OUTPUT_CODEC}')
|
||||||
|
output_object.upload_file(output_filename)
|
||||||
|
|
||||||
|
|
||||||
|
def poll_for_sqs_message(queue_name: str):
|
||||||
|
sqs = boto3.client('sqs')
|
||||||
|
queue_url = sqs.get_queue_url(QueueName=queue_name)['QueueUrl']
|
||||||
|
response = sqs.receive_message(QueueUrl=queue_url)
|
||||||
|
|
||||||
|
try:
|
||||||
|
messages = response['Messages']
|
||||||
|
except KeyError:
|
||||||
|
logging.info('No messages in the queue')
|
||||||
|
messages = []
|
||||||
|
|
||||||
|
for message in messages:
|
||||||
|
body = json.loads(message['Body'])
|
||||||
|
job_id = body['job_id']
|
||||||
|
output_s3_url = f"s3://{OUTPUT_BUCKET_NAME}/track_{job_id}.mp3"
|
||||||
|
logging.info(f'Start separating {job_id} -> {output_s3_url}')
|
||||||
|
|
||||||
|
try:
|
||||||
|
fetch_separate_and_upload(
|
||||||
|
input_s3_url=body['input_s3_url'],
|
||||||
|
output_s3_url=output_s3_url)
|
||||||
|
|
||||||
|
logging.info(f'Processing successful: {output_s3_url}')
|
||||||
|
|
||||||
|
dynamodb = boto3.client('dynamodb')
|
||||||
|
dynamodb.update_item(
|
||||||
|
TableName=TRACKS_TABLE_NAME,
|
||||||
|
Key={'id': {'S': job_id}},
|
||||||
|
AttributeUpdates={
|
||||||
|
'status': {'Value': {'S': 'successful'}},
|
||||||
|
'output_url': {'Value': {'S': output_s3_url}}
|
||||||
|
})
|
||||||
|
except:
|
||||||
|
logging.exception('Processing failed for some reason')
|
||||||
|
finally:
|
||||||
|
sqs.delete_message(QueueUrl=queue_url, ReceiptHandle=message['ReceiptHandle'])
|
||||||
|
|
||||||
|
Timer(5, poll_for_sqs_message, args=[queue_name]).start()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
poll_for_sqs_message(QUEUE_NAME)
|
||||||
BIN
worker/tests/audio_example.mp3
Normal file
BIN
worker/tests/audio_example.mp3
Normal file
Binary file not shown.
39
worker/tests/test_separator.py
Normal file
39
worker/tests/test_separator.py
Normal file
@@ -0,0 +1,39 @@
|
|||||||
|
import os
|
||||||
|
|
||||||
|
import boto3
|
||||||
|
import pytest
|
||||||
|
from moto import mock_s3
|
||||||
|
|
||||||
|
from separator import fetch_separate_and_upload, parse_s3_url
|
||||||
|
|
||||||
|
|
||||||
|
TEST_FILENAME = 'audio_example.mp3'
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mocked_cloud():
|
||||||
|
with mock_s3():
|
||||||
|
s3 = boto3.resource('s3', region_name='eu-east-1')
|
||||||
|
input_bucket = s3.create_bucket(Bucket='input')
|
||||||
|
|
||||||
|
test_file_local_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), TEST_FILENAME)
|
||||||
|
input_bucket.upload_file(Filename=test_file_local_path, Key=TEST_FILENAME)
|
||||||
|
|
||||||
|
yield {
|
||||||
|
's3': s3,
|
||||||
|
'input_bucket': input_bucket,
|
||||||
|
'output_bucket': s3.create_bucket(Bucket='output')
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def test_fetch_separate_and_upload(mocked_cloud):
|
||||||
|
fetch_separate_and_upload('s3://input/audio_example.mp3', 's3://output/audio_example.mp3', s3=mocked_cloud['s3'])
|
||||||
|
assert [e.key for e in mocked_cloud['output_bucket'].objects.all()] == ['audio_example.mp3']
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize('url,expected_result', [
|
||||||
|
('s3://bucket/key', ('bucket', 'key')),
|
||||||
|
('s3://bucket/with/deep/nested/path.ext', ('bucket', 'with/deep/nested/path.ext')),
|
||||||
|
])
|
||||||
|
def test_parse_s3_url(url, expected_result):
|
||||||
|
assert parse_s3_url(url) == expected_result
|
||||||
Reference in New Issue
Block a user