diff --git a/web/__init__.py b/web/__init__.py index 17f186e..0f2fe4f 100644 --- a/web/__init__.py +++ b/web/__init__.py @@ -1,14 +1,14 @@ from web.api import api_blueprint from pathlib import Path from gevent import pywsgi as wsgi -from flask import Flask, jsonify, Response, request, render_template, url_for +from flask import Flask, Response, request, render_template from synthesizer.inference import Synthesizer from encoder import inference as encoder from vocoder.hifigan import inference as gan_vocoder from vocoder.wavernn import inference as rnn_vocoder import numpy as np import re -from scipy.io.wavfile import write, read +from scipy.io.wavfile import write import io import base64 from flask_cors import CORS @@ -21,40 +21,16 @@ def webApp(): app.config['RESTPLUS_MASK_SWAGGER'] = False app.register_blueprint(api_blueprint) - CORS(app) #允许跨域,注释掉此行则禁止跨域请求 + # CORS(app) #允许跨域,注释掉此行则禁止跨域请求 csrf = CSRFProtect(app) csrf.init_app(app) - # API For Non-Trainer - # 1. list sample audio files - # 2. record / upload / select audio files - # 3. load melspetron of audio - # 4. inference by audio + text + models(encoder, vocoder, synthesizer) - # 5. export result - - # enc_models_dir = "encoder/saved_models" - # voc_models_di = "vocoder/saved_models" - # encoders = list(Path(enc_models_dir).glob("*.pt")) - # vocoders = list(Path(voc_models_di).glob("**/*.pt")) + syn_models_dirt = "synthesizer/saved_models" synthesizers = list(Path(syn_models_dirt).glob("**/*.pt")) - # print("Loaded encoder models: " + str(len(encoders))) - # print("Loaded vocoder models: " + str(len(vocoders))) - print("Loaded synthesizer models: " + str(len(synthesizers))) synthesizers_cache = {} encoder.load_model(Path("encoder/saved_models/pretrained.pt")) gan_vocoder.load_model(Path("vocoder/saved_models/pretrained/g_hifigan.pt")) - @app.route("/api/models", methods=["GET"]) - def models(): - return jsonify( - { - # "encoder": list(e.name for e in encoders), - # "vocoder": list(e.name for e in vocoders), - "synthesizers": - list({"name": e.name, "path": str(e)} for e in synthesizers), - } - ) - def pcm2float(sig, dtype='float32'): """Convert PCM signal to floating point with a range from -1 to 1. Use dtype='float32' for single precision. diff --git a/web/api/__init__.py b/web/api/__init__.py index db524d7..a0c8726 100644 --- a/web/api/__init__.py +++ b/web/api/__init__.py @@ -1,6 +1,7 @@ from flask import Blueprint from flask_restx import Api from .audio import api as audio +from .synthesizer import api as synthesizer api_blueprint = Blueprint('api', __name__, url_prefix='/api') @@ -12,3 +13,4 @@ api = Api( ) api.add_namespace(audio) +api.add_namespace(synthesizer) \ No newline at end of file diff --git a/web/api/synthesizer.py b/web/api/synthesizer.py new file mode 100644 index 0000000..23963b3 --- /dev/null +++ b/web/api/synthesizer.py @@ -0,0 +1,23 @@ +from pathlib import Path +from flask_restx import Namespace, Resource, fields + +api = Namespace('synthesizers', description='Synthesizers related operations') + +synthesizer = api.model('Synthesizer', { + 'name': fields.String(required=True, description='The synthesizer name'), + 'path': fields.String(required=True, description='The synthesizer path'), +}) + +synthesizers_cache = {} +syn_models_dirt = "synthesizer/saved_models" +synthesizers = list(Path(syn_models_dirt).glob("**/*.pt")) +print("Loaded synthesizer models: " + str(len(synthesizers))) + +@api.route('/') +class SynthesizerList(Resource): + @api.doc('list_synthesizers') + @api.marshal_list_with(synthesizer) + def get(self): + '''List all synthesizers''' + return list({"name": e.name, "path": str(e)} for e in synthesizers) +