mirror of
https://github.com/babysor/MockingBird.git
synced 2024-03-22 13:11:31 +08:00
167 lines
5.9 KiB
Python
167 lines
5.9 KiB
Python
|
import os
|
||
|
from pathlib import Path
|
||
|
from gevent import pywsgi as wsgi
|
||
|
from flask import Flask, jsonify, Response, request, render_template
|
||
|
from synthesizer.inference import Synthesizer
|
||
|
from encoder import inference as encoder
|
||
|
from vocoder.hifigan import inference as gan_vocoder
|
||
|
from vocoder.wavernn import inference as rnn_vocoder
|
||
|
import numpy as np
|
||
|
import re
|
||
|
from scipy.io.wavfile import write, read
|
||
|
import io
|
||
|
import base64
|
||
|
from flask_cors import CORS
|
||
|
from flask_wtf import CSRFProtect
|
||
|
|
||
|
def webApp():
|
||
|
# Init and load config
|
||
|
app = Flask(__name__, instance_relative_config=True)
|
||
|
|
||
|
app.config.from_object("web.config.default")
|
||
|
|
||
|
CORS(app) #允许跨域,注释掉此行则禁止跨域请求
|
||
|
csrf = CSRFProtect(app)
|
||
|
csrf.init_app(app)
|
||
|
# API For Non-Trainer
|
||
|
# 1. list sample audio files
|
||
|
# 2. record / upload / select audio files
|
||
|
# 3. load melspetron of audio
|
||
|
# 4. inference by audio + text + models(encoder, vocoder, synthesizer)
|
||
|
# 5. export result
|
||
|
audio_samples = []
|
||
|
AUDIO_SAMPLES_DIR = app.config.get("AUDIO_SAMPLES_DIR")
|
||
|
if os.path.isdir(AUDIO_SAMPLES_DIR):
|
||
|
audio_samples = list(Path(AUDIO_SAMPLES_DIR).glob("*.wav"))
|
||
|
print("Loaded samples: " + str(len(audio_samples)))
|
||
|
# enc_models_dir = "encoder/saved_models"
|
||
|
# voc_models_di = "vocoder/saved_models"
|
||
|
# encoders = list(Path(enc_models_dir).glob("*.pt"))
|
||
|
# vocoders = list(Path(voc_models_di).glob("**/*.pt"))
|
||
|
syn_models_dirt = "synthesizer/saved_models"
|
||
|
synthesizers = list(Path(syn_models_dirt).glob("**/*.pt"))
|
||
|
# print("Loaded encoder models: " + str(len(encoders)))
|
||
|
# print("Loaded vocoder models: " + str(len(vocoders)))
|
||
|
print("Loaded synthesizer models: " + str(len(synthesizers)))
|
||
|
synthesizers_cache = {}
|
||
|
encoder.load_model(Path("encoder/saved_models/pretrained.pt"))
|
||
|
gan_vocoder.load_model(Path("vocoder/saved_models/pretrained/g_hifigan.pt"))
|
||
|
|
||
|
# TODO: move to utils
|
||
|
def generate(wav_path):
|
||
|
with open(wav_path, "rb") as fwav:
|
||
|
data = fwav.read(1024)
|
||
|
while data:
|
||
|
yield data
|
||
|
data = fwav.read(1024)
|
||
|
|
||
|
@app.route("/api/audios", methods=["GET"])
|
||
|
def audios():
|
||
|
return jsonify(
|
||
|
{"data": list(a.name for a in audio_samples), "total": len(audio_samples)}
|
||
|
)
|
||
|
|
||
|
@app.route("/api/audios/<name>", methods=["GET"])
|
||
|
def audio_play(name):
|
||
|
return Response(generate(AUDIO_SAMPLES_DIR + name), mimetype="audio/x-wav")
|
||
|
|
||
|
@app.route("/api/models", methods=["GET"])
|
||
|
def models():
|
||
|
return jsonify(
|
||
|
{
|
||
|
# "encoder": list(e.name for e in encoders),
|
||
|
# "vocoder": list(e.name for e in vocoders),
|
||
|
"synthesizers":
|
||
|
list({"name": e.name, "path": str(e)} for e in synthesizers),
|
||
|
}
|
||
|
)
|
||
|
|
||
|
def pcm2float(sig, dtype='float32'):
|
||
|
"""Convert PCM signal to floating point with a range from -1 to 1.
|
||
|
Use dtype='float32' for single precision.
|
||
|
Parameters
|
||
|
----------
|
||
|
sig : array_like
|
||
|
Input array, must have integral type.
|
||
|
dtype : data type, optional
|
||
|
Desired (floating point) data type.
|
||
|
Returns
|
||
|
-------
|
||
|
numpy.ndarray
|
||
|
Normalized floating point data.
|
||
|
See Also
|
||
|
--------
|
||
|
float2pcm, dtype
|
||
|
"""
|
||
|
sig = np.asarray(sig)
|
||
|
if sig.dtype.kind not in 'iu':
|
||
|
raise TypeError("'sig' must be an array of integers")
|
||
|
dtype = np.dtype(dtype)
|
||
|
if dtype.kind != 'f':
|
||
|
raise TypeError("'dtype' must be a floating point type")
|
||
|
|
||
|
i = np.iinfo(sig.dtype)
|
||
|
abs_max = 2 ** (i.bits - 1)
|
||
|
offset = i.min + abs_max
|
||
|
return (sig.astype(dtype) - offset) / abs_max
|
||
|
|
||
|
# Cache for synthesizer
|
||
|
@csrf.exempt
|
||
|
@app.route("/api/synthesize", methods=["POST"])
|
||
|
def synthesize():
|
||
|
# TODO Implementation with json to support more platform
|
||
|
|
||
|
# Load synthesizer
|
||
|
if "synt_path" in request.form:
|
||
|
synt_path = request.form["synt_path"]
|
||
|
else:
|
||
|
synt_path = synthesizers[0]
|
||
|
print("NO synthsizer is specified, try default first one.")
|
||
|
if synthesizers_cache.get(synt_path) is None:
|
||
|
current_synt = Synthesizer(Path(synt_path))
|
||
|
synthesizers_cache[synt_path] = current_synt
|
||
|
else:
|
||
|
current_synt = synthesizers_cache[synt_path]
|
||
|
print("using synthesizer model: " + str(synt_path))
|
||
|
# Load input wav
|
||
|
wav_base64 = request.form["upfile_b64"]
|
||
|
wav = base64.b64decode(bytes(wav_base64, 'utf-8'))
|
||
|
wav = pcm2float(np.frombuffer(wav, dtype=np.int16), dtype=np.float32)
|
||
|
encoder_wav = encoder.preprocess_wav(wav, 16000)
|
||
|
embed, _, _ = encoder.embed_utterance(encoder_wav, return_partials=True)
|
||
|
|
||
|
# Load input text
|
||
|
texts = request.form["text"].split("\n")
|
||
|
punctuation = '!,。、,' # punctuate and split/clean text
|
||
|
processed_texts = []
|
||
|
for text in texts:
|
||
|
for processed_text in re.sub(r'[{}]+'.format(punctuation), '\n', text).split('\n'):
|
||
|
if processed_text:
|
||
|
processed_texts.append(processed_text.strip())
|
||
|
texts = processed_texts
|
||
|
|
||
|
# synthesize and vocode
|
||
|
embeds = [embed] * len(texts)
|
||
|
specs = current_synt.synthesize_spectrograms(texts, embeds)
|
||
|
spec = np.concatenate(specs, axis=1)
|
||
|
wav = gan_vocoder.infer_waveform(spec)
|
||
|
|
||
|
# Return cooked wav
|
||
|
out = io.BytesIO()
|
||
|
write(out, Synthesizer.sample_rate, wav)
|
||
|
return Response(out, mimetype="audio/wav")
|
||
|
|
||
|
@app.route('/', methods=['GET'])
|
||
|
def index():
|
||
|
return render_template("index.html")
|
||
|
|
||
|
host = app.config.get("HOST")
|
||
|
port = app.config.get("PORT")
|
||
|
print(f"Web server: http://{host}:{port}")
|
||
|
server = wsgi.WSGIServer((host, port), app)
|
||
|
server.serve_forever()
|
||
|
|
||
|
return app
|
||
|
|
||
|
if __name__ == "__main__":
|
||
|
webApp()
|