Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

multi-mel gan 用MNN推理 #197

Open
bringtree opened this issue Apr 23, 2021 · 3 comments
Open

multi-mel gan 用MNN推理 #197

bringtree opened this issue Apr 23, 2021 · 3 comments

Comments

@bringtree
Copy link
Owner

bringtree commented Apr 23, 2021

step 1 torch2onnx:

#!/ssd4/exec/huangps/anaconda3/envs/melgan/bin/python

import torch
import torchvision
import numpy as np
from model.generator import Generator
from utils.hparams import HParam, load_hparam_str
from utils.pqmf import PQMF
import wave



checkpoint = torch.load('./chkpt/hps/hps_13efcb4_0600.pt')
hp = load_hparam_str(checkpoint['hp_str'])

vocoder = Generator(hp.audio.n_mel_channels, hp.model.n_residual_layers,
                        ratios=hp.model.generator_ratio, mult = hp.model.mult,
                        out_band = hp.model.out_channels).cuda()
vocoder.load_state_dict(checkpoint['model_g'])
vocoder.eval(inference=False)

# vocoder.inference(mel)

mel = np.load("/ssd5/exec/huangps/melgan/datasets/LJSpeech-1.1/mels/LJ001-0001.npy")
mel = torch.from_numpy(mel).to(device='cuda', dtype=torch.float32)
mel = mel.unsqueeze(0)
dummy_input = mel
input_names = [ "mel" ]
output_names = [ "output" ]

dynamic_axes = {
    "mel" : {0: "batch_size", 2: "seq_len"}
}

torch.onnx.export(vocoder, dummy_input, "melgan.onnx", verbose=True, input_names=input_names, output_names=output_names, dynamic_axes=dynamic_axes)



# MAX_WAV_VALUE = 32768.0
# with torch.no_grad():
#     mel = mel.detach()
#     if len(mel.shape) == 2:
#         mel = mel.unsqueeze(0)
#     mel = mel.cuda()
#     audio = vocoder.inference(mel)
#     # For multi-band inference
#     if hp.model.out_channels > 1:
#         pqmf = PQMF()
#         audio = pqmf.synthesis(audio).view(-1)
    
#     audio = audio.squeeze() # collapse all dimension except time axis
#     audio = audio[:-(hp.audio.hop_length*10)]
#     audio = MAX_WAV_VALUE * audio
#     audio = audio.clamp(min=-MAX_WAV_VALUE, max=MAX_WAV_VALUE-1)
#     audio = audio.short()
#     audio = audio.cpu().detach().numpy()

# print(audio.shape)
# print(audio[:10])
# with wave.open('1.wav', 'wb') as wavfile:
#     wavfile.setparams((1, 2, 22050, 0, 'NONE', 'NONE'))
#     wavfile.writeframes(audio)
@bringtree
Copy link
Owner Author

step2 验证onnx

import onnxruntime
import numpy as np
import torch
import wave
import torchvision
from utils.hparams import HParam, load_hparam_str
from utils.pqmf import PQMF

# checkpoint = torch.load('./chkpt/hps/hps_13efcb4_0600.pt')
# hp = load_hparam_str(checkpoint['hp_str'])

sess = onnxruntime.InferenceSession('./melgan.onnx', None)

input_names = [ "mel" ]
output_names = [ "output" ]

mel = np.load("/ssd5/exec/huangps/melgan/datasets/LJSpeech-1.1/mels/LJ001-0001.npy")

mel = mel.reshape([1,80,-1])
audio = sess.run(output_names, {'mel': mel})
print(audio)
# audio = torch.from_numpy(audio[0]).to(device='cpu', dtype=torch.float32)



# MAX_WAV_VALUE = 32768.0
# with torch.no_grad():
    
#     pqmf = PQMF()
#     audio = pqmf.synthesis(audio).view(-1)
#     audio = audio.squeeze()
    
#     audio = audio[:-(256*10)]
#     audio = MAX_WAV_VALUE * audio
#     audio = audio.clamp(min=-MAX_WAV_VALUE, max=MAX_WAV_VALUE-1)
#     audio = audio.short()
#     audio = audio.cpu().detach().numpy()


# print(audio.shape)
# print(audio[:10])
# with wave.open('1.wav', 'wb') as wavfile:
#     wavfile.setparams((1, 2, 22050, 0, 'NONE', 'NONE'))
#     wavfile.writeframes(audio)

@bringtree
Copy link
Owner Author

bringtree commented Apr 23, 2021

step3 导出MNN模型

 ./MNNConvert -f ONNX --modelFile melgan.onnx --MNNModel melgan.mnn --bizCode biz

numpy 转bin

https://blog.csdn.net/guyuealian/article/details/106422400

@bringtree
Copy link
Owner Author

step4 C++推理验证

//
//  vocoder.cpp
//  MNN
//
//

#include <math.h>
#include <fstream>
#include <iostream>
#include <memory>
#include <MNN/Interpreter.hpp>


#define MNN_OPEN_TIME_TRACE
#include <MNN/AutoTime.hpp>

using namespace MNN;

int main(int argc, char *argv[]) {


    const auto melganModel = "/Users/peisonghuang/MNN/demo/model/melgan.mnn";
    const auto inputFileName = "/Users/peisonghuang/MNN/demo/model/LJ001-0001.bin";

    // create net and session
    auto mnnNet = std::shared_ptr<MNN::Interpreter>(MNN::Interpreter::createFromFile(melganModel));

    MNN::ScheduleConfig netConfig;
    netConfig.type = MNN_FORWARD_CPU;
    netConfig.numThread = 1;
    auto session = mnnNet->createSession(netConfig);
    auto input = mnnNet->getSessionInput(session, "mel");
    mnnNet->resizeTensor(input, {1, 80, 832});
    mnnNet->resizeSession(session);

//     read data from bin 80 832
    {
        MNN::Tensor givenTensor(input, Tensor::CAFFE);

        std::ifstream inputFile(inputFileName, std::ios::in | std::ios::binary);

        float fnum[80][832] = {0};
        inputFile.read((char *) &fnum, sizeof fnum);
        inputFile.close();

        for (int i = 0; i < 80; i++) {
            for (int j = 0; j < 832; j++) {
                givenTensor.host<float>()[i * 832 + j] = static_cast<float_t>(fnum[i][j]);
            }
        }
        input->copyFromHostTensor(&givenTensor);
    }

    // run...
    {
        AUTOTIME;
        mnnNet->runSession(session);
    }

    // get output
    {
        auto outputTensor = mnnNet->getSessionOutput(session, "output");
        auto nchwTensor = new Tensor(outputTensor, Tensor::CAFFE);
        outputTensor->copyToHostTensor(nchwTensor);
        for (int i = 0; i < 4; i++){
            for( int j = 0; j < 3; j++)
                std::cout << nchwTensor->host<float>()[i*53248+j] << " ";
            std::cout << std::endl;
        }

    }
    return 0;
}

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant