Replacing the placeholder worker with a real one

This commit is contained in:
2020-04-16 14:24:17 +02:00
parent 50a4c09c3c
commit ca0c55f052
9 changed files with 181 additions and 4 deletions

109
worker/separator.py Normal file
View File

@@ -0,0 +1,109 @@
import os
import logging
import json
import sys
from collections import namedtuple
from tempfile import NamedTemporaryFile, TemporaryDirectory
from threading import Timer
from urllib.parse import urlparse
import boto3
from spleeter.separator import Separator
from spleeter.audio.adapter import get_audio_adapter
SPLEETER_CONFIGURATION = os.getenv('SPLEETER_CONFIGURATION', 'spleeter:2stems')
OUTPUT_CODEC = os.getenv('OUTPUT_CODEC', 'mp3')
OUTPUT_BITRATE = os.getenv('OUTPUT_BITRATE', '128k')
MAX_AUDIO_DURATION = float(os.getenv('MAX_AUDIO_DURATION', 600.))
AUDIO_START_OFFSET = float(os.getenv('AUDIO_START_OFFSET', 0.))
OUTPUT_FILENAME_FORMAT = os.getenv('OUTPUT_FILENAME_FORMAT', '{instrument}.{codec}')
QUEUE_NAME = os.getenv('QUEUE_NAME')
POLLING_INTERVAL = int(os.getenv('POLLING_INTERVAL', 5))
OUTPUT_BUCKET_NAME = os.getenv('OUTPUT_BUCKET_NAME')
TRACKS_TABLE_NAME = os.getenv('TRACKS_TABLE_NAME')
USE_MULTICHANNEL_WIENER_FILTERING = False
S3Entry = namedtuple('S3Entry', ['bucket_name', 'key'])
def parse_s3_url(s3_url: str) -> S3Entry:
o = urlparse(s3_url)
return S3Entry(o.netloc, o.path.lstrip('/'))
def fetch_separate_and_upload(input_s3_url, output_s3_url, s3=None):
if s3 is None:
s3 = boto3.resource("s3")
with NamedTemporaryFile() as input_file, TemporaryDirectory() as output_path:
input_object = s3.Object(**parse_s3_url(input_s3_url)._asdict())
input_object.download_file(input_file.name)
audio_adapter = get_audio_adapter(None)
separator = Separator(
SPLEETER_CONFIGURATION,
MWF=USE_MULTICHANNEL_WIENER_FILTERING)
separator.separate_to_file(
input_file.name,
output_path,
audio_adapter=audio_adapter,
offset=AUDIO_START_OFFSET,
duration=MAX_AUDIO_DURATION,
codec=OUTPUT_CODEC,
bitrate=OUTPUT_BITRATE,
filename_format=OUTPUT_FILENAME_FORMAT,
synchronous=True)
logging.info(f'Uploading output to: {output_s3_url}')
output_object = s3.Object(**parse_s3_url(output_s3_url)._asdict())
output_filename = os.path.join(output_path, f'accompaniment.{OUTPUT_CODEC}')
output_object.upload_file(output_filename)
def poll_for_sqs_message(queue_name: str):
sqs = boto3.client('sqs')
queue_url = sqs.get_queue_url(QueueName=queue_name)['QueueUrl']
response = sqs.receive_message(QueueUrl=queue_url)
try:
messages = response['Messages']
except KeyError:
logging.info('No messages in the queue')
messages = []
for message in messages:
body = json.loads(message['Body'])
job_id = body['job_id']
output_s3_url = f"s3://{OUTPUT_BUCKET_NAME}/track_{job_id}.mp3"
logging.info(f'Start separating {job_id} -> {output_s3_url}')
try:
fetch_separate_and_upload(
input_s3_url=body['input_s3_url'],
output_s3_url=output_s3_url)
logging.info(f'Processing successful: {output_s3_url}')
dynamodb = boto3.client('dynamodb')
dynamodb.update_item(
TableName=TRACKS_TABLE_NAME,
Key={'id': {'S': job_id}},
AttributeUpdates={
'status': {'Value': {'S': 'successful'}},
'output_url': {'Value': {'S': output_s3_url}}
})
except:
logging.exception('Processing failed for some reason')
finally:
sqs.delete_message(QueueUrl=queue_url, ReceiptHandle=message['ReceiptHandle'])
Timer(5, poll_for_sqs_message, args=[queue_name]).start()
if __name__ == '__main__':
logging.basicConfig(level=logging.INFO)
poll_for_sqs_message(QUEUE_NAME)