112 lines
3.9 KiB
Python
112 lines
3.9 KiB
Python
import os
|
|
import logging
|
|
import json
|
|
import sys
|
|
from collections import namedtuple
|
|
from tempfile import NamedTemporaryFile, TemporaryDirectory
|
|
from threading import Timer
|
|
from urllib.parse import urlparse
|
|
|
|
import boto3
|
|
from spleeter.separator import Separator
|
|
from spleeter.audio.adapter import get_audio_adapter
|
|
|
|
|
|
SPLEETER_CONFIGURATION = os.getenv('SPLEETER_CONFIGURATION', 'spleeter:2stems')
|
|
OUTPUT_CODEC = os.getenv('OUTPUT_CODEC', 'mp3')
|
|
OUTPUT_BITRATE = os.getenv('OUTPUT_BITRATE', '128k')
|
|
MAX_AUDIO_DURATION = float(os.getenv('MAX_AUDIO_DURATION', 600.))
|
|
AUDIO_START_OFFSET = float(os.getenv('AUDIO_START_OFFSET', 0.))
|
|
OUTPUT_FILENAME_FORMAT = os.getenv('OUTPUT_FILENAME_FORMAT', '{instrument}.{codec}')
|
|
QUEUE_NAME = os.getenv('QUEUE_NAME')
|
|
POLLING_INTERVAL = int(os.getenv('POLLING_INTERVAL', 5))
|
|
OUTPUT_BUCKET_NAME = os.getenv('OUTPUT_BUCKET_NAME')
|
|
OUTPUT_BUCKET_REGION= os.getenv('OUTPUT_BUCKET_REGION', 'eu-west-1')
|
|
TRACKS_TABLE_NAME = os.getenv('TRACKS_TABLE_NAME')
|
|
USE_MULTICHANNEL_WIENER_FILTERING = False
|
|
|
|
|
|
S3Entry = namedtuple('S3Entry', ['bucket_name', 'key'])
|
|
|
|
|
|
def parse_s3_url(s3_url: str) -> S3Entry:
|
|
o = urlparse(s3_url)
|
|
return S3Entry(o.netloc, o.path.lstrip('/'))
|
|
|
|
|
|
def fetch_separate_and_upload(input_s3_url, output_s3_url, s3=None):
|
|
if s3 is None:
|
|
s3 = boto3.resource("s3")
|
|
|
|
with NamedTemporaryFile() as input_file, TemporaryDirectory() as output_path:
|
|
input_object = s3.Object(**parse_s3_url(input_s3_url)._asdict())
|
|
input_object.download_file(input_file.name)
|
|
|
|
audio_adapter = get_audio_adapter(None)
|
|
separator = Separator(
|
|
SPLEETER_CONFIGURATION,
|
|
MWF=USE_MULTICHANNEL_WIENER_FILTERING)
|
|
|
|
separator.separate_to_file(
|
|
input_file.name,
|
|
output_path,
|
|
audio_adapter=audio_adapter,
|
|
offset=AUDIO_START_OFFSET,
|
|
duration=MAX_AUDIO_DURATION,
|
|
codec=OUTPUT_CODEC,
|
|
bitrate=OUTPUT_BITRATE,
|
|
filename_format=OUTPUT_FILENAME_FORMAT,
|
|
synchronous=True)
|
|
|
|
logging.info(f'Uploading output to: {output_s3_url}')
|
|
output_object = s3.Object(**parse_s3_url(output_s3_url)._asdict())
|
|
|
|
output_filename = os.path.join(output_path, f'accompaniment.{OUTPUT_CODEC}')
|
|
output_object.upload_file(output_filename)
|
|
|
|
|
|
def poll_for_sqs_message(queue_name: str):
|
|
sqs = boto3.client('sqs')
|
|
queue_url = sqs.get_queue_url(QueueName=queue_name)['QueueUrl']
|
|
response = sqs.receive_message(QueueUrl=queue_url)
|
|
|
|
try:
|
|
messages = response['Messages']
|
|
except KeyError:
|
|
logging.info('No messages in the queue')
|
|
messages = []
|
|
|
|
for message in messages:
|
|
body = json.loads(message['Body'])
|
|
job_id = body['job_id']
|
|
output_s3_url = f"s3://{OUTPUT_BUCKET_NAME}/track_{job_id}.mp3"
|
|
download_url = f'https://{OUTPUT_BUCKET_NAME}.s3-{OUTPUT_BUCKET_REGION}.amazonaws.com/track_{job_id}.mp3'
|
|
logging.info(f'Start separating {job_id} -> {output_s3_url}')
|
|
|
|
try:
|
|
fetch_separate_and_upload(
|
|
input_s3_url=body['input_s3_url'],
|
|
output_s3_url=output_s3_url)
|
|
|
|
logging.info(f'Processing successful: {output_s3_url}')
|
|
|
|
dynamodb = boto3.client('dynamodb')
|
|
dynamodb.update_item(
|
|
TableName=TRACKS_TABLE_NAME,
|
|
Key={'id': {'S': job_id}},
|
|
AttributeUpdates={
|
|
'status': {'Value': {'S': 'successful'}},
|
|
'output_url': {'Value': {'S': download_url}}
|
|
})
|
|
except:
|
|
logging.exception('Processing failed for some reason')
|
|
finally:
|
|
sqs.delete_message(QueueUrl=queue_url, ReceiptHandle=message['ReceiptHandle'])
|
|
|
|
Timer(5, poll_for_sqs_message, args=[queue_name]).start()
|
|
|
|
|
|
if __name__ == '__main__':
|
|
logging.basicConfig(level=logging.INFO)
|
|
poll_for_sqs_message(QUEUE_NAME)
|