Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix gpu-onnx infer #2562

Merged
merged 10 commits into from
Jul 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 14 additions & 7 deletions runtime/gpu/model_repo/scoring/1/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,6 @@ def initialize(self, args):
def init_ctc_rescore(self, parameters):
num_processes = multiprocessing.cpu_count()
cutoff_prob = 0.9999
blank_id = 0
alpha = 2.0
beta = 1.0
bidecoder = 0
Expand Down Expand Up @@ -104,8 +103,12 @@ def init_ctc_rescore(self, parameters):

self.num_processes = num_processes
self.cutoff_prob = cutoff_prob
self.blank_id = blank_id
_, vocab = self.load_vocab(vocab_path)
ret = self.load_vocab(vocab_path)
id2vocab, vocab, space_id, blank_id, sos_eos = ret
self.space_id = space_id if space_id else -1
self.blank_id = blank_id if blank_id else 0
self.eos = self.sos = sos_eos if sos_eos else len(vocab) - 1

if lm_path and os.path.exists(lm_path):
self.lm = Scorer(alpha, beta, lm_path, vocab)
print("Successfully load language model!")
Expand All @@ -125,24 +128,28 @@ def init_ctc_rescore(self, parameters):
)
self.vocabulary = vocab
self.bidecoder = bidecoder
sos = eos = len(vocab) - 1
self.sos = sos
self.eos = eos

def load_vocab(self, vocab_file):
"""
load lang_char.txt
"""
id2vocab = {}
space_id, blank_id, sos_eos = None, None, None
with open(vocab_file, "r", encoding="utf-8") as f:
for line in f:
line = line.strip()
char, id = line.split()
id2vocab[int(id)] = char
if char == " ":
space_id = int(id)
elif char == "<blank>":
blank_id = int(id)
elif char == "<sos/eos>":
sos_eos = int(id)
vocab = [0] * len(id2vocab)
for id, char in id2vocab.items():
vocab[id] = char
return id2vocab, vocab
return (id2vocab, vocab, space_id, blank_id, sos_eos)

def load_hotwords(self, hotwords_file):
"""
Expand Down
3 changes: 2 additions & 1 deletion runtime/gpu/model_repo_stateful/wenet/1/wenet_onnx_model.py
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, cur_mask_len 这个改动是必须的吗?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

源代码里面有hist_enc为None的情况,len(None)会报错,测试的时候确实也有出现hist_enc为None的情况。

if hist_enc is None:
  cur_enc = cur_encoder_out[idx]
.....
cur_mask_len = int(len(hist_enc) + seq_lens[idx])

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Didn't realize we changed the initial history cache. FYI, @yuekaizhang Didn't we find this issue before?

Original file line number Diff line number Diff line change
Expand Up @@ -183,11 +183,12 @@ def infer(self, batch_log_probs, batch_log_probs_idx, seq_lens,
hist_enc = batch_encoder_hist[idx]
if hist_enc is None:
cur_enc = cur_encoder_out[idx]
cur_mask_len = int(0 + seq_lens[idx])
else:
cur_enc = torch.cat([hist_enc, cur_encoder_out[idx]],
axis=0)
cur_mask_len = int(len(hist_enc) + seq_lens[idx])
rescore_encoder_hist.append(cur_enc)
cur_mask_len = int(len(hist_enc) + seq_lens[idx])
rescore_encoder_lens.append(cur_mask_len)
rescore_hyps.append(score_hyps[idx])
if cur_enc.shape[0] > max_length:
Expand Down
2 changes: 1 addition & 1 deletion wenet/bin/export_onnx_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -1200,7 +1200,7 @@ def export_rescoring_decoder(model, configs, args, logger, decoder_onnx_path,
configs['cmvn_conf'] = {}
else:
assert configs['cmvn'] == "global_cmvn"
assert configs['cmvn']['cmvn_conf'] is not None
assert configs['cmvn_conf'] is not None
configs['cmvn_conf']["cmvn_file"] = args.cmvn_file
if (args.reverse_weight != -1.0
and "reverse_weight" in configs["model_conf"]):
Expand Down
18 changes: 11 additions & 7 deletions wenet/bin/recognize_onnx_gpu.py
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you remove the blank lines? Or just use one blank line instead of two lines? Thanks!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已删除

Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,6 @@
'https://github.com/Slyne/ctc_decoder.git')
sys.exit(1)


def get_args():
parser = argparse.ArgumentParser(description='recognize with your model')
parser.add_argument('--config', required=True, help='config file')
Expand Down Expand Up @@ -106,10 +105,8 @@ def get_args():
action='store_true',
help='whether to export fp16 model, default false')
args = parser.parse_args()
print(args)
return args


def main():
args = get_args()
logging.basicConfig(level=logging.DEBUG,
Expand All @@ -122,6 +119,7 @@ def main():
configs = override_config(configs, args.override_config)

reverse_weight = configs["model_conf"].get("reverse_weight", 0.0)
special_tokens = configs.get('tokenizer_conf', {}).get('special_tokens', None)
test_conf = copy.deepcopy(configs['dataset_conf'])
test_conf['filter_conf']['max_length'] = 102400
test_conf['filter_conf']['min_length'] = 0
Expand All @@ -145,7 +143,6 @@ def main():
tokenizer,
test_conf,
partition=False)

test_data_loader = DataLoader(test_dataset, batch_size=None, num_workers=0)

# Init asr model from configs
Expand All @@ -171,10 +168,18 @@ def main():
assert len(arr) == 2
char_dict[int(arr[1])] = arr[0]
vocabulary.append(arr[0])
eos = sos = len(char_dict) - 1

vocab_size = len(char_dict)
sos = (vocab_size - 1 if special_tokens is None else
special_tokens.get("<sos>", vocab_size - 1))
eos = (vocab_size - 1 if special_tokens is None else
special_tokens.get("<eos>", vocab_size - 1))

with torch.no_grad(), open(args.result_file, 'w') as fout:
for _, batch in enumerate(test_data_loader):
keys, feats, _, feats_lengths, _ = batch
keys = batch['keys']
feats = batch['feats']
feats_lengths = batch['feats_lengths']
feats, feats_lengths = feats.numpy(), feats_lengths.numpy()
if args.fp16:
feats = feats.astype(np.float16)
Expand Down Expand Up @@ -288,6 +293,5 @@ def main():
logging.info('{} {}'.format(key, content))
fout.write('{} {}\n'.format(key, content))


if __name__ == '__main__':
main()
Loading