diff --git a/requirements.txt b/requirements.txt index f8aec04..27cac28 100644 --- a/requirements.txt +++ b/requirements.txt @@ -19,3 +19,4 @@ flask flask_wtf flask_cors gevent==21.8.0 +flask_restx \ No newline at end of file diff --git a/synthesizer/hparams.py b/synthesizer/hparams.py index a779c69..897b6d4 100644 --- a/synthesizer/hparams.py +++ b/synthesizer/hparams.py @@ -41,7 +41,7 @@ hparams = HParams( tts_lstm_dims = 1024, tts_postnet_K = 5, tts_num_highways = 4, - tts_dropout = 0.2, + tts_dropout = 0.5, tts_cleaner_names = ["basic_cleaners"], tts_stop_threshold = -3.4, # Value below which audio generation ends. # For example, for a range of [-4, 4], this diff --git a/web/__init__.py b/web/__init__.py index 91e071c..17f186e 100644 --- a/web/__init__.py +++ b/web/__init__.py @@ -1,7 +1,7 @@ -import os +from web.api import api_blueprint from pathlib import Path from gevent import pywsgi as wsgi -from flask import Flask, jsonify, Response, request, render_template +from flask import Flask, jsonify, Response, request, render_template, url_for from synthesizer.inference import Synthesizer from encoder import inference as encoder from vocoder.hifigan import inference as gan_vocoder @@ -17,8 +17,9 @@ from flask_wtf import CSRFProtect def webApp(): # Init and load config app = Flask(__name__, instance_relative_config=True) - app.config.from_object("web.config.default") + app.config['RESTPLUS_MASK_SWAGGER'] = False + app.register_blueprint(api_blueprint) CORS(app) #允许跨域,注释掉此行则禁止跨域请求 csrf = CSRFProtect(app) @@ -29,11 +30,7 @@ def webApp(): # 3. load melspetron of audio # 4. inference by audio + text + models(encoder, vocoder, synthesizer) # 5. export result - audio_samples = [] - AUDIO_SAMPLES_DIR = app.config.get("AUDIO_SAMPLES_DIR") - if os.path.isdir(AUDIO_SAMPLES_DIR): - audio_samples = list(Path(AUDIO_SAMPLES_DIR).glob("*.wav")) - print("Loaded samples: " + str(len(audio_samples))) + # enc_models_dir = "encoder/saved_models" # voc_models_di = "vocoder/saved_models" # encoders = list(Path(enc_models_dir).glob("*.pt")) @@ -47,24 +44,6 @@ def webApp(): encoder.load_model(Path("encoder/saved_models/pretrained.pt")) gan_vocoder.load_model(Path("vocoder/saved_models/pretrained/g_hifigan.pt")) - # TODO: move to utils - def generate(wav_path): - with open(wav_path, "rb") as fwav: - data = fwav.read(1024) - while data: - yield data - data = fwav.read(1024) - - @app.route("/api/audios", methods=["GET"]) - def audios(): - return jsonify( - {"data": list(a.name for a in audio_samples), "total": len(audio_samples)} - ) - - @app.route("/api/audios/", methods=["GET"]) - def audio_play(name): - return Response(generate(AUDIO_SAMPLES_DIR + name), mimetype="audio/x-wav") - @app.route("/api/models", methods=["GET"]) def models(): return jsonify( @@ -160,7 +139,7 @@ def webApp(): print(f"Web server: http://{host}:{port}") server = wsgi.WSGIServer((host, port), app) server.serve_forever() - + return app if __name__ == "__main__": diff --git a/web/api/__init__.py b/web/api/__init__.py new file mode 100644 index 0000000..db524d7 --- /dev/null +++ b/web/api/__init__.py @@ -0,0 +1,14 @@ +from flask import Blueprint +from flask_restx import Api +from .audio import api as audio + +api_blueprint = Blueprint('api', __name__, url_prefix='/api') + +api = Api( + app=api_blueprint, + title='Mocking Bird', + version='1.0', + description='My API' +) + +api.add_namespace(audio) diff --git a/web/api/audio.py b/web/api/audio.py new file mode 100644 index 0000000..b30e5dd --- /dev/null +++ b/web/api/audio.py @@ -0,0 +1,43 @@ +import os +from pathlib import Path +from flask_restx import Namespace, Resource, fields +from flask import Response, current_app + +api = Namespace('audios', description='Audios related operations') + +audio = api.model('Audio', { + 'name': fields.String(required=True, description='The audio name'), +}) + +def generate(wav_path): + with open(wav_path, "rb") as fwav: + data = fwav.read(1024) + while data: + yield data + data = fwav.read(1024) + +@api.route('/') +class AudioList(Resource): + @api.doc('list_audios') + @api.marshal_list_with(audio) + def get(self): + '''List all audios''' + audio_samples = [] + AUDIO_SAMPLES_DIR = current_app.config.get("AUDIO_SAMPLES_DIR") + if os.path.isdir(AUDIO_SAMPLES_DIR): + audio_samples = list(Path(AUDIO_SAMPLES_DIR).glob("*.wav")) + return list(a.name for a in audio_samples) + +@api.route('/') +@api.param('name', 'The name of audio') +@api.response(404, 'audio not found') +class Audio(Resource): + @api.doc('get_audio') + @api.marshal_with(audio) + def get(self, name): + '''Fetch a cat given its identifier''' + AUDIO_SAMPLES_DIR = current_app.config.get("AUDIO_SAMPLES_DIR") + if Path(AUDIO_SAMPLES_DIR + name).exists(): + return Response(generate(AUDIO_SAMPLES_DIR + name), mimetype="audio/x-wav") + api.abort(404) + \ No newline at end of file diff --git a/web/config/default.py b/web/config/default.py index e75d266..7892ae8 100644 --- a/web/config/default.py +++ b/web/config/default.py @@ -4,4 +4,4 @@ HOST = 'localhost' PORT = 8080 MAX_CONTENT_PATH =1024 * 1024 * 4 # mp3文件大小限定不能超过4M SECRET_KEY = "mockingbird_key" -WTF_CSRF_SECRET_KEY = "mockingbird_key" \ No newline at end of file +WTF_CSRF_SECRET_KEY = "mockingbird_key"