mirror of
https://github.com/babysor/MockingBird.git
synced 2024-03-22 13:11:31 +08:00
Web server release v2 (#99)
* Init App * init server.py (#93) * init server.py * Update requirements.txt Add requirement Co-authored-by: auau <auau@test.com> Co-authored-by: babysor00 <babysor00@gmail.com> * Run web.py! Run web.py! * Restruct readme and add instruction to use web server * fix training preprocess of vocoder * Init App * init server.py (#93) * init server.py * Update requirements.txt Add requirement Co-authored-by: auau <auau@test.com> Co-authored-by: babysor00 <babysor00@gmail.com> * Run web.py! Run web.py! * fix training preprocess of vocoder * Refactor to restful style Co-authored-by: balala <Ozgay@users.noreply.github.com> Co-authored-by: auau <auau@test.com>
This commit is contained in:
parent
4acfee2a64
commit
0d0b55d3e9
|
@ -19,3 +19,4 @@ flask
|
||||||
flask_wtf
|
flask_wtf
|
||||||
flask_cors
|
flask_cors
|
||||||
gevent==21.8.0
|
gevent==21.8.0
|
||||||
|
flask_restx
|
|
@ -41,7 +41,7 @@ hparams = HParams(
|
||||||
tts_lstm_dims = 1024,
|
tts_lstm_dims = 1024,
|
||||||
tts_postnet_K = 5,
|
tts_postnet_K = 5,
|
||||||
tts_num_highways = 4,
|
tts_num_highways = 4,
|
||||||
tts_dropout = 0.2,
|
tts_dropout = 0.5,
|
||||||
tts_cleaner_names = ["basic_cleaners"],
|
tts_cleaner_names = ["basic_cleaners"],
|
||||||
tts_stop_threshold = -3.4, # Value below which audio generation ends.
|
tts_stop_threshold = -3.4, # Value below which audio generation ends.
|
||||||
# For example, for a range of [-4, 4], this
|
# For example, for a range of [-4, 4], this
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
import os
|
from web.api import api_blueprint
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from gevent import pywsgi as wsgi
|
from gevent import pywsgi as wsgi
|
||||||
from flask import Flask, jsonify, Response, request, render_template
|
from flask import Flask, jsonify, Response, request, render_template, url_for
|
||||||
from synthesizer.inference import Synthesizer
|
from synthesizer.inference import Synthesizer
|
||||||
from encoder import inference as encoder
|
from encoder import inference as encoder
|
||||||
from vocoder.hifigan import inference as gan_vocoder
|
from vocoder.hifigan import inference as gan_vocoder
|
||||||
|
@ -17,8 +17,9 @@ from flask_wtf import CSRFProtect
|
||||||
def webApp():
|
def webApp():
|
||||||
# Init and load config
|
# Init and load config
|
||||||
app = Flask(__name__, instance_relative_config=True)
|
app = Flask(__name__, instance_relative_config=True)
|
||||||
|
|
||||||
app.config.from_object("web.config.default")
|
app.config.from_object("web.config.default")
|
||||||
|
app.config['RESTPLUS_MASK_SWAGGER'] = False
|
||||||
|
app.register_blueprint(api_blueprint)
|
||||||
|
|
||||||
CORS(app) #允许跨域,注释掉此行则禁止跨域请求
|
CORS(app) #允许跨域,注释掉此行则禁止跨域请求
|
||||||
csrf = CSRFProtect(app)
|
csrf = CSRFProtect(app)
|
||||||
|
@ -29,11 +30,7 @@ def webApp():
|
||||||
# 3. load melspetron of audio
|
# 3. load melspetron of audio
|
||||||
# 4. inference by audio + text + models(encoder, vocoder, synthesizer)
|
# 4. inference by audio + text + models(encoder, vocoder, synthesizer)
|
||||||
# 5. export result
|
# 5. export result
|
||||||
audio_samples = []
|
|
||||||
AUDIO_SAMPLES_DIR = app.config.get("AUDIO_SAMPLES_DIR")
|
|
||||||
if os.path.isdir(AUDIO_SAMPLES_DIR):
|
|
||||||
audio_samples = list(Path(AUDIO_SAMPLES_DIR).glob("*.wav"))
|
|
||||||
print("Loaded samples: " + str(len(audio_samples)))
|
|
||||||
# enc_models_dir = "encoder/saved_models"
|
# enc_models_dir = "encoder/saved_models"
|
||||||
# voc_models_di = "vocoder/saved_models"
|
# voc_models_di = "vocoder/saved_models"
|
||||||
# encoders = list(Path(enc_models_dir).glob("*.pt"))
|
# encoders = list(Path(enc_models_dir).glob("*.pt"))
|
||||||
|
@ -47,24 +44,6 @@ def webApp():
|
||||||
encoder.load_model(Path("encoder/saved_models/pretrained.pt"))
|
encoder.load_model(Path("encoder/saved_models/pretrained.pt"))
|
||||||
gan_vocoder.load_model(Path("vocoder/saved_models/pretrained/g_hifigan.pt"))
|
gan_vocoder.load_model(Path("vocoder/saved_models/pretrained/g_hifigan.pt"))
|
||||||
|
|
||||||
# TODO: move to utils
|
|
||||||
def generate(wav_path):
|
|
||||||
with open(wav_path, "rb") as fwav:
|
|
||||||
data = fwav.read(1024)
|
|
||||||
while data:
|
|
||||||
yield data
|
|
||||||
data = fwav.read(1024)
|
|
||||||
|
|
||||||
@app.route("/api/audios", methods=["GET"])
|
|
||||||
def audios():
|
|
||||||
return jsonify(
|
|
||||||
{"data": list(a.name for a in audio_samples), "total": len(audio_samples)}
|
|
||||||
)
|
|
||||||
|
|
||||||
@app.route("/api/audios/<name>", methods=["GET"])
|
|
||||||
def audio_play(name):
|
|
||||||
return Response(generate(AUDIO_SAMPLES_DIR + name), mimetype="audio/x-wav")
|
|
||||||
|
|
||||||
@app.route("/api/models", methods=["GET"])
|
@app.route("/api/models", methods=["GET"])
|
||||||
def models():
|
def models():
|
||||||
return jsonify(
|
return jsonify(
|
||||||
|
@ -160,7 +139,7 @@ def webApp():
|
||||||
print(f"Web server: http://{host}:{port}")
|
print(f"Web server: http://{host}:{port}")
|
||||||
server = wsgi.WSGIServer((host, port), app)
|
server = wsgi.WSGIServer((host, port), app)
|
||||||
server.serve_forever()
|
server.serve_forever()
|
||||||
|
|
||||||
return app
|
return app
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
14
web/api/__init__.py
Normal file
14
web/api/__init__.py
Normal file
|
@ -0,0 +1,14 @@
|
||||||
|
from flask import Blueprint
|
||||||
|
from flask_restx import Api
|
||||||
|
from .audio import api as audio
|
||||||
|
|
||||||
|
api_blueprint = Blueprint('api', __name__, url_prefix='/api')
|
||||||
|
|
||||||
|
api = Api(
|
||||||
|
app=api_blueprint,
|
||||||
|
title='Mocking Bird',
|
||||||
|
version='1.0',
|
||||||
|
description='My API'
|
||||||
|
)
|
||||||
|
|
||||||
|
api.add_namespace(audio)
|
43
web/api/audio.py
Normal file
43
web/api/audio.py
Normal file
|
@ -0,0 +1,43 @@
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
from flask_restx import Namespace, Resource, fields
|
||||||
|
from flask import Response, current_app
|
||||||
|
|
||||||
|
api = Namespace('audios', description='Audios related operations')
|
||||||
|
|
||||||
|
audio = api.model('Audio', {
|
||||||
|
'name': fields.String(required=True, description='The audio name'),
|
||||||
|
})
|
||||||
|
|
||||||
|
def generate(wav_path):
|
||||||
|
with open(wav_path, "rb") as fwav:
|
||||||
|
data = fwav.read(1024)
|
||||||
|
while data:
|
||||||
|
yield data
|
||||||
|
data = fwav.read(1024)
|
||||||
|
|
||||||
|
@api.route('/')
|
||||||
|
class AudioList(Resource):
|
||||||
|
@api.doc('list_audios')
|
||||||
|
@api.marshal_list_with(audio)
|
||||||
|
def get(self):
|
||||||
|
'''List all audios'''
|
||||||
|
audio_samples = []
|
||||||
|
AUDIO_SAMPLES_DIR = current_app.config.get("AUDIO_SAMPLES_DIR")
|
||||||
|
if os.path.isdir(AUDIO_SAMPLES_DIR):
|
||||||
|
audio_samples = list(Path(AUDIO_SAMPLES_DIR).glob("*.wav"))
|
||||||
|
return list(a.name for a in audio_samples)
|
||||||
|
|
||||||
|
@api.route('/<name>')
|
||||||
|
@api.param('name', 'The name of audio')
|
||||||
|
@api.response(404, 'audio not found')
|
||||||
|
class Audio(Resource):
|
||||||
|
@api.doc('get_audio')
|
||||||
|
@api.marshal_with(audio)
|
||||||
|
def get(self, name):
|
||||||
|
'''Fetch a cat given its identifier'''
|
||||||
|
AUDIO_SAMPLES_DIR = current_app.config.get("AUDIO_SAMPLES_DIR")
|
||||||
|
if Path(AUDIO_SAMPLES_DIR + name).exists():
|
||||||
|
return Response(generate(AUDIO_SAMPLES_DIR + name), mimetype="audio/x-wav")
|
||||||
|
api.abort(404)
|
||||||
|
|
|
@ -4,4 +4,4 @@ HOST = 'localhost'
|
||||||
PORT = 8080
|
PORT = 8080
|
||||||
MAX_CONTENT_PATH =1024 * 1024 * 4 # mp3文件大小限定不能超过4M
|
MAX_CONTENT_PATH =1024 * 1024 * 4 # mp3文件大小限定不能超过4M
|
||||||
SECRET_KEY = "mockingbird_key"
|
SECRET_KEY = "mockingbird_key"
|
||||||
WTF_CSRF_SECRET_KEY = "mockingbird_key"
|
WTF_CSRF_SECRET_KEY = "mockingbird_key"
|
||||||
|
|
Loading…
Reference in New Issue
Block a user