Web server release v2 (#99)

* Init App

* init server.py (#93)

* init server.py

* Update requirements.txt

Add requirement

Co-authored-by: auau <auau@test.com>
Co-authored-by: babysor00 <babysor00@gmail.com>

* Run web.py!

Run web.py!

* Restruct readme and add instruction to use web server

* fix training preprocess of vocoder

* Init App

* init server.py (#93)

* init server.py

* Update requirements.txt

Add requirement

Co-authored-by: auau <auau@test.com>
Co-authored-by: babysor00 <babysor00@gmail.com>

* Run web.py!

Run web.py!

* fix training preprocess of vocoder

* Refactor to restful style

Co-authored-by: balala <Ozgay@users.noreply.github.com>
Co-authored-by: auau <auau@test.com>
This commit is contained in:
Vega 2021-09-25 17:07:46 +08:00 committed by GitHub
parent 4acfee2a64
commit 0d0b55d3e9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 66 additions and 29 deletions

View File

@ -19,3 +19,4 @@ flask
flask_wtf
flask_cors
gevent==21.8.0
flask_restx

View File

@ -41,7 +41,7 @@ hparams = HParams(
tts_lstm_dims = 1024,
tts_postnet_K = 5,
tts_num_highways = 4,
tts_dropout = 0.2,
tts_dropout = 0.5,
tts_cleaner_names = ["basic_cleaners"],
tts_stop_threshold = -3.4, # Value below which audio generation ends.
# For example, for a range of [-4, 4], this

View File

@ -1,7 +1,7 @@
import os
from web.api import api_blueprint
from pathlib import Path
from gevent import pywsgi as wsgi
from flask import Flask, jsonify, Response, request, render_template
from flask import Flask, jsonify, Response, request, render_template, url_for
from synthesizer.inference import Synthesizer
from encoder import inference as encoder
from vocoder.hifigan import inference as gan_vocoder
@ -17,8 +17,9 @@ from flask_wtf import CSRFProtect
def webApp():
# Init and load config
app = Flask(__name__, instance_relative_config=True)
app.config.from_object("web.config.default")
app.config['RESTPLUS_MASK_SWAGGER'] = False
app.register_blueprint(api_blueprint)
CORS(app) #允许跨域,注释掉此行则禁止跨域请求
csrf = CSRFProtect(app)
@ -29,11 +30,7 @@ def webApp():
# 3. load melspetron of audio
# 4. inference by audio + text + models(encoder, vocoder, synthesizer)
# 5. export result
audio_samples = []
AUDIO_SAMPLES_DIR = app.config.get("AUDIO_SAMPLES_DIR")
if os.path.isdir(AUDIO_SAMPLES_DIR):
audio_samples = list(Path(AUDIO_SAMPLES_DIR).glob("*.wav"))
print("Loaded samples: " + str(len(audio_samples)))
# enc_models_dir = "encoder/saved_models"
# voc_models_di = "vocoder/saved_models"
# encoders = list(Path(enc_models_dir).glob("*.pt"))
@ -47,24 +44,6 @@ def webApp():
encoder.load_model(Path("encoder/saved_models/pretrained.pt"))
gan_vocoder.load_model(Path("vocoder/saved_models/pretrained/g_hifigan.pt"))
# TODO: move to utils
def generate(wav_path):
with open(wav_path, "rb") as fwav:
data = fwav.read(1024)
while data:
yield data
data = fwav.read(1024)
@app.route("/api/audios", methods=["GET"])
def audios():
return jsonify(
{"data": list(a.name for a in audio_samples), "total": len(audio_samples)}
)
@app.route("/api/audios/<name>", methods=["GET"])
def audio_play(name):
return Response(generate(AUDIO_SAMPLES_DIR + name), mimetype="audio/x-wav")
@app.route("/api/models", methods=["GET"])
def models():
return jsonify(
@ -160,7 +139,7 @@ def webApp():
print(f"Web server: http://{host}:{port}")
server = wsgi.WSGIServer((host, port), app)
server.serve_forever()
return app
if __name__ == "__main__":

14
web/api/__init__.py Normal file
View File

@ -0,0 +1,14 @@
from flask import Blueprint
from flask_restx import Api
from .audio import api as audio
api_blueprint = Blueprint('api', __name__, url_prefix='/api')
api = Api(
app=api_blueprint,
title='Mocking Bird',
version='1.0',
description='My API'
)
api.add_namespace(audio)

43
web/api/audio.py Normal file
View File

@ -0,0 +1,43 @@
import os
from pathlib import Path
from flask_restx import Namespace, Resource, fields
from flask import Response, current_app
api = Namespace('audios', description='Audios related operations')
audio = api.model('Audio', {
'name': fields.String(required=True, description='The audio name'),
})
def generate(wav_path):
with open(wav_path, "rb") as fwav:
data = fwav.read(1024)
while data:
yield data
data = fwav.read(1024)
@api.route('/')
class AudioList(Resource):
@api.doc('list_audios')
@api.marshal_list_with(audio)
def get(self):
'''List all audios'''
audio_samples = []
AUDIO_SAMPLES_DIR = current_app.config.get("AUDIO_SAMPLES_DIR")
if os.path.isdir(AUDIO_SAMPLES_DIR):
audio_samples = list(Path(AUDIO_SAMPLES_DIR).glob("*.wav"))
return list(a.name for a in audio_samples)
@api.route('/<name>')
@api.param('name', 'The name of audio')
@api.response(404, 'audio not found')
class Audio(Resource):
@api.doc('get_audio')
@api.marshal_with(audio)
def get(self, name):
'''Fetch a cat given its identifier'''
AUDIO_SAMPLES_DIR = current_app.config.get("AUDIO_SAMPLES_DIR")
if Path(AUDIO_SAMPLES_DIR + name).exists():
return Response(generate(AUDIO_SAMPLES_DIR + name), mimetype="audio/x-wav")
api.abort(404)

View File

@ -4,4 +4,4 @@ HOST = 'localhost'
PORT = 8080
MAX_CONTENT_PATH =1024 * 1024 * 4 # mp3文件大小限定不能超过4M
SECRET_KEY = "mockingbird_key"
WTF_CSRF_SECRET_KEY = "mockingbird_key"
WTF_CSRF_SECRET_KEY = "mockingbird_key"