Replacing the placeholder worker with a real one
This commit is contained in:
7
app.py
7
app.py
@@ -79,16 +79,15 @@ class SeparatorWorker(core.Stack):
|
||||
**kwargs):
|
||||
super().__init__(scope, id, **kwargs)
|
||||
|
||||
worker_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), 'worker-placeholder'))
|
||||
worker_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), 'worker'))
|
||||
|
||||
self.service = ecs_patterns.QueueProcessingFargateService(self, 'separator-service',
|
||||
cluster=cluster,
|
||||
cpu=256,
|
||||
memory_limit_mib=512,
|
||||
cpu=2048,
|
||||
memory_limit_mib=8192,
|
||||
image=ecs.ContainerImage.from_asset(directory=worker_dir),
|
||||
environment={
|
||||
'TRACKS_TABLE_NAME': tracks_table.table_name,
|
||||
'INPUT_BUCKET_NAME': input_bucket.bucket_name,
|
||||
'OUTPUT_BUCKET_NAME': output_bucket.bucket_name
|
||||
})
|
||||
|
||||
|
||||
9
worker/Dockerfile
Normal file
9
worker/Dockerfile
Normal file
@@ -0,0 +1,9 @@
|
||||
FROM researchdeezer/spleeter
|
||||
ENV PYTHONDONTWRITEBYTECODE=1
|
||||
|
||||
WORKDIR /app
|
||||
COPY . /app
|
||||
RUN pip install --no-cache-dir -r requirements.txt
|
||||
RUN python -m prefetch_models
|
||||
|
||||
ENTRYPOINT ["python", "separator.py"]
|
||||
0
worker/conftest.py
Normal file
0
worker/conftest.py
Normal file
14
worker/prefetch_models.py
Normal file
14
worker/prefetch_models.py
Normal file
@@ -0,0 +1,14 @@
|
||||
'''
|
||||
Pre-fetches the required model from github.
|
||||
'''
|
||||
|
||||
import os
|
||||
from spleeter.model.provider import get_default_model_provider
|
||||
|
||||
|
||||
PREFETCH_MODELS = ['2stems']
|
||||
|
||||
for model_name in PREFETCH_MODELS:
|
||||
get_default_model_provider().download(
|
||||
model_name,
|
||||
os.path.join(os.getenv('MODEL_PATH'), model_name))
|
||||
4
worker/pytest.ini
Normal file
4
worker/pytest.ini
Normal file
@@ -0,0 +1,4 @@
|
||||
[pytest]
|
||||
filterwarnings =
|
||||
ignore::FutureWarning
|
||||
ignore::DeprecationWarning
|
||||
3
worker/requirements.txt
Normal file
3
worker/requirements.txt
Normal file
@@ -0,0 +1,3 @@
|
||||
boto3
|
||||
pytest
|
||||
moto
|
||||
109
worker/separator.py
Normal file
109
worker/separator.py
Normal file
@@ -0,0 +1,109 @@
|
||||
import os
|
||||
import logging
|
||||
import json
|
||||
import sys
|
||||
from collections import namedtuple
|
||||
from tempfile import NamedTemporaryFile, TemporaryDirectory
|
||||
from threading import Timer
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import boto3
|
||||
from spleeter.separator import Separator
|
||||
from spleeter.audio.adapter import get_audio_adapter
|
||||
|
||||
|
||||
SPLEETER_CONFIGURATION = os.getenv('SPLEETER_CONFIGURATION', 'spleeter:2stems')
|
||||
OUTPUT_CODEC = os.getenv('OUTPUT_CODEC', 'mp3')
|
||||
OUTPUT_BITRATE = os.getenv('OUTPUT_BITRATE', '128k')
|
||||
MAX_AUDIO_DURATION = float(os.getenv('MAX_AUDIO_DURATION', 600.))
|
||||
AUDIO_START_OFFSET = float(os.getenv('AUDIO_START_OFFSET', 0.))
|
||||
OUTPUT_FILENAME_FORMAT = os.getenv('OUTPUT_FILENAME_FORMAT', '{instrument}.{codec}')
|
||||
QUEUE_NAME = os.getenv('QUEUE_NAME')
|
||||
POLLING_INTERVAL = int(os.getenv('POLLING_INTERVAL', 5))
|
||||
OUTPUT_BUCKET_NAME = os.getenv('OUTPUT_BUCKET_NAME')
|
||||
TRACKS_TABLE_NAME = os.getenv('TRACKS_TABLE_NAME')
|
||||
USE_MULTICHANNEL_WIENER_FILTERING = False
|
||||
|
||||
|
||||
S3Entry = namedtuple('S3Entry', ['bucket_name', 'key'])
|
||||
|
||||
|
||||
def parse_s3_url(s3_url: str) -> S3Entry:
|
||||
o = urlparse(s3_url)
|
||||
return S3Entry(o.netloc, o.path.lstrip('/'))
|
||||
|
||||
|
||||
def fetch_separate_and_upload(input_s3_url, output_s3_url, s3=None):
|
||||
if s3 is None:
|
||||
s3 = boto3.resource("s3")
|
||||
|
||||
with NamedTemporaryFile() as input_file, TemporaryDirectory() as output_path:
|
||||
input_object = s3.Object(**parse_s3_url(input_s3_url)._asdict())
|
||||
input_object.download_file(input_file.name)
|
||||
|
||||
audio_adapter = get_audio_adapter(None)
|
||||
separator = Separator(
|
||||
SPLEETER_CONFIGURATION,
|
||||
MWF=USE_MULTICHANNEL_WIENER_FILTERING)
|
||||
|
||||
separator.separate_to_file(
|
||||
input_file.name,
|
||||
output_path,
|
||||
audio_adapter=audio_adapter,
|
||||
offset=AUDIO_START_OFFSET,
|
||||
duration=MAX_AUDIO_DURATION,
|
||||
codec=OUTPUT_CODEC,
|
||||
bitrate=OUTPUT_BITRATE,
|
||||
filename_format=OUTPUT_FILENAME_FORMAT,
|
||||
synchronous=True)
|
||||
|
||||
logging.info(f'Uploading output to: {output_s3_url}')
|
||||
output_object = s3.Object(**parse_s3_url(output_s3_url)._asdict())
|
||||
|
||||
output_filename = os.path.join(output_path, f'accompaniment.{OUTPUT_CODEC}')
|
||||
output_object.upload_file(output_filename)
|
||||
|
||||
|
||||
def poll_for_sqs_message(queue_name: str):
|
||||
sqs = boto3.client('sqs')
|
||||
queue_url = sqs.get_queue_url(QueueName=queue_name)['QueueUrl']
|
||||
response = sqs.receive_message(QueueUrl=queue_url)
|
||||
|
||||
try:
|
||||
messages = response['Messages']
|
||||
except KeyError:
|
||||
logging.info('No messages in the queue')
|
||||
messages = []
|
||||
|
||||
for message in messages:
|
||||
body = json.loads(message['Body'])
|
||||
job_id = body['job_id']
|
||||
output_s3_url = f"s3://{OUTPUT_BUCKET_NAME}/track_{job_id}.mp3"
|
||||
logging.info(f'Start separating {job_id} -> {output_s3_url}')
|
||||
|
||||
try:
|
||||
fetch_separate_and_upload(
|
||||
input_s3_url=body['input_s3_url'],
|
||||
output_s3_url=output_s3_url)
|
||||
|
||||
logging.info(f'Processing successful: {output_s3_url}')
|
||||
|
||||
dynamodb = boto3.client('dynamodb')
|
||||
dynamodb.update_item(
|
||||
TableName=TRACKS_TABLE_NAME,
|
||||
Key={'id': {'S': job_id}},
|
||||
AttributeUpdates={
|
||||
'status': {'Value': {'S': 'successful'}},
|
||||
'output_url': {'Value': {'S': output_s3_url}}
|
||||
})
|
||||
except:
|
||||
logging.exception('Processing failed for some reason')
|
||||
finally:
|
||||
sqs.delete_message(QueueUrl=queue_url, ReceiptHandle=message['ReceiptHandle'])
|
||||
|
||||
Timer(5, poll_for_sqs_message, args=[queue_name]).start()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
poll_for_sqs_message(QUEUE_NAME)
|
||||
BIN
worker/tests/audio_example.mp3
Normal file
BIN
worker/tests/audio_example.mp3
Normal file
Binary file not shown.
39
worker/tests/test_separator.py
Normal file
39
worker/tests/test_separator.py
Normal file
@@ -0,0 +1,39 @@
|
||||
import os
|
||||
|
||||
import boto3
|
||||
import pytest
|
||||
from moto import mock_s3
|
||||
|
||||
from separator import fetch_separate_and_upload, parse_s3_url
|
||||
|
||||
|
||||
TEST_FILENAME = 'audio_example.mp3'
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mocked_cloud():
|
||||
with mock_s3():
|
||||
s3 = boto3.resource('s3', region_name='eu-east-1')
|
||||
input_bucket = s3.create_bucket(Bucket='input')
|
||||
|
||||
test_file_local_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), TEST_FILENAME)
|
||||
input_bucket.upload_file(Filename=test_file_local_path, Key=TEST_FILENAME)
|
||||
|
||||
yield {
|
||||
's3': s3,
|
||||
'input_bucket': input_bucket,
|
||||
'output_bucket': s3.create_bucket(Bucket='output')
|
||||
}
|
||||
|
||||
|
||||
def test_fetch_separate_and_upload(mocked_cloud):
|
||||
fetch_separate_and_upload('s3://input/audio_example.mp3', 's3://output/audio_example.mp3', s3=mocked_cloud['s3'])
|
||||
assert [e.key for e in mocked_cloud['output_bucket'].objects.all()] == ['audio_example.mp3']
|
||||
|
||||
|
||||
@pytest.mark.parametrize('url,expected_result', [
|
||||
('s3://bucket/key', ('bucket', 'key')),
|
||||
('s3://bucket/with/deep/nested/path.ext', ('bucket', 'with/deep/nested/path.ext')),
|
||||
])
|
||||
def test_parse_s3_url(url, expected_result):
|
||||
assert parse_s3_url(url) == expected_result
|
||||
Reference in New Issue
Block a user