mirror of
https://github.com/babysor/MockingBird.git
synced 2024-03-22 13:11:31 +08:00
Change synthesizer read to restful style
This commit is contained in:
parent
0d0b55d3e9
commit
4d9e460063
|
@ -1,14 +1,14 @@
|
||||||
from web.api import api_blueprint
|
from web.api import api_blueprint
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from gevent import pywsgi as wsgi
|
from gevent import pywsgi as wsgi
|
||||||
from flask import Flask, jsonify, Response, request, render_template, url_for
|
from flask import Flask, Response, request, render_template
|
||||||
from synthesizer.inference import Synthesizer
|
from synthesizer.inference import Synthesizer
|
||||||
from encoder import inference as encoder
|
from encoder import inference as encoder
|
||||||
from vocoder.hifigan import inference as gan_vocoder
|
from vocoder.hifigan import inference as gan_vocoder
|
||||||
from vocoder.wavernn import inference as rnn_vocoder
|
from vocoder.wavernn import inference as rnn_vocoder
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import re
|
import re
|
||||||
from scipy.io.wavfile import write, read
|
from scipy.io.wavfile import write
|
||||||
import io
|
import io
|
||||||
import base64
|
import base64
|
||||||
from flask_cors import CORS
|
from flask_cors import CORS
|
||||||
|
@ -21,40 +21,16 @@ def webApp():
|
||||||
app.config['RESTPLUS_MASK_SWAGGER'] = False
|
app.config['RESTPLUS_MASK_SWAGGER'] = False
|
||||||
app.register_blueprint(api_blueprint)
|
app.register_blueprint(api_blueprint)
|
||||||
|
|
||||||
CORS(app) #允许跨域,注释掉此行则禁止跨域请求
|
# CORS(app) #允许跨域,注释掉此行则禁止跨域请求
|
||||||
csrf = CSRFProtect(app)
|
csrf = CSRFProtect(app)
|
||||||
csrf.init_app(app)
|
csrf.init_app(app)
|
||||||
# API For Non-Trainer
|
|
||||||
# 1. list sample audio files
|
|
||||||
# 2. record / upload / select audio files
|
|
||||||
# 3. load melspetron of audio
|
|
||||||
# 4. inference by audio + text + models(encoder, vocoder, synthesizer)
|
|
||||||
# 5. export result
|
|
||||||
|
|
||||||
# enc_models_dir = "encoder/saved_models"
|
|
||||||
# voc_models_di = "vocoder/saved_models"
|
|
||||||
# encoders = list(Path(enc_models_dir).glob("*.pt"))
|
|
||||||
# vocoders = list(Path(voc_models_di).glob("**/*.pt"))
|
|
||||||
syn_models_dirt = "synthesizer/saved_models"
|
syn_models_dirt = "synthesizer/saved_models"
|
||||||
synthesizers = list(Path(syn_models_dirt).glob("**/*.pt"))
|
synthesizers = list(Path(syn_models_dirt).glob("**/*.pt"))
|
||||||
# print("Loaded encoder models: " + str(len(encoders)))
|
|
||||||
# print("Loaded vocoder models: " + str(len(vocoders)))
|
|
||||||
print("Loaded synthesizer models: " + str(len(synthesizers)))
|
|
||||||
synthesizers_cache = {}
|
synthesizers_cache = {}
|
||||||
encoder.load_model(Path("encoder/saved_models/pretrained.pt"))
|
encoder.load_model(Path("encoder/saved_models/pretrained.pt"))
|
||||||
gan_vocoder.load_model(Path("vocoder/saved_models/pretrained/g_hifigan.pt"))
|
gan_vocoder.load_model(Path("vocoder/saved_models/pretrained/g_hifigan.pt"))
|
||||||
|
|
||||||
@app.route("/api/models", methods=["GET"])
|
|
||||||
def models():
|
|
||||||
return jsonify(
|
|
||||||
{
|
|
||||||
# "encoder": list(e.name for e in encoders),
|
|
||||||
# "vocoder": list(e.name for e in vocoders),
|
|
||||||
"synthesizers":
|
|
||||||
list({"name": e.name, "path": str(e)} for e in synthesizers),
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
def pcm2float(sig, dtype='float32'):
|
def pcm2float(sig, dtype='float32'):
|
||||||
"""Convert PCM signal to floating point with a range from -1 to 1.
|
"""Convert PCM signal to floating point with a range from -1 to 1.
|
||||||
Use dtype='float32' for single precision.
|
Use dtype='float32' for single precision.
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
from flask import Blueprint
|
from flask import Blueprint
|
||||||
from flask_restx import Api
|
from flask_restx import Api
|
||||||
from .audio import api as audio
|
from .audio import api as audio
|
||||||
|
from .synthesizer import api as synthesizer
|
||||||
|
|
||||||
api_blueprint = Blueprint('api', __name__, url_prefix='/api')
|
api_blueprint = Blueprint('api', __name__, url_prefix='/api')
|
||||||
|
|
||||||
|
@ -12,3 +13,4 @@ api = Api(
|
||||||
)
|
)
|
||||||
|
|
||||||
api.add_namespace(audio)
|
api.add_namespace(audio)
|
||||||
|
api.add_namespace(synthesizer)
|
23
web/api/synthesizer.py
Normal file
23
web/api/synthesizer.py
Normal file
|
@ -0,0 +1,23 @@
|
||||||
|
from pathlib import Path
|
||||||
|
from flask_restx import Namespace, Resource, fields
|
||||||
|
|
||||||
|
api = Namespace('synthesizers', description='Synthesizers related operations')
|
||||||
|
|
||||||
|
synthesizer = api.model('Synthesizer', {
|
||||||
|
'name': fields.String(required=True, description='The synthesizer name'),
|
||||||
|
'path': fields.String(required=True, description='The synthesizer path'),
|
||||||
|
})
|
||||||
|
|
||||||
|
synthesizers_cache = {}
|
||||||
|
syn_models_dirt = "synthesizer/saved_models"
|
||||||
|
synthesizers = list(Path(syn_models_dirt).glob("**/*.pt"))
|
||||||
|
print("Loaded synthesizer models: " + str(len(synthesizers)))
|
||||||
|
|
||||||
|
@api.route('/')
|
||||||
|
class SynthesizerList(Resource):
|
||||||
|
@api.doc('list_synthesizers')
|
||||||
|
@api.marshal_list_with(synthesizer)
|
||||||
|
def get(self):
|
||||||
|
'''List all synthesizers'''
|
||||||
|
return list({"name": e.name, "path": str(e)} for e in synthesizers)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user