Change synthesizer read to restful style

pull/102/head
babysor00 2021-09-26 10:01:50 +08:00
parent 0d0b55d3e9
commit 4d9e460063
3 changed files with 29 additions and 28 deletions

View File

@ -1,14 +1,14 @@
from web.api import api_blueprint
from pathlib import Path
from gevent import pywsgi as wsgi
from flask import Flask, jsonify, Response, request, render_template, url_for
from flask import Flask, Response, request, render_template
from synthesizer.inference import Synthesizer
from encoder import inference as encoder
from vocoder.hifigan import inference as gan_vocoder
from vocoder.wavernn import inference as rnn_vocoder
import numpy as np
import re
from scipy.io.wavfile import write, read
from scipy.io.wavfile import write
import io
import base64
from flask_cors import CORS
@ -21,40 +21,16 @@ def webApp():
app.config['RESTPLUS_MASK_SWAGGER'] = False
app.register_blueprint(api_blueprint)
CORS(app) #允许跨域,注释掉此行则禁止跨域请求
# CORS(app) #允许跨域,注释掉此行则禁止跨域请求
csrf = CSRFProtect(app)
csrf.init_app(app)
# API For Non-Trainer
# 1. list sample audio files
# 2. record / upload / select audio files
# 3. load melspetron of audio
# 4. inference by audio + text + models(encoder, vocoder, synthesizer)
# 5. export result
# enc_models_dir = "encoder/saved_models"
# voc_models_di = "vocoder/saved_models"
# encoders = list(Path(enc_models_dir).glob("*.pt"))
# vocoders = list(Path(voc_models_di).glob("**/*.pt"))
syn_models_dirt = "synthesizer/saved_models"
synthesizers = list(Path(syn_models_dirt).glob("**/*.pt"))
# print("Loaded encoder models: " + str(len(encoders)))
# print("Loaded vocoder models: " + str(len(vocoders)))
print("Loaded synthesizer models: " + str(len(synthesizers)))
synthesizers_cache = {}
encoder.load_model(Path("encoder/saved_models/pretrained.pt"))
gan_vocoder.load_model(Path("vocoder/saved_models/pretrained/g_hifigan.pt"))
@app.route("/api/models", methods=["GET"])
def models():
return jsonify(
{
# "encoder": list(e.name for e in encoders),
# "vocoder": list(e.name for e in vocoders),
"synthesizers":
list({"name": e.name, "path": str(e)} for e in synthesizers),
}
)
def pcm2float(sig, dtype='float32'):
"""Convert PCM signal to floating point with a range from -1 to 1.
Use dtype='float32' for single precision.

View File

@ -1,6 +1,7 @@
from flask import Blueprint
from flask_restx import Api
from .audio import api as audio
from .synthesizer import api as synthesizer
api_blueprint = Blueprint('api', __name__, url_prefix='/api')
@ -12,3 +13,4 @@ api = Api(
)
api.add_namespace(audio)
api.add_namespace(synthesizer)

23
web/api/synthesizer.py Normal file
View File

@ -0,0 +1,23 @@
from pathlib import Path
from flask_restx import Namespace, Resource, fields
api = Namespace('synthesizers', description='Synthesizers related operations')
synthesizer = api.model('Synthesizer', {
'name': fields.String(required=True, description='The synthesizer name'),
'path': fields.String(required=True, description='The synthesizer path'),
})
synthesizers_cache = {}
syn_models_dirt = "synthesizer/saved_models"
synthesizers = list(Path(syn_models_dirt).glob("**/*.pt"))
print("Loaded synthesizer models: " + str(len(synthesizers)))
@api.route('/')
class SynthesizerList(Resource):
@api.doc('list_synthesizers')
@api.marshal_list_with(synthesizer)
def get(self):
'''List all synthesizers'''
return list({"name": e.name, "path": str(e)} for e in synthesizers)