Files
karaoke-me/worker/separator.py

113 lines
3.9 KiB
Python
Raw Normal View History

import os
import logging
import json
import sys
from collections import namedtuple
from tempfile import NamedTemporaryFile, TemporaryDirectory
from threading import Timer
from urllib.parse import urlparse
import boto3
from spleeter.separator import Separator
from spleeter.audio.adapter import get_audio_adapter
SPLEETER_CONFIGURATION = os.getenv('SPLEETER_CONFIGURATION', 'spleeter:2stems')
OUTPUT_CODEC = os.getenv('OUTPUT_CODEC', 'mp3')
OUTPUT_BITRATE = os.getenv('OUTPUT_BITRATE', '128k')
MAX_AUDIO_DURATION = float(os.getenv('MAX_AUDIO_DURATION', 600.))
AUDIO_START_OFFSET = float(os.getenv('AUDIO_START_OFFSET', 0.))
OUTPUT_FILENAME_FORMAT = os.getenv('OUTPUT_FILENAME_FORMAT', '{instrument}.{codec}')
QUEUE_NAME = os.getenv('QUEUE_NAME')
2020-04-17 14:16:45 +02:00
POLLING_INTERVAL = int(os.getenv('POLLING_INTERVAL', 10))
OUTPUT_BUCKET_NAME = os.getenv('OUTPUT_BUCKET_NAME')
2020-04-17 13:32:43 +02:00
OUTPUT_BUCKET_REGION= os.getenv('OUTPUT_BUCKET_REGION', 'eu-west-1')
TRACKS_TABLE_NAME = os.getenv('TRACKS_TABLE_NAME')
USE_MULTICHANNEL_WIENER_FILTERING = False
S3Entry = namedtuple('S3Entry', ['bucket_name', 'key'])
def parse_s3_url(s3_url: str) -> S3Entry:
o = urlparse(s3_url)
return S3Entry(o.netloc, o.path.lstrip('/'))
def fetch_separate_and_upload(input_s3_url, output_s3_url, s3=None):
if s3 is None:
s3 = boto3.resource("s3")
with NamedTemporaryFile() as input_file, TemporaryDirectory() as output_path:
input_object = s3.Object(**parse_s3_url(input_s3_url)._asdict())
input_object.download_file(input_file.name)
audio_adapter = get_audio_adapter(None)
separator = Separator(
SPLEETER_CONFIGURATION,
MWF=USE_MULTICHANNEL_WIENER_FILTERING)
separator.separate_to_file(
input_file.name,
output_path,
audio_adapter=audio_adapter,
offset=AUDIO_START_OFFSET,
duration=MAX_AUDIO_DURATION,
codec=OUTPUT_CODEC,
bitrate=OUTPUT_BITRATE,
filename_format=OUTPUT_FILENAME_FORMAT,
synchronous=True)
logging.info(f'Uploading output to: {output_s3_url}')
output_object = s3.Object(**parse_s3_url(output_s3_url)._asdict())
output_filename = os.path.join(output_path, f'accompaniment.{OUTPUT_CODEC}')
output_object.upload_file(output_filename)
def poll_for_sqs_message(queue_name: str):
sqs = boto3.client('sqs')
queue_url = sqs.get_queue_url(QueueName=queue_name)['QueueUrl']
response = sqs.receive_message(QueueUrl=queue_url)
try:
messages = response['Messages']
except KeyError:
logging.info('No messages in the queue')
messages = []
for message in messages:
body = json.loads(message['Body'])
job_id = body['job_id']
output_s3_url = f"s3://{OUTPUT_BUCKET_NAME}/track_{job_id}.mp3"
2020-04-17 13:32:43 +02:00
download_url = f'https://{OUTPUT_BUCKET_NAME}.s3-{OUTPUT_BUCKET_REGION}.amazonaws.com/track_{job_id}.mp3'
logging.info(f'Start separating {job_id} -> {output_s3_url}')
try:
fetch_separate_and_upload(
input_s3_url=body['input_s3_url'],
output_s3_url=output_s3_url)
logging.info(f'Processing successful: {output_s3_url}')
dynamodb = boto3.client('dynamodb')
dynamodb.update_item(
TableName=TRACKS_TABLE_NAME,
Key={'id': {'S': job_id}},
AttributeUpdates={
'status': {'Value': {'S': 'successful'}},
2020-04-17 13:32:43 +02:00
'output_url': {'Value': {'S': download_url}}
})
except:
logging.exception('Processing failed for some reason')
finally:
sqs.delete_message(QueueUrl=queue_url, ReceiptHandle=message['ReceiptHandle'])
2020-04-17 14:16:45 +02:00
# Schedule the next poll operation
Timer(POLLING_INTERVAL, poll_for_sqs_message, args=[queue_name]).start()
if __name__ == '__main__':
logging.basicConfig(level=logging.INFO)
poll_for_sqs_message(QUEUE_NAME)