Change synthesizer read to restful style

This commit is contained in:
babysor00 2021-09-26 10:01:50 +08:00
parent 0d0b55d3e9
commit 4d9e460063
3 changed files with 29 additions and 28 deletions

View File

@ -1,14 +1,14 @@
from web.api import api_blueprint from web.api import api_blueprint
from pathlib import Path from pathlib import Path
from gevent import pywsgi as wsgi from gevent import pywsgi as wsgi
from flask import Flask, jsonify, Response, request, render_template, url_for from flask import Flask, Response, request, render_template
from synthesizer.inference import Synthesizer from synthesizer.inference import Synthesizer
from encoder import inference as encoder from encoder import inference as encoder
from vocoder.hifigan import inference as gan_vocoder from vocoder.hifigan import inference as gan_vocoder
from vocoder.wavernn import inference as rnn_vocoder from vocoder.wavernn import inference as rnn_vocoder
import numpy as np import numpy as np
import re import re
from scipy.io.wavfile import write, read from scipy.io.wavfile import write
import io import io
import base64 import base64
from flask_cors import CORS from flask_cors import CORS
@ -21,40 +21,16 @@ def webApp():
app.config['RESTPLUS_MASK_SWAGGER'] = False app.config['RESTPLUS_MASK_SWAGGER'] = False
app.register_blueprint(api_blueprint) app.register_blueprint(api_blueprint)
CORS(app) #允许跨域,注释掉此行则禁止跨域请求 # CORS(app) #允许跨域,注释掉此行则禁止跨域请求
csrf = CSRFProtect(app) csrf = CSRFProtect(app)
csrf.init_app(app) csrf.init_app(app)
# API For Non-Trainer
# 1. list sample audio files
# 2. record / upload / select audio files
# 3. load melspetron of audio
# 4. inference by audio + text + models(encoder, vocoder, synthesizer)
# 5. export result
# enc_models_dir = "encoder/saved_models"
# voc_models_di = "vocoder/saved_models"
# encoders = list(Path(enc_models_dir).glob("*.pt"))
# vocoders = list(Path(voc_models_di).glob("**/*.pt"))
syn_models_dirt = "synthesizer/saved_models" syn_models_dirt = "synthesizer/saved_models"
synthesizers = list(Path(syn_models_dirt).glob("**/*.pt")) synthesizers = list(Path(syn_models_dirt).glob("**/*.pt"))
# print("Loaded encoder models: " + str(len(encoders)))
# print("Loaded vocoder models: " + str(len(vocoders)))
print("Loaded synthesizer models: " + str(len(synthesizers)))
synthesizers_cache = {} synthesizers_cache = {}
encoder.load_model(Path("encoder/saved_models/pretrained.pt")) encoder.load_model(Path("encoder/saved_models/pretrained.pt"))
gan_vocoder.load_model(Path("vocoder/saved_models/pretrained/g_hifigan.pt")) gan_vocoder.load_model(Path("vocoder/saved_models/pretrained/g_hifigan.pt"))
@app.route("/api/models", methods=["GET"])
def models():
return jsonify(
{
# "encoder": list(e.name for e in encoders),
# "vocoder": list(e.name for e in vocoders),
"synthesizers":
list({"name": e.name, "path": str(e)} for e in synthesizers),
}
)
def pcm2float(sig, dtype='float32'): def pcm2float(sig, dtype='float32'):
"""Convert PCM signal to floating point with a range from -1 to 1. """Convert PCM signal to floating point with a range from -1 to 1.
Use dtype='float32' for single precision. Use dtype='float32' for single precision.

View File

@ -1,6 +1,7 @@
from flask import Blueprint from flask import Blueprint
from flask_restx import Api from flask_restx import Api
from .audio import api as audio from .audio import api as audio
from .synthesizer import api as synthesizer
api_blueprint = Blueprint('api', __name__, url_prefix='/api') api_blueprint = Blueprint('api', __name__, url_prefix='/api')
@ -12,3 +13,4 @@ api = Api(
) )
api.add_namespace(audio) api.add_namespace(audio)
api.add_namespace(synthesizer)

23
web/api/synthesizer.py Normal file
View File

@ -0,0 +1,23 @@
from pathlib import Path
from flask_restx import Namespace, Resource, fields
api = Namespace('synthesizers', description='Synthesizers related operations')
synthesizer = api.model('Synthesizer', {
'name': fields.String(required=True, description='The synthesizer name'),
'path': fields.String(required=True, description='The synthesizer path'),
})
synthesizers_cache = {}
syn_models_dirt = "synthesizer/saved_models"
synthesizers = list(Path(syn_models_dirt).glob("**/*.pt"))
print("Loaded synthesizer models: " + str(len(synthesizers)))
@api.route('/')
class SynthesizerList(Resource):
@api.doc('list_synthesizers')
@api.marshal_list_with(synthesizer)
def get(self):
'''List all synthesizers'''
return list({"name": e.name, "path": str(e)} for e in synthesizers)