Replacing the placeholder worker with a real one
This commit is contained in:
109
worker/separator.py
Normal file
109
worker/separator.py
Normal file
@@ -0,0 +1,109 @@
|
||||
import os
|
||||
import logging
|
||||
import json
|
||||
import sys
|
||||
from collections import namedtuple
|
||||
from tempfile import NamedTemporaryFile, TemporaryDirectory
|
||||
from threading import Timer
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import boto3
|
||||
from spleeter.separator import Separator
|
||||
from spleeter.audio.adapter import get_audio_adapter
|
||||
|
||||
|
||||
SPLEETER_CONFIGURATION = os.getenv('SPLEETER_CONFIGURATION', 'spleeter:2stems')
|
||||
OUTPUT_CODEC = os.getenv('OUTPUT_CODEC', 'mp3')
|
||||
OUTPUT_BITRATE = os.getenv('OUTPUT_BITRATE', '128k')
|
||||
MAX_AUDIO_DURATION = float(os.getenv('MAX_AUDIO_DURATION', 600.))
|
||||
AUDIO_START_OFFSET = float(os.getenv('AUDIO_START_OFFSET', 0.))
|
||||
OUTPUT_FILENAME_FORMAT = os.getenv('OUTPUT_FILENAME_FORMAT', '{instrument}.{codec}')
|
||||
QUEUE_NAME = os.getenv('QUEUE_NAME')
|
||||
POLLING_INTERVAL = int(os.getenv('POLLING_INTERVAL', 5))
|
||||
OUTPUT_BUCKET_NAME = os.getenv('OUTPUT_BUCKET_NAME')
|
||||
TRACKS_TABLE_NAME = os.getenv('TRACKS_TABLE_NAME')
|
||||
USE_MULTICHANNEL_WIENER_FILTERING = False
|
||||
|
||||
|
||||
S3Entry = namedtuple('S3Entry', ['bucket_name', 'key'])
|
||||
|
||||
|
||||
def parse_s3_url(s3_url: str) -> S3Entry:
|
||||
o = urlparse(s3_url)
|
||||
return S3Entry(o.netloc, o.path.lstrip('/'))
|
||||
|
||||
|
||||
def fetch_separate_and_upload(input_s3_url, output_s3_url, s3=None):
|
||||
if s3 is None:
|
||||
s3 = boto3.resource("s3")
|
||||
|
||||
with NamedTemporaryFile() as input_file, TemporaryDirectory() as output_path:
|
||||
input_object = s3.Object(**parse_s3_url(input_s3_url)._asdict())
|
||||
input_object.download_file(input_file.name)
|
||||
|
||||
audio_adapter = get_audio_adapter(None)
|
||||
separator = Separator(
|
||||
SPLEETER_CONFIGURATION,
|
||||
MWF=USE_MULTICHANNEL_WIENER_FILTERING)
|
||||
|
||||
separator.separate_to_file(
|
||||
input_file.name,
|
||||
output_path,
|
||||
audio_adapter=audio_adapter,
|
||||
offset=AUDIO_START_OFFSET,
|
||||
duration=MAX_AUDIO_DURATION,
|
||||
codec=OUTPUT_CODEC,
|
||||
bitrate=OUTPUT_BITRATE,
|
||||
filename_format=OUTPUT_FILENAME_FORMAT,
|
||||
synchronous=True)
|
||||
|
||||
logging.info(f'Uploading output to: {output_s3_url}')
|
||||
output_object = s3.Object(**parse_s3_url(output_s3_url)._asdict())
|
||||
|
||||
output_filename = os.path.join(output_path, f'accompaniment.{OUTPUT_CODEC}')
|
||||
output_object.upload_file(output_filename)
|
||||
|
||||
|
||||
def poll_for_sqs_message(queue_name: str):
|
||||
sqs = boto3.client('sqs')
|
||||
queue_url = sqs.get_queue_url(QueueName=queue_name)['QueueUrl']
|
||||
response = sqs.receive_message(QueueUrl=queue_url)
|
||||
|
||||
try:
|
||||
messages = response['Messages']
|
||||
except KeyError:
|
||||
logging.info('No messages in the queue')
|
||||
messages = []
|
||||
|
||||
for message in messages:
|
||||
body = json.loads(message['Body'])
|
||||
job_id = body['job_id']
|
||||
output_s3_url = f"s3://{OUTPUT_BUCKET_NAME}/track_{job_id}.mp3"
|
||||
logging.info(f'Start separating {job_id} -> {output_s3_url}')
|
||||
|
||||
try:
|
||||
fetch_separate_and_upload(
|
||||
input_s3_url=body['input_s3_url'],
|
||||
output_s3_url=output_s3_url)
|
||||
|
||||
logging.info(f'Processing successful: {output_s3_url}')
|
||||
|
||||
dynamodb = boto3.client('dynamodb')
|
||||
dynamodb.update_item(
|
||||
TableName=TRACKS_TABLE_NAME,
|
||||
Key={'id': {'S': job_id}},
|
||||
AttributeUpdates={
|
||||
'status': {'Value': {'S': 'successful'}},
|
||||
'output_url': {'Value': {'S': output_s3_url}}
|
||||
})
|
||||
except:
|
||||
logging.exception('Processing failed for some reason')
|
||||
finally:
|
||||
sqs.delete_message(QueueUrl=queue_url, ReceiptHandle=message['ReceiptHandle'])
|
||||
|
||||
Timer(5, poll_for_sqs_message, args=[queue_name]).start()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
poll_for_sqs_message(QUEUE_NAME)
|
||||
Reference in New Issue
Block a user