From 8a55b0aba625e817d94daa1f4af14baad6db4a4a Mon Sep 17 00:00:00 2001 From: Vimal Manohar Date: Wed, 23 Nov 2016 16:22:57 -0500 Subject: [PATCH 001/213] Bug fix in nnet3-latgen-faster which missed uttspk option --- src/nnet3bin/nnet3-latgen-faster.cc | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/nnet3bin/nnet3-latgen-faster.cc b/src/nnet3bin/nnet3-latgen-faster.cc index 5a090acb5b5..e0f21e723e7 100644 --- a/src/nnet3bin/nnet3-latgen-faster.cc +++ b/src/nnet3bin/nnet3-latgen-faster.cc @@ -65,6 +65,8 @@ int main(int argc, char *argv[]) { po.Register("ivectors", &ivector_rspecifier, "Rspecifier for " "iVectors as vectors (i.e. not estimated online); per utterance " "by default, or per speaker if you provide the --utt2spk option."); + po.Register("utt2spk", &utt2spk_rspecifier, "Rspecifier for " + "utt2spk option used to get ivectors per speaker"); po.Register("online-ivectors", &online_ivector_rspecifier, "Rspecifier for " "iVectors estimated online, as matrices. If you supply this," " you must set the --online-ivector-period option."); From 12c619f180ce7ebed2790cdaef847399126ccf4f Mon Sep 17 00:00:00 2001 From: Vimal Manohar Date: Wed, 23 Nov 2016 16:28:57 -0500 Subject: [PATCH 002/213] Bug fix in sparse-matrix.cc --- src/matrix/sparse-matrix.cc | 1 + 1 file changed, 1 insertion(+) diff --git a/src/matrix/sparse-matrix.cc b/src/matrix/sparse-matrix.cc index 477d36f190a..2ef909f66dd 100644 --- a/src/matrix/sparse-matrix.cc +++ b/src/matrix/sparse-matrix.cc @@ -714,6 +714,7 @@ void GeneralMatrix::Compress() { void GeneralMatrix::Uncompress() { if (cmat_.NumRows() != 0) { + mat_.Resize(cmat_.NumRows(), cmat_.NumCols(), kUndefined); cmat_.CopyToMat(&mat_); cmat_.Clear(); } From 21e6e9f5855fdf3279222dc4a27b51be48811b8c Mon Sep 17 00:00:00 2001 From: Vimal Manohar Date: Wed, 23 Nov 2016 21:40:46 -0500 Subject: [PATCH 003/213] asr_diarization: Adding get_frame_shift.sh --- egs/wsj/s5/utils/data/get_frame_shift.sh | 30 ++++++++++++++---------- 1 file changed, 17 insertions(+), 13 deletions(-) diff --git a/egs/wsj/s5/utils/data/get_frame_shift.sh b/egs/wsj/s5/utils/data/get_frame_shift.sh index d032c9c17fa..f5a3bac9009 100755 --- a/egs/wsj/s5/utils/data/get_frame_shift.sh +++ b/egs/wsj/s5/utils/data/get_frame_shift.sh @@ -38,23 +38,27 @@ if [ ! -s $dir/utt2dur ]; then utils/data/get_utt2dur.sh $dir 1>&2 fi -if [ ! -f $dir/feats.scp ]; then - echo "$0: $dir/feats.scp does not exist" 1>&2 - exit 1 -fi +if [ ! -f $dir/frame_shift ]; then + if [ ! -f $dir/feats.scp ]; then + echo "$0: $dir/feats.scp does not exist" 1>&2 + exit 1 + fi -temp=$(mktemp /tmp/tmp.XXXX) + temp=$(mktemp /tmp/tmp.XXXX) -feat-to-len "scp:head -n 10 $dir/feats.scp|" ark,t:- > $temp + feat-to-len "scp:head -n 10 $dir/feats.scp|" ark,t:- > $temp -if [ -z $temp ]; then - echo "$0: error running feat-to-len" 1>&2 - exit 1 -fi + if [ -z $temp ]; then + echo "$0: error running feat-to-len" 1>&2 + exit 1 + fi -head -n 10 $dir/utt2dur | paste - $temp | \ - awk '{ dur += $2; frames += $4; } END { shift = dur / frames; if (shift > 0.01 && shift < 0.0102) shift = 0.01; print shift; }' || exit 1; + frame_shift=$(head -n 10 $dir/utt2dur | paste - $temp | awk '{ dur += $2; frames += $4; } END { shift = dur / frames; if (shift > 0.01 && shift < 0.0102) shift = 0.01; print shift; }') || exit 1; + + echo $frame_shift > $dir/frame_shift + rm $temp +fi -rm $temp +cat $dir/frame_shift exit 0 From ecdae90d76bb191ff33544879c73c60eeba476d3 Mon Sep 17 00:00:00 2001 From: Vimal Manohar Date: Thu, 24 Nov 2016 01:46:18 -0500 Subject: [PATCH 004/213] Pass --no-text option to validate data dir in speed perturbation --- egs/wsj/s5/utils/data/perturb_data_dir_speed_3way.sh | 6 +++++- egs/wsj/s5/utils/perturb_data_dir_speed.sh | 7 ++++++- 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/egs/wsj/s5/utils/data/perturb_data_dir_speed_3way.sh b/egs/wsj/s5/utils/data/perturb_data_dir_speed_3way.sh index c575166534e..4b12a94eee9 100755 --- a/egs/wsj/s5/utils/data/perturb_data_dir_speed_3way.sh +++ b/egs/wsj/s5/utils/data/perturb_data_dir_speed_3way.sh @@ -43,5 +43,9 @@ utils/data/combine_data.sh $destdir ${srcdir} ${destdir}_speed0.9 ${destdir}_spe rm -r ${destdir}_speed0.9 ${destdir}_speed1.1 echo "$0: generated 3-way speed-perturbed version of data in $srcdir, in $destdir" -utils/validate_data_dir.sh --no-feats $destdir +if [ -f $srcdir/text ]; then + utils/validate_data_dir.sh --no-feats $destdir +else + utils/validate_data_dir.sh --no-feats --no-text $destdir +fi diff --git a/egs/wsj/s5/utils/perturb_data_dir_speed.sh b/egs/wsj/s5/utils/perturb_data_dir_speed.sh index 20ff86755eb..e3d56d58b9c 100755 --- a/egs/wsj/s5/utils/perturb_data_dir_speed.sh +++ b/egs/wsj/s5/utils/perturb_data_dir_speed.sh @@ -112,4 +112,9 @@ cat $srcdir/utt2dur | utils/apply_map.pl -f 1 $destdir/utt_map | \ rm $destdir/spk_map $destdir/utt_map 2>/dev/null echo "$0: generated speed-perturbed version of data in $srcdir, in $destdir" -utils/validate_data_dir.sh --no-feats $destdir + +if [ -f $srcdir/text ]; then + utils/validate_data_dir.sh --no-feats $destdir +else + utils/validate_data_dir.sh --no-feats --no-text $destdir +fi From 5b7f150de474e1a58511b5c5a4e481254300eb7f Mon Sep 17 00:00:00 2001 From: Vimal Manohar Date: Wed, 23 Nov 2016 16:23:33 -0500 Subject: [PATCH 005/213] Print Cuda profile in nnet3-compute --- src/nnet3bin/nnet3-compute.cc | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/nnet3bin/nnet3-compute.cc b/src/nnet3bin/nnet3-compute.cc index 9305ef7e6b6..d46220c7ffd 100644 --- a/src/nnet3bin/nnet3-compute.cc +++ b/src/nnet3bin/nnet3-compute.cc @@ -159,6 +159,9 @@ int main(int argc, char *argv[]) { num_success++; } +#if HAVE_CUDA==1 + CuDevice::Instantiate().PrintProfile(); +#endif double elapsed = timer.Elapsed(); KALDI_LOG << "Time taken "<< elapsed << "s: real-time factor assuming 100 frames/sec is " From a1a5e0e863a0250529959c294462da580490acfe Mon Sep 17 00:00:00 2001 From: Vimal Manohar Date: Sun, 6 Nov 2016 10:37:19 -0500 Subject: [PATCH 006/213] asr_diarization: Fix stats printing --- src/nnet3/nnet-component-itf.cc | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/src/nnet3/nnet-component-itf.cc b/src/nnet3/nnet-component-itf.cc index 00dd802e091..f94843b725e 100644 --- a/src/nnet3/nnet-component-itf.cc +++ b/src/nnet3/nnet-component-itf.cc @@ -323,7 +323,7 @@ std::string NonlinearComponent::Info() const { stream << ", self-repair-upper-threshold=" << self_repair_upper_threshold_; if (self_repair_scale_ != 0.0) stream << ", self-repair-scale=" << self_repair_scale_; - if (count_ > 0 && value_sum_.Dim() == dim_ && deriv_sum_.Dim() == dim_) { + if (count_ > 0 && value_sum_.Dim() == dim_) { stream << ", count=" << std::setprecision(3) << count_ << std::setprecision(6); stream << ", self-repaired-proportion=" @@ -333,10 +333,12 @@ std::string NonlinearComponent::Info() const { Vector value_avg(value_avg_dbl); value_avg.Scale(1.0 / count_); stream << ", value-avg=" << SummarizeVector(value_avg); - Vector deriv_avg_dbl(deriv_sum_); - Vector deriv_avg(deriv_avg_dbl); - deriv_avg.Scale(1.0 / count_); - stream << ", deriv-avg=" << SummarizeVector(deriv_avg); + if (deriv_sum_.Dim() == dim_) { + Vector deriv_avg_dbl(deriv_sum_); + Vector deriv_avg(deriv_avg_dbl); + deriv_avg.Scale(1.0 / count_); + stream << ", deriv-avg=" << SummarizeVector(deriv_avg); + } } return stream.str(); } From 5d162053e299a5c04264b1fc2ae6bed2c270c8be Mon Sep 17 00:00:00 2001 From: Vimal Manohar Date: Thu, 24 Nov 2016 01:27:59 -0500 Subject: [PATCH 007/213] asr_diarization: Add --skip-dims option to apply-cmvn-sliding --- src/featbin/apply-cmvn-sliding.cc | 22 +++++++++++++++++----- 1 file changed, 17 insertions(+), 5 deletions(-) diff --git a/src/featbin/apply-cmvn-sliding.cc b/src/featbin/apply-cmvn-sliding.cc index 4a6d02d16cd..105319761b5 100644 --- a/src/featbin/apply-cmvn-sliding.cc +++ b/src/featbin/apply-cmvn-sliding.cc @@ -35,10 +35,13 @@ int main(int argc, char *argv[]) { "Useful for speaker-id; see also apply-cmvn-online\n" "\n" "Usage: apply-cmvn-sliding [options] \n"; - + + std::string skip_dims_str; ParseOptions po(usage); SlidingWindowCmnOptions opts; opts.Register(&po); + po.Register("skip-dims", &skip_dims_str, "Dimensions for which to skip " + "normalization: colon-separated list of integers, e.g. 13:14:15)"); po.Read(argc, argv); @@ -47,15 +50,24 @@ int main(int argc, char *argv[]) { exit(1); } + std::vector skip_dims; // optionally use "fake" + // (zero-mean/unit-variance) stats for some + // dims to disable normalization. + if (!SplitStringToIntegers(skip_dims_str, ":", false, &skip_dims)) { + KALDI_ERR << "Bad --skip-dims option (should be colon-separated list of " + << "integers)"; + } + + int32 num_done = 0, num_err = 0; - + std::string feat_rspecifier = po.GetArg(1); std::string feat_wspecifier = po.GetArg(2); SequentialBaseFloatMatrixReader feat_reader(feat_rspecifier); BaseFloatMatrixWriter feat_writer(feat_wspecifier); - - for (;!feat_reader.Done(); feat_reader.Next()) { + + for (; !feat_reader.Done(); feat_reader.Next()) { std::string utt = feat_reader.Key(); Matrix feat(feat_reader.Value()); if (feat.NumRows() == 0) { @@ -67,7 +79,7 @@ int main(int argc, char *argv[]) { feat.NumCols(), kUndefined); SlidingWindowCmn(opts, feat, &cmvn_feat); - + feat_writer.Write(utt, cmvn_feat); num_done++; } From 4c5cd5438b6166abe3fdee40489fc840d7c950b1 Mon Sep 17 00:00:00 2001 From: Vimal Manohar Date: Tue, 22 Nov 2016 11:23:08 -0500 Subject: [PATCH 008/213] asr_diarization: Adding length-tolerace to extract ivector scripts --- .../s5/steps/online/nnet2/extract_ivectors.sh | 8 +++---- src/bin/weight-post.cc | 23 ++++++++++++------- 2 files changed, 19 insertions(+), 12 deletions(-) diff --git a/egs/wsj/s5/steps/online/nnet2/extract_ivectors.sh b/egs/wsj/s5/steps/online/nnet2/extract_ivectors.sh index f27baecd673..2f55053efd5 100755 --- a/egs/wsj/s5/steps/online/nnet2/extract_ivectors.sh +++ b/egs/wsj/s5/steps/online/nnet2/extract_ivectors.sh @@ -172,8 +172,8 @@ if [ $sub_speaker_frames -gt 0 ]; then feat-to-len scp:$data/feats.scp ark,t:- > $dir/utt_counts || exit 1; fi if ! [ $(wc -l <$dir/utt_counts) -eq $(wc -l <$data/feats.scp) ]; then - echo "$0: error getting per-utterance counts." - exit 0; + echo "$0: error getting per-utterance counts. Number of lines in $dir/utt_counts differs from $data/feats.scp" + exit 1; fi cat $data/spk2utt | python -c " import sys @@ -229,8 +229,8 @@ if [ $stage -le 2 ]; then if [ ! -z "$ali_or_decode_dir" ]; then $cmd JOB=1:$nj $dir/log/extract_ivectors.JOB.log \ gmm-global-get-post --n=$num_gselect --min-post=$min_post $srcdir/final.dubm "$gmm_feats" ark:- \| \ - weight-post ark:- "ark,s,cs:gunzip -c $dir/weights.gz|" ark:- \| \ - ivector-extract --acoustic-weight=$posterior_scale --compute-objf-change=true \ + weight-post --length-tolerance=1 ark:- "ark,s,cs:gunzip -c $dir/weights.gz|" ark:- \| \ + ivector-extract --length-tolerance=1 --acoustic-weight=$posterior_scale --compute-objf-change=true \ --max-count=$max_count --spk2utt=ark:$this_sdata/JOB/spk2utt \ $srcdir/final.ie "$feats" ark,s,cs:- ark,t:$dir/ivectors_spk.JOB.ark || exit 1; else diff --git a/src/bin/weight-post.cc b/src/bin/weight-post.cc index d536896eaaa..bbaad465195 100644 --- a/src/bin/weight-post.cc +++ b/src/bin/weight-post.cc @@ -26,32 +26,38 @@ int main(int argc, char *argv[]) { try { using namespace kaldi; - typedef kaldi::int32 int32; + typedef kaldi::int32 int32; + + int32 length_tolerance = 2; const char *usage = "Takes archives (typically per-utterance) of posteriors and per-frame weights,\n" "and weights the posteriors by the per-frame weights\n" "\n" "Usage: weight-post \n"; - + ParseOptions po(usage); + + po.Register("length-tolerance", &length_tolerance, + "Tolerate this many frames of length mismatch"); + po.Read(argc, argv); if (po.NumArgs() != 3) { po.PrintUsage(); exit(1); } - + std::string post_rspecifier = po.GetArg(1), weights_rspecifier = po.GetArg(2), post_wspecifier = po.GetArg(3); SequentialPosteriorReader posterior_reader(post_rspecifier); RandomAccessBaseFloatVectorReader weights_reader(weights_rspecifier); - PosteriorWriter post_writer(post_wspecifier); - + PosteriorWriter post_writer(post_wspecifier); + int32 num_done = 0, num_err = 0; - + for (; !posterior_reader.Done(); posterior_reader.Next()) { std::string key = posterior_reader.Key(); Posterior post = posterior_reader.Value(); @@ -61,7 +67,8 @@ int main(int argc, char *argv[]) { continue; } const Vector &weights = weights_reader.Value(key); - if (weights.Dim() != static_cast(post.size())) { + if (std::abs(weights.Dim() - static_cast(post.size())) > + length_tolerance) { KALDI_WARN << "Weights for utterance " << key << " have wrong size, " << weights.Dim() << " vs. " << post.size(); @@ -71,7 +78,7 @@ int main(int argc, char *argv[]) { for (size_t i = 0; i < post.size(); i++) { if (weights(i) == 0.0) post[i].clear(); for (size_t j = 0; j < post[i].size(); j++) - post[i][j].second *= weights(i); + post[i][j].second *= i < weights.Dim() ? weights(i) : 0.0; } post_writer.Write(key, post); num_done++; From a71da1a4c59714bc0862e6b9a54e23197ee7c3bb Mon Sep 17 00:00:00 2001 From: Vimal Manohar Date: Fri, 25 Nov 2016 15:45:40 -0500 Subject: [PATCH 009/213] asr_diarization: Adding --do-average option to matrix-sum-rows --- src/bin/matrix-sum-rows.cc | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/src/bin/matrix-sum-rows.cc b/src/bin/matrix-sum-rows.cc index 7e60483eef2..ee6504ba2b1 100644 --- a/src/bin/matrix-sum-rows.cc +++ b/src/bin/matrix-sum-rows.cc @@ -34,9 +34,13 @@ int main(int argc, char *argv[]) { "e.g.: matrix-sum-rows ark:- ark:- | vector-sum ark:- sum.vec\n" "See also: matrix-sum, vector-sum\n"; + bool do_average = false; ParseOptions po(usage); + po.Register("do-average", &do_average, + "Do average instead of sum"); + po.Read(argc, argv); if (po.NumArgs() != 2) { @@ -45,28 +49,28 @@ int main(int argc, char *argv[]) { } std::string rspecifier = po.GetArg(1); std::string wspecifier = po.GetArg(2); - + SequentialBaseFloatMatrixReader mat_reader(rspecifier); BaseFloatVectorWriter vec_writer(wspecifier); - + int32 num_done = 0; int64 num_rows_done = 0; - + for (; !mat_reader.Done(); mat_reader.Next()) { std::string key = mat_reader.Key(); Matrix mat(mat_reader.Value()); Vector vec(mat.NumCols()); - vec.AddRowSumMat(1.0, mat, 0.0); + vec.AddRowSumMat(!do_average ? 1.0 : 1.0 / mat.NumRows(), mat, 0.0); // Do the summation in double, to minimize roundoff. Vector float_vec(vec); vec_writer.Write(key, float_vec); num_done++; num_rows_done += mat.NumRows(); } - + KALDI_LOG << "Summed rows " << num_done << " matrices, " << num_rows_done << " rows in total."; - + return (num_done != 0 ? 0 : 1); } catch(const std::exception &e) { std::cerr << e.what(); From 8fa2b211a218473362c907f24312c0c7275fcc0d Mon Sep 17 00:00:00 2001 From: Vimal Manohar Date: Fri, 23 Sep 2016 23:15:58 -0400 Subject: [PATCH 010/213] asr_diarization: Added weight-pdf-post, vector-to-feat, kaldi-matrix softmax per row, copy-matrix apply-log, matrix-add-offset, matrix-dot-product --- src/bin/Makefile | 3 +- src/bin/copy-matrix.cc | 37 ++++++- src/bin/matrix-add-offset.cc | 84 ++++++++++++++++ src/bin/matrix-dot-product.cc | 183 ++++++++++++++++++++++++++++++++++ src/bin/vector-scale.cc | 37 +++++-- src/bin/weight-matrix.cc | 84 ++++++++++++++++ src/bin/weight-pdf-post.cc | 154 ++++++++++++++++++++++++++++ src/featbin/Makefile | 3 +- src/featbin/extract-column.cc | 84 ++++++++++++++++ src/featbin/vector-to-feat.cc | 100 +++++++++++++++++++ src/matrix/kaldi-matrix.cc | 9 ++ src/matrix/kaldi-matrix.h | 5 + 12 files changed, 767 insertions(+), 16 deletions(-) create mode 100644 src/bin/matrix-add-offset.cc create mode 100644 src/bin/matrix-dot-product.cc create mode 100644 src/bin/weight-matrix.cc create mode 100644 src/bin/weight-pdf-post.cc create mode 100644 src/featbin/extract-column.cc create mode 100644 src/featbin/vector-to-feat.cc diff --git a/src/bin/Makefile b/src/bin/Makefile index 687040889b3..3dc59fe8112 100644 --- a/src/bin/Makefile +++ b/src/bin/Makefile @@ -24,7 +24,8 @@ BINFILES = align-equal align-equal-compiled acc-tree-stats \ matrix-logprob matrix-sum \ build-pfile-from-ali get-post-on-ali tree-info am-info \ vector-sum matrix-sum-rows est-pca sum-lda-accs sum-mllt-accs \ - transform-vec align-text matrix-dim + transform-vec align-text matrix-dim weight-pdf-post weight-matrix \ + matrix-add-offset matrix-dot-product OBJFILES = diff --git a/src/bin/copy-matrix.cc b/src/bin/copy-matrix.cc index d7b8181c64c..56f2e51d90f 100644 --- a/src/bin/copy-matrix.cc +++ b/src/bin/copy-matrix.cc @@ -36,16 +36,30 @@ int main(int argc, char *argv[]) { " e.g.: copy-matrix --binary=false 1.mat -\n" " copy-matrix ark:2.trans ark,t:-\n" "See also: copy-feats\n"; - + bool binary = true; + bool apply_log = false; + bool apply_exp = false; + bool apply_softmax_per_row = false; + BaseFloat apply_power = 1.0; BaseFloat scale = 1.0; + ParseOptions po(usage); po.Register("binary", &binary, "Write in binary mode (only relevant if output is a wxfilename)"); po.Register("scale", &scale, "This option can be used to scale the matrices being copied."); - + po.Register("apply-log", &apply_log, + "This option can be used to apply log on the matrices. " + "Must be avoided if matrix has negative quantities."); + po.Register("apply-exp", &apply_exp, + "This option can be used to apply exp on the matrices"); + po.Register("apply-power", &apply_power, + "This option can be used to apply a power on the matrices"); + po.Register("apply-softmax-per-row", &apply_softmax_per_row, + "This option can be used to apply softmax per row of the matrices"); + po.Read(argc, argv); if (po.NumArgs() != 2) { @@ -53,6 +67,10 @@ int main(int argc, char *argv[]) { exit(1); } + if ( (apply_log && apply_exp) || (apply_softmax_per_row && apply_exp) || + (apply_softmax_per_row && apply_log) ) + KALDI_ERR << "Only one of apply-log, apply-exp and " + << "apply-softmax-per-row can be given"; std::string matrix_in_fn = po.GetArg(1), matrix_out_fn = po.GetArg(2); @@ -68,11 +86,15 @@ int main(int argc, char *argv[]) { if (in_is_rspecifier != out_is_wspecifier) KALDI_ERR << "Cannot mix archives with regular files (copying matrices)"; - + if (!in_is_rspecifier) { Matrix mat; ReadKaldiObject(matrix_in_fn, &mat); if (scale != 1.0) mat.Scale(scale); + if (apply_log) mat.ApplyLog(); + if (apply_exp) mat.ApplyExp(); + if (apply_softmax_per_row) mat.ApplySoftMaxPerRow(); + if (apply_power != 1.0) mat.ApplyPow(apply_power); Output ko(matrix_out_fn, binary); mat.Write(ko.Stream(), binary); KALDI_LOG << "Copied matrix to " << matrix_out_fn; @@ -82,9 +104,14 @@ int main(int argc, char *argv[]) { BaseFloatMatrixWriter writer(matrix_out_fn); SequentialBaseFloatMatrixReader reader(matrix_in_fn); for (; !reader.Done(); reader.Next(), num_done++) { - if (scale != 1.0) { + if (scale != 1.0 || apply_log || apply_exp || + apply_power != 1.0 || apply_softmax_per_row) { Matrix mat(reader.Value()); - mat.Scale(scale); + if (scale != 1.0) mat.Scale(scale); + if (apply_log) mat.ApplyLog(); + if (apply_exp) mat.ApplyExp(); + if (apply_softmax_per_row) mat.ApplySoftMaxPerRow(); + if (apply_power != 1.0) mat.ApplyPow(apply_power); writer.Write(reader.Key(), mat); } else { writer.Write(reader.Key(), reader.Value()); diff --git a/src/bin/matrix-add-offset.cc b/src/bin/matrix-add-offset.cc new file mode 100644 index 00000000000..90f72ba3254 --- /dev/null +++ b/src/bin/matrix-add-offset.cc @@ -0,0 +1,84 @@ +// bin/matrix-add-offset.cc + +// Copyright 2015 Vimal Manohar + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "base/kaldi-common.h" +#include "util/common-utils.h" +#include "matrix/kaldi-matrix.h" + + +int main(int argc, char *argv[]) { + try { + using namespace kaldi; + + const char *usage = + "Add an offset vector to the rows of matrices in a table.\n" + "\n" + "Usage: matrix-add-offset [options] " + " \n" + "e.g.: matrix-add-offset log_post.mat neg_priors.vec log_like.mat\n" + "See also: matrix-sum-rows, matrix-sum, vector-sum\n"; + + + ParseOptions po(usage); + + po.Read(argc, argv); + + if (po.NumArgs() != 3) { + po.PrintUsage(); + exit(1); + } + std::string rspecifier = po.GetArg(1); + std::string vector_rxfilename = po.GetArg(2); + std::string wspecifier = po.GetArg(3); + + SequentialBaseFloatMatrixReader mat_reader(rspecifier); + BaseFloatMatrixWriter mat_writer(wspecifier); + + int32 num_done = 0; + + Vector vec; + { + bool binary_in; + Input ki(vector_rxfilename, &binary_in); + vec.Read(ki.Stream(), binary_in); + } + + for (; !mat_reader.Done(); mat_reader.Next()) { + std::string key = mat_reader.Key(); + Matrix mat(mat_reader.Value()); + if (vec.Dim() != mat.NumCols()) { + KALDI_ERR << "Mismatch in vector dimension and " + << "number of columns in matrix; " + << vec.Dim() << " vs " << mat.NumCols(); + } + mat.AddVecToRows(1.0, vec); + mat_writer.Write(key, mat); + num_done++; + } + + KALDI_LOG << "Added offset to " << num_done << " matrices."; + + return (num_done != 0 ? 0 : 1); + } catch(const std::exception &e) { + std::cerr << e.what(); + return -1; + } +} + + diff --git a/src/bin/matrix-dot-product.cc b/src/bin/matrix-dot-product.cc new file mode 100644 index 00000000000..a292cab9a40 --- /dev/null +++ b/src/bin/matrix-dot-product.cc @@ -0,0 +1,183 @@ +// bin/matrix-dot-product.cc + +// Copyright 2016 Vimal Manohar + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "base/kaldi-common.h" +#include "util/common-utils.h" +#include "matrix/kaldi-matrix.h" + +int main(int argc, char *argv[]) { + try { + using namespace kaldi; + + const char *usage = + "Get element-wise dot product of matrices. Always returns a matrix " + "that is the same size as the first matrix.\n" + "If there is a mismatch in number of rows, the utterance is skipped, " + "unless the mismatch is within a tolerance. If the second matrix has " + "number of rows that is larger than the first matrix by less than the " + "specified tolerance, then a submatrix of the second matrix is " + "multiplied element-wise with the first matrix.\n" + "\n" + "Usage: matrix-dot-product [options] " + "[ ...] " + "\n" + " e.g.: matrix-dot-product ark:1.weights ark:2.weights " + "ark:combine.weights\n" + "or \n" + "Usage: matrix-dot-product [options] " + "[ ...] " + "\n" + " e.g.: matrix-sum --binary=false 1.mat 2.mat product.mat\n" + "See also: matrix-sum, matrix-sum-rows\n"; + + bool binary = true; + int32 length_tolerance = 0; + + ParseOptions po(usage); + + po.Register("binary", &binary, "If true, write output as binary (only " + "relevant for usage types two or three"); + po.Register("length-tolerance", &length_tolerance, + "Tolerance length mismatch of this many frames"); + + po.Read(argc, argv); + + if (po.NumArgs() < 2) { + po.PrintUsage(); + exit(1); + } + + int32 N = po.NumArgs(); + std::string matrix_in_fn1 = po.GetArg(1), + matrix_out_fn = po.GetArg(N); + + if (ClassifyWspecifier(matrix_out_fn, NULL, NULL, NULL) != kNoWspecifier) { + // output to table. + + // Output matrix + BaseFloatMatrixWriter matrix_writer(matrix_out_fn); + + // Input matrices + SequentialBaseFloatMatrixReader matrix_reader1(matrix_in_fn1); + std::vector + matrix_readers(N-2, + static_cast(NULL)); + std::vector matrix_in_fns(N-2); + for (int32 i = 2; i < N; ++i) { + matrix_readers[i-2] = new RandomAccessBaseFloatMatrixReader( + po.GetArg(i)); + matrix_in_fns[i-2] = po.GetArg(i); + } + int32 n_utts = 0, n_total_matrices = 0, + n_success = 0, n_missing = 0, n_other_errors = 0; + + for (; !matrix_reader1.Done(); matrix_reader1.Next()) { + std::string key = matrix_reader1.Key(); + Matrix matrix1 = matrix_reader1.Value(); + matrix_reader1.FreeCurrent(); + n_utts++; + n_total_matrices++; + + Matrix matrix_out(matrix1); + + int32 i = 0; + for (i = 0; i < N-2; ++i) { + bool failed = false; // Indicates failure for this key. + if (matrix_readers[i]->HasKey(key)) { + const Matrix &matrix2 = matrix_readers[i]->Value(key); + n_total_matrices++; + if (SameDim(matrix2, matrix_out)) { + matrix_out.MulElements(matrix2); + } else { + KALDI_WARN << "Dimension mismatch for utterance " << key + << " : " << matrix2.NumRows() << " by " + << matrix2.NumCols() << " for " + << "system " << (i + 2) << ", rspecifier: " + << matrix_in_fns[i] << " vs " << matrix_out.NumRows() + << " by " << matrix_out.NumCols() + << " primary matrix, rspecifier:" << matrix_in_fn1; + if (matrix2.NumRows() - matrix_out.NumRows() <= + length_tolerance) { + KALDI_WARN << "Tolerated length mismatch for key " << key; + matrix_out.MulElements(matrix2.Range(0, matrix_out.NumRows(), + 0, matrix2.NumCols())); + } else { + KALDI_WARN << "Skipping key " << key; + failed = true; + n_other_errors++; + } + } + } else { + KALDI_WARN << "No matrix found for utterance " << key << " for " + << "system " << (i + 2) << ", rspecifier: " + << matrix_in_fns[i]; + failed = true; + n_missing++; + } + + if (failed) break; + } + + if (i != N-2) // Skipping utterance + continue; + + matrix_writer.Write(key, matrix_out); + n_success++; + } + + KALDI_LOG << "Processed " << n_utts << " utterances: with a total of " + << n_total_matrices << " matrices across " << (N-1) + << " different systems."; + KALDI_LOG << "Produced output for " << n_success << " utterances; " + << n_missing << " total missing matrices and skipped " + << n_other_errors << "matrices."; + + DeletePointers(&matrix_readers); + + return (n_success != 0 && n_missing < (n_success - n_missing)) ? 0 : 1; + } else { + for (int32 i = 1; i < N; i++) { + if (ClassifyRspecifier(po.GetArg(i), NULL, NULL) != kNoRspecifier) { + KALDI_ERR << "Wrong usage: if last argument is not " + << "table, the other arguments must not be tables."; + } + } + + Matrix mat1; + ReadKaldiObject(po.GetArg(1), &mat1); + + for (int32 i = 2; i < N; i++) { + Matrix mat; + ReadKaldiObject(po.GetArg(i), &mat); + + mat1.MulElements(mat); + } + + WriteKaldiObject(mat1, po.GetArg(N), binary); + KALDI_LOG << "Multiplied " << (po.NumArgs() - 1) << " matrices; " + << "wrote product to " << PrintableWxfilename(po.GetArg(N)); + + return 0; + } + } catch(const std::exception &e) { + std::cerr << e.what(); + return -1; + } +} + diff --git a/src/bin/vector-scale.cc b/src/bin/vector-scale.cc index 60d4d3121d2..ea68ae31ad0 100644 --- a/src/bin/vector-scale.cc +++ b/src/bin/vector-scale.cc @@ -30,11 +30,14 @@ int main(int argc, char *argv[]) { const char *usage = "Scale a set of vectors in a Table (useful for speaker vectors and " "per-frame weights)\n" - "Usage: vector-scale [options] \n"; + "Usage: vector-scale [options] \n"; ParseOptions po(usage); BaseFloat scale = 1.0; + bool binary = false; + po.Register("binary", &binary, "If true, write output as binary " + "not relevant for archives"); po.Register("scale", &scale, "Scaling factor for vectors"); po.Read(argc, argv); @@ -43,17 +46,33 @@ int main(int argc, char *argv[]) { exit(1); } - std::string rspecifier = po.GetArg(1); - std::string wspecifier = po.GetArg(2); + std::string vector_in_fn = po.GetArg(1); + std::string vector_out_fn = po.GetArg(2); - BaseFloatVectorWriter vec_writer(wspecifier); - - SequentialBaseFloatVectorReader vec_reader(rspecifier); - for (; !vec_reader.Done(); vec_reader.Next()) { - Vector vec(vec_reader.Value()); + if (ClassifyWspecifier(vector_in_fn, NULL, NULL, NULL) != kNoWspecifier) { + if (ClassifyRspecifier(vector_in_fn, NULL, NULL) == kNoRspecifier) { + KALDI_ERR << "Cannot mix archives and regular files"; + } + BaseFloatVectorWriter vec_writer(vector_out_fn); + SequentialBaseFloatVectorReader vec_reader(vector_in_fn); + for (; !vec_reader.Done(); vec_reader.Next()) { + Vector vec(vec_reader.Value()); + vec.Scale(scale); + vec_writer.Write(vec_reader.Key(), vec); + } + } else { + if (ClassifyRspecifier(vector_in_fn, NULL, NULL) != kNoRspecifier) { + KALDI_ERR << "Cannot mix archives and regular files"; + } + bool binary_in; + Input ki(vector_in_fn, &binary_in); + Vector vec; + vec.Read(ki.Stream(), binary_in); vec.Scale(scale); - vec_writer.Write(vec_reader.Key(), vec); + Output ko(vector_out_fn, binary); + vec.Write(ko.Stream(), binary); } + return 0; } catch(const std::exception &e) { std::cerr << e.what(); diff --git a/src/bin/weight-matrix.cc b/src/bin/weight-matrix.cc new file mode 100644 index 00000000000..c6823b8da29 --- /dev/null +++ b/src/bin/weight-matrix.cc @@ -0,0 +1,84 @@ +// bin/weight-matrix.cc + +// Copyright 2016 Vimal Manohar + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + + +#include "base/kaldi-common.h" +#include "util/common-utils.h" + +int main(int argc, char *argv[]) { + try { + using namespace kaldi; + typedef kaldi::int32 int32; + + const char *usage = + "Takes archives (typically per-utterance) of features and " + "per-frame weights,\n" + "and weights the features by the per-frame weights\n" + "\n" + "Usage: weight-matrix " + "\n"; + + ParseOptions po(usage); + po.Read(argc, argv); + + if (po.NumArgs() != 3) { + po.PrintUsage(); + exit(1); + } + + std::string matrix_rspecifier = po.GetArg(1), + weights_rspecifier = po.GetArg(2), + matrix_wspecifier = po.GetArg(3); + + SequentialBaseFloatMatrixReader matrix_reader(matrix_rspecifier); + RandomAccessBaseFloatVectorReader weights_reader(weights_rspecifier); + BaseFloatMatrixWriter matrix_writer(matrix_wspecifier); + + int32 num_done = 0, num_err = 0; + + for (; !matrix_reader.Done(); matrix_reader.Next()) { + std::string key = matrix_reader.Key(); + Matrix mat = matrix_reader.Value(); + if (!weights_reader.HasKey(key)) { + KALDI_WARN << "No weight vectors for utterance " << key; + num_err++; + continue; + } + const Vector &weights = weights_reader.Value(key); + if (weights.Dim() != mat.NumRows()) { + KALDI_WARN << "Weights for utterance " << key + << " have wrong size, " << weights.Dim() + << " vs. " << mat.NumRows(); + num_err++; + continue; + } + mat.MulRowsVec(weights); + matrix_writer.Write(key, mat); + num_done++; + } + KALDI_LOG << "Applied per-frame weights for " << num_done + << " matrices; errors on " << num_err; + return (num_done != 0 ? 0 : 1); + } catch(const std::exception &e) { + std::cerr << e.what(); + return -1; + } +} + + diff --git a/src/bin/weight-pdf-post.cc b/src/bin/weight-pdf-post.cc new file mode 100644 index 00000000000..c7477a046c8 --- /dev/null +++ b/src/bin/weight-pdf-post.cc @@ -0,0 +1,154 @@ +// bin/weight-pdf-post.cc + +// Copyright 2015 Vimal Manohar (Johns Hopkins University) + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + + +#include "base/kaldi-common.h" +#include "util/common-utils.h" +#include "gmm/am-diag-gmm.h" +#include "hmm/transition-model.h" +#include "hmm/hmm-utils.h" +#include "hmm/posterior.h" + +namespace kaldi { + +void WeightPdfPost(const ConstIntegerSet &pdf_set, + BaseFloat pdf_scale, + Posterior *post) { + for (size_t i = 0; i < post->size(); i++) { + std::vector > this_post; + this_post.reserve((*post)[i].size()); + for (size_t j = 0; j < (*post)[i].size(); j++) { + int32 pdf_id = (*post)[i][j].first; + BaseFloat weight = (*post)[i][j].second; + if (pdf_set.count(pdf_id) != 0) { // is a silence. + if (pdf_scale != 0.0) + this_post.push_back(std::make_pair(pdf_id, weight*pdf_scale)); + } else { + this_post.push_back(std::make_pair(pdf_id, weight)); + } + } + (*post)[i].swap(this_post); + } +} + +void WeightPdfPostDistributed(const ConstIntegerSet &pdf_set, + BaseFloat pdf_scale, + Posterior *post) { + for (size_t i = 0; i < post->size(); i++) { + std::vector > this_post; + this_post.reserve((*post)[i].size()); + BaseFloat sil_weight = 0.0, nonsil_weight = 0.0; + for (size_t j = 0; j < (*post)[i].size(); j++) { + int32 pdf_id = (*post)[i][j].first; + BaseFloat weight = (*post)[i][j].second; + if (pdf_set.count(pdf_id) != 0) + sil_weight += weight; + else + nonsil_weight += weight; + } + // This "distributed" weighting approach doesn't make sense if we have + // negative weights. + KALDI_ASSERT(sil_weight >= 0.0 && nonsil_weight >= 0.0); + if (sil_weight + nonsil_weight == 0.0) continue; + BaseFloat frame_scale = (sil_weight * pdf_scale + nonsil_weight) / + (sil_weight + nonsil_weight); + if (frame_scale != 0.0) { + for (size_t j = 0; j < (*post)[i].size(); j++) { + int32 pdf_id = (*post)[i][j].first; + BaseFloat weight = (*post)[i][j].second; + this_post.push_back(std::make_pair(pdf_id, weight * frame_scale)); + } + } + (*post)[i].swap(this_post); + } +} + +} // namespace kaldi + +int main(int argc, char *argv[]) { + using namespace kaldi; + typedef kaldi::int32 int32; + try { + const char *usage = + "Apply weight to specific pdfs or tids in posts\n" + "Usage: weight-pdf-post [options] " + " \n" + "e.g.:\n" + " weight-pdf-post 0.00001 0:2 ark:1.post ark:nosil.post\n"; + + ParseOptions po(usage); + + bool distribute = false; + + po.Register("distribute", &distribute, "If true, rather than weighting the " + "individual posteriors, apply the weighting to the " + "whole frame: " + "i.e. on time t, scale all posterior entries by " + "p(sil)*silence-weight + p(non-sil)*1.0"); + + po.Read(argc, argv); + + if (po.NumArgs() != 4) { + po.PrintUsage(); + exit(1); + } + + std::string pdf_weight_str = po.GetArg(1), + pdfs_str = po.GetArg(2), + posteriors_rspecifier = po.GetArg(3), + posteriors_wspecifier = po.GetArg(4); + + BaseFloat pdf_weight = 0.0; + if (!ConvertStringToReal(pdf_weight_str, &pdf_weight)) + KALDI_ERR << "Invalid pdf-weight parameter: expected float, got \"" + << pdf_weight << '"'; + std::vector pdfs; + if (!SplitStringToIntegers(pdfs_str, ":", false, &pdfs)) + KALDI_ERR << "Invalid pdf string string " << pdfs_str; + if (pdfs.empty()) + KALDI_WARN <<"No pdf specified, this will have no effect"; + ConstIntegerSet pdf_set(pdfs); // faster lookup. + + int32 num_posteriors = 0; + SequentialPosteriorReader posterior_reader(posteriors_rspecifier); + PosteriorWriter posterior_writer(posteriors_wspecifier); + + for (; !posterior_reader.Done(); posterior_reader.Next()) { + num_posteriors++; + // Posterior is vector > > + Posterior post = posterior_reader.Value(); + // Posterior is vector > > + if (distribute) + WeightPdfPostDistributed(pdf_set, + pdf_weight, &post); + else + WeightPdfPost(pdf_set, + pdf_weight, &post); + + posterior_writer.Write(posterior_reader.Key(), post); + } + KALDI_LOG << "Done " << num_posteriors << " posteriors."; + return (num_posteriors != 0 ? 0 : 1); + } catch(const std::exception &e) { + std::cerr << e.what(); + return -1; + } +} + + diff --git a/src/featbin/Makefile b/src/featbin/Makefile index dc2bea215d8..e1a9a1ebe0d 100644 --- a/src/featbin/Makefile +++ b/src/featbin/Makefile @@ -15,7 +15,8 @@ BINFILES = compute-mfcc-feats compute-plp-feats compute-fbank-feats \ process-kaldi-pitch-feats compare-feats wav-to-duration add-deltas-sdc \ compute-and-process-kaldi-pitch-feats modify-cmvn-stats wav-copy \ wav-reverberate append-vector-to-feats detect-sinusoids shift-feats \ - concat-feats append-post-to-feats post-to-feats + concat-feats append-post-to-feats post-to-feats vector-to-feat \ + extract-column OBJFILES = diff --git a/src/featbin/extract-column.cc b/src/featbin/extract-column.cc new file mode 100644 index 00000000000..7fa6644af03 --- /dev/null +++ b/src/featbin/extract-column.cc @@ -0,0 +1,84 @@ +// featbin/extract-column.cc + +// Copyright 2015 Vimal Manohar (Johns Hopkins University) + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "base/kaldi-common.h" +#include "util/common-utils.h" +#include "matrix/kaldi-matrix.h" + +int main(int argc, char *argv[]) { + try { + using namespace kaldi; + using namespace std; + + const char *usage = + "Extract a column out of a matrix. \n" + "This is most useful to extract log-energies \n" + "from feature files\n" + "\n" + "Usage: extract-column [options] --column-index= " + " \n" + " e.g. extract-column ark:feats-in.ark ark:energies.ark\n" + "See also: select-feats, subset-feats, subsample-feats, extract-rows\n"; + + ParseOptions po(usage); + + int32 column_index = 0; + + po.Register("column-index", &column_index, + "Index of column to extract"); + + po.Read(argc, argv); + + if (po.NumArgs() != 2) { + po.PrintUsage(); + exit(1); + } + + string feat_rspecifier = po.GetArg(1); + string vector_wspecifier = po.GetArg(2); + + SequentialBaseFloatMatrixReader reader(feat_rspecifier); + BaseFloatVectorWriter writer(vector_wspecifier); + + int32 num_done = 0, num_err = 0; + + string line; + + for (; !reader.Done(); reader.Next(), num_done++) { + const Matrix& feats(reader.Value()); + Vector col(feats.NumRows()); + if (column_index >= feats.NumCols()) { + KALDI_ERR << "Column index " << column_index << " is " + << "not less than number of columns " << feats.NumCols(); + } + col.CopyColFromMat(feats, column_index); + writer.Write(reader.Key(), col); + } + + KALDI_LOG << "Processed " << num_done << " matrices successfully; " + << "errors on " << num_err; + + return (num_done > 0 ? 0 : 1); + } catch(const std::exception &e) { + std::cerr << e.what(); + return -1; + } +} + + diff --git a/src/featbin/vector-to-feat.cc b/src/featbin/vector-to-feat.cc new file mode 100644 index 00000000000..1fe521db864 --- /dev/null +++ b/src/featbin/vector-to-feat.cc @@ -0,0 +1,100 @@ +// featbin/vector-to-feat.cc + +// Copyright 2015 Vimal Manohar + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "base/kaldi-common.h" +#include "util/common-utils.h" +#include "matrix/kaldi-matrix.h" + +int main(int argc, char *argv[]) { + try { + using namespace kaldi; + + const char *usage = + "Convert a vector into a single feature so that it can be appended \n" + "to other feature matrices\n" + "Usage: vector-to-feats \n" + "or: vector-to-feats \n" + "e.g.: vector-to-feats scp:weights.scp ark:weight_feats.ark\n" + " or: vector-to-feats weight_vec feat_mat\n" + "See also: copy-feats, copy-matrix, paste-feats, \n" + "subsample-feats, splice-feats\n"; + + ParseOptions po(usage); + bool compress = false, binary = true; + + po.Register("binary", &binary, "Binary-mode output (not relevant if writing " + "to archive)"); + po.Register("compress", &compress, "If true, write output in compressed form" + "(only currently supported for wxfilename, i.e. archive/script," + "output)"); + + po.Read(argc, argv); + + if (po.NumArgs() != 2) { + po.PrintUsage(); + exit(1); + } + + int32 num_done = 0; + + if (ClassifyRspecifier(po.GetArg(1), NULL, NULL) != kNoRspecifier) { + std::string vector_rspecifier = po.GetArg(1); + std::string feature_wspecifier = po.GetArg(2); + + SequentialBaseFloatVectorReader vector_reader(vector_rspecifier); + BaseFloatMatrixWriter feat_writer(feature_wspecifier); + CompressedMatrixWriter compressed_feat_writer(feature_wspecifier); + + for (; !vector_reader.Done(); vector_reader.Next(), ++num_done) { + const Vector &vec = vector_reader.Value(); + Matrix feat(vec.Dim(), 1); + feat.CopyColFromVec(vec, 0); + + if (!compress) + feat_writer.Write(vector_reader.Key(), feat); + else + compressed_feat_writer.Write(vector_reader.Key(), + CompressedMatrix(feat)); + } + KALDI_LOG << "Converted " << num_done << " vectors into features"; + return (num_done != 0 ? 0 : 1); + } + + KALDI_ASSERT(!compress && "Compression not yet supported for single files"); + + std::string vector_rxfilename = po.GetArg(1), + feature_wxfilename = po.GetArg(2); + + Vector vec; + ReadKaldiObject(vector_rxfilename, &vec); + + Matrix feat(vec.Dim(), 1); + feat.CopyColFromVec(vec, 0); + + WriteKaldiObject(feat, feature_wxfilename, binary); + + KALDI_LOG << "Converted vector " << PrintableRxfilename(vector_rxfilename) + << " to " << PrintableWxfilename(feature_wxfilename); + return 0; + } catch(const std::exception &e) { + std::cerr << e.what(); + return -1; + } +} + diff --git a/src/matrix/kaldi-matrix.cc b/src/matrix/kaldi-matrix.cc index 34003e8a550..4c3948ba2f5 100644 --- a/src/matrix/kaldi-matrix.cc +++ b/src/matrix/kaldi-matrix.cc @@ -2533,6 +2533,15 @@ Real MatrixBase::ApplySoftMax() { return max + Log(sum); } +template +void MatrixBase::ApplySoftMaxPerRow() { + for (MatrixIndexT i = 0; i < num_rows_; i++) { + Row(i).ApplySoftMax(); + kaldi::ApproxEqual(Row(i).Sum(), 1.0); + } + KALDI_ASSERT(Max() <= 1.0 && Min() >= 0.0); +} + template void MatrixBase::Tanh(const MatrixBase &src) { KALDI_ASSERT(SameDim(*this, src)); diff --git a/src/matrix/kaldi-matrix.h b/src/matrix/kaldi-matrix.h index e254fcad118..dccd52a9af4 100644 --- a/src/matrix/kaldi-matrix.h +++ b/src/matrix/kaldi-matrix.h @@ -453,6 +453,11 @@ class MatrixBase { /// Apply soft-max to the collection of all elements of the /// matrix and return normalizer (log sum of exponentials). Real ApplySoftMax(); + + /// Softmax nonlinearity + /// Y = Softmax(X) : Yij = e^Xij / sum_k(e^Xik), done to each row + /// for each row, the max value is first subtracted for good numerical stability + void ApplySoftMaxPerRow(); /// Set each element to the sigmoid of the corresponding element of "src". void Sigmoid(const MatrixBase &src); From fb43c8ca9ff2910313ee49ca5a143e11ef6c9b7d Mon Sep 17 00:00:00 2001 From: Vimal Manohar Date: Sun, 25 Sep 2016 15:02:24 -0400 Subject: [PATCH 011/213] asr_diarization: Modify subsegment_feats and add fix_subsegmented_feats.pl --- .../s5/utils/data/fix_subsegmented_feats.pl | 79 +++++++++++++++++++ egs/wsj/s5/utils/data/get_subsegment_feats.sh | 46 +++++++++++ egs/wsj/s5/utils/data/subsegment_data_dir.sh | 38 ++++++--- src/util/kaldi-holder.cc | 44 ++++++++++- src/util/kaldi-holder.h | 5 ++ 5 files changed, 201 insertions(+), 11 deletions(-) create mode 100755 egs/wsj/s5/utils/data/fix_subsegmented_feats.pl create mode 100755 egs/wsj/s5/utils/data/get_subsegment_feats.sh diff --git a/egs/wsj/s5/utils/data/fix_subsegmented_feats.pl b/egs/wsj/s5/utils/data/fix_subsegmented_feats.pl new file mode 100755 index 00000000000..bd8aeb8e409 --- /dev/null +++ b/egs/wsj/s5/utils/data/fix_subsegmented_feats.pl @@ -0,0 +1,79 @@ +#!/usr/bin/env perl + +# Copyright 2016 Vimal Manohar +# Apache 2.0. + +use warnings; + +# This script modifies the feats ranges and ensures that they don't +# exceed the max number of frames supplied in utt2max_frames. +# utt2max_frames can be computed by using +# steps/segmentation/get_reco2num_frames.sh +# cut -d ' ' -f 1,2 /segments | utils/apply_map.pl -f 2 /reco2num_frames > /utt2max_frames + +(scalar @ARGV == 1) or die "Usage: fix_subsegmented_feats.pl "; + +my $utt2max_frames_file = $ARGV[0]; + +open MAX_FRAMES, $utt2max_frames_file or die "fix_subsegmented_feats.pl: Could not open file $utt2max_frames_file"; + +my %utt2max_frames; + +while () { + chomp; + my @F = split; + + (scalar @F == 2) or die "fix_subsegmented_feats.pl: Invalid line $_ in $utt2max_frames_file"; + + $utt2max_frames{$F[0]} = $F[1]; +} + +while () { + my $line = $_; + + if (m/\[([^][]*)\]\[([^][]*)\]\s*$/) { + print ("fix_subsegmented_feats.pl: this script only supports single indices"); + exit(1); + } + + my $before_range = ""; + my $range = ""; + + if (m/^(.*)\[([^][]*)\]\s*$/) { + $before_range = $1; + $range = $2; + } else { + print; + next; + } + + my @F = split(/ /, $before_range); + my $utt = shift @F; + defined $utt2max_frames{$utt} or die "fix_subsegmented_feats.pl: Could not find key $utt in $utt2num_frames_file.\nError with line $line"; + + if ($range !~ m/^(\d*):(\d*)([,]?.*)$/) { + print STDERR "fix_subsegmented_feats.pl: could not make sense of input line $_"; + exit(1); + } + + my $row_start = $1; + my $row_end = $2; + my $col_range = $3; + + if ($row_end >= $utt2max_frames{$utt}) { + print STDERR "Fixed row_end for $utt from $row_end to $utt2max_frames{$utt}-1\n"; + $row_end = $utt2max_frames{$utt} - 1; + } + + if ($row_start ne "") { + $range = "$row_start:$row_end"; + } else { + $range = ""; + } + + if ($col_range ne "") { + $range .= ",$col_range"; + } + + print ("$utt " . join(" ", @F) . "[" . $range . "]\n"); +} diff --git a/egs/wsj/s5/utils/data/get_subsegment_feats.sh b/egs/wsj/s5/utils/data/get_subsegment_feats.sh new file mode 100755 index 00000000000..6baba68eedd --- /dev/null +++ b/egs/wsj/s5/utils/data/get_subsegment_feats.sh @@ -0,0 +1,46 @@ +#! /bin/bash + +# Copyright 2016 Johns Hopkins University (Author: Dan Povey) +# 2016 Vimal Manohar +# Apache 2.0. + +if [ $# -ne 4 ]; then + echo "This scripts gets subsegmented_feats (by adding ranges to data/feats.scp) " + echo "for the subsegments file. This is does one part of the " + echo "functionality in subsegment_data_dir.sh, which additionally " + echo "creates a new subsegmented data directory." + echo "Usage: $0 " + echo " e.g.: $0 data/train/feats.scp 0.01 0.015 subsegments" + exit 1 +fi + +feats=$1 +frame_shift=$2 +frame_overlap=$3 +subsegments=$4 + +# The subsegments format is . +# e.g. 'utt_foo-1 utt_foo 7.21 8.93' +# The first awk command replaces this with the format: +# +# e.g. 'utt_foo-1 utt_foo 721 893' +# and the apply_map.pl command replaces 'utt_foo' (the 2nd field) with its corresponding entry +# from the original wav.scp, so we get a line like: +# e.g. 'utt_foo-1 foo-bar.ark:514231 721 892' +# Note: the reason we subtract one from the last time is that it's going to +# represent the 'last' frame, not the 'end' frame [i.e. not one past the last], +# in the matlab-like, but zero-indexed [first:last] notion. For instance, a segment with 1 frame +# would have start-time 0.00 and end-time 0.01, which would become the frame range +# [0:0] +# The second awk command turns this into something like +# utt_foo-1 foo-bar.ark:514231[721:892] +# It has to be a bit careful because the format actually allows for more general things +# like pipes that might contain spaces, so it has to be able to produce output like the +# following: +# utt_foo-1 some command|[721:892] +# Lastly, utils/data/normalize_data_range.pl will only do something nontrivial if +# the original data-dir already had data-ranges in square brackets. +awk -v s=$frame_shift -v fovlp=$frame_overlap '{print $1, $2, int(($3/s)+0.5), int(($4-fovlp)/s+0.5);}' <$subsegments| \ + utils/apply_map.pl -f 2 $feats | \ + awk '{p=NF-1; for (n=1;n " + echo " $0 [options] [] " echo "This script sub-segments a data directory. is to" echo "have lines of the form " echo "and is of the form ... ." echo "This script appropriately combines the with the original" echo "segments file, if necessary, and if not, creates a segments file." + echo " is an optional argument." echo "e.g.:" echo " $0 data/train [options] exp/tri3b_resegment/segments exp/tri3b_resegment/text data/train_resegmented" echo " Options:" @@ -50,11 +51,23 @@ export LC_ALL=C srcdir=$1 subsegments=$2 -new_text=$3 -dir=$4 +no_text=true +if [ $# -eq 4 ]; then + new_text=$3 + dir=$4 + no_text=false -for f in "$subsegments" "$new_text" "$srcdir/utt2spk"; do + if [ ! -f "$new_text" ]; then + echo "$0: no such file $new_text" + exit 1 + fi + +else + dir=$3 +fi + +for f in "$subsegments" "$srcdir/utt2spk"; do if [ ! -f "$f" ]; then echo "$0: no such file $f" exit 1; @@ -65,9 +78,11 @@ if ! mkdir -p $dir; then echo "$0: failed to create directory $dir" fi -if ! cmp <(awk '{print $1}' <$subsegments) <(awk '{print $1}' <$new_text); then - echo "$0: expected the first fields of the files $subsegments and $new_text to be identical" - exit 1 +if ! $no_text; then + if ! cmp <(awk '{print $1}' <$subsegments) <(awk '{print $1}' <$new_text); then + echo "$0: expected the first fields of the files $subsegments and $new_text to be identical" + exit 1 + fi fi # create the utt2spk in $dir @@ -86,8 +101,11 @@ awk '{print $1, $2}' < $subsegments > $dir/new2old_utt utils/apply_map.pl -f 2 $srcdir/utt2spk < $dir/new2old_utt >$dir/utt2spk # .. and the new spk2utt file. utils/utt2spk_to_spk2utt.pl <$dir/utt2spk >$dir/spk2utt -# the new text file is just what the user provides. -cp $new_text $dir/text + +if ! $no_text; then + # the new text file is just what the user provides. + cp $new_text $dir/text +fi # copy the source wav.scp cp $srcdir/wav.scp $dir diff --git a/src/util/kaldi-holder.cc b/src/util/kaldi-holder.cc index a26bdf2ce29..a86f09a2030 100644 --- a/src/util/kaldi-holder.cc +++ b/src/util/kaldi-holder.cc @@ -34,7 +34,7 @@ bool ExtractObjectRange(const Matrix &input, const std::string &range, SplitStringToVector(range, ",", false, &splits); if (!((splits.size() == 1 && !splits[0].empty()) || (splits.size() == 2 && !splits[0].empty() && !splits[1].empty()))) { - KALDI_ERR << "Invalid range specifier: " << range; + KALDI_ERR << "Invalid range specifier for matrix: " << range; return false; } std::vector row_range, col_range; @@ -75,6 +75,48 @@ template bool ExtractObjectRange(const Matrix &, const std::string &, template bool ExtractObjectRange(const Matrix &, const std::string &, Matrix *); +template +bool ExtractObjectRange(const Vector &input, const std::string &range, + Vector *output) { + if (range.empty()) { + KALDI_ERR << "Empty range specifier."; + return false; + } + std::vector splits; + SplitStringToVector(range, ",", false, &splits); + if (!((splits.size() == 1 && !splits[0].empty()))) { + KALDI_ERR << "Invalid range specifier for vector: " << range; + return false; + } + std::vector index_range; + bool status = true; + if (splits[0] != ":") + status = SplitStringToIntegers(splits[0], ":", false, &index_range); + + if (index_range.size() == 0) { + index_range.push_back(0); + index_range.push_back(input.Dim() - 1); + } + + if (!(status && index_range.size() == 2 && + index_range[0] >= 0 && index_range[0] <= index_range[1] && + index_range[1] < input.Dim())) { + KALDI_ERR << "Invalid range specifier: " << range + << " for vector of size " << input.Dim(); + return false; + } + int32 size = index_range[1] - index_range[0] + 1; + output->Resize(size, kUndefined); + output->CopyFromVec(input.Range(index_range[0], size)); + return true; +} + +// template instantiation +template bool ExtractObjectRange(const Vector &, const std::string &, + Vector *); +template bool ExtractObjectRange(const Vector &, const std::string &, + Vector *); + bool ExtractRangeSpecifier(const std::string &rxfilename_with_range, std::string *data_rxfilename, std::string *range) { diff --git a/src/util/kaldi-holder.h b/src/util/kaldi-holder.h index 06d7ec8e745..9ab148387ee 100644 --- a/src/util/kaldi-holder.h +++ b/src/util/kaldi-holder.h @@ -242,6 +242,11 @@ template bool ExtractObjectRange(const Matrix &input, const std::string &range, Matrix *output); +/// The template is specialized types Vector and Vector. +template +bool ExtractObjectRange(const Vector &input, const std::string &range, + Vector *output); + // In SequentialTableReaderScriptImpl and RandomAccessTableReaderScriptImpl, for // cases where the scp contained 'range specifiers' (things in square brackets From d7e0b7f47050bc126b45d9d45f55b833b37aa8ea Mon Sep 17 00:00:00 2001 From: Vimal Manohar Date: Tue, 30 Aug 2016 16:48:35 -0400 Subject: [PATCH 012/213] asr_diarization: Utility scripts get_reco2utt, get_utt2dur and get_segments_for_data, get_utt2num_frames, get_reco2num_frames, get_reco2dur, convert_ali_to_vec, quantize_vector, convert_rttm_to_utt2spk_and_segments, get_frame_shift_from_config --- .../steps/segmentation/convert_ali_to_vec.pl | 17 +++ .../convert_rttm_to_utt2spk_and_segments.py | 79 +++++++++++++ .../convert_utt2spk_and_segments_to_rttm.py | 65 +++++++++++ .../get_frame_shift_info_from_config.pl | 21 ++++ .../s5/steps/segmentation/quantize_vector.pl | 28 +++++ .../utils/data/convert_data_dir_to_whole.sh | 108 ++++++++++++++++++ egs/wsj/s5/utils/data/get_reco2dur.sh | 87 ++++++++++++++ egs/wsj/s5/utils/data/get_reco2num_frames.sh | 28 +++++ egs/wsj/s5/utils/data/get_reco2utt.sh | 21 ++++ .../s5/utils/data/get_segments_for_data.sh | 2 +- egs/wsj/s5/utils/data/get_utt2dur.sh | 2 +- egs/wsj/s5/utils/data/get_utt2num_frames.sh | 42 +++++++ 12 files changed, 498 insertions(+), 2 deletions(-) create mode 100755 egs/wsj/s5/steps/segmentation/convert_ali_to_vec.pl create mode 100755 egs/wsj/s5/steps/segmentation/convert_rttm_to_utt2spk_and_segments.py create mode 100755 egs/wsj/s5/steps/segmentation/convert_utt2spk_and_segments_to_rttm.py create mode 100755 egs/wsj/s5/steps/segmentation/get_frame_shift_info_from_config.pl create mode 100755 egs/wsj/s5/steps/segmentation/quantize_vector.pl create mode 100755 egs/wsj/s5/utils/data/convert_data_dir_to_whole.sh create mode 100755 egs/wsj/s5/utils/data/get_reco2dur.sh create mode 100755 egs/wsj/s5/utils/data/get_reco2num_frames.sh create mode 100755 egs/wsj/s5/utils/data/get_reco2utt.sh create mode 100755 egs/wsj/s5/utils/data/get_utt2num_frames.sh diff --git a/egs/wsj/s5/steps/segmentation/convert_ali_to_vec.pl b/egs/wsj/s5/steps/segmentation/convert_ali_to_vec.pl new file mode 100755 index 00000000000..c0d1a9eeae2 --- /dev/null +++ b/egs/wsj/s5/steps/segmentation/convert_ali_to_vec.pl @@ -0,0 +1,17 @@ +#! /usr/bin/perl + +# Converts a kaldi integer vector in text format to +# a kaldi vector in text format by adding a pair +# of square brackets around the data. +# Assumes the first column to be the utterance id. + +while (<>) { + chomp; + my @F = split; + + printf ("$F[0] [ "); + for (my $i = 1; $i <= $#F; $i++) { + printf ("$F[$i] "); + } + print ("]"); +} diff --git a/egs/wsj/s5/steps/segmentation/convert_rttm_to_utt2spk_and_segments.py b/egs/wsj/s5/steps/segmentation/convert_rttm_to_utt2spk_and_segments.py new file mode 100755 index 00000000000..23dc5a14f09 --- /dev/null +++ b/egs/wsj/s5/steps/segmentation/convert_rttm_to_utt2spk_and_segments.py @@ -0,0 +1,79 @@ +#! /usr/bin/env python + +"""This script converts an RTTM with +speaker info into kaldi utt2spk and segments""" + +import argparse + +def get_args(): + parser = argparse.ArgumentParser( + description="""This script converts an RTTM with + speaker info into kaldi utt2spk and segments""") + parser.add_argument("--use-reco-id-as-spkr", type=str, + choices=["true", "false"], + help="Use the recording ID based on RTTM and " + "reco2file_and_channel as the speaker") + parser.add_argument("rttm_file", type=str, + help="""Input RTTM file. + The format of the RTTM file is + """ + """ """) + parser.add_argument("reco2file_and_channel", type=str, + help="""Input reco2file_and_channel. + The format is .""") + parser.add_argument("utt2spk", type=str, + help="Output utt2spk file") + parser.add_argument("segments", type=str, + help="Output segments file") + + args = parser.parse_args() + + args.use_reco_id_as_spkr = bool(args.use_reco_id_as_spkr == "true") + + return args + +def main(): + args = get_args() + + file_and_channel2reco = {} + for line in open(args.reco2file_and_channel): + parts = line.strip().split() + file_and_channel2reco[(parts[1], parts[2])] = parts[0] + + utt2spk_writer = open(args.utt2spk, 'w') + segments_writer = open(args.segments, 'w') + for line in open(args.rttm_file): + parts = line.strip().split() + if parts[0] != "SPEAKER": + continue + + file_id = parts[1] + channel = parts[2] + + try: + reco = file_and_channel2reco[(file_id, channel)] + except KeyError as e: + raise Exception("Could not find recording with " + "(file_id, channel) " + "= ({0},{1}) in {2}: {3}\n".format( + file_id, channel, + args.reco2file_and_channel, str(e))) + + start_time = float(parts[3]) + end_time = start_time + float(parts[4]) + + if args.use_reco_id_as_spkr: + spkr = reco + else: + spkr = parts[7] + + st = int(start_time * 100) + end = int(end_time * 100) + utt = "{0}-{1:06d}-{2:06d}".format(spkr, st, end) + + utt2spk_writer.write("{0} {1}\n".format(utt, spkr)) + segments_writer.write("{0} {1} {2:7.2f} {3:7.2f}\n".format( + utt, reco, start_time, end_time)) + +if __name__ == '__main__': + main() diff --git a/egs/wsj/s5/steps/segmentation/convert_utt2spk_and_segments_to_rttm.py b/egs/wsj/s5/steps/segmentation/convert_utt2spk_and_segments_to_rttm.py new file mode 100755 index 00000000000..1443259286b --- /dev/null +++ b/egs/wsj/s5/steps/segmentation/convert_utt2spk_and_segments_to_rttm.py @@ -0,0 +1,65 @@ +#! /usr/bin/env python + +"""This script converts kaldi-style utt2spk and segments to an RTTM""" + +import argparse + +def get_args(): + parser = argparse.ArgumentParser( + description="""This script converts kaldi-style utt2spk and + segments to an RTTM""") + + parser.add_argument("utt2spk", type=str, + help="Input utt2spk file") + parser.add_argument("segments", type=str, + help="Input segments file") + parser.add_argument("reco2file_and_channel", type=str, + help="""Input reco2file_and_channel. + The format is .""") + parser.add_argument("rttm_file", type=str, + help="Output RTTM file") + + args = parser.parse_args() + return args + +def main(): + args = get_args() + + reco2file_and_channel = {} + for line in open(args.reco2file_and_channel): + parts = line.strip().split() + reco2file_and_channel[parts[0]] = (parts[1], parts[2]) + + utt2spk = {} + with open(args.utt2spk, 'r') as utt2spk_reader: + for line in utt2spk_reader: + parts = line.strip().split() + utt2spk[parts[0]] = parts[1] + + with open(args.rttm_file, 'w') as rttm_writer: + for line in open(args.segments, 'r'): + parts = line.strip().split() + + utt = parts[0] + spkr = utt2spk[utt] + + reco = parts[1] + + try: + file_id, channel = reco2file_and_channel[reco] + except KeyError as e: + raise Exception("Could not find recording {0} in {1}: " + "{2}\n".format(reco, + args.reco2file_and_channel, + str(e))) + + start_time = float(parts[2]) + duration = float(parts[3]) - start_time + + rttm_writer.write("SPEAKER {0} {1} {2:7.2f} {3:7.2f} " + " {4} \n".format( + file_id, channel, start_time, + duration, spkr)) + +if __name__ == '__main__': + main() diff --git a/egs/wsj/s5/steps/segmentation/get_frame_shift_info_from_config.pl b/egs/wsj/s5/steps/segmentation/get_frame_shift_info_from_config.pl new file mode 100755 index 00000000000..79a42aa9852 --- /dev/null +++ b/egs/wsj/s5/steps/segmentation/get_frame_shift_info_from_config.pl @@ -0,0 +1,21 @@ +#! /usr/bin/perl +use strict; +use warnings; + +# This script parses a features config file such as conf/mfcc.conf +# and returns the pair of values frame_shift and frame_overlap in seconds. + +my $frame_shift = 0.01; +my $frame_overlap = 0.015; + +while (<>) { + if (m/--frame-length=(\d+)/) { + $frame_shift = $1 / 1000; + } + + if (m/--window-length=(\d+)/) { + $frame_overlap = $1 / 1000 - $frame_shift; + } +} + +print "$frame_shift $frame_overlap\n"; diff --git a/egs/wsj/s5/steps/segmentation/quantize_vector.pl b/egs/wsj/s5/steps/segmentation/quantize_vector.pl new file mode 100755 index 00000000000..0bccebade4c --- /dev/null +++ b/egs/wsj/s5/steps/segmentation/quantize_vector.pl @@ -0,0 +1,28 @@ +#!/usr/bin/perl + +# This script convert per-frame speech probabilities into +# 0-1 labels. + +@ARGV <= 1 or die "Usage: quantize_vector.pl [threshold]"; + +my $t = 0.5; + +if (scalar @ARGV == 1) { + $t = $ARGV[0]; +} + +while () { + chomp; + my @F = split; + + my $str = "$F[0]"; + for (my $i = 2; $i < $#F; $i++) { + if ($F[$i] >= $t) { + $str = "$str 1"; + } else { + $str = "$str 0"; + } + } + + print ("$str\n"); +} diff --git a/egs/wsj/s5/utils/data/convert_data_dir_to_whole.sh b/egs/wsj/s5/utils/data/convert_data_dir_to_whole.sh new file mode 100755 index 00000000000..f55f60c4774 --- /dev/null +++ b/egs/wsj/s5/utils/data/convert_data_dir_to_whole.sh @@ -0,0 +1,108 @@ +#! /bin/bash + +# This scripts converts a data directory into a "whole" data directory +# by removing the segments and using the recordings themselves as +# utterances + +set -o pipefail + +. path.sh + +cmd=run.pl +stage=-1 + +. parse_options.sh + +if [ $# -ne 2 ]; then + echo "Usage: convert_data_dir_to_whole.sh " + echo " e.g.: convert_data_dir_to_whole.sh data/dev data/dev_whole" + exit 1 +fi + +data=$1 +dir=$2 + +if [ ! -f $data/segments ]; then + # Data directory already does not contain segments. So just copy it. + utils/copy_data_dir.sh $data $dir + exit 0 +fi + +mkdir -p $dir +cp $data/wav.scp $dir +cp $data/reco2file_and_channel $dir +rm -f $dir/{utt2spk,text} || true + +[ -f $data/stm ] && cp $data/stm $dir +[ -f $data/glm ] && cp $data/glm $dir + +text_files= +[ -f $data/text ] && text_files="$data/text $dir/text" + +# Combine utt2spk and text from the segments into utt2spk and text for the whole +# recording. +cat $data/segments | perl -e ' +if (scalar @ARGV == 4) { + ($utt2spk_in, $utt2spk_out, $text_in, $text_out) = @ARGV; +} elsif (scalar @ARGV == 2) { + ($utt2spk_in, $utt2spk_out) = @ARGV; +} else { + die "Unexpected number of arguments"; +} + +if (defined $text_in) { + open(TI, "<$text_in") || die "Error: fail to open $text_in\n"; + open(TO, ">$text_out") || die "Error: fail to open $text_out\n"; +} +open(UI, "<$utt2spk_in") || die "Error: fail to open $utt2spk_in\n"; +open(UO, ">$utt2spk_out") || die "Error: fail to open $utt2spk_out\n"; + +my %file2utt = (); +while () { + chomp; + my @col = split; + @col >= 4 or die "bad line $_\n"; + + if (! defined $file2utt{$col[1]}) { + $file2utt{$col[1]} = []; + } + push @{$file2utt{$col[1]}}, $col[0]; +} + +my %text = (); +my %utt2spk = (); + +while () { + chomp; + my @col = split; + $utt2spk{$col[0]} = $col[1]; +} + +if (defined $text_in) { + while () { + chomp; + my @col = split; + @col >= 1 or die "bad line $_\n"; + + my $utt = shift @col; + $text{$utt} = join(" ", @col); + } +} + +foreach $file (keys %file2utt) { + my @utts = @{$file2utt{$file}}; + #print STDERR $file . " " . join(" ", @utts) . "\n"; + print UO "$file $file\n"; + + if (defined $text_in) { + $text_line = ""; + print TO "$file $text_line\n"; + } +} +' $data/utt2spk $dir/utt2spk $text_files + +sort -u $dir/utt2spk > $dir/utt2spk.tmp +mv $dir/utt2spk.tmp $dir/utt2spk +utils/utt2spk_to_spk2utt.pl $dir/utt2spk > $dir/spk2utt + +utils/fix_data_dir.sh $dir diff --git a/egs/wsj/s5/utils/data/get_reco2dur.sh b/egs/wsj/s5/utils/data/get_reco2dur.sh new file mode 100755 index 00000000000..7d2ccb71769 --- /dev/null +++ b/egs/wsj/s5/utils/data/get_reco2dur.sh @@ -0,0 +1,87 @@ +#!/bin/bash + +# Copyright 2016 Johns Hopkins University (author: Daniel Povey) +# Apache 2.0 + +# This script operates on a data directory, such as in data/train/, and adds the +# reco2dur file if it does not already exist. The file 'reco2dur' maps from +# utterance to the duration of the utterance in seconds. This script works it +# out from the 'segments' file, or, if not present, from the wav.scp file (it +# first tries interrogating the headers, and if this fails, it reads the wave +# files in entirely.) + +frame_shift=0.01 + +. utils/parse_options.sh +. ./path.sh + +if [ $# != 1 ]; then + echo "Usage: $0 [options] " + echo "e.g.:" + echo " $0 data/train" + echo " Options:" + echo " --frame-shift # frame shift in seconds. Only relevant when we are" + echo " # getting duration from feats.scp (default: 0.01). " + exit 1 +fi + +export LC_ALL=C + +data=$1 + +if [ -s $data/reco2dur ] && \ + [ $(cat $data/wav.scp | wc -l) -eq $(cat $data/reco2dur | wc -l) ]; then + echo "$0: $data/reco2dur already exists with the expected length. We won't recompute it." + exit 0; +fi + +# if the wav.scp contains only lines of the form +# utt1 /foo/bar/sph2pipe -f wav /baz/foo.sph | +if cat $data/wav.scp | perl -e ' + while (<>) { s/\|\s*$/ |/; # make sure final | is preceded by space. + @A = split; if (!($#A == 5 && $A[1] =~ m/sph2pipe$/ && + $A[2] eq "-f" && $A[3] eq "wav" && $A[5] eq "|")) { exit(1); } + $utt = $A[0]; $sphere_file = $A[4]; + + if (!open(F, "<$sphere_file")) { die "Error opening sphere file $sphere_file"; } + $sample_rate = -1; $sample_count = -1; + for ($n = 0; $n <= 30; $n++) { + $line = ; + if ($line =~ m/sample_rate -i (\d+)/) { $sample_rate = $1; } + if ($line =~ m/sample_count -i (\d+)/) { $sample_count = $1; } + if ($line =~ m/end_head/) { break; } + } + close(F); + if ($sample_rate == -1 || $sample_count == -1) { + die "could not parse sphere header from $sphere_file"; + } + $duration = $sample_count * 1.0 / $sample_rate; + print "$utt $duration\n"; + } ' > $data/reco2dur; then + echo "$0: successfully obtained utterance lengths from sphere-file headers" +else + echo "$0: could not get utterance lengths from sphere-file headers, using wav-to-duration" + if ! command -v wav-to-duration >/dev/null; then + echo "$0: wav-to-duration is not on your path" + exit 1; + fi + + read_entire_file=false + if cat $data/wav.scp | grep -q 'sox.*speed'; then + read_entire_file=true + echo "$0: reading from the entire wav file to fix the problem caused by sox commands with speed perturbation. It is going to be slow." + echo "... It is much faster if you call get_reco2dur.sh *before* doing the speed perturbation via e.g. perturb_data_dir_speed.sh or " + echo "... perturb_data_dir_speed_3way.sh." + fi + + if ! wav-to-duration --read-entire-file=$read_entire_file scp:$data/wav.scp ark,t:$data/reco2dur 2>&1 | grep -v 'nonzero return status'; then + echo "$0: there was a problem getting the durations; moving $data/reco2dur to $data/.backup/" + mkdir -p $data/.backup/ + mv $data/reco2dur $data/.backup/ + fi +fi + +echo "$0: computed $data/reco2dur" + +exit 0 + diff --git a/egs/wsj/s5/utils/data/get_reco2num_frames.sh b/egs/wsj/s5/utils/data/get_reco2num_frames.sh new file mode 100755 index 00000000000..03ab7b40616 --- /dev/null +++ b/egs/wsj/s5/utils/data/get_reco2num_frames.sh @@ -0,0 +1,28 @@ +#! /bin/bash + +cmd=run.pl +nj=4 + +frame_shift=0.01 +frame_overlap=0.015 + +. utils/parse_options.sh + +if [ $# -ne 1 ]; then + echo "Usage: $0 " + exit 1 +fi + +data=$1 + +if [ -f $data/reco2num_frames ]; then + echo "$0: $data/reco2num_frames already present!" + exit 0; +fi + +utils/data/get_reco2dur.sh $data +awk -v fs=$frame_shift -v fovlp=$frame_overlap \ + '{print $1" "int( ($2 - fovlp) / fs)}' $data/reco2dur > $data/reco2num_frames + +echo "$0: Computed and wrote $data/reco2num_frames" + diff --git a/egs/wsj/s5/utils/data/get_reco2utt.sh b/egs/wsj/s5/utils/data/get_reco2utt.sh new file mode 100755 index 00000000000..6c30f812cfe --- /dev/null +++ b/egs/wsj/s5/utils/data/get_reco2utt.sh @@ -0,0 +1,21 @@ +#! /bin/bash + +# Copyright 2016 Vimal Manohar +# Apache 2.0 + +if [ $# -ne 1 ]; then + echo "This script creates a reco2utt file in the data directory, " + echo "which is analogous to spk2utt file but with the first column " + echo "as recording instead of speaker." + echo "Usage: get_reco2utt.sh " + echo " e.g.: get_reco2utt.sh data/train" + exit 1 +fi + +data=$1 + +if [ ! -s $data/segments ]; then + utils/data/get_segments_for_data.sh $data > $data/segments +fi + +cut -d ' ' -f 1,2 $data/segments | utils/utt2spk_to_spk2utt.pl > $data/reco2utt diff --git a/egs/wsj/s5/utils/data/get_segments_for_data.sh b/egs/wsj/s5/utils/data/get_segments_for_data.sh index 694acc6a256..7adc4c465d3 100755 --- a/egs/wsj/s5/utils/data/get_segments_for_data.sh +++ b/egs/wsj/s5/utils/data/get_segments_for_data.sh @@ -19,7 +19,7 @@ fi data=$1 -if [ ! -f $data/utt2dur ]; then +if [ ! -s $data/utt2dur ]; then utils/data/get_utt2dur.sh $data 1>&2 || exit 1; fi diff --git a/egs/wsj/s5/utils/data/get_utt2dur.sh b/egs/wsj/s5/utils/data/get_utt2dur.sh index f14fc2c5e81..c415e8dfb81 100755 --- a/egs/wsj/s5/utils/data/get_utt2dur.sh +++ b/egs/wsj/s5/utils/data/get_utt2dur.sh @@ -35,7 +35,7 @@ if [ -s $data/utt2dur ] && \ exit 0; fi -if [ -f $data/segments ]; then +if [ -s $data/segments ]; then echo "$0: working out $data/utt2dur from $data/segments" cat $data/segments | awk '{len=$4-$3; print $1, len;}' > $data/utt2dur elif [ -f $data/wav.scp ]; then diff --git a/egs/wsj/s5/utils/data/get_utt2num_frames.sh b/egs/wsj/s5/utils/data/get_utt2num_frames.sh new file mode 100755 index 00000000000..e2921601ec9 --- /dev/null +++ b/egs/wsj/s5/utils/data/get_utt2num_frames.sh @@ -0,0 +1,42 @@ +#! /bin/bash + +cmd=run.pl +nj=4 + +frame_shift=0.01 +frame_overlap=0.015 + +. utils/parse_options.sh + +if [ $# -ne 1 ]; then + echo "This script writes a file utt2num_frames with the " + echo "number of frames in each utterance as measured based on the " + echo "duration of the utterances (in utt2dur) and the specified " + echo "frame_shift and frame_overlap." + echo "Usage: $0 " + exit 1 +fi + +data=$1 + +if [ -f $data/utt2num_frames ]; then + echo "$0: $data/utt2num_frames already present!" + exit 0; +fi + +if [ ! -f $data/feats.scp ]; then + utils/data/get_utt2dur.sh $data + awk -v fs=$frame_shift -v fovlp=$frame_overlap \ + '{print $1" "int( ($2 - fovlp) / fs)}' $data/utt2dur > $data/utt2num_frames + exit 0 +fi + +utils/split_data.sh $data $nj || exit 1 +$cmd JOB=1:$nj $data/log/get_utt2num_frames.JOB.log \ + feat-to-len scp:$data/split${nj}/JOB/feats.scp ark,t:$data/split$nj/JOB/utt2num_frames || exit 1 + +for n in `seq $nj`; do + cat $data/split$nj/$n/utt2num_frames +done > $data/utt2num_frames + +echo "$0: Computed and wrote $data/utt2num_frames" From 4c646132d8ffc3b880bc180c170ed5f59a369c6e Mon Sep 17 00:00:00 2001 From: Vimal Manohar Date: Fri, 25 Nov 2016 17:33:49 -0500 Subject: [PATCH 013/213] asr_diarization: SAD post-processing --- egs/wsj/s5/steps/segmentation/get_sad_map.py | 156 ++++++++++++++++++ .../internal/convert_ali_to_vad.sh | 54 ++++++ .../internal/post_process_segments.sh | 129 +++++++++++++++ .../post_process_sad_to_segments.sh | 103 ++++++++++++ 4 files changed, 442 insertions(+) create mode 100755 egs/wsj/s5/steps/segmentation/get_sad_map.py create mode 100755 egs/wsj/s5/steps/segmentation/internal/convert_ali_to_vad.sh create mode 100755 egs/wsj/s5/steps/segmentation/internal/post_process_segments.sh create mode 100755 egs/wsj/s5/steps/segmentation/post_process_sad_to_segments.sh diff --git a/egs/wsj/s5/steps/segmentation/get_sad_map.py b/egs/wsj/s5/steps/segmentation/get_sad_map.py new file mode 100755 index 00000000000..9160503c7ad --- /dev/null +++ b/egs/wsj/s5/steps/segmentation/get_sad_map.py @@ -0,0 +1,156 @@ +#! /usr/bin/env python + +"""This script prints a mapping from phones to speech +activity labels +0 for silence, 1 for speech, 2 for noise and 3 for OOV. +Other labels can be optionally defined. +e.g. If 1, 2 and 3 are silence phones, 4, 5 and 6 are speech phones, +the SAD map would be +1 0 +2 0 +3 0 +4 1 +5 1 +6 1. +The silence and speech are read from the phones/silence.txt and +phones/nonsilence.txt from the lang directory. +An initial SAD map can be provided using --init-sad-map to override +the above default mapping of phones. This is useful to say map + or noise phones to separate SAD labels. +""" + +import argparse + + +class StrToBoolAction(argparse.Action): + """ A custom action to convert bools from shell format i.e., true/false + to python format i.e., True/False """ + def __call__(self, parser, namespace, values, option_string=None): + try: + if values == "true": + setattr(namespace, self.dest, True) + elif values == "true": + setattr(namespace, self.dest, False) + else: + raise ValueError + except ValueError: + raise Exception("Unknown value {0} for --{1}".format(values, + self.dest)) + + +class NullstrToNoneAction(argparse.Action): + """ A custom action to convert empty strings passed by shell + to None in python. This is necessary as shell scripts print null + strings when a variable is not specified. We could use the more apt + None in python. """ + def __call__(self, parser, namespace, values, option_string=None): + if values.strip() == "": + setattr(namespace, self.dest, None) + else: + setattr(namespace, self.dest, values) + + +def get_args(): + parser = argparse.ArgumentParser( + description="""This script prints a mapping from phones to speech + activity labels + 0 for silence, 1 for speech, 2 for noise and 3 for OOV. + Other labels can be optionally defined. + e.g. If 1, 2 and 3 are silence phones, 4, 5 and 6 are speech phones, + the SAD map would be + 1 0 + 2 0 + 3 0 + 4 1 + 5 1 + 6 1. + The silence and speech are read from the phones/silence.txt and + phones/nonsilence.txt from the lang directory. + An initial SAD map can be provided using --init-sad-map to override + the above default mapping of phones. This is useful to say map + or noise phones to separate SAD labels. + """) + + parser.add_argument("--init-sad-map", type=str, action=NullstrToNoneAction, + help="""Initial SAD map that will be used to override + the default mapping using phones/silence.txt and + phones/nonsilence.txt. Does not need to specify labels + for all the phones. + e.g. + 3 + 2""") + + noise_group = parser.add_mutually_exclusive_group() + noise_group.add_argument("--noise-phones-file", type=str, + action=NullstrToNoneAction, + help="Map noise phones from file to label 2") + noise_group.add_argument("--noise-phones-list", type=str, + action=NullstrToNoneAction, + help="A colon-separated list of noise phones to " + "map to label 2") + parser.add_argument("--unk", type=str, action=NullstrToNoneAction, + help="""UNK phone, if provided will be mapped to + label 3""") + + parser.add_argument("--map-noise-to-sil", type=str, + action=StrToBoolAction, + choices=["true", "false"], default=False, + help="""Map noise phones to silence before writing the + map. i.e. anything with label 2 is mapped to + label 0.""") + parser.add_argument("--map-unk-to-speech", type=str, + action=StrToBoolAction, + choices=["true", "false"], default=False, + help="""Map UNK phone to speech before writing the map + i.e. anything with label 3 is mapped to label 1.""") + + parser.add_argument("lang_dir") + + args = parser.parse_args() + + return args + + +def main(): + args = get_args() + + sad_map = {} + + for line in open('{0}/phones/nonsilence.txt'.format(args.lang_dir)): + parts = line.strip().split() + sad_map[parts[0]] = 1 + + for line in open('{0}/phones/silence.txt'.format(args.lang_dir)): + parts = line.strip().split() + sad_map[parts[0]] = 0 + + if args.init_sad_map is not None: + for line in open(args.init_sad_map): + parts = line.strip().split() + try: + sad_map[parts[0]] = int(parts[1]) + except Exception: + raise Exception("Invalid line " + line) + + if args.unk is not None: + sad_map[args.unk] = 3 + + noise_phones = {} + if args.noise_phones_file is not None: + for line in open(args.noise_phones_file): + parts = line.strip().split() + noise_phones[parts[0]] = 1 + + if args.noise_phones_list is not None: + for x in args.noise_phones_list.split(":"): + noise_phones[x] = 1 + + for x, l in sad_map.iteritems(): + if l == 2 and args.map_noise_to_sil: + l = 0 + if l == 3 and args.map_unk_to_speech: + l = 1 + print ("{0} {1}".format(x, l)) + +if __name__ == "__main__": + main() diff --git a/egs/wsj/s5/steps/segmentation/internal/convert_ali_to_vad.sh b/egs/wsj/s5/steps/segmentation/internal/convert_ali_to_vad.sh new file mode 100755 index 00000000000..353e6d4664e --- /dev/null +++ b/egs/wsj/s5/steps/segmentation/internal/convert_ali_to_vad.sh @@ -0,0 +1,54 @@ +#! /bin/bash + +set -o pipefail +set -e +set -u + +. path.sh + +cmd=run.pl + +frame_shift=0.01 +frame_subsampling_factor=1 + +. parse_options.sh + +if [ $# -ne 4 ]; then + echo "This script converts the alignment in the alignment directory " + echo "to speech activity segments based on the provided phone-map." + echo "Usage: $0 exp/tri3_ali data/lang/phones/sad.map exp/tri3_ali_vad" + exit 1 +fi + +ali_dir=$1 +phone_map=$2 +dir=$3 + +for f in $phone_map $ali_dir/ali.1.gz; do + [ ! -f $f ] && echo "$0: Could not find $f" && exit 1 +done + +mkdir -p $dir + +nj=`cat $ali_dir/num_jobs` || exit 1 +echo $nj > $dir/num_jobs + +if [ -f $ali_dir/frame_subsampling_factor ]; then + frame_subsampling_factor=`cat $ali_dir/frame_subsampling_factor` +fi + +ali_frame_shift=`perl -e "print ($frame_shift * $frame_subsampling_factor);"` +ali_frame_overlap=`perl -e "print ($ali_frame_shift * 1.5);"` + +dir=`perl -e '($dir,$pwd)= @ARGV; if($dir!~m:^/:) { $dir = "$pwd/$dir"; } print $dir; ' $dir ${PWD}` + +$cmd JOB=1:$nj $dir/log/get_sad.JOB.log \ + segmentation-init-from-ali \ + "ark:gunzip -c ${ali_dir}/ali.JOB.gz | ali-to-phones --per-frame ${ali_dir}/final.mdl ark:- ark:- |" \ + ark:- \| segmentation-copy --label-map=$phone_map ark:- ark:- \| \ + segmentation-post-process --merge-adjacent-segments ark:- \ + ark,scp:$dir/sad_seg.JOB.ark,$dir/sad_seg.JOB.scp + +for n in `seq $nj`; do + cat $dir/sad_seg.$n.scp +done | sort -k1,1 > $dir/sad_seg.scp diff --git a/egs/wsj/s5/steps/segmentation/internal/post_process_segments.sh b/egs/wsj/s5/steps/segmentation/internal/post_process_segments.sh new file mode 100755 index 00000000000..c2750b4a895 --- /dev/null +++ b/egs/wsj/s5/steps/segmentation/internal/post_process_segments.sh @@ -0,0 +1,129 @@ +#! /bin/bash + +# Copyright 2015-16 Vimal Manohar +# Apache 2.0. + +set -e +set -o pipefail +set -u + +. path.sh + +cmd=run.pl +stage=-10 + +# General segmentation options +pad_length=50 # Pad speech segments by this many frames on either side +max_blend_length=10 # Maximum duration of speech that will be removed as part + # of smoothing process. This is only if there are no other + # speech segments nearby. +max_intersegment_length=50 # Merge nearby speech segments if the silence + # between them is less than this many frames. +post_pad_length=50 # Pad speech segments by this many frames on either side + # after the merging process using max_intersegment_length +max_segment_length=1000 # Segments that are longer than this are split into + # overlapping frames. +overlap_length=100 # Overlapping frames when segments are split. + # See the above option. +min_silence_length=30 # Min silence length at which to split very long segments + +frame_shift=0.01 + +. utils/parse_options.sh + +if [ $# -ne 3 ]; then + echo "This script post-processes a speech activity segmentation to create " + echo "a kaldi-style data directory." + echo "See the comments for the kind of post-processing options." + echo "Usage: $0 " + echo " e.g.: $0 data/dev_aspire_whole exp/vad_dev_aspire data/dev_aspire_seg" + exit 1 +fi + +data_dir=$1 +dir=$2 +segmented_data_dir=$3 + +for f in $dir/orig_segmentation.1.gz $data_dir/segments; do + if [ ! -f $f ]; then + echo "$0: Could not find $f" + exit 1 + fi +done + +cat < $dir/segmentation.conf +pad_length=$pad_length # Pad speech segments by this many frames on either side +max_blend_length=$max_blend_length # Maximum duration of speech that will be removed as part + # of smoothing process. This is only if there are no other + # speech segments nearby. +max_intersegment_length=$max_intersegment_length # Merge nearby speech segments if the silence + # between them is less than this many frames. +post_pad_length=$post_pad_length # Pad speech segments by this many frames on either side + # after the merging process using max_intersegment_length +max_segment_length=$max_segment_length # Segments that are longer than this are split into + # overlapping frames. +overlap_length=$overlap_length # Overlapping frames when segments are split. + # See the above option. +min_silence_length=$min_silence_length # Min silence length at which to split very long segments + +frame_shift=$frame_shift +EOF + +nj=`cat $dir/num_jobs` || exit 1 + +if [ $stage -le 1 ]; then + rm -r $segmented_data_dir || true + utils/data/convert_data_dir_to_whole.sh $data_dir $segmented_data_dir || exit 1 + rm $segmented_data_dir/text +fi + +if [ $stage -le 2 ]; then + # Post-process the orignal SAD segmentation using the following steps: + # 1) blend short speech segments of less than $max_blend_length frames + # into silence + # 2) Remove all silence frames and widen speech segments by padding + # $pad_length frames + # 3) Merge adjacent segments that have an intersegment length of less than + # $max_intersegment_length frames + # 4) Widen speech segments again after merging + # 5) Split segments into segments of $max_segment_length at the point where + # the original segmentation had silence + # 6) Split segments into overlapping segments of max length + # $max_segment_length and overlap $overlap_length + # 7) Convert segmentation to kaldi segments and utt2spk + $cmd JOB=1:$nj $dir/log/post_process_segmentation.JOB.log \ + gunzip -c $dir/orig_segmentation.JOB.gz \| \ + segmentation-post-process --merge-adjacent-segments --max-intersegment-length=0 ark:- ark:- \| \ + segmentation-post-process --max-blend-length=$max_blend_length --blend-short-segments-class=1 ark:- ark:- \| \ + segmentation-post-process --remove-labels=0 --pad-label=1 --pad-length=$pad_length ark:- ark:- \| \ + segmentation-post-process --merge-adjacent-segments --max-intersegment-length=$max_intersegment_length ark:- ark:- \| \ + segmentation-post-process --pad-label=1 --pad-length=$post_pad_length ark:- ark:- \| \ + segmentation-split-segments --alignments="ark,s,cs:gunzip -c $dir/orig_segmentation.JOB.gz | segmentation-to-ali ark:- ark:- |" \ + --max-segment-length=$max_segment_length --min-alignment-chunk-length=$min_silence_length --ali-label=0 ark:- ark:- \| \ + segmentation-split-segments \ + --max-segment-length=$max_segment_length --overlap-length=$overlap_length ark:- ark:- \| \ + segmentation-to-segments --frame-shift=$frame_shift ark:- \ + ark,t:$dir/utt2spk.JOB $dir/segments.JOB || exit 1 +fi + +for n in `seq $nj`; do + cat $dir/utt2spk.$n +done > $segmented_data_dir/utt2spk + +for n in `seq $nj`; do + cat $dir/segments.$n +done > $segmented_data_dir/segments + +if [ ! -s $segmented_data_dir/utt2spk ] || [ ! -s $segmented_data_dir/segments ]; then + echo "$0: Segmentation failed to generate segments or utt2spk!" + exit 1 +fi + +utils/utt2spk_to_spk2utt.pl $segmented_data_dir/utt2spk > $segmented_data_dir/spk2utt || exit 1 +utils/fix_data_dir.sh $segmented_data_dir + +if [ ! -s $segmented_data_dir/utt2spk ] || [ ! -s $segmented_data_dir/segments ]; then + echo "$0: Segmentation failed to generate segments or utt2spk!" + exit 1 +fi + diff --git a/egs/wsj/s5/steps/segmentation/post_process_sad_to_segments.sh b/egs/wsj/s5/steps/segmentation/post_process_sad_to_segments.sh new file mode 100755 index 00000000000..f4011f20a03 --- /dev/null +++ b/egs/wsj/s5/steps/segmentation/post_process_sad_to_segments.sh @@ -0,0 +1,103 @@ +#! /bin/bash + +# Copyright 2015 Vimal Manohar +# Apache 2.0. + +set -e -o pipefail -u +. path.sh + +cmd=run.pl +stage=-10 + +segmentation_config=conf/segmentation.conf +nj=18 + +frame_shift=0.01 +weight_threshold=0.5 +ali_suffix=_acwt0.1 + +phone_map= + +. utils/parse_options.sh + +if [ $# -ne 5 ] && [ $# -ne 4 ]; then + echo "This script converts an alignment directory containing per-frame SAD " + echo "labels or per-frame speech probabilities into kaldi-style " + echo "segmented data directory. " + echo "This script first converts the per-frame labels or weights into " + echo "segmentation and then calls " + echo "steps/segmentation/internal/post_process_sad_to_segments.sh, " + echo "which does the actual post-processing step." + echo "Usage: $0 ( |) " + echo " e.g.: $0 data/dev_aspire_whole exp/vad_dev_aspire data/dev_aspire_seg" + exit 1 +fi + +data_dir=$1 +vad_dir= + +if [ $# -eq 5 ]; then + lang=$2 + vad_dir=$3 + shift; shift; shift +else + weights_scp=$2 + shift; shift +fi + +dir=$1 +segmented_data_dir=$2 + +cat $data_dir/segments | awk '{print $1" "$2}' | \ + utils/utt2spk_to_spk2utt.pl > $data_dir/reco2utt + +utils/split_data.sh $data_dir $nj + +for n in `seq $nj`; do + cat $data_dir/split$nj/$n/segments | awk '{print $1" "$2}' | \ + utils/utt2spk_to_spk2utt.pl > $data_dir/split$nj/$n/reco2utt +done + + +mkdir -p $dir + +if [ ! -z "$vad_dir" ]; then + nj=`cat $vad_dir/num_jobs` || exit 1 + + if [ -z "$phone_map" ]; then + phone_map=$dir/phone_map + + { + cat $lang/phones/silence.int | awk '{print $1" 0"}'; + cat $lang/phones/nonsilence.int | awk '{print $1" 1"}'; + } | sort -k1,1 -n > $dir/phone_map + fi + + if [ $stage -le 0 ]; then + # Convert the original SAD into segmentation + $cmd JOB=1:$nj $dir/log/segmentation.JOB.log \ + segmentation-init-from-ali --reco2utt-rspecifier="ark,t:$data_dir/split$nj/JOB/reco2utt" \ + --segmentation-rspecifier="ark:segmentation-init-from-segments --shift-to-zero=false --frame-shift=$frame_shift $data_dir/split$nj/JOB/segments ark:- |" \ + "ark:gunzip -c $vad_dir/ali${ali_suffix}.JOB.gz |" ark:- \| \ + segmentation-copy --label-map=$phone_map ark:- \ + "ark:| gzip -c > $dir/orig_segmentation.JOB.gz" + fi +else + for n in `seq $nj`; do + utils/filter_scp.pl $data_dir/split$nj/$n/reco2utt $weights_scp > $dir/weights.$n.scp + done + + $cmd JOB=1:$nj $dir/log/weights_to_segments.JOB.log \ + copy-vector scp:$dir/weights.JOB.scp ark,t:- \| \ + awk -v t=$weight_threshold '{printf $1; for (i=3; i < NF; i++) { if ($i >= t) printf (" 1"); else printf (" 0"); }; print "";}' \| \ + segmentation-init-from-ali --reco2utt-rspecifier="ark,t:$data_dir/split$nj/JOB/reco2utt" \ + --segmentation-rspecifier="ark:segmentation-init-from-segments --shift-to-zero=false --frame-shift=$frame_shift $data_dir/split$nj/JOB/segments ark:- |" \ + ark,t:- "ark:| gzip -c > $dir/orig_segmentation.JOB.gz" +fi + +echo $nj > $dir/num_jobs + +steps/segmentation/internal/post_process_segments.sh \ + --stage $stage --cmd "$cmd" \ + --config $segmentation_config --frame-shift $frame_shift \ + $data_dir $dir $segmented_data_dir From 1913caf15d255feef6dbe4355764848a4496a6cd Mon Sep 17 00:00:00 2001 From: Vimal Manohar Date: Thu, 24 Nov 2016 01:58:27 -0500 Subject: [PATCH 014/213] asr_diarization: Modify modify_speaker_info to add --respect-recording-info option --- egs/wsj/s5/utils/data/modify_speaker_info.sh | 25 ++++++++++++++++---- 1 file changed, 21 insertions(+), 4 deletions(-) diff --git a/egs/wsj/s5/utils/data/modify_speaker_info.sh b/egs/wsj/s5/utils/data/modify_speaker_info.sh index f75e9be5f67..e42f0df551d 100755 --- a/egs/wsj/s5/utils/data/modify_speaker_info.sh +++ b/egs/wsj/s5/utils/data/modify_speaker_info.sh @@ -37,6 +37,7 @@ utts_per_spk_max=-1 seconds_per_spk_max=-1 respect_speaker_info=true +respect_recording_info=true # end configuration section . utils/parse_options.sh @@ -93,10 +94,26 @@ else utt2dur_opt= fi -utils/data/internal/modify_speaker_info.py \ - $utt2dur_opt --respect-speaker-info=$respect_speaker_info \ - --utts-per-spk-max=$utts_per_spk_max --seconds-per-spk-max=$seconds_per_spk_max \ - <$srcdir/utt2spk >$destdir/utt2spk +if ! $respect_speaker_info && $respect_recording_info; then + if [ -f $srcdir/segments ]; then + cat $srcdir/segments | awk '{print $1" "$2}' | \ + utils/data/internal/modify_speaker_info.py \ + $utt2dur_opt --respect-speaker-info=true \ + --utts-per-spk-max=$utts_per_spk_max --seconds-per-spk-max=$seconds_per_spk_max \ + >$destdir/utt2spk + else + cat $srcdir/wav.scp | awk '{print $1" "$2}' | \ + utils/data/internal/modify_speaker_info.py \ + $utt2dur_opt --respect-speaker-info=true \ + --utts-per-spk-max=$utts_per_spk_max --seconds-per-spk-max=$seconds_per_spk_max \ + >$destdir/utt2spk + fi +else + utils/data/internal/modify_speaker_info.py \ + $utt2dur_opt --respect-speaker-info=$respect_speaker_info \ + --utts-per-spk-max=$utts_per_spk_max --seconds-per-spk-max=$seconds_per_spk_max \ + <$srcdir/utt2spk >$destdir/utt2spk +fi utils/utt2spk_to_spk2utt.pl <$destdir/utt2spk >$destdir/spk2utt From bd98bea8b40e16a13b12de4fdb204442d341bb25 Mon Sep 17 00:00:00 2001 From: Vimal Manohar Date: Thu, 24 Nov 2016 02:00:15 -0500 Subject: [PATCH 015/213] asr_diarization: Modify subset_data_dir.sh, copy_data_dir.sh to copy reco2file_and_channel and modify subset_data_dir.sh to add more options --- egs/wsj/s5/utils/copy_data_dir.sh | 7 ++++--- egs/wsj/s5/utils/fix_data_dir.sh | 9 +++++++-- egs/wsj/s5/utils/subset_data_dir.sh | 5 +++++ 3 files changed, 16 insertions(+), 5 deletions(-) diff --git a/egs/wsj/s5/utils/copy_data_dir.sh b/egs/wsj/s5/utils/copy_data_dir.sh index 008233daf62..222bc708527 100755 --- a/egs/wsj/s5/utils/copy_data_dir.sh +++ b/egs/wsj/s5/utils/copy_data_dir.sh @@ -83,15 +83,16 @@ fi if [ -f $srcdir/segments ]; then utils/apply_map.pl -f 1 $destdir/utt_map <$srcdir/segments >$destdir/segments cp $srcdir/wav.scp $destdir - if [ -f $srcdir/reco2file_and_channel ]; then - cp $srcdir/reco2file_and_channel $destdir/ - fi else # no segments->wav indexed by utt. if [ -f $srcdir/wav.scp ]; then utils/apply_map.pl -f 1 $destdir/utt_map <$srcdir/wav.scp >$destdir/wav.scp fi fi +if [ -f $srcdir/reco2file_and_channel ]; then + cp $srcdir/reco2file_and_channel $destdir/ +fi + if [ -f $srcdir/text ]; then utils/apply_map.pl -f 1 $destdir/utt_map <$srcdir/text >$destdir/text fi diff --git a/egs/wsj/s5/utils/fix_data_dir.sh b/egs/wsj/s5/utils/fix_data_dir.sh index 0333d628544..33e710a605f 100755 --- a/egs/wsj/s5/utils/fix_data_dir.sh +++ b/egs/wsj/s5/utils/fix_data_dir.sh @@ -6,6 +6,11 @@ # It puts the original contents of data-dir into # data-dir/.backup +utt_extra_files= +spk_extra_files= + +. utils/parse_options.sh + if [ $# != 1 ]; then echo "Usage: utils/data/fix_data_dir.sh " echo "e.g.: utils/data/fix_data_dir.sh data/train" @@ -110,7 +115,7 @@ function filter_speakers { filter_file $tmpdir/speakers $data/spk2utt utils/spk2utt_to_utt2spk.pl $data/spk2utt > $data/utt2spk - for s in cmvn.scp spk2gender; do + for s in cmvn.scp spk2gender $spk_extra_files; do f=$data/$s if [ -f $f ]; then filter_file $tmpdir/speakers $f @@ -158,7 +163,7 @@ function filter_utts { fi fi - for x in utt2spk utt2uniq feats.scp vad.scp text segments utt2lang utt2dur utt2num_frames $maybe_wav; do + for x in utt2spk utt2uniq feats.scp vad.scp text segments utt2lang utt2dur utt2num_frames $maybe_wav $utt_extra_files; do if [ -f $data/$x ]; then cp $data/$x $data/.backup/$x if ! cmp -s $data/$x <( utils/filter_scp.pl $tmpdir/utts $data/$x ) ; then diff --git a/egs/wsj/s5/utils/subset_data_dir.sh b/egs/wsj/s5/utils/subset_data_dir.sh index 5fe3217ddad..9533d0216c9 100755 --- a/egs/wsj/s5/utils/subset_data_dir.sh +++ b/egs/wsj/s5/utils/subset_data_dir.sh @@ -108,6 +108,7 @@ function do_filtering { [ -f $srcdir/vad.scp ] && utils/filter_scp.pl $destdir/utt2spk <$srcdir/vad.scp >$destdir/vad.scp [ -f $srcdir/utt2lang ] && utils/filter_scp.pl $destdir/utt2spk <$srcdir/utt2lang >$destdir/utt2lang [ -f $srcdir/utt2dur ] && utils/filter_scp.pl $destdir/utt2spk <$srcdir/utt2dur >$destdir/utt2dur + [ -f $srcdir/utt2uniq ] && utils/filter_scp.pl $destdir/utt2spk <$srcdir/utt2uniq >$destdir/utt2uniq [ -f $srcdir/wav.scp ] && utils/filter_scp.pl $destdir/utt2spk <$srcdir/wav.scp >$destdir/wav.scp [ -f $srcdir/spk2warp ] && utils/filter_scp.pl $destdir/spk2utt <$srcdir/spk2warp >$destdir/spk2warp [ -f $srcdir/utt2warp ] && utils/filter_scp.pl $destdir/utt2spk <$srcdir/utt2warp >$destdir/utt2warp @@ -126,6 +127,10 @@ function do_filtering { [ -f $srcdir/stm ] && utils/filter_scp.pl $destdir/reco < $srcdir/stm > $destdir/stm rm $destdir/reco + else + awk '{print $1;}' $destdir/wav.scp | sort | uniq > $destdir/reco + [ -f $srcdir/reco2file_and_channel ] && \ + utils/filter_scp.pl $destdir/reco <$srcdir/reco2file_and_channel >$destdir/reco2file_and_channel fi srcutts=`cat $srcdir/utt2spk | wc -l` destutts=`cat $destdir/utt2spk | wc -l` From 64d863ac0c6c950c2da04fe6370b8dd18544feff Mon Sep 17 00:00:00 2001 From: Vimal Manohar Date: Thu, 24 Nov 2016 00:04:34 -0500 Subject: [PATCH 016/213] asr_diarization: Moved evaluate_segmentation.pl to steps/segmentation --- .../local/resegment/evaluate_segmentation.pl | 199 +----------------- .../segmentation/evaluate_segmentation.pl | 198 +++++++++++++++++ 2 files changed, 199 insertions(+), 198 deletions(-) mode change 100755 => 120000 egs/babel/s5c/local/resegment/evaluate_segmentation.pl create mode 100755 egs/wsj/s5/steps/segmentation/evaluate_segmentation.pl diff --git a/egs/babel/s5c/local/resegment/evaluate_segmentation.pl b/egs/babel/s5c/local/resegment/evaluate_segmentation.pl deleted file mode 100755 index 06a762d7762..00000000000 --- a/egs/babel/s5c/local/resegment/evaluate_segmentation.pl +++ /dev/null @@ -1,198 +0,0 @@ -#!/usr/bin/env perl - -# Copyright 2014 Johns Hopkins University (Author: Sanjeev Khudanpur), Vimal Manohar -# Apache 2.0 - -################################################################################ -# -# This script was written to check the goodness of automatic segmentation tools -# It assumes input in the form of two Kaldi segments files, i.e. a file each of -# whose lines contain four space-separated values: -# -# UtteranceID FileID StartTime EndTime -# -# It computes # missed frames, # false positives and # overlapping frames. -# -################################################################################ - -if ($#ARGV == 1) { - $ReferenceSegmentation = $ARGV[0]; - $HypothesizedSegmentation = $ARGV[1]; - printf STDERR ("Comparing reference segmentation\n\t%s\nwith proposed segmentation\n\t%s\n", - $ReferenceSegmentation, - $HypothesizedSegmentation); -} else { - printf STDERR "This program compares the reference segmenation with the proposted segmentation\n"; - printf STDERR "Usage: $0 reference_segments_filename proposed_segments_filename\n"; - printf STDERR "e.g. $0 data/dev10h/segments data/dev10h.seg/segments\n"; - exit (0); -} - -################################################################################ -# First read the reference segmentation, and -# store the start- and end-times of all segments in each file. -################################################################################ - -open (SEGMENTS, "cat $ReferenceSegmentation | sort -k2,2 -k3n,3 -k4n,4 |") - || die "Unable to open $ReferenceSegmentation"; -$numLines = 0; -while ($line=) { - chomp $line; - @field = split("[ \t]+", $line); - unless ($#field == 3) { - exit (1); - printf STDERR "Skipping unparseable line in file $ReferenceSegmentation\n\t$line\n"; - next; - } - $fileID = $field[1]; - unless (exists $firstSeg{$fileID}) { - $firstSeg{$fileID} = $numLines; - $actualSpeech{$fileID} = 0.0; - $hypothesizedSpeech{$fileID} = 0.0; - $foundSpeech{$fileID} = 0.0; - $falseAlarm{$fileID} = 0.0; - $minStartTime{$fileID} = 0.0; - $maxEndTime{$fileID} = 0.0; - } - $refSegName[$numLines] = $field[0]; - $refSegStart[$numLines] = $field[2]; - $refSegEnd[$numLines] = $field[3]; - $actualSpeech{$fileID} += ($field[3]-$field[2]); - $minStartTime{$fileID} = $field[2] if ($minStartTime{$fileID}>$field[2]); - $maxEndTime{$fileID} = $field[3] if ($maxEndTime{$fileID}<$field[3]); - $lastSeg{$fileID} = $numLines; - ++$numLines; -} -close(SEGMENTS); -print STDERR "Read $numLines segments from $ReferenceSegmentation\n"; - -################################################################################ -# Process hypothesized segments sequentially, and gather speech/nonspeech stats -################################################################################ - -open (SEGMENTS, "cat $HypothesizedSegmentation | sort -k2,2 -k1,1 |") - # Kaldi segments files are sorted by UtteranceID, but we re-sort them here - # so that all segments of a file are read together, sorted by start-time. - || die "Unable to open $HypothesizedSegmentation"; -$numLines = 0; -$totalHypSpeech = 0.0; -$totalFoundSpeech = 0.0; -$totalFalseAlarm = 0.0; -$numShortSegs = 0; -$numLongSegs = 0; -while ($line=) { - chomp $line; - @field = split("[ \t]+", $line); - unless ($#field == 3) { - exit (1); - printf STDERR "Skipping unparseable line in file $HypothesizedSegmentation\n\t$line\n"; - next; - } - $fileID = $field[1]; - $segStart = $field[2]; - $segEnd = $field[3]; - if (exists $firstSeg{$fileID}) { - # This FileID exists in the reference segmentation - # So gather statistics for this UtteranceID - $hypothesizedSpeech{$fileID} += ($segEnd-$segStart); - $totalHypSpeech += ($segEnd-$segStart); - if (($segStart>=$maxEndTime{$fileID}) || ($segEnd<=$minStartTime{$fileID})) { - # This entire segment is a false alarm - $falseAlarm{$fileID} += ($segEnd-$segStart); - $totalFalseAlarm += ($segEnd-$segStart); - } else { - # This segment may overlap one or more reference segments - $p = $firstSeg{$fileID}; - while ($refSegEnd[$p]<=$segStart) { - ++$p; - } - # The overlap, if any, begins at the reference segment p - $q = $lastSeg{$fileID}; - while ($refSegStart[$q]>=$segEnd) { - --$q; - } - # The overlap, if any, ends at the reference segment q - if ($q<$p) { - # This segment sits entirely in the nonspeech region - # between the two reference speech segments q and p - $falseAlarm{$fileID} += ($segEnd-$segStart); - $totalFalseAlarm += ($segEnd-$segStart); - } else { - if (($segEnd-$segStart)<0.20) { - # For diagnosing Pascal's VAD segmentation - print STDOUT "Found short speech region $line\n"; - ++$numShortSegs; - } elsif (($segEnd-$segStart)>60.0) { - ++$numLongSegs; - # For diagnosing Pascal's VAD segmentation - print STDOUT "Found long speech region $line\n"; - } - # There is some overlap with segments p through q - for ($s=$p; $s<=$q; ++$s) { - if ($segStart<$refSegStart[$s]) { - # There is a leading false alarm portion before s - $falseAlarm{$fileID} += ($refSegStart[$s]-$segStart); - $totalFalseAlarm += ($refSegStart[$s]-$segStart); - $segStart=$refSegStart[$s]; - } - $speechPortion = ($refSegEnd[$s]<$segEnd) ? - ($refSegEnd[$s]-$segStart) : ($segEnd-$segStart); - $foundSpeech{$fileID} += $speechPortion; - $totalFoundSpeech += $speechPortion; - $segStart=$refSegEnd[$s]; - } - if ($segEnd>$segStart) { - # There is a trailing false alarm portion after q - $falseAlarm{$fileID} += ($segEnd-$segStart); - $totalFalseAlarm += ($segEnd-$segStart); - } - } - } - } else { - # This FileID does not exist in the reference segmentation - # So all this speech counts as a false alarm - exit (1); - printf STDERR ("Unexpected fileID in hypothesized segments: %s", $fileID); - $totalFalseAlarm += ($segEnd-$segStart); - } - ++$numLines; -} -close(SEGMENTS); -print STDERR "Read $numLines segments from $HypothesizedSegmentation\n"; - -################################################################################ -# Now that all hypothesized segments have been processed, compute needed stats -################################################################################ - -$totalActualSpeech = 0.0; -$totalNonSpeechEst = 0.0; # This is just a crude estimate of total nonspeech. -foreach $fileID (sort keys %actualSpeech) { - $totalActualSpeech += $actualSpeech{$fileID}; - $totalNonSpeechEst += $maxEndTime{$fileID} - $actualSpeech{$fileID}; - ####################################################################### - # Print file-wise statistics to STDOUT; can pipe to /dev/null is needed - ####################################################################### - printf STDOUT ("%s: %.2f min actual speech, %.2f min hypothesized: %.2f min overlap (%d\%), %.2f min false alarm (~%d\%)\n", - $fileID, - ($actualSpeech{$fileID}/60.0), - ($hypothesizedSpeech{$fileID}/60.0), - ($foundSpeech{$fileID}/60.0), - ($foundSpeech{$fileID}*100/($actualSpeech{$fileID}+0.01)), - ($falseAlarm{$fileID}/60.0), - ($falseAlarm{$fileID}*100/($maxEndTime{$fileID}-$actualSpeech{$fileID}+0.01))); -} - -################################################################################ -# Finally, we have everything needed to report the segmentation statistics. -################################################################################ - -printf STDERR ("------------------------------------------------------------------------\n"); -printf STDERR ("TOTAL: %.2f hrs actual speech, %.2f hrs hypothesized: %.2f hrs overlap (%d\%), %.2f hrs false alarm (~%d\%)\n", - ($totalActualSpeech/3600.0), - ($totalHypSpeech/3600.0), - ($totalFoundSpeech/3600.0), - ($totalFoundSpeech*100/($totalActualSpeech+0.000001)), - ($totalFalseAlarm/3600.0), - ($totalFalseAlarm*100/($totalNonSpeechEst+0.000001))); -printf STDERR ("\t$numShortSegs segments < 0.2 sec and $numLongSegs segments > 60.0 sec\n"); -printf STDERR ("------------------------------------------------------------------------\n"); diff --git a/egs/babel/s5c/local/resegment/evaluate_segmentation.pl b/egs/babel/s5c/local/resegment/evaluate_segmentation.pl new file mode 120000 index 00000000000..09276466c2b --- /dev/null +++ b/egs/babel/s5c/local/resegment/evaluate_segmentation.pl @@ -0,0 +1 @@ +../../steps/segmentation/evaluate_segmentation.py \ No newline at end of file diff --git a/egs/wsj/s5/steps/segmentation/evaluate_segmentation.pl b/egs/wsj/s5/steps/segmentation/evaluate_segmentation.pl new file mode 100755 index 00000000000..06a762d7762 --- /dev/null +++ b/egs/wsj/s5/steps/segmentation/evaluate_segmentation.pl @@ -0,0 +1,198 @@ +#!/usr/bin/env perl + +# Copyright 2014 Johns Hopkins University (Author: Sanjeev Khudanpur), Vimal Manohar +# Apache 2.0 + +################################################################################ +# +# This script was written to check the goodness of automatic segmentation tools +# It assumes input in the form of two Kaldi segments files, i.e. a file each of +# whose lines contain four space-separated values: +# +# UtteranceID FileID StartTime EndTime +# +# It computes # missed frames, # false positives and # overlapping frames. +# +################################################################################ + +if ($#ARGV == 1) { + $ReferenceSegmentation = $ARGV[0]; + $HypothesizedSegmentation = $ARGV[1]; + printf STDERR ("Comparing reference segmentation\n\t%s\nwith proposed segmentation\n\t%s\n", + $ReferenceSegmentation, + $HypothesizedSegmentation); +} else { + printf STDERR "This program compares the reference segmenation with the proposted segmentation\n"; + printf STDERR "Usage: $0 reference_segments_filename proposed_segments_filename\n"; + printf STDERR "e.g. $0 data/dev10h/segments data/dev10h.seg/segments\n"; + exit (0); +} + +################################################################################ +# First read the reference segmentation, and +# store the start- and end-times of all segments in each file. +################################################################################ + +open (SEGMENTS, "cat $ReferenceSegmentation | sort -k2,2 -k3n,3 -k4n,4 |") + || die "Unable to open $ReferenceSegmentation"; +$numLines = 0; +while ($line=) { + chomp $line; + @field = split("[ \t]+", $line); + unless ($#field == 3) { + exit (1); + printf STDERR "Skipping unparseable line in file $ReferenceSegmentation\n\t$line\n"; + next; + } + $fileID = $field[1]; + unless (exists $firstSeg{$fileID}) { + $firstSeg{$fileID} = $numLines; + $actualSpeech{$fileID} = 0.0; + $hypothesizedSpeech{$fileID} = 0.0; + $foundSpeech{$fileID} = 0.0; + $falseAlarm{$fileID} = 0.0; + $minStartTime{$fileID} = 0.0; + $maxEndTime{$fileID} = 0.0; + } + $refSegName[$numLines] = $field[0]; + $refSegStart[$numLines] = $field[2]; + $refSegEnd[$numLines] = $field[3]; + $actualSpeech{$fileID} += ($field[3]-$field[2]); + $minStartTime{$fileID} = $field[2] if ($minStartTime{$fileID}>$field[2]); + $maxEndTime{$fileID} = $field[3] if ($maxEndTime{$fileID}<$field[3]); + $lastSeg{$fileID} = $numLines; + ++$numLines; +} +close(SEGMENTS); +print STDERR "Read $numLines segments from $ReferenceSegmentation\n"; + +################################################################################ +# Process hypothesized segments sequentially, and gather speech/nonspeech stats +################################################################################ + +open (SEGMENTS, "cat $HypothesizedSegmentation | sort -k2,2 -k1,1 |") + # Kaldi segments files are sorted by UtteranceID, but we re-sort them here + # so that all segments of a file are read together, sorted by start-time. + || die "Unable to open $HypothesizedSegmentation"; +$numLines = 0; +$totalHypSpeech = 0.0; +$totalFoundSpeech = 0.0; +$totalFalseAlarm = 0.0; +$numShortSegs = 0; +$numLongSegs = 0; +while ($line=) { + chomp $line; + @field = split("[ \t]+", $line); + unless ($#field == 3) { + exit (1); + printf STDERR "Skipping unparseable line in file $HypothesizedSegmentation\n\t$line\n"; + next; + } + $fileID = $field[1]; + $segStart = $field[2]; + $segEnd = $field[3]; + if (exists $firstSeg{$fileID}) { + # This FileID exists in the reference segmentation + # So gather statistics for this UtteranceID + $hypothesizedSpeech{$fileID} += ($segEnd-$segStart); + $totalHypSpeech += ($segEnd-$segStart); + if (($segStart>=$maxEndTime{$fileID}) || ($segEnd<=$minStartTime{$fileID})) { + # This entire segment is a false alarm + $falseAlarm{$fileID} += ($segEnd-$segStart); + $totalFalseAlarm += ($segEnd-$segStart); + } else { + # This segment may overlap one or more reference segments + $p = $firstSeg{$fileID}; + while ($refSegEnd[$p]<=$segStart) { + ++$p; + } + # The overlap, if any, begins at the reference segment p + $q = $lastSeg{$fileID}; + while ($refSegStart[$q]>=$segEnd) { + --$q; + } + # The overlap, if any, ends at the reference segment q + if ($q<$p) { + # This segment sits entirely in the nonspeech region + # between the two reference speech segments q and p + $falseAlarm{$fileID} += ($segEnd-$segStart); + $totalFalseAlarm += ($segEnd-$segStart); + } else { + if (($segEnd-$segStart)<0.20) { + # For diagnosing Pascal's VAD segmentation + print STDOUT "Found short speech region $line\n"; + ++$numShortSegs; + } elsif (($segEnd-$segStart)>60.0) { + ++$numLongSegs; + # For diagnosing Pascal's VAD segmentation + print STDOUT "Found long speech region $line\n"; + } + # There is some overlap with segments p through q + for ($s=$p; $s<=$q; ++$s) { + if ($segStart<$refSegStart[$s]) { + # There is a leading false alarm portion before s + $falseAlarm{$fileID} += ($refSegStart[$s]-$segStart); + $totalFalseAlarm += ($refSegStart[$s]-$segStart); + $segStart=$refSegStart[$s]; + } + $speechPortion = ($refSegEnd[$s]<$segEnd) ? + ($refSegEnd[$s]-$segStart) : ($segEnd-$segStart); + $foundSpeech{$fileID} += $speechPortion; + $totalFoundSpeech += $speechPortion; + $segStart=$refSegEnd[$s]; + } + if ($segEnd>$segStart) { + # There is a trailing false alarm portion after q + $falseAlarm{$fileID} += ($segEnd-$segStart); + $totalFalseAlarm += ($segEnd-$segStart); + } + } + } + } else { + # This FileID does not exist in the reference segmentation + # So all this speech counts as a false alarm + exit (1); + printf STDERR ("Unexpected fileID in hypothesized segments: %s", $fileID); + $totalFalseAlarm += ($segEnd-$segStart); + } + ++$numLines; +} +close(SEGMENTS); +print STDERR "Read $numLines segments from $HypothesizedSegmentation\n"; + +################################################################################ +# Now that all hypothesized segments have been processed, compute needed stats +################################################################################ + +$totalActualSpeech = 0.0; +$totalNonSpeechEst = 0.0; # This is just a crude estimate of total nonspeech. +foreach $fileID (sort keys %actualSpeech) { + $totalActualSpeech += $actualSpeech{$fileID}; + $totalNonSpeechEst += $maxEndTime{$fileID} - $actualSpeech{$fileID}; + ####################################################################### + # Print file-wise statistics to STDOUT; can pipe to /dev/null is needed + ####################################################################### + printf STDOUT ("%s: %.2f min actual speech, %.2f min hypothesized: %.2f min overlap (%d\%), %.2f min false alarm (~%d\%)\n", + $fileID, + ($actualSpeech{$fileID}/60.0), + ($hypothesizedSpeech{$fileID}/60.0), + ($foundSpeech{$fileID}/60.0), + ($foundSpeech{$fileID}*100/($actualSpeech{$fileID}+0.01)), + ($falseAlarm{$fileID}/60.0), + ($falseAlarm{$fileID}*100/($maxEndTime{$fileID}-$actualSpeech{$fileID}+0.01))); +} + +################################################################################ +# Finally, we have everything needed to report the segmentation statistics. +################################################################################ + +printf STDERR ("------------------------------------------------------------------------\n"); +printf STDERR ("TOTAL: %.2f hrs actual speech, %.2f hrs hypothesized: %.2f hrs overlap (%d\%), %.2f hrs false alarm (~%d\%)\n", + ($totalActualSpeech/3600.0), + ($totalHypSpeech/3600.0), + ($totalFoundSpeech/3600.0), + ($totalFoundSpeech*100/($totalActualSpeech+0.000001)), + ($totalFalseAlarm/3600.0), + ($totalFalseAlarm*100/($totalNonSpeechEst+0.000001))); +printf STDERR ("\t$numShortSegs segments < 0.2 sec and $numLongSegs segments > 60.0 sec\n"); +printf STDERR ("------------------------------------------------------------------------\n"); From 7478ae14ed98b95619b5f75e3549b3b5c34380d2 Mon Sep 17 00:00:00 2001 From: Vimal Manohar Date: Thu, 24 Nov 2016 01:34:26 -0500 Subject: [PATCH 017/213] asr_diarization: Modify perturb_data_dir_volume.sh to write reco2vol and have limits --- .../s5/utils/data/perturb_data_dir_volume.sh | 60 ++++++++++++++++--- 1 file changed, 51 insertions(+), 9 deletions(-) diff --git a/egs/wsj/s5/utils/data/perturb_data_dir_volume.sh b/egs/wsj/s5/utils/data/perturb_data_dir_volume.sh index bc76939643c..185c7abf426 100755 --- a/egs/wsj/s5/utils/data/perturb_data_dir_volume.sh +++ b/egs/wsj/s5/utils/data/perturb_data_dir_volume.sh @@ -7,6 +7,11 @@ # the wav.scp to perturb the volume (typically useful for training data when # using systems that don't have cepstral mean normalization). +reco2vol= +force=false +scale_low=0.125 +scale_high=2 + . utils/parse_options.sh if [ $# != 1 ]; then @@ -25,30 +30,67 @@ if [ ! -f $data/wav.scp ]; then exit 1 fi -if grep -q "sox --vol" $data/wav.scp; then +if ! $force && grep -q "sox --vol" $data/wav.scp; then echo "$0: It looks like the data was already volume perturbed. Not doing anything." exit 0 fi -cat $data/wav.scp | python -c " +if [ -z "$reco2vol" ]; then + cat $data/wav.scp | python -c " import sys, os, subprocess, re, random random.seed(0) -scale_low = 1.0/8 -scale_high = 2.0 +scale_low = $scale_low +scale_high = $scale_high +volume_writer = open('$data/reco2vol', 'w') +for line in sys.stdin.readlines(): + if len(line.strip()) == 0: + continue + # Handle three cases of rxfilenames appropriately; 'input piped command', 'file offset' and 'filename' + vol = random.uniform(scale_low, scale_high) + + parts = line.strip().split() + if line.strip()[-1] == '|': + print '{0} sox --vol {1} -t wav - -t wav - |'.format(line.strip(), vol) + elif re.search(':[0-9]+$', line.strip()) is not None: + print '{id} wav-copy {wav} - | sox --vol {vol} -t wav - -t wav - |'.format(id = parts[0], wav=' '.join(parts[1:]), vol = vol) + else: + print '{id} sox --vol {vol} -t wav {wav} -t wav - |'.format(id = parts[0], wav=' '.join(parts[1:]), vol = vol) + volume_writer.write('{id} {vol}\n'.format(id = parts[0], vol = vol)) +" > $data/wav.scp_scaled || exit 1; +else + cat $data/wav.scp | python -c " +import sys, os, subprocess, re +volumes = {} +for line in open('$reco2vol'): + if len(line.strip()) == 0: + continue + parts = line.strip().split() + volumes[parts[0]] = float(parts[1]) + for line in sys.stdin.readlines(): if len(line.strip()) == 0: continue # Handle three cases of rxfilenames appropriately; 'input piped command', 'file offset' and 'filename' + + parts = line.strip().split() + id = parts[0] + + if id not in volumes: + raise Exception('Could not find volume for id {id}'.format(id = id)) + + vol = volumes[id] + if line.strip()[-1] == '|': - print '{0} sox --vol {1} -t wav - -t wav - |'.format(line.strip(), random.uniform(scale_low, scale_high)) + print '{0} sox --vol {1} -t wav - -t wav - |'.format(line.strip(), vol) elif re.search(':[0-9]+$', line.strip()) is not None: - parts = line.split() - print '{id} wav-copy {wav} - | sox --vol {vol} -t wav - -t wav - |'.format(id = parts[0], wav=' '.join(parts[1:]), vol = random.uniform(scale_low, scale_high)) + print '{id} wav-copy {wav} - | sox --vol {vol} -t wav - -t wav - |'.format(id = parts[0], wav=' '.join(parts[1:]), vol = vol) else: - parts = line.split() - print '{id} sox --vol {vol} -t wav {wav} -t wav - |'.format(id = parts[0], wav=' '.join(parts[1:]), vol = random.uniform(scale_low, scale_high)) + print '{id} sox --vol {vol} -t wav {wav} -t wav - |'.format(id = parts[0], wav=' '.join(parts[1:]), vol = vol) " > $data/wav.scp_scaled || exit 1; + cp $reco2vol $data/reco2vol +fi + len1=$(cat $data/wav.scp | wc -l) len2=$(cat $data/wav.scp_scaled | wc -l) if [ "$len1" != "$len2" ]; then From aaa35ff772a51cf14bedb60ffa3b68dea60fba02 Mon Sep 17 00:00:00 2001 From: Vimal Manohar Date: Thu, 24 Nov 2016 01:36:35 -0500 Subject: [PATCH 018/213] asr_diarization: Get reverberated version of scp --- .../s5/steps/segmentation/get_reverb_scp.pl | 58 +++++++++++++++++++ 1 file changed, 58 insertions(+) create mode 100755 egs/wsj/s5/steps/segmentation/get_reverb_scp.pl diff --git a/egs/wsj/s5/steps/segmentation/get_reverb_scp.pl b/egs/wsj/s5/steps/segmentation/get_reverb_scp.pl new file mode 100755 index 00000000000..57f63b517f2 --- /dev/null +++ b/egs/wsj/s5/steps/segmentation/get_reverb_scp.pl @@ -0,0 +1,58 @@ +#! /usr/bin/perl +use strict; +use warnings; + +my $field_begin = -1; +my $field_end = -1; + +if ($ARGV[0] eq "-f") { + shift @ARGV; + my $field_spec = shift @ARGV; + if ($field_spec =~ m/^\d+$/) { + $field_begin = $field_spec - 1; $field_end = $field_spec - 1; + } + if ($field_spec =~ m/^(\d*)[-:](\d*)/) { # accept e.g. 1:10 as a courtesty (properly, 1-10) + if ($1 ne "") { + $field_begin = $1 - 1; # Change to zero-based indexing. + } + if ($2 ne "") { + $field_end = $2 - 1; # Change to zero-based indexing. + } + } + if (!defined $field_begin && !defined $field_end) { + die "Bad argument to -f option: $field_spec"; + } +} + +if (scalar @ARGV != 1 && scalar @ARGV != 2 ) { + print "Usage: get_reverb_scp.pl [-f -] [] < input_scp > output_scp\n"; + exit(1); +} + +my $num_reps = $ARGV[0]; +my $prefix = "rev"; + +if (scalar @ARGV == 2) { + $prefix = $ARGV[1]; +} + +while () { + chomp; + my @A = split; + + for (my $i = 1; $i <= $num_reps; $i++) { + for (my $pos = 0; $pos <= $#A; $pos++) { + my $a = $A[$pos]; + if ( ($field_begin < 0 || $pos >= $field_begin) + && ($field_end < 0 || $pos <= $field_end) ) { + if ($a =~ m/^(sp[0-9.]+-)(.+)$/) { + $a = $1 . "$prefix" . $i . "_" . $2; + } else { + $a = "$prefix" . $i . "_" . $a; + } + } + print $a . " "; + } + print "\n"; + } +} From e04f86f04d8bd993f3300a2619cc643aef46b15c Mon Sep 17 00:00:00 2001 From: Vimal Manohar Date: Thu, 24 Nov 2016 01:36:13 -0500 Subject: [PATCH 019/213] asr_diarization: Adding script split_data_on_reco.sh --- .../steps/segmentation/split_data_on_reco.sh | 29 +++++++++++++++++++ 1 file changed, 29 insertions(+) create mode 100755 egs/wsj/s5/steps/segmentation/split_data_on_reco.sh diff --git a/egs/wsj/s5/steps/segmentation/split_data_on_reco.sh b/egs/wsj/s5/steps/segmentation/split_data_on_reco.sh new file mode 100755 index 00000000000..4c167d99a1e --- /dev/null +++ b/egs/wsj/s5/steps/segmentation/split_data_on_reco.sh @@ -0,0 +1,29 @@ +#! /bin/bash + +set -e + +if [ $# -ne 3 ]; then + echo "Usage: split_data_on_reco.sh " + exit 1 +fi + +ref_data=$1 +data=$2 +nj=$3 + +utils/data/get_reco2utt.sh $ref_data +utils/data/get_reco2utt.sh $data + +utils/split_data.sh --per-reco $ref_data $nj + +for n in `seq $nj`; do + srn=$ref_data/split${nj}reco/$n + dsn=$data/split${nj}reco/$n + + mkdir -p $dsn + + utils/data/get_reco2utt.sh $srn + utils/filter_scp.pl $srn/reco2utt $data/reco2utt > $dsn/reco2utt + utils/spk2utt_to_utt2spk.pl $dsn/reco2utt > $dsn/utt2reco + utils/subset_data_dir.sh --utt-list $dsn/utt2reco $data $dsn +done From 3dc469299dc054a416a12ae2fae29b92f752546a Mon Sep 17 00:00:00 2001 From: Vimal Manohar Date: Thu, 24 Nov 2016 01:39:18 -0500 Subject: [PATCH 020/213] asr_diarization: add per-reco option to split_data.sh --- egs/wsj/s5/utils/split_data.sh | 25 ++++++++++++++++++++++--- 1 file changed, 22 insertions(+), 3 deletions(-) diff --git a/egs/wsj/s5/utils/split_data.sh b/egs/wsj/s5/utils/split_data.sh index e44a4ab6359..646830481db 100755 --- a/egs/wsj/s5/utils/split_data.sh +++ b/egs/wsj/s5/utils/split_data.sh @@ -16,9 +16,14 @@ # limitations under the License. split_per_spk=true +split_per_reco=false if [ "$1" == "--per-utt" ]; then split_per_spk=false shift +elif [ "$1" == "--per-reco" ]; then + split_per_spk=false + split_per_reco=true + shift fi if [ $# != 2 ]; then @@ -59,10 +64,14 @@ if [ -f $data/text ] && [ $nu -ne $nt ]; then echo "** use utils/fix_data_dir.sh to fix this." fi - if $split_per_spk; then utt2spk_opt="--utt2spk=$data/utt2spk" utt="" +elif $split_per_reco; then + utils/data/get_reco2utt.sh $data + utils/spk2utt_to_utt2spk.pl $data/reco2utt > $data/utt2reco + utt2spk_opt="--utt2spk=$data/utt2reco" + utt="reco" else utt2spk_opt= utt="utt" @@ -86,6 +95,7 @@ if ! $need_to_split; then fi utt2spks=$(for n in `seq $numsplit`; do echo $data/split${numsplit}${utt}/$n/utt2spk; done) +utt2recos=$(for n in `seq $numsplit`; do echo $data/split${numsplit}${utt}/$n/utt2reco; done) directories=$(for n in `seq $numsplit`; do echo $data/split${numsplit}${utt}/$n; done) @@ -100,11 +110,20 @@ fi which lockfile >&/dev/null && lockfile -l 60 $data/.split_lock trap 'rm -f $data/.split_lock' EXIT HUP INT PIPE TERM -utils/split_scp.pl $utt2spk_opt $data/utt2spk $utt2spks || exit 1 +if $split_per_reco; then + utils/split_scp.pl $utt2spk_opt $data/utt2reco $utt2recos || exit 1 +else + utils/split_scp.pl $utt2spk_opt $data/utt2spk $utt2spks || exit 1 +fi for n in `seq $numsplit`; do dsn=$data/split${numsplit}${utt}/$n - utils/utt2spk_to_spk2utt.pl $dsn/utt2spk > $dsn/spk2utt || exit 1; + + if $split_per_reco; then + utils/filter_scp.pl $dsn/utt2reco $data/utt2spk > $dsn/utt2spk + fi + + utils/utt2spk_to_spk2utt.pl $dsn/utt2spk > $dsn/spk2utt || exit 1 done maybe_wav_scp= From bfec70247b20d1e71792dca5f5d97a8ffcda4430 Mon Sep 17 00:00:00 2001 From: Vimal Manohar Date: Wed, 23 Nov 2016 15:47:38 -0500 Subject: [PATCH 021/213] asr_diarization: Added deriv weights and xent per dim objective --- src/nnet3/nnet-diagnostics.cc | 24 ++++--- src/nnet3/nnet-diagnostics.h | 13 +++- src/nnet3/nnet-example-utils.cc | 65 ++++++++++++++----- src/nnet3/nnet-example.cc | 70 ++++++++++++++++----- src/nnet3/nnet-example.h | 35 ++++++++++- src/nnet3/nnet-nnet.cc | 12 +++- src/nnet3/nnet-nnet.h | 9 ++- src/nnet3/nnet-training.cc | 98 +++++++++++++++++++++++++---- src/nnet3/nnet-training.h | 11 +++- src/nnet3bin/nnet3-acc-lda-stats.cc | 14 ++++- src/nnet3bin/nnet3-copy-egs.cc | 14 +++++ 11 files changed, 302 insertions(+), 63 deletions(-) diff --git a/src/nnet3/nnet-diagnostics.cc b/src/nnet3/nnet-diagnostics.cc index 7f7d485ffe0..64abe8a0578 100644 --- a/src/nnet3/nnet-diagnostics.cc +++ b/src/nnet3/nnet-diagnostics.cc @@ -92,20 +92,24 @@ void NnetComputeProb::ProcessOutputs(const NnetExample &eg, << "mismatch for '" << io.name << "': " << output.NumCols() << " (nnet) vs. " << io.features.NumCols() << " (egs)\n"; } + + const Vector *deriv_weights = NULL; + if (config_.apply_deriv_weights && io.deriv_weights.Dim() > 0) + deriv_weights = &(io.deriv_weights); { BaseFloat tot_weight, tot_objf; bool supply_deriv = config_.compute_deriv; ComputeObjectiveFunction(io.features, obj_type, io.name, supply_deriv, computer, - &tot_weight, &tot_objf); + &tot_weight, &tot_objf, deriv_weights); SimpleObjectiveInfo &totals = objf_info_[io.name]; totals.tot_weight += tot_weight; totals.tot_objective += tot_objf; } - if (obj_type == kLinear && config_.compute_accuracy) { + if (config_.compute_accuracy) { BaseFloat tot_weight, tot_accuracy; ComputeAccuracy(io.features, output, - &tot_weight, &tot_accuracy); + &tot_weight, &tot_accuracy, deriv_weights); SimpleObjectiveInfo &totals = accuracy_info_[io.name]; totals.tot_weight += tot_weight; totals.tot_objective += tot_accuracy; @@ -156,7 +160,8 @@ bool NnetComputeProb::PrintTotalStats() const { void ComputeAccuracy(const GeneralMatrix &supervision, const CuMatrixBase &nnet_output, BaseFloat *tot_weight_out, - BaseFloat *tot_accuracy_out) { + BaseFloat *tot_accuracy_out, + const Vector *deriv_weights) { int32 num_rows = nnet_output.NumRows(), num_cols = nnet_output.NumCols(); KALDI_ASSERT(supervision.NumRows() == num_rows && @@ -181,24 +186,27 @@ void ComputeAccuracy(const GeneralMatrix &supervision, for (int32 r = 0; r < num_rows; r++) { SubVector vec(mat, r); BaseFloat row_sum = vec.Sum(); - KALDI_ASSERT(row_sum >= 0.0); + // KALDI_ASSERT(row_sum >= 0.0); // For conventional ASR systems int32 best_index; vec.Max(&best_index); // discard max value. + if (deriv_weights) + row_sum *= (*deriv_weights)(r); tot_weight += row_sum; if (best_index == best_index_cpu[r]) tot_accuracy += row_sum; } break; - } case kFullMatrix: { const Matrix &mat = supervision.GetFullMatrix(); for (int32 r = 0; r < num_rows; r++) { SubVector vec(mat, r); BaseFloat row_sum = vec.Sum(); - KALDI_ASSERT(row_sum >= 0.0); + // KALDI_ASSERT(row_sum >= 0.0); // For conventional ASR systems int32 best_index; vec.Max(&best_index); // discard max value. + if (deriv_weights) + row_sum *= (*deriv_weights)(r); tot_weight += row_sum; if (best_index == best_index_cpu[r]) tot_accuracy += row_sum; @@ -212,6 +220,8 @@ void ComputeAccuracy(const GeneralMatrix &supervision, BaseFloat row_sum = row.Sum(); int32 best_index; row.Max(&best_index); + if (deriv_weights) + row_sum *= (*deriv_weights)(r); KALDI_ASSERT(best_index < num_cols); tot_weight += row_sum; if (best_index == best_index_cpu[r]) diff --git a/src/nnet3/nnet-diagnostics.h b/src/nnet3/nnet-diagnostics.h index 298548857dd..6ed6c4a33a7 100644 --- a/src/nnet3/nnet-diagnostics.h +++ b/src/nnet3/nnet-diagnostics.h @@ -36,7 +36,6 @@ struct SimpleObjectiveInfo { double tot_objective; SimpleObjectiveInfo(): tot_weight(0.0), tot_objective(0.0) { } - }; @@ -44,12 +43,15 @@ struct NnetComputeProbOptions { bool debug_computation; bool compute_deriv; bool compute_accuracy; + bool apply_deriv_weights; + NnetOptimizeOptions optimize_config; NnetComputeOptions compute_config; NnetComputeProbOptions(): debug_computation(false), compute_deriv(false), - compute_accuracy(true) { } + compute_accuracy(true), + apply_deriv_weights(true) { } void Register(OptionsItf *opts) { // compute_deriv is not included in the command line options // because it's not relevant for nnet3-compute-prob. @@ -57,6 +59,9 @@ struct NnetComputeProbOptions { "debug for the actual computation (very verbose!)"); opts->Register("compute-accuracy", &compute_accuracy, "If true, compute " "accuracy values as well as objective functions"); + opts->Register("apply-deriv-weights", &apply_deriv_weights, + "Apply per-frame deriv weights"); + // register the optimization options with the prefix "optimization". ParseOptions optimization_opts("optimization", opts); optimize_config.Register(&optimization_opts); @@ -102,6 +107,7 @@ class NnetComputeProb { const Nnet &GetDeriv() const; ~NnetComputeProb(); + private: void ProcessOutputs(const NnetExample &eg, NnetComputer *computer); @@ -152,7 +158,8 @@ class NnetComputeProb { void ComputeAccuracy(const GeneralMatrix &supervision, const CuMatrixBase &nnet_output, BaseFloat *tot_weight, - BaseFloat *tot_accuracy); + BaseFloat *tot_accuracy, + const Vector *deriv_weights = NULL); } // namespace nnet3 diff --git a/src/nnet3/nnet-example-utils.cc b/src/nnet3/nnet-example-utils.cc index 30f7840f6f8..39922153db4 100644 --- a/src/nnet3/nnet-example-utils.cc +++ b/src/nnet3/nnet-example-utils.cc @@ -63,9 +63,9 @@ static void GetIoSizes(const std::vector &src, KALDI_ASSERT(*names_iter == io.name); int32 i = names_iter - names_begin; int32 this_dim = io.features.NumCols(); - if (dims[i] == -1) + if (dims[i] == -1) { dims[i] = this_dim; - else if(dims[i] != this_dim) { + } else if (dims[i] != this_dim) { KALDI_ERR << "Merging examples with inconsistent feature dims: " << dims[i] << " vs. " << this_dim << " for '" << io.name << "'."; @@ -87,9 +87,20 @@ static void MergeIo(const std::vector &src, const std::vector &sizes, bool compress, NnetExample *merged_eg) { + // The total number of Indexes we have across all examples. int32 num_feats = names.size(); + std::vector cur_size(num_feats, 0); + + // The features in the different NnetIo in the Indexes across all examples std::vector > output_lists(num_feats); + + // The deriv weights in the different NnetIo in the Indexes across all + // examples + std::vector const*> > + output_deriv_weights(num_feats); + + // Initialize the merged_eg merged_eg->io.clear(); merged_eg->io.resize(num_feats); for (int32 f = 0; f < num_feats; f++) { @@ -100,22 +111,29 @@ static void MergeIo(const std::vector &src, io.indexes.resize(size); } - std::vector::const_iterator names_begin = names.begin(), + std::vector::const_iterator names_begin = names.begin(); names_end = names.end(); - std::vector::const_iterator iter = src.begin(), end = src.end(); - for (int32 n = 0; iter != end; ++iter,++n) { - std::vector::const_iterator iter2 = iter->io.begin(), - end2 = iter->io.end(); - for (; iter2 != end2; ++iter2) { - const NnetIo &io = *iter2; + std::vector::const_iterator eg_iter = src.begin(), + eg_end = src.end(); + for (int32 n = 0; eg_iter != eg_end; ++eg_iter, ++n) { + std::vector::const_iterator io_iter = eg_iter->io.begin(), + io_end = eg_iter->io.end(); + for (; io_iter != io_end; ++io_iter) { + const NnetIo &io = *io_iter; std::vector::const_iterator names_iter = std::lower_bound(names_begin, names_end, io.name); KALDI_ASSERT(*names_iter == io.name); + int32 f = names_iter - names_begin; - int32 this_size = io.indexes.size(), - &this_offset = cur_size[f]; + int32 this_size = io.indexes.size(); + int32 &this_offset = cur_size[f]; KALDI_ASSERT(this_size + this_offset <= sizes[f]); + + // Add f^th Io's features and deriv_weights output_lists[f].push_back(&(io.features)); + output_deriv_weights[f].push_back(&(io.deriv_weights)); + + // Work on the Indexes for the f^th Io in merged_eg NnetIo &output_io = merged_eg->io[f]; std::copy(io.indexes.begin(), io.indexes.end(), output_io.indexes.begin() + this_offset); @@ -139,10 +157,26 @@ static void MergeIo(const std::vector &src, // the following won't do anything if the features were sparse. merged_eg->io[f].features.Compress(); } - } -} + Vector &this_deriv_weights = merged_eg->io[f].deriv_weights; + if (output_deriv_weights[f][0]->Dim() > 0) { + this_deriv_weights.Resize( + merged_eg->io[f].indexes.size(), kUndefined); + KALDI_ASSERT(this_deriv_weights.Dim() == + merged_eg->io[f].features.NumRows()); + std::vector const*>::const_iterator + it = output_deriv_weights[f].begin(), + end = output_deriv_weights[f].end(); + + for (int32 i = 0, cur_offset = 0; it != end; ++it, i++) { + KALDI_ASSERT((*it)->Dim() == output_lists[f][i]->NumRows()); + this_deriv_weights.Range(cur_offset, (*it)->Dim()).CopyFromVec(**it); + cur_offset += (*it)->Dim(); + } + } + } +} void MergeExamples(const std::vector &src, bool compress, @@ -282,9 +316,8 @@ void RoundUpNumFrames(int32 frame_subsampling_factor, KALDI_ERR << "--num-frames-overlap=" << (*num_frames_overlap) << " < " << "--num-frames=" << (*num_frames); } - } -} // namespace nnet3 -} // namespace kaldi +} // namespace nnet3 +} // namespace kaldi diff --git a/src/nnet3/nnet-example.cc b/src/nnet3/nnet-example.cc index 9a34258e0ee..11305f55324 100644 --- a/src/nnet3/nnet-example.cc +++ b/src/nnet3/nnet-example.cc @@ -19,6 +19,7 @@ // limitations under the License. #include "nnet3/nnet-example.h" +#include "nnet3/nnet-example-utils.h" #include "lat/lattice-functions.h" #include "hmm/posterior.h" @@ -31,6 +32,8 @@ void NnetIo::Write(std::ostream &os, bool binary) const { WriteToken(os, binary, name); WriteIndexVector(os, binary, indexes); features.Write(os, binary); + WriteToken(os, binary, ""); // for DerivWeights. Want to save space. + WriteVectorAsChar(os, binary, deriv_weights); WriteToken(os, binary, ""); KALDI_ASSERT(static_cast(features.NumRows()) == indexes.size()); } @@ -40,7 +43,14 @@ void NnetIo::Read(std::istream &is, bool binary) { ReadToken(is, binary, &name); ReadIndexVector(is, binary, &indexes); features.Read(is, binary); - ExpectToken(is, binary, ""); + std::string token; + ReadToken(is, binary, &token); + // in the future this back-compatibility code can be reworked. + if (token != "") { + KALDI_ASSERT(token == ""); + ReadVectorAsChar(is, binary, &deriv_weights); + ExpectToken(is, binary, ""); + } } bool NnetIo::operator == (const NnetIo &other) const { @@ -52,40 +62,70 @@ bool NnetIo::operator == (const NnetIo &other) const { Matrix this_mat, other_mat; features.GetMatrix(&this_mat); other.features.GetMatrix(&other_mat); - return ApproxEqual(this_mat, other_mat); + return (ApproxEqual(this_mat, other_mat) && + deriv_weights.ApproxEqual(other.deriv_weights)); } NnetIo::NnetIo(const std::string &name, - int32 t_begin, const MatrixBase &feats): + int32 t_begin, const MatrixBase &feats, + int32 skip_frame): name(name), features(feats) { - int32 num_rows = feats.NumRows(); - KALDI_ASSERT(num_rows > 0); - indexes.resize(num_rows); // sets all n,t,x to zeros. - for (int32 i = 0; i < num_rows; i++) - indexes[i].t = t_begin + i; + int32 num_skipped_rows = feats.NumRows(); + KALDI_ASSERT(num_skipped_rows > 0); + indexes.resize(num_skipped_rows); // sets all n,t,x to zeros. + for (int32 i = 0; i < num_skipped_rows; i++) + indexes[i].t = t_begin + i * skip_frame; +} + +NnetIo::NnetIo(const std::string &name, + const VectorBase &deriv_weights, + int32 t_begin, const MatrixBase &feats, + int32 skip_frame): + name(name), features(feats), deriv_weights(deriv_weights) { + int32 num_skipped_rows = feats.NumRows(); + KALDI_ASSERT(num_skipped_rows > 0); + indexes.resize(num_skipped_rows); // sets all n,t,x to zeros. + for (int32 i = 0; i < num_skipped_rows; i++) + indexes[i].t = t_begin + i * skip_frame; } void NnetIo::Swap(NnetIo *other) { name.swap(other->name); indexes.swap(other->indexes); features.Swap(&(other->features)); + deriv_weights.Swap(&(other->deriv_weights)); } NnetIo::NnetIo(const std::string &name, int32 dim, int32 t_begin, - const Posterior &labels): + const Posterior &labels, + int32 skip_frame): name(name) { - int32 num_rows = labels.size(); - KALDI_ASSERT(num_rows > 0); + int32 num_skipped_rows = labels.size(); + KALDI_ASSERT(num_skipped_rows > 0); SparseMatrix sparse_feats(dim, labels); features = sparse_feats; - indexes.resize(num_rows); // sets all n,t,x to zeros. - for (int32 i = 0; i < num_rows; i++) - indexes[i].t = t_begin + i; + indexes.resize(num_skipped_rows); // sets all n,t,x to zeros. + for (int32 i = 0; i < num_skipped_rows; i++) + indexes[i].t = t_begin + i * skip_frame; } - +NnetIo::NnetIo(const std::string &name, + const VectorBase &deriv_weights, + int32 dim, + int32 t_begin, + const Posterior &labels, + int32 skip_frame): + name(name), deriv_weights(deriv_weights) { + int32 num_skipped_rows = labels.size(); + KALDI_ASSERT(num_skipped_rows > 0); + SparseMatrix sparse_feats(dim, labels); + features = sparse_feats; + indexes.resize(num_skipped_rows); // sets all n,t,x to zeros. + for (int32 i = 0; i < num_skipped_rows; i++) + indexes[i].t = t_begin + i * skip_frame; +} void NnetExample::Write(std::ostream &os, bool binary) const { // Note: weight, label, input_frames and spk_info are members. This is a diff --git a/src/nnet3/nnet-example.h b/src/nnet3/nnet-example.h index 1df7cd1e78e..b1ae42a78c9 100644 --- a/src/nnet3/nnet-example.h +++ b/src/nnet3/nnet-example.h @@ -45,12 +45,32 @@ struct NnetIo { /// a Matrix, or SparseMatrix (a SparseMatrix would be the natural format for posteriors). GeneralMatrix features; + /// This is a vector of per-frame weights, required to be between 0 and 1, + /// that is applied to the derivative during training (but not during model + /// combination, where the derivatives need to agree with the computed objf + /// values for the optimization code to work). + /// If this vector is empty it means we're not applying per-frame weights, + /// so it's equivalent to a vector of all ones. This vector is written + /// to disk compactly as unsigned char. + Vector deriv_weights; + /// This constructor creates NnetIo with name "name", indexes with n=0, x=0, /// and t values ranging from t_begin to t_begin + feats.NumRows() - 1, and /// the provided features. t_begin should be the frame that feats.Row(0) /// represents. NnetIo(const std::string &name, - int32 t_begin, const MatrixBase &feats); + int32 t_begin, + const MatrixBase &feats, + int32 skip_frame = 1); + + /// This is similar to the above constructor but also takes in a + /// a deriv weights argument. + NnetIo(const std::string &name, + const VectorBase &deriv_weights, + int32 t_begin, + const MatrixBase &feats, + int32 skip_frame = 1); + /// This constructor sets "name" to the provided string, sets "indexes" with /// n=0, x=0, and t from t_begin to t_begin + labels.size() - 1, and the labels @@ -58,7 +78,17 @@ struct NnetIo { NnetIo(const std::string &name, int32 dim, int32 t_begin, - const Posterior &labels); + const Posterior &labels, + int32 skip_frame = 1); + + /// This is similar to the above constructor but also takes in a + /// a deriv weights argument. + NnetIo(const std::string &name, + const VectorBase &deriv_weights, + int32 dim, + int32 t_begin, + const Posterior &labels, + int32 skip_frame = 1); void Swap(NnetIo *other); @@ -80,7 +110,6 @@ struct NnetIo { /// more frames of input, used for standard cross-entropy training of neural /// nets (and possibly for other objective functions). struct NnetExample { - /// "io" contains the input and output. In principle there can be multiple /// types of both input and output, with different names. The order is /// irrelevant. diff --git a/src/nnet3/nnet-nnet.cc b/src/nnet3/nnet-nnet.cc index ad5f715a294..4fcbbc70a1f 100644 --- a/src/nnet3/nnet-nnet.cc +++ b/src/nnet3/nnet-nnet.cc @@ -84,8 +84,14 @@ std::string Nnet::GetAsConfigLine(int32 node_index, bool include_dim) const { node.descriptor.WriteConfig(ans, node_names_); if (include_dim) ans << " dim=" << node.Dim(*this); - ans << " objective=" << (node.u.objective_type == kLinear ? "linear" : - "quadratic"); + + if (node.u.objective_type == kLinear) + ans << " objective=linear"; + else if (node.u.objective_type == kQuadratic) + ans << " objective=quadratic"; + else if (node.u.objective_type == kXentPerDim) + ans << " objective=xent-per-dim"; + break; case kComponent: ans << "component-node name=" << name << " component=" @@ -390,6 +396,8 @@ void Nnet::ProcessOutputNodeConfigLine( nodes_[node_index].u.objective_type = kLinear; } else if (objective_type == "quadratic") { nodes_[node_index].u.objective_type = kQuadratic; + } else if (objective_type == "xent-per-dim") { + nodes_[node_index].u.objective_type = kXentPerDim; } else { KALDI_ERR << "Invalid objective type: " << objective_type; } diff --git a/src/nnet3/nnet-nnet.h b/src/nnet3/nnet-nnet.h index 16e8333d5b1..b9ed3c1052b 100644 --- a/src/nnet3/nnet-nnet.h +++ b/src/nnet3/nnet-nnet.h @@ -49,7 +49,12 @@ namespace nnet3 { /// - Objective type kQuadratic is used to mean the objective function /// f(x, y) = -0.5 (x-y).(x-y), which is to be maximized, as in the kLinear /// case. -enum ObjectiveType { kLinear, kQuadratic }; +/// - Objective type kXentPerDim is the objective function that is used +/// to learn a set of bernoulli random variables. +/// f(x, y) = x * y + (1-x) * Log(1-Exp(y)), where +/// x is the true probability of class 1 and +/// y is the predicted log probability of class 1 +enum ObjectiveType { kLinear, kQuadratic, kXentPerDim }; enum NodeType { kInput, kDescriptor, kComponent, kDimRange, kNone }; @@ -249,7 +254,7 @@ class Nnet { void ResetGenerators(); // resets random-number generators for all // random components. You must also set srand() for this to be // effective. - + private: void Destroy(); diff --git a/src/nnet3/nnet-training.cc b/src/nnet3/nnet-training.cc index 87d64e27871..9d957afe1de 100644 --- a/src/nnet3/nnet-training.cc +++ b/src/nnet3/nnet-training.cc @@ -39,7 +39,7 @@ NnetTrainer::NnetTrainer(const NnetTrainerOptions &config, // natural-gradient updates. SetZero(is_gradient, delta_nnet_); const int32 num_updatable = NumUpdatableComponents(*delta_nnet_); - num_max_change_per_component_applied_.resize(num_updatable, 0); + num_max_change_per_component_applied_.resize(num_updatable, 0); num_max_change_global_applied_ = 0; if (config_.read_cache != "") { @@ -52,7 +52,7 @@ NnetTrainer::NnetTrainer(const NnetTrainerOptions &config, KALDI_WARN << "Could not open cached computation. " "Probably this is the first training iteration."; } - } + } } @@ -88,9 +88,12 @@ void NnetTrainer::ProcessOutputs(const NnetExample &eg, ObjectiveType obj_type = nnet_->GetNode(node_index).u.objective_type; BaseFloat tot_weight, tot_objf; bool supply_deriv = true; + const Vector *deriv_weights = NULL; + if (config_.apply_deriv_weights && io.deriv_weights.Dim() > 0) + deriv_weights = &(io.deriv_weights); ComputeObjectiveFunction(io.features, obj_type, io.name, supply_deriv, computer, - &tot_weight, &tot_objf); + &tot_weight, &tot_objf, deriv_weights); objf_info_[io.name].UpdateStats(io.name, config_.print_interval, num_minibatches_processed_++, tot_weight, tot_objf); @@ -167,7 +170,7 @@ void NnetTrainer::UpdateParamsWithMaxChange() { << " / " << num_updatable << " Updatable Components." << "(smallest factor=" << min_scale << " on " << component_name_with_min_scale - << " with max-change=" << max_change_with_min_scale <<"). "; + << " with max-change=" << max_change_with_min_scale <<"). "; if (param_delta > config_.max_param_change) ostr << "Global max-change factor was " << config_.max_param_change / param_delta @@ -276,7 +279,7 @@ bool ObjectiveFunctionInfo::PrintTotalStats(const std::string &name) const { << (tot_objf / tot_weight) << " over " << tot_weight << " frames."; } else { KALDI_LOG << "Overall average objective function for '" << name << "' is " - << objf << " + " << aux_objf << " = " << sum_objf + << objf << " + " << aux_objf << " = " << sum_objf << " over " << tot_weight << " frames."; } KALDI_LOG << "[this line is to be parsed by a script:] " @@ -290,7 +293,7 @@ NnetTrainer::~NnetTrainer() { Output ko(config_.write_cache, config_.binary_write_cache); compiler_.WriteCache(ko.Stream(), config_.binary_write_cache); KALDI_LOG << "Wrote computation cache to " << config_.write_cache; - } + } delete delta_nnet_; } @@ -300,7 +303,8 @@ void ComputeObjectiveFunction(const GeneralMatrix &supervision, bool supply_deriv, NnetComputer *computer, BaseFloat *tot_weight, - BaseFloat *tot_objf) { + BaseFloat *tot_objf, + const VectorBase *deriv_weights) { const CuMatrixBase &output = computer->GetOutput(output_name); if (output.NumCols() != supervision.NumCols()) @@ -309,6 +313,51 @@ void ComputeObjectiveFunction(const GeneralMatrix &supervision, << " (nnet) vs. " << supervision.NumCols() << " (egs)\n"; switch (objective_type) { + case kXentPerDim: { + // objective is x * log(y) + (1-x) * log(1-y) + CuMatrix cu_post(supervision.NumRows(), supervision.NumCols(), + kUndefined); // x + cu_post.CopyFromGeneralMat(supervision); + + CuMatrix n_cu_post(cu_post.NumRows(), cu_post.NumCols()); + n_cu_post.Set(1.0); + n_cu_post.AddMat(-1.0, cu_post); // 1-x + + CuMatrix log_prob(output); // y + log_prob.ApplyLog(); // log(y) + + CuMatrix n_output(output.NumRows(), + output.NumCols(), kSetZero); + n_output.Set(1.0); + n_output.AddMat(-1.0, output); // 1-y + n_output.ApplyLog(); // log(1-y) + + BaseFloat num_elements = static_cast(cu_post.NumRows()); + if (deriv_weights) { + CuVector cu_deriv_weights(*deriv_weights); + num_elements = cu_deriv_weights.Sum(); + cu_post.MulRowsVec(cu_deriv_weights); + n_cu_post.MulRowsVec(cu_deriv_weights); + } + + *tot_weight = num_elements * cu_post.NumCols(); + *tot_objf = TraceMatMat(log_prob, cu_post, kTrans) + + TraceMatMat(n_output, n_cu_post, kTrans); + + if (supply_deriv) { + // deriv is x / y - (1-x) / (1-y) + n_output.ApplyExp(); // 1-y + n_cu_post.DivElements(n_output); // 1-x / (1-y) + + log_prob.ApplyExp(); // y + cu_post.DivElements(log_prob); // x / y + + cu_post.AddMat(-1.0, n_cu_post); // x / y - (1-x) / (1-y) + computer->AcceptOutputDeriv(output_name, &cu_post); + } + + break; + } case kLinear: { // objective is x * y. switch (supervision.Type()) { @@ -318,20 +367,38 @@ void ComputeObjectiveFunction(const GeneralMatrix &supervision, // The cross-entropy objective is computed by a simple dot product, // because after the LogSoftmaxLayer, the output is already in the form // of log-likelihoods that are normalized to sum to one. - *tot_weight = cu_post.Sum(); - *tot_objf = TraceMatSmat(output, cu_post, kTrans); - if (supply_deriv) { + if (deriv_weights) { CuMatrix output_deriv(output.NumRows(), output.NumCols(), kUndefined); cu_post.CopyToMat(&output_deriv); - computer->AcceptOutputDeriv(output_name, &output_deriv); + CuVector cu_deriv_weights(*deriv_weights); + output_deriv.MulRowsVec(cu_deriv_weights); + *tot_weight = cu_deriv_weights.Sum(); + *tot_objf = TraceMatMat(output, output_deriv, kTrans); + if (supply_deriv) { + computer->AcceptOutputDeriv(output_name, &output_deriv); + } + } else { + *tot_weight = cu_post.Sum(); + *tot_objf = TraceMatSmat(output, cu_post, kTrans); + if (supply_deriv) { + CuMatrix output_deriv(output.NumRows(), output.NumCols(), + kUndefined); + cu_post.CopyToMat(&output_deriv); + computer->AcceptOutputDeriv(output_name, &output_deriv); + } } + break; } case kFullMatrix: { // there is a redundant matrix copy in here if we're not using a GPU // but we don't anticipate this code branch being used in many cases. CuMatrix cu_post(supervision.GetFullMatrix()); + if (deriv_weights) { + CuVector cu_deriv_weights(*deriv_weights); + cu_post.MulRowsVec(cu_deriv_weights); + } *tot_weight = cu_post.Sum(); *tot_objf = TraceMatMat(output, cu_post, kTrans); if (supply_deriv) @@ -343,6 +410,10 @@ void ComputeObjectiveFunction(const GeneralMatrix &supervision, supervision.GetMatrix(&post); CuMatrix cu_post; cu_post.Swap(&post); + if (deriv_weights) { + CuVector cu_deriv_weights(*deriv_weights); + cu_post.MulRowsVec(cu_deriv_weights); + } *tot_weight = cu_post.Sum(); *tot_objf = TraceMatMat(output, cu_post, kTrans); if (supply_deriv) @@ -360,6 +431,11 @@ void ComputeObjectiveFunction(const GeneralMatrix &supervision, diff.CopyFromGeneralMat(supervision); diff.AddMat(-1.0, output); *tot_weight = diff.NumRows(); + if (deriv_weights) { + CuVector cu_deriv_weights(*deriv_weights); + diff.MulRowsVec(cu_deriv_weights); + *tot_weight = deriv_weights->Sum(); + } *tot_objf = -0.5 * TraceMatMat(diff, diff, kTrans); if (supply_deriv) computer->AcceptOutputDeriv(output_name, &diff); diff --git a/src/nnet3/nnet-training.h b/src/nnet3/nnet-training.h index 70c90267c66..7b22bc75211 100644 --- a/src/nnet3/nnet-training.h +++ b/src/nnet3/nnet-training.h @@ -42,6 +42,8 @@ struct NnetTrainerOptions { BaseFloat max_param_change; NnetOptimizeOptions optimize_config; NnetComputeOptions compute_config; + bool apply_deriv_weights; + NnetTrainerOptions(): zero_component_stats(true), store_component_stats(true), @@ -49,7 +51,8 @@ struct NnetTrainerOptions { debug_computation(false), momentum(0.0), binary_write_cache(true), - max_param_change(2.0) { } + max_param_change(2.0), + apply_deriv_weights(true) { } void Register(OptionsItf *opts) { opts->Register("store-component-stats", &store_component_stats, "If true, store activations and derivatives for nonlinear " @@ -69,6 +72,9 @@ struct NnetTrainerOptions { "so that the 'effective' learning rate is the same as " "before (because momentum would normally increase the " "effective learning rate by 1/(1-momentum))"); + opts->Register("apply-deriv-weights", &apply_deriv_weights, + "If true, apply the per-frame derivative weights stored with " + "the example"); opts->Register("read-cache", &read_cache, "the location where we can read " "the cached computation from"); opts->Register("write-cache", &write_cache, "the location where we want to " @@ -226,7 +232,8 @@ void ComputeObjectiveFunction(const GeneralMatrix &supervision, bool supply_deriv, NnetComputer *computer, BaseFloat *tot_weight, - BaseFloat *tot_objf); + BaseFloat *tot_objf, + const VectorBase* deriv_weights = NULL); diff --git a/src/nnet3bin/nnet3-acc-lda-stats.cc b/src/nnet3bin/nnet3-acc-lda-stats.cc index 0b3b537855e..b41c4a6704d 100644 --- a/src/nnet3bin/nnet3-acc-lda-stats.cc +++ b/src/nnet3bin/nnet3-acc-lda-stats.cc @@ -87,13 +87,18 @@ class NnetLdaStatsAccumulator { // but we're about to do an outer product, so this doesn't dominate. Vector row(cu_row); + BaseFloat deriv_weight = 1.0; + if (output_supervision->deriv_weights.Dim() > 0 && r < output_supervision->deriv_weights.Dim()) { + deriv_weight = output_supervision->deriv_weights(r); + } + const SparseVector &post(smat.Row(r)); const std::pair *post_data = post.Data(), *post_end = post_data + post.NumElements(); for (; post_data != post_end; ++post_data) { MatrixIndexT pdf = post_data->first; BaseFloat weight = post_data->second; - BaseFloat pruned_weight = RandPrune(weight, rand_prune); + BaseFloat pruned_weight = RandPrune(weight, rand_prune) * deriv_weight; if (pruned_weight != 0.0) lda_stats_.Accumulate(row, pdf, pruned_weight); } @@ -110,11 +115,16 @@ class NnetLdaStatsAccumulator { // but we're about to do an outer product, so this doesn't dominate. Vector row(cu_row); + BaseFloat deriv_weight = 1.0; + if (output_supervision->deriv_weights.Dim() > 0 && r < output_supervision->deriv_weights.Dim()) { + deriv_weight = output_supervision->deriv_weights(r); + } + SubVector post(output_mat, r); int32 num_pdfs = post.Dim(); for (int32 pdf = 0; pdf < num_pdfs; pdf++) { BaseFloat weight = post(pdf); - BaseFloat pruned_weight = RandPrune(weight, rand_prune); + BaseFloat pruned_weight = RandPrune(weight, rand_prune) * deriv_weight; if (pruned_weight != 0.0) lda_stats_.Accumulate(row, pdf, pruned_weight); } diff --git a/src/nnet3bin/nnet3-copy-egs.cc b/src/nnet3bin/nnet3-copy-egs.cc index efb51f51910..0b82d91353a 100644 --- a/src/nnet3bin/nnet3-copy-egs.cc +++ b/src/nnet3bin/nnet3-copy-egs.cc @@ -137,6 +137,7 @@ void FilterExample(const NnetExample &eg, if (!is_input_or_output) { // Just copy everything. io_out.indexes = io_in.indexes; io_out.features = io_in.features; + io_out.deriv_weights = io_in.deriv_weights; } else { const std::vector &indexes_in = io_in.indexes; std::vector &indexes_out = io_out.indexes; @@ -157,6 +158,19 @@ void FilterExample(const NnetExample &eg, } } KALDI_ASSERT(iter_out == keep.end()); + + if (io_in.deriv_weights.Dim() > 0) { + io_out.deriv_weights.Resize(num_kept, kUndefined); + int32 in_dim = 0, out_dim = 0; + iter_out = keep.begin(); + for (; iter_out != keep.end(); ++iter_out, in_dim++) { + if (*iter_out) + io_out.deriv_weights(out_dim++) = io_in.deriv_weights(in_dim); + } + KALDI_ASSERT(out_dim == num_kept); + KALDI_ASSERT(iter_out == keep.end()); + } + if (num_kept == 0) KALDI_ERR << "FilterExample removed all indexes for '" << name << "'"; From 99dcd967e4c0fce094456469b249943f4a1ec464 Mon Sep 17 00:00:00 2001 From: Vimal Manohar Date: Wed, 23 Nov 2016 15:54:29 -0500 Subject: [PATCH 022/213] asr_diarization: Adding compress format option --- src/matrix/compressed-matrix.cc | 52 +++++++++++++++++++-------------- src/matrix/compressed-matrix.h | 14 +++++---- src/matrix/sparse-matrix.cc | 4 +-- src/matrix/sparse-matrix.h | 6 ++-- src/nnet3/nnet-example.cc | 4 +-- src/nnet3/nnet-example.h | 17 +++++++++-- 6 files changed, 61 insertions(+), 36 deletions(-) diff --git a/src/matrix/compressed-matrix.cc b/src/matrix/compressed-matrix.cc index 2ac2c544bc8..6fc365c8f03 100644 --- a/src/matrix/compressed-matrix.cc +++ b/src/matrix/compressed-matrix.cc @@ -24,14 +24,14 @@ namespace kaldi { -//static +//static MatrixIndexT CompressedMatrix::DataSize(const GlobalHeader &header) { // Returns size in bytes of the data. if (header.format == 1) { return sizeof(GlobalHeader) + header.num_cols * (sizeof(PerColHeader) + header.num_rows); } else { - KALDI_ASSERT(header.format == 2) ; + KALDI_ASSERT(header.format == 2); return sizeof(GlobalHeader) + 2 * header.num_rows * header.num_cols; } @@ -40,7 +40,7 @@ MatrixIndexT CompressedMatrix::DataSize(const GlobalHeader &header) { template void CompressedMatrix::CopyFromMat( - const MatrixBase &mat) { + const MatrixBase &mat, int32 format) { if (data_ != NULL) { delete [] static_cast(data_); // call delete [] because was allocated with new float[] data_ = NULL; @@ -52,7 +52,7 @@ void CompressedMatrix::CopyFromMat( KALDI_COMPILE_TIME_ASSERT(sizeof(global_header) == 20); // otherwise // something weird is happening and our code probably won't work or // won't be robust across platforms. - + // Below, the point of the "safety_margin" is that the minimum // and maximum values in the matrix shouldn't coincide with // the minimum and maximum ranges of the 16-bit range, because @@ -80,16 +80,22 @@ void CompressedMatrix::CopyFromMat( global_header.num_rows = mat.NumRows(); global_header.num_cols = mat.NumCols(); - if (mat.NumRows() > 8) { - global_header.format = 1; // format where each row has a PerColHeader. + if (format <= 0) { + if (mat.NumRows() > 8) { + global_header.format = 1; // format where each row has a PerColHeader. + } else { + global_header.format = 2; // format where all data is uint16. + } + } else if (format == 1 || format == 2) { + global_header.format = format; } else { - global_header.format = 2; // format where all data is uint16. + KALDI_ERR << "Error format for compression:format should be <=2."; } - + int32 data_size = DataSize(global_header); data_ = AllocateData(data_size); - + *(reinterpret_cast(data_)) = global_header; if (global_header.format == 1) { @@ -124,10 +130,12 @@ void CompressedMatrix::CopyFromMat( // Instantiate the template for float and double. template -void CompressedMatrix::CopyFromMat(const MatrixBase &mat); +void CompressedMatrix::CopyFromMat(const MatrixBase &mat, + int32 format); template -void CompressedMatrix::CopyFromMat(const MatrixBase &mat); +void CompressedMatrix::CopyFromMat(const MatrixBase &mat, + int32 format); CompressedMatrix::CompressedMatrix( @@ -146,10 +154,10 @@ CompressedMatrix::CompressedMatrix( if (old_num_rows == 0) { return; } // Zero-size matrix stored as zero pointer. if (num_rows == 0 || num_cols == 0) { return; } - + GlobalHeader new_global_header; KALDI_COMPILE_TIME_ASSERT(sizeof(new_global_header) == 20); - + GlobalHeader *old_global_header = reinterpret_cast(cmat.Data()); new_global_header = *old_global_header; @@ -159,10 +167,10 @@ CompressedMatrix::CompressedMatrix( // We don't switch format from 1 -> 2 (in case of size reduction) yet; if this // is needed, we will do this below by creating a temporary Matrix. new_global_header.format = old_global_header->format; - + data_ = AllocateData(DataSize(new_global_header)); // allocate memory *(reinterpret_cast(data_)) = new_global_header; - + if (old_global_header->format == 1) { // Both have the format where we have a PerColHeader and then compress as // chars... @@ -196,7 +204,7 @@ CompressedMatrix::CompressedMatrix( reinterpret_cast(old_global_header + 1); uint16 *new_data = reinterpret_cast(reinterpret_cast(data_) + 1); - + old_data += col_offset + (old_num_cols * row_offset); for (int32 row = 0; row < num_rows; row++) { @@ -281,7 +289,7 @@ void CompressedMatrix::ComputeColHeader( // Now, sdata.begin(), sdata.begin() + quarter_nr, and sdata.begin() + // 3*quarter_nr, and sdata.end() - 1, contain the elements that would appear // at those positions in sorted order. - + header->percentile_0 = std::min(FloatToUint16(global_header, sdata[0]), 65532); header->percentile_25 = @@ -297,7 +305,7 @@ void CompressedMatrix::ComputeColHeader( header->percentile_100 = std::max( FloatToUint16(global_header, sdata[num_rows-1]), header->percentile_75 + static_cast(1)); - + } else { // handle this pathological case. std::sort(sdata.begin(), sdata.end()); // Note: we know num_rows is at least 1. @@ -382,7 +390,7 @@ void CompressedMatrix::CompressColumn( unsigned char *byte_data) { ComputeColHeader(global_header, data, stride, num_rows, header); - + float p0 = Uint16ToFloat(global_header, header->percentile_0), p25 = Uint16ToFloat(global_header, header->percentile_25), p75 = Uint16ToFloat(global_header, header->percentile_75), @@ -491,7 +499,7 @@ void CompressedMatrix::CopyToMat(MatrixBase *mat, mat->CopyFromMat(temp, kTrans); return; } - + if (data_ == NULL) { KALDI_ASSERT(mat->NumRows() == 0); KALDI_ASSERT(mat->NumCols() == 0); @@ -501,7 +509,7 @@ void CompressedMatrix::CopyToMat(MatrixBase *mat, int32 num_cols = h->num_cols, num_rows = h->num_rows; KALDI_ASSERT(mat->NumRows() == num_rows); KALDI_ASSERT(mat->NumCols() == num_cols); - + if (h->format == 1) { PerColHeader *per_col_header = reinterpret_cast(h+1); unsigned char *byte_data = reinterpret_cast(per_col_header + @@ -625,7 +633,7 @@ void CompressedMatrix::CopyToMat(int32 row_offset, GlobalHeader *h = reinterpret_cast(data_); int32 num_rows = h->num_rows, num_cols = h->num_cols, tgt_cols = dest->NumCols(), tgt_rows = dest->NumRows(); - + if (h->format == 1) { // format where we have a per-column header and use one byte per // element. diff --git a/src/matrix/compressed-matrix.h b/src/matrix/compressed-matrix.h index 603134ab800..a9dd1e4fcd2 100644 --- a/src/matrix/compressed-matrix.h +++ b/src/matrix/compressed-matrix.h @@ -35,12 +35,12 @@ namespace kaldi { /// column). /// The basic idea is for each column (in the normal configuration) -/// we work out the values at the 0th, 25th, 50th and 100th percentiles +/// we work out the values at the 0th, 25th, 75th and 100th percentiles /// and store them as 16-bit integers; we then encode each value in /// the column as a single byte, in 3 separate ranges with different -/// linear encodings (0-25th, 25-50th, 50th-100th). -/// If the matrix has 8 rows or fewer, we simply store all values as -/// uint16. +/// linear encodings (0-25th, 25-75th, 75th-100th). +/// If the matrix has 8 rows or fewer or format=2, we simply store all values +/// as uint16. class CompressedMatrix { public: @@ -49,7 +49,9 @@ class CompressedMatrix { ~CompressedMatrix() { Clear(); } template - CompressedMatrix(const MatrixBase &mat): data_(NULL) { CopyFromMat(mat); } + CompressedMatrix(const MatrixBase &mat, int32 format = 0): data_(NULL) { + CopyFromMat(mat, format); + } /// Initializer that can be used to select part of an existing /// CompressedMatrix without un-compressing and re-compressing (note: unlike @@ -65,7 +67,7 @@ class CompressedMatrix { /// This will resize *this and copy the contents of mat to *this. template - void CopyFromMat(const MatrixBase &mat); + void CopyFromMat(const MatrixBase &mat, int32 format = 0); CompressedMatrix(const CompressedMatrix &mat); diff --git a/src/matrix/sparse-matrix.cc b/src/matrix/sparse-matrix.cc index 2ef909f66dd..777819ed677 100644 --- a/src/matrix/sparse-matrix.cc +++ b/src/matrix/sparse-matrix.cc @@ -705,9 +705,9 @@ MatrixIndexT GeneralMatrix::NumCols() const { } -void GeneralMatrix::Compress() { +void GeneralMatrix::Compress(int32 format) { if (mat_.NumRows() != 0) { - cmat_.CopyFromMat(mat_); + cmat_.CopyFromMat(mat_, format); mat_.Resize(0, 0); } } diff --git a/src/matrix/sparse-matrix.h b/src/matrix/sparse-matrix.h index 9f9362542e1..88619da3034 100644 --- a/src/matrix/sparse-matrix.h +++ b/src/matrix/sparse-matrix.h @@ -228,8 +228,10 @@ class GeneralMatrix { public: GeneralMatrixType Type() const; - void Compress(); // If it was a full matrix, compresses, changing Type() to - // kCompressedMatrix; otherwise does nothing. + /// If it was a full matrix, compresses, changing Type() to + /// kCompressedMatrix; otherwise does nothing. + /// format shows the compression format. + void Compress(int32 format = 0); void Uncompress(); // If it was a compressed matrix, uncompresses, changing // Type() to kFullMatrix; otherwise does nothing. diff --git a/src/nnet3/nnet-example.cc b/src/nnet3/nnet-example.cc index 11305f55324..89d40b9ef89 100644 --- a/src/nnet3/nnet-example.cc +++ b/src/nnet3/nnet-example.cc @@ -154,12 +154,12 @@ void NnetExample::Read(std::istream &is, bool binary) { } -void NnetExample::Compress() { +void NnetExample::Compress(int32 format) { std::vector::iterator iter = io.begin(), end = io.end(); // calling features.Compress() will do nothing if they are sparse or already // compressed. for (; iter != end; ++iter) - iter->features.Compress(); + iter->features.Compress(format); } } // namespace nnet3 diff --git a/src/nnet3/nnet-example.h b/src/nnet3/nnet-example.h index b1ae42a78c9..f097369443a 100644 --- a/src/nnet3/nnet-example.h +++ b/src/nnet3/nnet-example.h @@ -94,6 +94,14 @@ struct NnetIo { NnetIo() { } + // Compress the features in this NnetIo structure with specified format. + // the "format" will be 1 for the original format where each column has a + // PerColHeader, and 2 for the format, where everything is represented as + // 16-bit integers. + // If format <= 0, then format 1 will be used, unless the matrix has 8 or + // fewer rows (in which case format 2 will be used). + void Compress(int32 format = 0) { features.Compress(format); } + // Use default copy constructor and assignment operators. void Write(std::ostream &os, bool binary) const; @@ -124,8 +132,13 @@ struct NnetExample { void Swap(NnetExample *other) { io.swap(other->io); } - /// Compresses any (input) features that are not sparse. - void Compress(); + // Compresses any features that are not sparse and not compressed. + // The "format" is 1 for the original format where each column has a + // PerColHeader, and 2 for the format, where everything is represented as + // 16-bit integers. + // If format <= 0, then format 1 will be used, unless the matrix has 8 or + // fewer rows (in which case format 2 will be used). + void Compress(int32 format = 0); /// Caution: this operator == is not very efficient. It's only used in /// testing code. From fb4737eedf96e0e6f2de5d7d9adc19057ad9d234 Mon Sep 17 00:00:00 2001 From: Vimal Manohar Date: Wed, 23 Nov 2016 15:54:56 -0500 Subject: [PATCH 023/213] asr_diarization: nnet3-get-egs etc. modified with deriv weights and compress format --- src/nnet3bin/nnet3-get-egs-dense-targets.cc | 150 +++++++++++++++++--- src/nnet3bin/nnet3-get-egs.cc | 82 +++++++++-- 2 files changed, 199 insertions(+), 33 deletions(-) diff --git a/src/nnet3bin/nnet3-get-egs-dense-targets.cc b/src/nnet3bin/nnet3-get-egs-dense-targets.cc index 23bf8922a5b..502e0700f27 100644 --- a/src/nnet3bin/nnet3-get-egs-dense-targets.cc +++ b/src/nnet3bin/nnet3-get-egs-dense-targets.cc @@ -32,9 +32,13 @@ namespace nnet3 { static void ProcessFile(const MatrixBase &feats, const MatrixBase *ivector_feats, + const VectorBase *deriv_weights, + const MatrixBase *l2reg_targets, const MatrixBase &targets, const std::string &utt_id, bool compress, + int32 input_compress_format, + int32 feats_compress_format, int32 num_targets, int32 left_context, int32 right_context, @@ -42,9 +46,9 @@ static void ProcessFile(const MatrixBase &feats, int64 *num_frames_written, int64 *num_egs_written, NnetExampleWriter *example_writer) { - KALDI_ASSERT(feats.NumRows() == static_cast(targets.NumRows())); - - for (int32 t = 0; t < feats.NumRows(); t += frames_per_eg) { + //KALDI_ASSERT(feats.NumRows() == static_cast(targets.NumRows())); + int min_size = std::min(feats.NumRows(), targets.NumRows()); + for (int32 t = 0; t < min_size; t += frames_per_eg) { // actual_frames_per_eg is the number of frames with actual targets. // At the end of the file, we pad with the last frame repeated @@ -52,18 +56,18 @@ static void ProcessFile(const MatrixBase &feats, // for recompilations). // TODO: We might need to ignore the end of the file. int32 actual_frames_per_eg = std::min(frames_per_eg, - feats.NumRows() - t); + min_size - t); int32 tot_frames = left_context + frames_per_eg + right_context; - Matrix input_frames(tot_frames, feats.NumCols()); + Matrix input_frames(tot_frames, feats.NumCols(), kUndefined); // Set up "input_frames". for (int32 j = -left_context; j < frames_per_eg + right_context; j++) { int32 t2 = j + t; if (t2 < 0) t2 = 0; - if (t2 >= feats.NumRows()) t2 = feats.NumRows() - 1; + if (t2 >= min_size) t2 = min_size - 1; SubVector src(feats, t2), dest(input_frames, j + left_context); dest.CopyFromVec(src); @@ -75,8 +79,11 @@ static void ProcessFile(const MatrixBase &feats, eg.io.push_back(NnetIo("input", - left_context, input_frames)); + if (compress) + eg.io.back().Compress(input_compress_format); + // if applicable, add the iVector feature. - if (ivector_feats != NULL) { + if (ivector_feats) { // try to get closest frame to middle of window to get // a representative iVector. int32 closest_frame = t + (actual_frames_per_eg / 2); @@ -102,17 +109,57 @@ static void ProcessFile(const MatrixBase &feats, for (int32 i = actual_frames_per_eg; i < frames_per_eg; i++) { // Copy the i^th row of the target matrix from the last row of the // input targets matrix - KALDI_ASSERT(t + actual_frames_per_eg - 1 == feats.NumRows() - 1); + KALDI_ASSERT(t + actual_frames_per_eg - 1 == min_size - 1); SubVector this_target_dest(targets_dest, i); SubVector this_target_src(targets, t+actual_frames_per_eg-1); this_target_dest.CopyFromVec(this_target_src); } - // push this created targets matrix into the eg - eg.io.push_back(NnetIo("output", 0, targets_dest)); + if (!deriv_weights) { + // push this created targets matrix into the eg + eg.io.push_back(NnetIo("output", 0, targets_dest)); + } else { + Vector this_deriv_weights(targets_dest.NumRows()); + int32 frames_to_copy = std::min(t + actual_frames_per_eg, deriv_weights->Dim()) - t; + this_deriv_weights.Range(0, frames_to_copy).CopyFromVec(deriv_weights->Range(t, frames_to_copy)); + if (this_deriv_weights.Sum() == 0) continue; // Ignore frames that have frame weights 0 + eg.io.push_back(NnetIo("output", this_deriv_weights, 0, targets_dest)); + } + + if (l2reg_targets) { + // add the labels. + Matrix l2reg_targets_dest(frames_per_eg, l2reg_targets->NumCols()); + for (int32 i = 0; i < actual_frames_per_eg; i++) { + // Copy the i^th row of the target matrix from the (t+i)^th row of the + // input targets matrix + SubVector this_target_dest(l2reg_targets_dest, i); + SubVector this_target_src(*l2reg_targets, t+i); + this_target_dest.CopyFromVec(this_target_src); + } + + // Copy the last frame's target to the padded frames + for (int32 i = actual_frames_per_eg; i < frames_per_eg; i++) { + // Copy the i^th row of the target matrix from the last row of the + // input targets matrix + KALDI_ASSERT(t + actual_frames_per_eg - 1 == feats.NumRows() - 1); + SubVector this_target_dest(l2reg_targets_dest, i); + SubVector this_target_src(*l2reg_targets, t+actual_frames_per_eg-1); + this_target_dest.CopyFromVec(this_target_src); + } + + if (!deriv_weights) { + eg.io.push_back(NnetIo("output-l2reg", 0, l2reg_targets_dest)); + } else { + Vector this_deriv_weights(l2reg_targets_dest.NumRows()); + int32 frames_to_copy = std::min(t + actual_frames_per_eg, deriv_weights->Dim()) - t; + this_deriv_weights.Range(0, frames_to_copy).CopyFromVec(deriv_weights->Range(t, frames_to_copy)); + if (this_deriv_weights.Sum() == 0) continue; // Ignore frames that have frame weights 0 + eg.io.push_back(NnetIo("output-l2reg", this_deriv_weights, 0, l2reg_targets_dest)); + } + } if (compress) - eg.Compress(); + eg.Compress(feats_compress_format); std::ostringstream os; os << utt_id << "-" << t; @@ -155,14 +202,20 @@ int main(int argc, char *argv[]) { bool compress = true; + int32 input_compress_format = 0, feats_compress_format = 0; int32 num_targets = -1, left_context = 0, right_context = 0, - num_frames = 1, length_tolerance = 100; + num_frames = 1, length_tolerance = 2; - std::string ivector_rspecifier; + std::string ivector_rspecifier, deriv_weights_rspecifier, + l2reg_targets_rspecifier; ParseOptions po(usage); po.Register("compress", &compress, "If true, write egs in " "compressed format."); + po.Register("compress-format", &feats_compress_format, "Format for " + "compressing all feats in general"); + po.Register("input-compress-format", &input_compress_format, "Format for " + "compressing input feats e.g. Use 2 for compressing wave"); po.Register("num-targets", &num_targets, "Number of targets for the neural network"); po.Register("left-context", &left_context, "Number of frames of left " "context the neural net requires."); @@ -174,6 +227,13 @@ int main(int argc, char *argv[]) { "features, as matrix."); po.Register("length-tolerance", &length_tolerance, "Tolerance for " "difference in num-frames between feat and ivector matrices"); + po.Register("deriv-weights-rspecifier", &deriv_weights_rspecifier, + "Per-frame weights (only binary - 0 or 1) that specifies " + "whether a frame's gradient must be backpropagated or not. " + "Not specifying this is equivalent to specifying a vector of " + "all 1s."); + po.Register("l2reg-targets-rspecifier", &l2reg_targets_rspecifier, + "Add l2 regularizer targets"); po.Read(argc, argv); @@ -194,6 +254,8 @@ int main(int argc, char *argv[]) { RandomAccessBaseFloatMatrixReader matrix_reader(matrix_rspecifier); NnetExampleWriter example_writer(examples_wspecifier); RandomAccessBaseFloatMatrixReader ivector_reader(ivector_rspecifier); + RandomAccessBaseFloatVectorReader deriv_weights_reader(deriv_weights_rspecifier); + RandomAccessBaseFloatMatrixReader l2reg_targets_reader(l2reg_targets_rspecifier); int32 num_done = 0, num_err = 0; int64 num_frames_written = 0, num_egs_written = 0; @@ -206,10 +268,10 @@ int main(int argc, char *argv[]) { num_err++; } else { const Matrix &target_matrix = matrix_reader.Value(key); - if (target_matrix.NumRows() != feats.NumRows()) { - KALDI_WARN << "Target matrix has wrong size " - << target_matrix.NumRows() - << " versus " << feats.NumRows(); + if ((target_matrix.NumRows() - feats.NumRows()) > length_tolerance) { + KALDI_WARN << "Length difference between feats " << feats.NumRows() + << " and target matrix " << target_matrix.NumRows() + << "exceeds tolerance " << length_tolerance; num_err++; continue; } @@ -226,7 +288,7 @@ int main(int argc, char *argv[]) { } } - if (ivector_feats != NULL && + if (ivector_feats && (abs(feats.NumRows() - ivector_feats->NumRows()) > length_tolerance || ivector_feats->NumRows() == 0)) { KALDI_WARN << "Length difference between feats " << feats.NumRows() @@ -235,8 +297,56 @@ int main(int argc, char *argv[]) { num_err++; continue; } - - ProcessFile(feats, ivector_feats, target_matrix, key, compress, + + const Vector *deriv_weights = NULL; + if (!deriv_weights_rspecifier.empty()) { + if (!deriv_weights_reader.HasKey(key)) { + KALDI_WARN << "No deriv weights for utterance " << key; + num_err++; + continue; + } else { + // this address will be valid until we call HasKey() or Value() + // again. + deriv_weights = &(deriv_weights_reader.Value(key)); + } + } + + if (deriv_weights && + (abs(feats.NumRows() - deriv_weights->Dim()) > length_tolerance + || deriv_weights->Dim() == 0)) { + KALDI_WARN << "Length difference between feats " << feats.NumRows() + << " and deriv weights " << deriv_weights->Dim() + << " exceeds tolerance " << length_tolerance; + num_err++; + continue; + } + + const Matrix *l2reg_target_matrix = NULL; + if (!l2reg_targets_rspecifier.empty()) { + if (!l2reg_targets_reader.HasKey(key)) { + KALDI_WARN << "No l2 regularizer targets for utterance " << key; + num_err++; + continue; + } + { + // this address will be valid until we call HasKey() or Value() + // again. + l2reg_target_matrix = &(l2reg_targets_reader.Value(key)); + + if (l2reg_target_matrix->NumRows() != feats.NumRows()) { + KALDI_WARN << "l2 regularizer target matrix has wrong size " + << l2reg_target_matrix->NumRows() + << " versus " << feats.NumRows(); + num_err++; + continue; + } + } + } + + + ProcessFile(feats, ivector_feats, deriv_weights, + l2reg_target_matrix, target_matrix, + key, compress, input_compress_format, feats_compress_format, num_targets, left_context, right_context, num_frames, &num_frames_written, &num_egs_written, &example_writer); diff --git a/src/nnet3bin/nnet3-get-egs.cc b/src/nnet3bin/nnet3-get-egs.cc index 75f264f1ceb..dbf8b636305 100644 --- a/src/nnet3bin/nnet3-get-egs.cc +++ b/src/nnet3bin/nnet3-get-egs.cc @@ -32,9 +32,12 @@ namespace nnet3 { static void ProcessFile(const MatrixBase &feats, const MatrixBase *ivector_feats, + const VectorBase *deriv_weights, const Posterior &pdf_post, const std::string &utt_id, bool compress, + int32 input_compress_format, + int32 feats_compress_format, int32 num_pdfs, int32 left_context, int32 right_context, @@ -42,16 +45,16 @@ static void ProcessFile(const MatrixBase &feats, int64 *num_frames_written, int64 *num_egs_written, NnetExampleWriter *example_writer) { - KALDI_ASSERT(feats.NumRows() == static_cast(pdf_post.size())); - - for (int32 t = 0; t < feats.NumRows(); t += frames_per_eg) { + //KALDI_ASSERT(feats.NumRows() == static_cast(pdf_post.size())); + int32 min_size = std::min(feats.NumRows(), static_cast(pdf_post.size())); + for (int32 t = 0; t < min_size; t += frames_per_eg) { // actual_frames_per_eg is the number of frames with nonzero // posteriors. At the end of the file we pad with zero posteriors // so that all examples have the same structure (prevents the need // for recompilations). int32 actual_frames_per_eg = std::min(frames_per_eg, - feats.NumRows() - t); + min_size - t); int32 tot_frames = left_context + frames_per_eg + right_context; @@ -62,7 +65,7 @@ static void ProcessFile(const MatrixBase &feats, for (int32 j = -left_context; j < frames_per_eg + right_context; j++) { int32 t2 = j + t; if (t2 < 0) t2 = 0; - if (t2 >= feats.NumRows()) t2 = feats.NumRows() - 1; + if (t2 >= min_size) t2 = min_size - 1; SubVector src(feats, t2), dest(input_frames, j + left_context); dest.CopyFromVec(src); @@ -73,9 +76,12 @@ static void ProcessFile(const MatrixBase &feats, // call the regular input "input". eg.io.push_back(NnetIo("input", - left_context, input_frames)); + + if (compress) + eg.io.back().Compress(input_compress_format); // if applicable, add the iVector feature. - if (ivector_feats != NULL) { + if (ivector_feats) { // try to get closest frame to middle of window to get // a representative iVector. int32 closest_frame = t + (actual_frames_per_eg / 2); @@ -92,10 +98,20 @@ static void ProcessFile(const MatrixBase &feats, for (int32 i = 0; i < actual_frames_per_eg; i++) labels[i] = pdf_post[t + i]; // remaining posteriors for frames are empty. - eg.io.push_back(NnetIo("output", num_pdfs, 0, labels)); + + if (!deriv_weights) { + eg.io.push_back(NnetIo("output", num_pdfs, 0, labels)); + } else { + Vector this_deriv_weights(frames_per_eg); + int32 frames_to_copy = std::min(t + actual_frames_per_eg, deriv_weights->Dim()) - t; + this_deriv_weights.Range(0, frames_to_copy).CopyFromVec(deriv_weights->Range(t, frames_to_copy)); + if (this_deriv_weights.Sum() == 0) continue; // Ignore frames that have frame weights 0 + eg.io.push_back(NnetIo("output", this_deriv_weights, num_pdfs, 0, labels)); + } + if (compress) - eg.Compress(); + eg.Compress(feats_compress_format); std::ostringstream os; os << utt_id << "-" << t; @@ -140,14 +156,19 @@ int main(int argc, char *argv[]) { bool compress = true; + int32 input_compress_format = 0, feats_compress_format = 0; int32 num_pdfs = -1, left_context = 0, right_context = 0, num_frames = 1, length_tolerance = 100; - std::string ivector_rspecifier; + std::string ivector_rspecifier, deriv_weights_rspecifier; ParseOptions po(usage); po.Register("compress", &compress, "If true, write egs in " "compressed format."); + po.Register("compress-format", &feats_compress_format, "Format for " + "compressing all feats in general"); + po.Register("input-compress-format", &input_compress_format, "Format for " + "compressing input feats e.g. Use 2 for compressing wave"); po.Register("num-pdfs", &num_pdfs, "Number of pdfs in the acoustic " "model"); po.Register("left-context", &left_context, "Number of frames of left " @@ -160,6 +181,11 @@ int main(int argc, char *argv[]) { "features, as a matrix."); po.Register("length-tolerance", &length_tolerance, "Tolerance for " "difference in num-frames between feat and ivector matrices"); + po.Register("deriv-weights-rspecifier", &deriv_weights_rspecifier, + "Per-frame weights (only binary - 0 or 1) that specifies " + "whether a frame's gradient must be backpropagated or not. " + "Not specifying this is equivalent to specifying a vector of " + "all 1s."); po.Read(argc, argv); @@ -181,6 +207,7 @@ int main(int argc, char *argv[]) { RandomAccessPosteriorReader pdf_post_reader(pdf_post_rspecifier); NnetExampleWriter example_writer(examples_wspecifier); RandomAccessBaseFloatMatrixReader ivector_reader(ivector_rspecifier); + RandomAccessBaseFloatVectorReader deriv_weights_reader(deriv_weights_rspecifier); int32 num_done = 0, num_err = 0; int64 num_frames_written = 0, num_egs_written = 0; @@ -192,13 +219,17 @@ int main(int argc, char *argv[]) { KALDI_WARN << "No pdf-level posterior for key " << key; num_err++; } else { - const Posterior &pdf_post = pdf_post_reader.Value(key); - if (pdf_post.size() != feats.NumRows()) { + Posterior pdf_post = pdf_post_reader.Value(key); + if (abs(static_cast(pdf_post.size()) - feats.NumRows()) > length_tolerance + || pdf_post.size() < feats.NumRows()) { KALDI_WARN << "Posterior has wrong size " << pdf_post.size() << " versus " << feats.NumRows(); num_err++; continue; } + while (static_cast(pdf_post.size()) > feats.NumRows()) { + pdf_post.pop_back(); + } const Matrix *ivector_feats = NULL; if (!ivector_rspecifier.empty()) { if (!ivector_reader.HasKey(key)) { @@ -212,7 +243,7 @@ int main(int argc, char *argv[]) { } } - if (ivector_feats != NULL && + if (ivector_feats && (abs(feats.NumRows() - ivector_feats->NumRows()) > length_tolerance || ivector_feats->NumRows() == 0)) { KALDI_WARN << "Length difference between feats " << feats.NumRows() @@ -221,8 +252,33 @@ int main(int argc, char *argv[]) { num_err++; continue; } + + const Vector *deriv_weights = NULL; + if (!deriv_weights_rspecifier.empty()) { + if (!deriv_weights_reader.HasKey(key)) { + KALDI_WARN << "No deriv weights for utterance " << key; + num_err++; + continue; + } else { + // this address will be valid until we call HasKey() or Value() + // again. + deriv_weights = &(deriv_weights_reader.Value(key)); + } + } + + if (deriv_weights && + (abs(feats.NumRows() - deriv_weights->Dim()) > length_tolerance + || deriv_weights->Dim() == 0)) { + KALDI_WARN << "Length difference between feats " << feats.NumRows() + << " and deriv weights " << deriv_weights->Dim() + << " exceeds tolerance " << length_tolerance; + num_err++; + continue; + } + - ProcessFile(feats, ivector_feats, pdf_post, key, compress, + ProcessFile(feats, ivector_feats, deriv_weights, pdf_post, + key, compress, input_compress_format, feats_compress_format, num_pdfs, left_context, right_context, num_frames, &num_frames_written, &num_egs_written, &example_writer); From 94367327d1766774c7a77e94324e6393cc905aa3 Mon Sep 17 00:00:00 2001 From: Vimal Manohar Date: Wed, 7 Dec 2016 00:22:31 -0500 Subject: [PATCH 024/213] asr_diarization: Log and Exp component --- src/nnet3/nnet-component-itf.cc | 9 +- src/nnet3/nnet-component-itf.h | 8 ++ src/nnet3/nnet-simple-component.cc | 166 +++++++++++++++++++++++++++-- src/nnet3/nnet-simple-component.h | 67 ++++++++++++ 4 files changed, 243 insertions(+), 7 deletions(-) diff --git a/src/nnet3/nnet-component-itf.cc b/src/nnet3/nnet-component-itf.cc index f94843b725e..695dbb6de56 100644 --- a/src/nnet3/nnet-component-itf.cc +++ b/src/nnet3/nnet-component-itf.cc @@ -89,6 +89,10 @@ Component* Component::NewComponentOfType(const std::string &component_type) { ans = new SoftmaxComponent(); } else if (component_type == "LogSoftmaxComponent") { ans = new LogSoftmaxComponent(); + } else if (component_type == "LogComponent") { + ans = new LogComponent(); + } else if (component_type == "ExpComponent") { + ans = new ExpComponent(); } else if (component_type == "RectifiedLinearComponent") { ans = new RectifiedLinearComponent(); } else if (component_type == "NormalizeComponent") { @@ -310,11 +314,14 @@ std::string NonlinearComponent::Info() const { std::stringstream stream; if (InputDim() == OutputDim()) { stream << Type() << ", dim=" << InputDim(); - } else { + } else if (OutputDim() - InputDim() == 1) { // Note: this is a very special case tailored for class NormalizeComponent. stream << Type() << ", input-dim=" << InputDim() << ", output-dim=" << OutputDim() << ", add-log-stddev=true"; + } else { + stream << Type() << ", input-dim=" << InputDim() + << ", output-dim=" << OutputDim(); } if (self_repair_lower_threshold_ != BaseFloat(kUnsetThreshold)) diff --git a/src/nnet3/nnet-component-itf.h b/src/nnet3/nnet-component-itf.h index e5974b46f46..3013c485ea4 100644 --- a/src/nnet3/nnet-component-itf.h +++ b/src/nnet3/nnet-component-itf.h @@ -403,6 +403,11 @@ class UpdatableComponent: public Component { /// Sets the learning rate directly, bypassing learning_rate_factor_. virtual void SetActualLearningRate(BaseFloat lrate) { learning_rate_ = lrate; } + /// Sets the learning rate factor + virtual void SetLearningRateFactor(BaseFloat lrate_factor) { + learning_rate_factor_ = lrate_factor; + } + /// Gets the learning rate of gradient descent. Note: if you call /// SetLearningRate(x), and learning_rate_factor_ != 1.0, /// a different value than x will returned. @@ -413,6 +418,9 @@ class UpdatableComponent: public Component { /// NnetTrainer by querying the max-changes for each component. /// See NnetTrainer::UpdateParamsWithMaxChange() in nnet3/nnet-training.cc. BaseFloat MaxChange() const { return max_change_; } + + /// Gets the learning rate factor + BaseFloat LearningRateFactor() const { return learning_rate_factor_; } virtual std::string Info() const; diff --git a/src/nnet3/nnet-simple-component.cc b/src/nnet3/nnet-simple-component.cc index 58908a0fe09..aa56dce1f23 100644 --- a/src/nnet3/nnet-simple-component.cc +++ b/src/nnet3/nnet-simple-component.cc @@ -2517,6 +2517,26 @@ void ConstantFunctionComponent::UnVectorize(const VectorBase ¶ms) output_.CopyFromVec(params); } +void ExpComponent::Propagate(const ComponentPrecomputedIndexes *indexes, + const CuMatrixBase &in, + CuMatrixBase *out) const { + // Applied exp function + out->CopyFromMat(in); + out->ApplyExp(); +} + +void ExpComponent::Backprop(const std::string &debug_info, + const ComponentPrecomputedIndexes *indexes, + const CuMatrixBase &,//in_value, + const CuMatrixBase &out_value, + const CuMatrixBase &out_deriv, + Component *to_update, + CuMatrixBase *in_deriv) const { + if (in_deriv != NULL) { + in_deriv->CopyFromMat(out_value); + in_deriv->MulElements(out_deriv); + } +} NaturalGradientAffineComponent::NaturalGradientAffineComponent(): max_change_per_sample_(0.0), @@ -2568,10 +2588,15 @@ void NaturalGradientAffineComponent::Read(std::istream &is, bool binary) { ReadBasicType(is, binary, &max_change_scale_stats_); ReadToken(is, binary, &token); } - if (token != "" && - token != "") - KALDI_ERR << "Expected or " - << ", got " << token; + + std::ostringstream ostr_beg, ostr_end; + ostr_beg << "<" << Type() << ">"; // e.g. "" + ostr_end << ""; // e.g. "" + + if (token != ostr_end.str() && + token != ostr_beg.str()) + KALDI_ERR << "Expected " << ostr_beg.str() << " or " + << ostr_end.str() << ", got " << token; SetNaturalGradientConfigs(); } @@ -2720,7 +2745,10 @@ void NaturalGradientAffineComponent::Write(std::ostream &os, WriteBasicType(os, binary, active_scaling_count_); WriteToken(os, binary, ""); WriteBasicType(os, binary, max_change_scale_stats_); - WriteToken(os, binary, ""); + + std::ostringstream ostr_end; + ostr_end << ""; // e.g. "" + WriteToken(os, binary, ostr_end.str()); } std::string NaturalGradientAffineComponent::Info() const { @@ -3095,6 +3123,126 @@ void SoftmaxComponent::StoreStats(const CuMatrixBase &out_value) { StoreStatsInternal(out_value, NULL); } +std::string LogComponent::Info() const { + std::stringstream stream; + stream << NonlinearComponent::Info() + << ", log-floor=" << log_floor_; + return stream.str(); +} + +void LogComponent::InitFromConfig(ConfigLine *cfl) { + cfl->GetValue("log-floor", &log_floor_); + NonlinearComponent::InitFromConfig(cfl); +} + +void LogComponent::Propagate(const ComponentPrecomputedIndexes *indexes, + const CuMatrixBase &in, + CuMatrixBase *out) const { + // Apllies log function (x >= epsi ? log(x) : log(epsi)). + out->CopyFromMat(in); + out->ApplyFloor(log_floor_); + out->ApplyLog(); +} + +void LogComponent::Backprop(const std::string &debug_info, + const ComponentPrecomputedIndexes *indexes, + const CuMatrixBase &in_value, + const CuMatrixBase &out_value, + const CuMatrixBase &out_deriv, + Component *to_update, + CuMatrixBase *in_deriv) const { + if (in_deriv != NULL) { + CuMatrix divided_in_value(in_value), floored_in_value(in_value); + divided_in_value.Set(1.0); + floored_in_value.CopyFromMat(in_value); + floored_in_value.ApplyFloor(log_floor_); // (x > epsi ? x : epsi) + + divided_in_value.DivElements(floored_in_value); // (x > epsi ? 1/x : 1/epsi) + in_deriv->CopyFromMat(in_value); + in_deriv->Add(-1.0 * log_floor_); // (x - epsi) + in_deriv->ApplyHeaviside(); // (x > epsi ? 1 : 0) + in_deriv->MulElements(divided_in_value); // (dy/dx: x > epsi ? 1/x : 0) + in_deriv->MulElements(out_deriv); // dF/dx = dF/dy * dy/dx + } +} + +void LogComponent::Read(std::istream &is, bool binary) { + std::ostringstream ostr_beg, ostr_end; + ostr_beg << "<" << Type() << ">"; // e.g. "" + ostr_end << ""; // e.g. "" + ExpectOneOrTwoTokens(is, binary, ostr_beg.str(), ""); + ReadBasicType(is, binary, &dim_); // Read dimension. + ExpectToken(is, binary, ""); + value_sum_.Read(is, binary); + ExpectToken(is, binary, ""); + deriv_sum_.Read(is, binary); + ExpectToken(is, binary, ""); + ReadBasicType(is, binary, &count_); + value_sum_.Scale(count_); + deriv_sum_.Scale(count_); + + std::string token; + ReadToken(is, binary, &token); + if (token == "") { + ReadBasicType(is, binary, &self_repair_lower_threshold_); + ReadToken(is, binary, &token); + } + if (token == "") { + ReadBasicType(is, binary, &self_repair_upper_threshold_); + ReadToken(is, binary, &token); + } + if (token == "") { + ReadBasicType(is, binary, &self_repair_scale_); + ReadToken(is, binary, &token); + } + if (token == "") { + ReadBasicType(is, binary, &log_floor_); + ReadToken(is, binary, &token); + } + if (token != ostr_end.str()) { + KALDI_ERR << "Expected token " << ostr_end.str() + << ", got " << token; + } +} + +void LogComponent::Write(std::ostream &os, bool binary) const { + std::ostringstream ostr_beg, ostr_end; + ostr_beg << "<" << Type() << ">"; // e.g. "" + ostr_end << ""; // e.g. "" + WriteToken(os, binary, ostr_beg.str()); + WriteToken(os, binary, ""); + WriteBasicType(os, binary, dim_); + // Write the values and derivatives in a count-normalized way, for + // greater readability in text form. + WriteToken(os, binary, ""); + Vector temp(value_sum_); + if (count_ != 0.0) temp.Scale(1.0 / count_); + temp.Write(os, binary); + WriteToken(os, binary, ""); + + temp.Resize(deriv_sum_.Dim(), kUndefined); + temp.CopyFromVec(deriv_sum_); + if (count_ != 0.0) temp.Scale(1.0 / count_); + temp.Write(os, binary); + WriteToken(os, binary, ""); + WriteBasicType(os, binary, count_); + if (self_repair_lower_threshold_ != kUnsetThreshold) { + WriteToken(os, binary, ""); + WriteBasicType(os, binary, self_repair_lower_threshold_); + } + if (self_repair_upper_threshold_ != kUnsetThreshold) { + WriteToken(os, binary, ""); + WriteBasicType(os, binary, self_repair_upper_threshold_); + } + if (self_repair_scale_ != 0.0) { + WriteToken(os, binary, ""); + WriteBasicType(os, binary, self_repair_scale_); + } + WriteToken(os, binary, ""); + WriteBasicType(os, binary, log_floor_); + WriteToken(os, binary, ostr_end.str()); +} + void LogSoftmaxComponent::Propagate(const ComponentPrecomputedIndexes *indexes, const CuMatrixBase &in, @@ -3135,12 +3283,18 @@ void FixedScaleComponent::InitFromConfig(ConfigLine *cfl) { Init(vec); } else { int32 dim; + BaseFloat scale = 1.0; + bool scale_ok = cfl->GetValue("scale", &scale); if (!cfl->GetValue("dim", &dim) || cfl->HasUnusedValues()) KALDI_ERR << "Invalid initializer for layer of type " << Type() << ": \"" << cfl->WholeLine() << "\""; KALDI_ASSERT(dim > 0); CuVector vec(dim); - vec.SetRandn(); + if (scale_ok) { + vec.Set(scale); + } else { + vec.SetRandn(); + } Init(vec); } } diff --git a/src/nnet3/nnet-simple-component.h b/src/nnet3/nnet-simple-component.h index f09a989759a..95a32bbe7a3 100644 --- a/src/nnet3/nnet-simple-component.h +++ b/src/nnet3/nnet-simple-component.h @@ -697,6 +697,71 @@ class LogSoftmaxComponent: public NonlinearComponent { LogSoftmaxComponent &operator = (const LogSoftmaxComponent &other); // Disallow. }; +// The LogComponent outputs the log of input values as y = Log(max(x, epsi)) +class LogComponent: public NonlinearComponent { + public: + explicit LogComponent(const LogComponent &other): + NonlinearComponent(other), log_floor_(other.log_floor_) { } + LogComponent(): log_floor_(1e-20) { } + virtual std::string Type() const { return "LogComponent"; } + virtual int32 Properties() const { + return kSimpleComponent|kBackpropNeedsInput|kStoresStats; + } + + virtual std::string Info() const; + + virtual void InitFromConfig(ConfigLine *cfl); + + virtual void Propagate(const ComponentPrecomputedIndexes *indexes, + const CuMatrixBase &in, + CuMatrixBase *out) const; + virtual void Backprop(const std::string &debug_info, + const ComponentPrecomputedIndexes *indexes, + const CuMatrixBase &in_value, + const CuMatrixBase &out_value, + const CuMatrixBase &out_deriv, + Component *to_update, + CuMatrixBase *in_deriv) const; + + virtual Component* Copy() const { return new LogComponent(*this); } + + virtual void Read(std::istream &is, bool binary); + + virtual void Write(std::ostream &os, bool binary) const; + + private: + LogComponent &operator = (const LogComponent &other); // Disallow. + BaseFloat log_floor_; +}; + + +// The ExpComponent outputs the exp of input values as y = Exp(x) +class ExpComponent: public NonlinearComponent { + public: + explicit ExpComponent(const ExpComponent &other): + NonlinearComponent(other) { } + ExpComponent() { } + virtual std::string Type() const { return "ExpComponent"; } + virtual int32 Properties() const { + return kSimpleComponent|kBackpropNeedsOutput|kStoresStats; + } + virtual void Propagate(const ComponentPrecomputedIndexes *indexes, + const CuMatrixBase &in, + CuMatrixBase *out) const; + virtual void Backprop(const std::string &debug_info, + const ComponentPrecomputedIndexes *indexes, + const CuMatrixBase &, + const CuMatrixBase &out_value, + const CuMatrixBase &, + Component *to_update, + CuMatrixBase *in_deriv) const; + + virtual Component* Copy() const { return new ExpComponent(*this); } + private: + ExpComponent &operator = (const ExpComponent &other); // Disallow. +}; + + /// Keywords: natural gradient descent, NG-SGD, naturalgradient. For /// the top-level of the natural gradient code look here, and also in /// nnet-precondition-online.h. @@ -826,6 +891,8 @@ class FixedAffineComponent: public Component { // Function to provide access to linear_params_. const CuMatrix &LinearParams() const { return linear_params_; } + const CuVector &BiasParams() const { return bias_params_; } + protected: friend class AffineComponent; CuMatrix linear_params_; From 828544e0cf681c1e755a8162a487cfc314308eea Mon Sep 17 00:00:00 2001 From: Vimal Manohar Date: Wed, 23 Nov 2016 16:29:51 -0500 Subject: [PATCH 025/213] asr_diarization: Adding ScaleGradientComponent --- src/nnet3/nnet-component-itf.cc | 2 + src/nnet3/nnet-component-test.cc | 5 ++ src/nnet3/nnet-simple-component.cc | 81 ++++++++++++++++++++++++++++++ src/nnet3/nnet-simple-component.h | 40 +++++++++++++++ src/nnet3/nnet-test-utils.cc | 20 ++++++-- 5 files changed, 145 insertions(+), 3 deletions(-) diff --git a/src/nnet3/nnet-component-itf.cc b/src/nnet3/nnet-component-itf.cc index 695dbb6de56..389b9876b3c 100644 --- a/src/nnet3/nnet-component-itf.cc +++ b/src/nnet3/nnet-component-itf.cc @@ -123,6 +123,8 @@ Component* Component::NewComponentOfType(const std::string &component_type) { ans = new NoOpComponent(); } else if (component_type == "ClipGradientComponent") { ans = new ClipGradientComponent(); + } else if (component_type == "ScaleGradientComponent") { + ans = new ScaleGradientComponent(); } else if (component_type == "ElementwiseProductComponent") { ans = new ElementwiseProductComponent(); } else if (component_type == "ConvolutionComponent") { diff --git a/src/nnet3/nnet-component-test.cc b/src/nnet3/nnet-component-test.cc index 3cc6af1c70d..a2e5e23436c 100644 --- a/src/nnet3/nnet-component-test.cc +++ b/src/nnet3/nnet-component-test.cc @@ -379,6 +379,11 @@ bool TestSimpleComponentDataDerivative(const Component &c, KALDI_LOG << "Accepting deriv differences since " << "it is ClipGradientComponent."; return true; + } + else if (c.Type() == "ScaleGradientComponent") { + KALDI_LOG << "Accepting deriv differences since " + << "it is ScaleGradientComponent."; + return true; } return ans; } diff --git a/src/nnet3/nnet-simple-component.cc b/src/nnet3/nnet-simple-component.cc index aa56dce1f23..fcfd4b9affa 100644 --- a/src/nnet3/nnet-simple-component.cc +++ b/src/nnet3/nnet-simple-component.cc @@ -922,6 +922,87 @@ void ClipGradientComponent::Add(BaseFloat alpha, const Component &other_in) { num_clipped_ += alpha * other->num_clipped_; } + +void ScaleGradientComponent::Init(const CuVectorBase &scales) { + KALDI_ASSERT(scales.Dim() != 0); + scales_ = scales; +} + + +void ScaleGradientComponent::InitFromConfig(ConfigLine *cfl) { + std::string filename; + // Accepts "scales" config (for filename) or "dim" -> random init, for testing. + if (cfl->GetValue("scales", &filename)) { + if (cfl->HasUnusedValues()) + KALDI_ERR << "Invalid initializer for layer of type " + << Type() << ": \"" << cfl->WholeLine() << "\""; + CuVector vec; + ReadKaldiObject(filename, &vec); + Init(vec); + } else { + int32 dim; + BaseFloat scale = 1.0; + bool scale_ok = cfl->GetValue("scale", &scale); + if (!cfl->GetValue("dim", &dim) || cfl->HasUnusedValues()) + KALDI_ERR << "Invalid initializer for layer of type " + << Type() << ": \"" << cfl->WholeLine() << "\""; + KALDI_ASSERT(dim > 0); + CuVector vec(dim); + if (scale_ok) { + vec.Set(scale); + } else { + vec.SetRandn(); + } + Init(vec); + } +} + + +std::string ScaleGradientComponent::Info() const { + std::ostringstream stream; + stream << Component::Info(); + PrintParameterStats(stream, "scales", scales_, true); + return stream.str(); +} + +void ScaleGradientComponent::Propagate(const ComponentPrecomputedIndexes *indexes, + const CuMatrixBase &in, + CuMatrixBase *out) const { + out->CopyFromMat(in); // does nothing if same matrix. +} + +void ScaleGradientComponent::Backprop(const std::string &debug_info, + const ComponentPrecomputedIndexes *indexes, + const CuMatrixBase &, // in_value + const CuMatrixBase &, // out_value + const CuMatrixBase &out_deriv, + Component *, // to_update + CuMatrixBase *in_deriv) const { + in_deriv->CopyFromMat(out_deriv); // does nothing if same memory. + in_deriv->MulColsVec(scales_); +} + +Component* ScaleGradientComponent::Copy() const { + ScaleGradientComponent *ans = new ScaleGradientComponent(); + ans->scales_ = scales_; + return ans; +} + + +void ScaleGradientComponent::Write(std::ostream &os, bool binary) const { + WriteToken(os, binary, ""); + WriteToken(os, binary, ""); + scales_.Write(os, binary); + WriteToken(os, binary, ""); +} + +void ScaleGradientComponent::Read(std::istream &is, bool binary) { + ExpectOneOrTwoTokens(is, binary, "", ""); + scales_.Read(is, binary); + ExpectToken(is, binary, ""); +} + + void TanhComponent::Propagate(const ComponentPrecomputedIndexes *indexes, const CuMatrixBase &in, CuMatrixBase *out) const { diff --git a/src/nnet3/nnet-simple-component.h b/src/nnet3/nnet-simple-component.h index 95a32bbe7a3..ff9ec5fd26b 100644 --- a/src/nnet3/nnet-simple-component.h +++ b/src/nnet3/nnet-simple-component.h @@ -1196,6 +1196,46 @@ class ClipGradientComponent: public Component { }; +// Applied a per-element scale only on the gradient during back propagation +// Duplicates the input during forward propagation +class ScaleGradientComponent : public Component { + public: + ScaleGradientComponent() { } + virtual std::string Type() const { return "ScaleGradientComponent"; } + virtual std::string Info() const; + virtual int32 Properties() const { + return kSimpleComponent|kLinearInInput|kPropagateInPlace|kBackpropInPlace; + } + + void Init(const CuVectorBase &scales); + + // The ConfigLine cfl contains only the option scales=, + // where the string is the filename of a Kaldi-format matrix to read. + virtual void InitFromConfig(ConfigLine *cfl); + + virtual int32 InputDim() const { return scales_.Dim(); } + virtual int32 OutputDim() const { return scales_.Dim(); } + + virtual void Propagate(const ComponentPrecomputedIndexes *indexes, + const CuMatrixBase &in, + CuMatrixBase *out) const; + virtual void Backprop(const std::string &debug_info, + const ComponentPrecomputedIndexes *indexes, + const CuMatrixBase &, // in_value + const CuMatrixBase &, // out_value + const CuMatrixBase &out_deriv, + Component *, // to_update + CuMatrixBase *in_deriv) const; + virtual Component* Copy() const; + virtual void Read(std::istream &is, bool binary); + virtual void Write(std::ostream &os, bool binary) const; + + protected: + CuVector scales_; + KALDI_DISALLOW_COPY_AND_ASSIGN(ScaleGradientComponent); +}; + + /** PermuteComponent changes the order of the columns (i.e. the feature or activation dimensions). Output dimension i is mapped to input dimension column_map_[i], so it's like doing: diff --git a/src/nnet3/nnet-test-utils.cc b/src/nnet3/nnet-test-utils.cc index 170ea51ca8f..da519fa1cd3 100644 --- a/src/nnet3/nnet-test-utils.cc +++ b/src/nnet3/nnet-test-utils.cc @@ -1104,7 +1104,7 @@ void ComputeExampleComputationRequestSimple( static void GenerateRandomComponentConfig(std::string *component_type, std::string *config) { - int32 n = RandInt(0, 30); + int32 n = RandInt(0, 33); BaseFloat learning_rate = 0.001 * RandInt(1, 3); std::ostringstream os; @@ -1401,8 +1401,7 @@ static void GenerateRandomComponentConfig(std::string *component_type, *component_type = "DropoutComponent"; os << "dim=" << RandInt(1, 200) << " dropout-proportion=" << RandUniform(); - break; - } + } case 30: { *component_type = "LstmNonlinearityComponent"; // set self-repair scale to zero so the derivative tests will pass. @@ -1410,6 +1409,21 @@ static void GenerateRandomComponentConfig(std::string *component_type, << " self-repair-scale=0.0"; break; } + case 31: { + *component_type = "LogComponent"; + os << "dim=" << RandInt(1, 50); + break; + } + case 32: { + *component_type = "ExpComponent"; + os << "dim=" << RandInt(1, 50); + break; + } + case 33: { + *component_type = "ScaleGradientComponent"; + os << "dim=" << RandInt(1, 100); + break; + } default: KALDI_ERR << "Error generating random component"; } From b80cf2456cc40223fe8db4b6a98923bd7b685dbd Mon Sep 17 00:00:00 2001 From: Vimal Manohar Date: Wed, 23 Nov 2016 21:44:34 -0500 Subject: [PATCH 026/213] asr_diarization: Adding AddGradientSacaleLayer to components.py --- egs/wsj/s5/steps/nnet3/components.py | 36 +++++++++++++++++++++++++++- 1 file changed, 35 insertions(+), 1 deletion(-) diff --git a/egs/wsj/s5/steps/nnet3/components.py b/egs/wsj/s5/steps/nnet3/components.py index 3fb92117d78..82566d2e37d 100644 --- a/egs/wsj/s5/steps/nnet3/components.py +++ b/egs/wsj/s5/steps/nnet3/components.py @@ -55,6 +55,35 @@ def AddNoOpLayer(config_lines, name, input): return {'descriptor': '{0}_noop'.format(name), 'dimension': input['dimension']} +def AddGradientScaleLayer(config_lines, name, input, scale = 1.0, scales_vec = None): + components = config_lines['components'] + component_nodes = config_lines['component-nodes'] + + if scales_vec is None: + components.append('component name={0}_gradient_scale type=ScaleGradientComponent dim={1} scale={2}'.format(name, input['dimension'], scale)) + else: + components.append('component name={0}_gradient_scale type=ScaleGradientComponent scales={2}'.format(name, scales_vec)) + + component_nodes.append('component-node name={0}_gradient_scale component={0}_gradient_scale input={1}'.format(name, input['descriptor'])) + + return {'descriptor': '{0}_gradient_scale'.format(name), + 'dimension': input['dimension']} + +def AddFixedScaleLayer(config_lines, name, input, + scale = 1.0, scales_vec = None): + components = config_lines['components'] + component_nodes = config_lines['component-nodes'] + + if scales_vec is None: + components.append('component name={0}-fixed-scale type=FixedScaleComponent dim={1} scale={2}'.format(name, input['dimension'], scale)) + else: + components.append('component name={0}-fixed-scale type=FixedScaleComponent scales={2}'.format(name, scales_vec)) + + component_nodes.append('component-node name={0}-fixed-scale component={0}-fixed-scale input={1}'.format(name, input['descriptor'])) + + return {'descriptor': '{0}-fixed-scale'.format(name), + 'dimension': input['dimension']} + def AddLdaLayer(config_lines, name, input, lda_file): return AddFixedAffineLayer(config_lines, name, input, lda_file) @@ -257,7 +286,9 @@ def AddFinalLayer(config_lines, input, output_dim, include_log_softmax = True, add_final_sigmoid = False, name_affix = None, - objective_type = "linear"): + objective_type = "linear", + objective_scale = 1.0, + objective_scales_vec = None): components = config_lines['components'] component_nodes = config_lines['component-nodes'] @@ -283,6 +314,9 @@ def AddFinalLayer(config_lines, input, output_dim, prev_layer_output = AddSigmoidLayer(config_lines, final_node_prefix, prev_layer_output) # we use the same name_affix as a prefix in for affine/scale nodes but as a # suffix for output node + if (objective_scale != 1.0 or objective_scales_vec is not None): + prev_layer_output = AddGradientScaleLayer(config_lines, final_node_prefix, prev_layer_output, objective_scale, objective_scales_vec) + AddOutputLayer(config_lines, prev_layer_output, label_delay, suffix = name_affix, objective_type = objective_type) def AddLstmLayer(config_lines, From 9ef542248d88c30a99d1df2d98618a6e071bec82 Mon Sep 17 00:00:00 2001 From: Vimal Manohar Date: Wed, 23 Nov 2016 21:40:35 -0500 Subject: [PATCH 027/213] asr_diarization: Adding get_egs changes into get_egs_targets --- egs/wsj/s5/steps/nnet3/get_egs_targets.sh | 111 ++++++++++++++++------ 1 file changed, 83 insertions(+), 28 deletions(-) diff --git a/egs/wsj/s5/steps/nnet3/get_egs_targets.sh b/egs/wsj/s5/steps/nnet3/get_egs_targets.sh index 7fbc24858b5..cfecf88df38 100755 --- a/egs/wsj/s5/steps/nnet3/get_egs_targets.sh +++ b/egs/wsj/s5/steps/nnet3/get_egs_targets.sh @@ -24,6 +24,8 @@ feat_type=raw # set it to 'lda' to use LDA features. target_type=sparse # dense to have dense targets, # sparse to have posteriors targets num_targets= # required for target-type=sparse with raw nnet +deriv_weights_scp= +l2_regularizer_targets= frames_per_eg=8 # number of frames of labels per example. more->less disk space and # less time preparing egs, but more I/O during training. # note: the script may reduce this if reduce_frames_per_eg is true. @@ -44,6 +46,12 @@ reduce_frames_per_eg=true # If true, this script may reduce the frames_per_eg # equal to the user-specified value. num_utts_subset=300 # number of utterances in validation and training # subsets used for shrinkage and diagnostics. +num_utts_subset_valid= # number of utterances in validation + # subsets used for shrinkage and diagnostics + # if provided, overrides num-utts-subset +num_utts_subset_train= # number of utterances in training + # subsets used for shrinkage and diagnostics. + # if provided, overrides num-utts-subset num_valid_frames_combine=0 # #valid frames for combination weights at the very end. num_train_frames_combine=10000 # # train frames for the above. num_frames_diagnostic=4000 # number of frames for "compute_prob" jobs @@ -59,6 +67,7 @@ stage=0 nj=6 # This should be set to the maximum number of jobs you are # comfortable to run in parallel; you can increase it if your disk # speed is greater and you have more machines. +srand=0 # rand seed for nnet3-copy-egs and nnet3-shuffle-egs online_ivector_dir= # can be used if we are including speaker information as iVectors. cmvn_opts= # can be used for specifying CMVN options, if feature type is not lda (if lda, # it doesn't make sense to use different options than were used as input to the @@ -111,9 +120,18 @@ utils/split_data.sh $data $nj mkdir -p $dir/log $dir/info +[ -z "$num_utts_subset_valid" ] && num_utts_subset_valid=$num_utts_subset +[ -z "$num_utts_subset_train" ] && num_utts_subset_train=$num_utts_subset + +num_utts=$(cat $data/utt2spk | wc -l) +if ! [ $num_utts -gt $[$num_utts_subset_valid*4] ]; then + echo "$0: number of utterances $num_utts in your training data is too small versus --num-utts-subset=$num_utts_subset" + echo "... you probably have so little data that it doesn't make sense to train a neural net." + exit 1 +fi # Get list of validation utterances. -awk '{print $1}' $data/utt2spk | utils/shuffle_list.pl | head -$num_utts_subset | sort \ +awk '{print $1}' $data/utt2spk | utils/shuffle_list.pl | head -$num_utts_subset_valid | sort \ > $dir/valid_uttlist || exit 1; if [ -f $data/utt2uniq ]; then # this matters if you use data augmentation. @@ -128,7 +146,7 @@ if [ -f $data/utt2uniq ]; then # this matters if you use data augmentation. fi awk '{print $1}' $data/utt2spk | utils/filter_scp.pl --exclude $dir/valid_uttlist | \ - utils/shuffle_list.pl | head -$num_utts_subset | sort > $dir/train_subset_uttlist || exit 1; + utils/shuffle_list.pl | head -$num_utts_subset_train > $dir/train_subset_uttlist || exit 1; if [ ! -z "$transform_dir" ] && [ -f $transform_dir/trans.1 ] && [ $feat_type != "raw" ]; then echo "$0: using transforms from $transform_dir" @@ -145,15 +163,33 @@ if [ -f $transform_dir/raw_trans.1 ] && [ $feat_type == "raw" ]; then fi fi +nj_subset=$nj +if [ $nj_subset -gt `cat $dir/train_subset_uttlist | wc -l` ]; then + nj_subset=`cat $dir/train_subset_uttlist | wc -l` +fi + +if [ $nj_subset -gt `cat $dir/valid_uttlist | wc -l` ]; then + nj_subset=`cat $dir/valid_uttlist | wc -l` +fi + +valid_uttlist_all= +train_subset_uttlist_all= +for n in `seq $nj_subset`; do + valid_uttlist_all="$valid_uttlist_all $dir/valid_uttlist.$n" + train_subset_uttlist_all="$train_subset_uttlist_all $dir/train_subset_uttlist.$n" +done + +utils/split_scp.pl $dir/valid_uttlist $valid_uttlist_all +utils/split_scp.pl $dir/train_subset_uttlist $train_subset_uttlist_all ## Set up features. echo "$0: feature type is $feat_type" case $feat_type in raw) feats="ark,s,cs:utils/filter_scp.pl --exclude $dir/valid_uttlist $sdata/JOB/feats.scp | apply-cmvn $cmvn_opts --utt2spk=ark:$sdata/JOB/utt2spk scp:$sdata/JOB/cmvn.scp scp:- ark:- |" - valid_feats="ark,s,cs:utils/filter_scp.pl $dir/valid_uttlist $data/feats.scp | apply-cmvn $cmvn_opts --utt2spk=ark:$data/utt2spk scp:$data/cmvn.scp scp:- ark:- |" - train_subset_feats="ark,s,cs:utils/filter_scp.pl $dir/train_subset_uttlist $data/feats.scp | apply-cmvn $cmvn_opts --utt2spk=ark:$data/utt2spk scp:$data/cmvn.scp scp:- ark:- |" + valid_feats="ark,s,cs:utils/filter_scp.pl $dir/valid_uttlist.JOB $data/feats.scp | apply-cmvn $cmvn_opts --utt2spk=ark:$data/utt2spk scp:$data/cmvn.scp scp:- ark:- |" + train_subset_feats="ark,s,cs:utils/filter_scp.pl $dir/train_subset_uttlist.JOB $data/feats.scp | apply-cmvn $cmvn_opts --utt2spk=ark:$data/utt2spk scp:$data/cmvn.scp scp:- ark:- |" echo $cmvn_opts >$dir/cmvn_opts # caution: the top-level nnet training script should copy this to its own dir now. ;; lda) @@ -164,8 +200,8 @@ case $feat_type in echo "You cannot supply --cmvn-opts option if feature type is LDA." && exit 1; cmvn_opts=$(cat $dir/cmvn_opts) feats="ark,s,cs:utils/filter_scp.pl --exclude $dir/valid_uttlist $sdata/JOB/feats.scp | apply-cmvn $cmvn_opts --utt2spk=ark:$sdata/JOB/utt2spk scp:$sdata/JOB/cmvn.scp scp:- ark:- | splice-feats $splice_opts ark:- ark:- | transform-feats $dir/final.mat ark:- ark:- |" - valid_feats="ark,s,cs:utils/filter_scp.pl $dir/valid_uttlist $data/feats.scp | apply-cmvn $cmvn_opts --utt2spk=ark:$data/utt2spk scp:$data/cmvn.scp scp:- ark:- | splice-feats $splice_opts ark:- ark:- | transform-feats $dir/final.mat ark:- ark:- |" - train_subset_feats="ark,s,cs:utils/filter_scp.pl $dir/train_subset_uttlist $data/feats.scp | apply-cmvn $cmvn_opts --utt2spk=ark:$data/utt2spk scp:$data/cmvn.scp scp:- ark:- | splice-feats $splice_opts ark:- ark:- | transform-feats $dir/final.mat ark:- ark:- |" + valid_feats="ark,s,cs:utils/filter_scp.pl $dir/valid_uttlist.JOB $data/feats.scp | apply-cmvn $cmvn_opts --utt2spk=ark:$data/utt2spk scp:$data/cmvn.scp scp:- ark:- | splice-feats $splice_opts ark:- ark:- | transform-feats $dir/final.mat ark:- ark:- |" + train_subset_feats="ark,s,cs:utils/filter_scp.pl $dir/train_subset_uttlist.JOB $data/feats.scp | apply-cmvn $cmvn_opts --utt2spk=ark:$data/utt2spk scp:$data/cmvn.scp scp:- ark:- | splice-feats $splice_opts ark:- ark:- | transform-feats $dir/final.mat ark:- ark:- |" ;; *) echo "$0: invalid feature type --feat-type '$feat_type'" && exit 1; esac @@ -182,8 +218,8 @@ if [ ! -z "$online_ivector_dir" ]; then ivector_period=$(cat $online_ivector_dir/ivector_period) || exit 1; ivector_opt="--ivectors='ark,s,cs:utils/filter_scp.pl $sdata/JOB/utt2spk $online_ivector_dir/ivector_online.scp | subsample-feats --n=-$ivector_period scp:- ark:- |'" - valid_ivector_opt="--ivectors='ark,s,cs:utils/filter_scp.pl $dir/valid_uttlist $online_ivector_dir/ivector_online.scp | subsample-feats --n=-$ivector_period scp:- ark:- |'" - train_subset_ivector_opt="--ivectors='ark,s,cs:utils/filter_scp.pl $dir/train_subset_uttlist $online_ivector_dir/ivector_online.scp | subsample-feats --n=-$ivector_period scp:- ark:- |'" + valid_ivector_opt="--ivectors='ark,s,cs:utils/filter_scp.pl $dir/valid_uttlist.JOB $online_ivector_dir/ivector_online.scp | subsample-feats --n=-$ivector_period scp:- ark:- |'" + train_subset_ivector_opt="--ivectors='ark,s,cs:utils/filter_scp.pl $dir/train_subset_uttlist.JOB $online_ivector_dir/ivector_online.scp | subsample-feats --n=-$ivector_period scp:- ark:- |'" else echo 0 >$dir/info/ivector_dim fi @@ -255,9 +291,13 @@ fi egs_opts="--left-context=$left_context --right-context=$right_context --compress=$compress" +[ ! -z "$deriv_weights_scp" ] && egs_opts="$egs_opts --deriv-weights-rspecifier=scp:$deriv_weights_scp" +[ ! -z "$l2_regularizer_targets" ] && egs_opts="$egs_opts --l2reg-targets-rspecifier=scp:$l2_regularizer_targets" + [ -z $valid_left_context ] && valid_left_context=$left_context; [ -z $valid_right_context ] && valid_right_context=$right_context; valid_egs_opts="--left-context=$valid_left_context --right-context=$valid_right_context --compress=$compress" +[ ! -z "$deriv_weights_scp" ] && valid_egs_opts="$valid_egs_opts --deriv-weights-rspecifier=scp:$deriv_weights_scp" echo $left_context > $dir/info/left_context echo $right_context > $dir/info/right_context @@ -281,15 +321,15 @@ case $target_type in "dense") get_egs_program="nnet3-get-egs-dense-targets --num-targets=$num_targets" - targets="ark:utils/filter_scp.pl --exclude $dir/valid_uttlist $targets_scp_split | copy-feats scp:- ark:- |" - valid_targets="ark:utils/filter_scp.pl $dir/valid_uttlist $targets_scp | copy-feats scp:- ark:- |" - train_subset_targets="ark:utils/filter_scp.pl $dir/train_subset_uttlist $targets_scp | copy-feats scp:- ark:- |" + targets="ark,s,cs:utils/filter_scp.pl --exclude $dir/valid_uttlist $targets_scp_split | copy-feats scp:- ark:- |" + valid_targets="ark,s,cs:utils/filter_scp.pl $dir/valid_uttlist.JOB $targets_scp | copy-feats scp:- ark:- |" + train_subset_targets="ark,s,cs:utils/filter_scp.pl $dir/train_subset_uttlist.JOB $targets_scp | copy-feats scp:- ark:- |" ;; "sparse") get_egs_program="nnet3-get-egs --num-pdfs=$num_targets" - targets="ark:utils/filter_scp.pl --exclude $dir/valid_uttlist $targets_scp_split | ali-to-post scp:- ark:- |" - valid_targets="ark:utils/filter_scp.pl $dir/valid_uttlist $targets_scp | ali-to-post scp:- ark:- |" - train_subset_targets="ark:utils/filter_scp.pl $dir/train_subset_uttlist $targets_scp | ali-to-post scp:- ark:- |" + targets="ark,s,cs:utils/filter_scp.pl --exclude $dir/valid_uttlist $targets_scp_split | ali-to-post scp:- ark:- |" + valid_targets="ark,s,cs:utils/filter_scp.pl $dir/valid_uttlist.JOB $targets_scp | ali-to-post scp:- ark:- |" + train_subset_targets="ark,s,cs:utils/filter_scp.pl $dir/train_subset_uttlist.JOB $targets_scp | ali-to-post scp:- ark:- |" ;; default) echo "$0: Unknown --target-type $target_type. Choices are dense and sparse" @@ -299,31 +339,43 @@ esac if [ $stage -le 3 ]; then echo "$0: Getting validation and training subset examples." rm -f $dir/.error 2>/dev/null - $cmd $dir/log/create_valid_subset.log \ + $cmd JOB=1:$nj_subset $dir/log/create_valid_subset.JOB.log \ $get_egs_program \ $valid_ivector_opt $valid_egs_opts "$valid_feats" \ "$valid_targets" \ - "ark:$dir/valid_all.egs" || touch $dir/.error & - $cmd $dir/log/create_train_subset.log \ + "ark:$dir/valid_all.JOB.egs" || touch $dir/.error & + $cmd JOB=1:$nj_subset $dir/log/create_train_subset.JOB.log \ $get_egs_program \ $train_subset_ivector_opt $valid_egs_opts "$train_subset_feats" \ "$train_subset_targets" \ - "ark:$dir/train_subset_all.egs" || touch $dir/.error & + "ark:$dir/train_subset_all.JOB.egs" || touch $dir/.error & wait; + + valid_egs_all= + train_subset_egs_all= + for n in `seq $nj_subset`; do + valid_egs_all="$valid_egs_all $dir/valid_all.$n.egs" + train_subset_egs_all="$train_subset_egs_all $dir/train_subset_all.$n.egs" + done + [ -f $dir/.error ] && echo "Error detected while creating train/valid egs" && exit 1 echo "... Getting subsets of validation examples for diagnostics and combination." $cmd $dir/log/create_valid_subset_combine.log \ - nnet3-subset-egs --n=$num_valid_frames_combine ark:$dir/valid_all.egs \ + cat $valid_egs_all \| \ + nnet3-subset-egs --n=$num_valid_frames_combine ark:- \ ark:$dir/valid_combine.egs || touch $dir/.error & $cmd $dir/log/create_valid_subset_diagnostic.log \ - nnet3-subset-egs --n=$num_frames_diagnostic ark:$dir/valid_all.egs \ + cat $valid_egs_all \| \ + nnet3-subset-egs --n=$num_frames_diagnostic ark:- \ ark:$dir/valid_diagnostic.egs || touch $dir/.error & $cmd $dir/log/create_train_subset_combine.log \ - nnet3-subset-egs --n=$num_train_frames_combine ark:$dir/train_subset_all.egs \ + cat $train_subset_egs_all \| \ + nnet3-subset-egs --n=$num_train_frames_combine ark:- \ ark:$dir/train_combine.egs || touch $dir/.error & $cmd $dir/log/create_train_subset_diagnostic.log \ - nnet3-subset-egs --n=$num_frames_diagnostic ark:$dir/train_subset_all.egs \ + cat $train_subset_egs_all \| \ + nnet3-subset-egs --n=$num_frames_diagnostic ark:- \ ark:$dir/train_diagnostic.egs || touch $dir/.error & wait sleep 5 # wait for file system to sync. @@ -332,7 +384,7 @@ if [ $stage -le 3 ]; then for f in $dir/{combine,train_diagnostic,valid_diagnostic}.egs; do [ ! -s $f ] && echo "No examples in file $f" && exit 1; done - rm -f $dir/valid_all.egs $dir/train_subset_all.egs $dir/{train,valid}_combine.egs + rm $dir/valid_all.*.egs $dir/train_subset_all.*.egs $dir/{train,valid}_combine.egs fi if [ $stage -le 4 ]; then @@ -349,7 +401,7 @@ if [ $stage -le 4 ]; then $get_egs_program \ $ivector_opt $egs_opts --num-frames=$frames_per_eg "$feats" "$targets" \ ark:- \| \ - nnet3-copy-egs --random=true --srand=JOB ark:- $egs_list || exit 1; + nnet3-copy-egs --random=true --srand=\$[JOB+$srand] ark:- $egs_list || exit 1; fi if [ $stage -le 5 ]; then @@ -365,7 +417,7 @@ if [ $stage -le 5 ]; then if [ $archives_multiple == 1 ]; then # normal case. $cmd --max-jobs-run $nj JOB=1:$num_archives_intermediate $dir/log/shuffle.JOB.log \ - nnet3-shuffle-egs --srand=JOB "ark:cat $egs_list|" ark:$dir/egs.JOB.ark || exit 1; + nnet3-shuffle-egs --srand=\$[JOB+$srand] "ark:cat $egs_list|" ark:$dir/egs.JOB.ark || exit 1; else # we need to shuffle the 'intermediate archives' and then split into the # final archives. we create soft links to manage this splitting, because @@ -381,12 +433,14 @@ if [ $stage -le 5 ]; then done done $cmd --max-jobs-run $nj JOB=1:$num_archives_intermediate $dir/log/shuffle.JOB.log \ - nnet3-shuffle-egs --srand=JOB "ark:cat $egs_list|" ark:- \| \ + nnet3-shuffle-egs --srand=\$[JOB+$srand] "ark:cat $egs_list|" ark:- \| \ nnet3-copy-egs ark:- $output_archives || exit 1; fi fi +wait + if [ $stage -le 6 ]; then echo "$0: removing temporary archives" for x in $(seq $nj); do @@ -400,10 +454,11 @@ if [ $stage -le 6 ]; then # there are some extra soft links that we should delete. for f in $dir/egs.*.*.ark; do rm $f; done fi - echo "$0: removing temporary" + echo "$0: removing temporary stuff" # Ignore errors below because trans.* might not exist. rm -f $dir/trans.{ark,scp} $dir/targets.*.scp 2>/dev/null fi -echo "$0: Finished preparing training examples" +wait +echo "$0: Finished preparing training examples" From 3827e1c8c558832531ab857d8e93f95d8ae22c98 Mon Sep 17 00:00:00 2001 From: Vimal Manohar Date: Wed, 23 Nov 2016 16:24:41 -0500 Subject: [PATCH 028/213] asr_diarization: Multiple outputs in nnet3 --- src/nnet3/nnet-combine.cc | 25 ++++++++++++++++----- src/nnet3/nnet-diagnostics.h | 5 +++++ src/nnet3/nnet-training.cc | 5 +++-- src/nnet3bin/nnet3-copy-egs.cc | 6 ++--- src/nnet3bin/nnet3-merge-egs.cc | 2 +- src/nnet3bin/nnet3-show-progress.cc | 34 ++++++++++++++++++++++++----- 6 files changed, 59 insertions(+), 18 deletions(-) diff --git a/src/nnet3/nnet-combine.cc b/src/nnet3/nnet-combine.cc index 45c1f74477b..d40c63bd3e7 100644 --- a/src/nnet3/nnet-combine.cc +++ b/src/nnet3/nnet-combine.cc @@ -424,15 +424,28 @@ double NnetCombiner::ComputeObjfAndDerivFromNnet( end = egs_.end(); for (; iter != end; ++iter) prob_computer_->Compute(*iter); - const SimpleObjectiveInfo *objf_info = prob_computer_->GetObjective("output"); - if (objf_info == NULL) - KALDI_ERR << "Error getting objective info (unsuitable egs?)"; - KALDI_ASSERT(objf_info->tot_weight > 0.0); + + double tot_weight = 0.0; + double tot_objf = 0.0; + + { + const unordered_map &objf_info = prob_computer_->GetAllObjectiveInfo(); + unordered_map::const_iterator objf_it = objf_info.begin(), + objf_end = objf_info.end(); + + for (; objf_it != objf_end; ++objf_it) { + tot_objf += objf_it->second.tot_objective; + tot_weight += objf_it->second.tot_weight; + } + } + + KALDI_ASSERT(tot_weight > 0.0); + const Nnet &deriv = prob_computer_->GetDeriv(); VectorizeNnet(deriv, nnet_params_deriv); // we prefer to deal with normalized objective functions. - nnet_params_deriv->Scale(1.0 / objf_info->tot_weight); - return objf_info->tot_objective / objf_info->tot_weight; + nnet_params_deriv->Scale(1.0 / tot_weight); + return tot_objf / tot_weight; } diff --git a/src/nnet3/nnet-diagnostics.h b/src/nnet3/nnet-diagnostics.h index 6ed6c4a33a7..59f0cd16f47 100644 --- a/src/nnet3/nnet-diagnostics.h +++ b/src/nnet3/nnet-diagnostics.h @@ -102,6 +102,11 @@ class NnetComputeProb { // or NULL if there is no such info. const SimpleObjectiveInfo *GetObjective(const std::string &output_name) const; + // return objective info for all outputs + const unordered_map & GetAllObjectiveInfo() const { + return objf_info_; + } + // if config.compute_deriv == true, returns a reference to the // computed derivative. Otherwise crashes. const Nnet &GetDeriv() const; diff --git a/src/nnet3/nnet-training.cc b/src/nnet3/nnet-training.cc index 9d957afe1de..bdbe244a648 100644 --- a/src/nnet3/nnet-training.cc +++ b/src/nnet3/nnet-training.cc @@ -188,11 +188,12 @@ bool NnetTrainer::PrintTotalStats() const { unordered_map::const_iterator iter = objf_info_.begin(), end = objf_info_.end(); - bool ans = false; + bool ans = true; for (; iter != end; ++iter) { const std::string &name = iter->first; const ObjectiveFunctionInfo &info = iter->second; - ans = ans || info.PrintTotalStats(name); + if (!info.PrintTotalStats(name)) + ans = false; } PrintMaxChangeStats(); return ans; diff --git a/src/nnet3bin/nnet3-copy-egs.cc b/src/nnet3bin/nnet3-copy-egs.cc index 0b82d91353a..ceb415ffe87 100644 --- a/src/nnet3bin/nnet3-copy-egs.cc +++ b/src/nnet3bin/nnet3-copy-egs.cc @@ -58,7 +58,7 @@ bool ContainsSingleExample(const NnetExample &eg, end = io.indexes.end(); // Should not have an empty input/output type. KALDI_ASSERT(!io.indexes.empty()); - if (io.name == "input" || io.name == "output") { + if (io.name == "input" || io.name.find("output") != std::string::npos) { int32 min_t = iter->t, max_t = iter->t; for (; iter != end; ++iter) { int32 this_t = iter->t; @@ -75,7 +75,7 @@ bool ContainsSingleExample(const NnetExample &eg, *min_input_t = min_t; *max_input_t = max_t; } else { - KALDI_ASSERT(io.name == "output"); + KALDI_ASSERT(io.name.find("output") != std::string::npos); done_output = true; *min_output_t = min_t; *max_output_t = max_t; @@ -127,7 +127,7 @@ void FilterExample(const NnetExample &eg, min_t = min_input_t; max_t = max_input_t; is_input_or_output = true; - } else if (name == "output") { + } else if (name.find("output") != std::string::npos) { min_t = min_output_t; max_t = max_output_t; is_input_or_output = true; diff --git a/src/nnet3bin/nnet3-merge-egs.cc b/src/nnet3bin/nnet3-merge-egs.cc index 8627671f53a..7415db8d12a 100644 --- a/src/nnet3bin/nnet3-merge-egs.cc +++ b/src/nnet3bin/nnet3-merge-egs.cc @@ -30,7 +30,7 @@ namespace nnet3 { // or crashes if it is not there. int32 NumOutputIndexes(const NnetExample &eg) { for (size_t i = 0; i < eg.io.size(); i++) - if (eg.io[i].name == "output") + if (eg.io[i].name.find("output") != std::string::npos) return eg.io[i].indexes.size(); KALDI_ERR << "No output named 'output' in the eg."; return 0; // Suppress compiler warning. diff --git a/src/nnet3bin/nnet3-show-progress.cc b/src/nnet3bin/nnet3-show-progress.cc index 10898dc0ca6..785d3d0aa88 100644 --- a/src/nnet3bin/nnet3-show-progress.cc +++ b/src/nnet3bin/nnet3-show-progress.cc @@ -107,17 +107,39 @@ int main(int argc, char *argv[]) { eg_end = examples.end(); for (; eg_iter != eg_end; ++eg_iter) prob_computer.Compute(*eg_iter); - const SimpleObjectiveInfo *objf_info = prob_computer.GetObjective("output"); - double objf_per_frame = objf_info->tot_objective / objf_info->tot_weight; + + double tot_weight = 0.0; + + { + const unordered_map &objf_info = prob_computer.GetAllObjectiveInfo(); + + unordered_map::const_iterator objf_it = objf_info.begin(), + objf_end = objf_info.end(); + + + for (; objf_it != objf_end; ++objf_it) { + double objf_per_frame = objf_it->second.tot_objective / objf_it->second.tot_weight; + + if (objf_it->first == "output") { + KALDI_LOG << "At position " << middle + << ", objf per frame is " << objf_per_frame; + } else { + KALDI_LOG << "At position " << middle + << ", objf per frame for '" << objf_it->first + << "' is " << objf_per_frame; + } + + tot_weight += objf_it->second.tot_weight; + } + } + const Nnet &nnet_gradient = prob_computer.GetDeriv(); - KALDI_LOG << "At position " << middle - << ", objf per frame is " << objf_per_frame; Vector old_dotprod(num_updatable), new_dotprod(num_updatable); ComponentDotProducts(nnet_gradient, nnet1, &old_dotprod); ComponentDotProducts(nnet_gradient, nnet2, &new_dotprod); - old_dotprod.Scale(1.0 / objf_info->tot_weight); - new_dotprod.Scale(1.0 / objf_info->tot_weight); + old_dotprod.Scale(1.0 / tot_weight); + new_dotprod.Scale(1.0 / tot_weight); diff.AddVec(1.0/ num_segments, new_dotprod); diff.AddVec(-1.0 / num_segments, old_dotprod); KALDI_VLOG(1) << "By segment " << s << ", objf change is " From e9535d8aa5f5ed373edae0128347c433c85fe44b Mon Sep 17 00:00:00 2001 From: Vimal Manohar Date: Wed, 23 Nov 2016 22:08:00 -0500 Subject: [PATCH 029/213] raw_python_script: Made LSTM and TDNN raw configs similar --- egs/wsj/s5/steps/nnet3/lstm/make_configs.py | 62 +++++++++++++++------ 1 file changed, 45 insertions(+), 17 deletions(-) diff --git a/egs/wsj/s5/steps/nnet3/lstm/make_configs.py b/egs/wsj/s5/steps/nnet3/lstm/make_configs.py index 205b6034fad..9fb9fad1d0c 100755 --- a/egs/wsj/s5/steps/nnet3/lstm/make_configs.py +++ b/egs/wsj/s5/steps/nnet3/lstm/make_configs.py @@ -56,6 +56,18 @@ def GetArgs(): parser.add_argument("--max-change-per-component-final", type=float, help="Enforces per-component max change for the final affine layer. " "if 0 it would not be enforced.", default=1.5) + parser.add_argument("--add-lda", type=str, action=nnet3_train_lib.StrToBoolAction, + help="If \"true\" an LDA matrix computed from the input features " + "(spliced according to the first set of splice-indexes) will be used as " + "the first Affine layer. This affine layer's parameters are fixed during training. " + "This variable needs to be set to \"false\" when using dense-targets.", + default=True, choices = ["false", "true"]) + parser.add_argument("--add-final-sigmoid", type=str, action=nnet3_train_lib.StrToBoolAction, + help="add a sigmoid layer as the final layer. Applicable only if skip-final-softmax is true.", + choices=['true', 'false'], default = False) + parser.add_argument("--objective-type", type=str, default="linear", + choices = ["linear", "quadratic", "xent"], + help = "the type of objective; i.e. quadratic or linear or cross-entropy per dim") # LSTM options parser.add_argument("--num-lstm-layers", type=int, @@ -217,7 +229,9 @@ def ParseLstmDelayString(lstm_delay): raise ValueError("invalid --lstm-delay argument, too-short element: " + lstm_delay) elif len(indexes) == 2 and indexes[0] * indexes[1] >= 0: - raise ValueError('Warning: ' + str(indexes) + ' is not a standard BLSTM mode. There should be a negative delay for the forward, and a postive delay for the backward.') + raise ValueError('Warning: ' + str(indexes) + + ' is not a standard BLSTM mode. ' + + 'There should be a negative delay for the forward, and a postive delay for the backward.') if len(indexes) == 2 and indexes[0] > 0: # always a negative delay followed by a postive delay indexes[0], indexes[1] = indexes[1], indexes[0] lstm_delay_array.append(indexes) @@ -227,29 +241,35 @@ def ParseLstmDelayString(lstm_delay): return lstm_delay_array -def MakeConfigs(config_dir, feat_dim, ivector_dim, num_targets, +def MakeConfigs(config_dir, feat_dim, ivector_dim, num_targets, add_lda, splice_indexes, lstm_delay, cell_dim, hidden_dim, recurrent_projection_dim, non_recurrent_projection_dim, num_lstm_layers, num_hidden_layers, norm_based_clipping, clipping_threshold, zeroing_threshold, zeroing_interval, ng_per_element_scale_options, ng_affine_options, - label_delay, include_log_softmax, xent_regularize, + label_delay, include_log_softmax, add_final_sigmoid, + objective_type, xent_regularize, self_repair_scale_nonlinearity, self_repair_scale_clipgradient, max_change_per_component, max_change_per_component_final): config_lines = {'components':[], 'component-nodes':[]} config_files={} - prev_layer_output = nodes.AddInputLayer(config_lines, feat_dim, splice_indexes[0], ivector_dim) + prev_layer_output = nodes.AddInputLayer(config_lines, feat_dim, splice_indexes[0], + ivector_dim) # Add the init config lines for estimating the preconditioning matrices init_config_lines = copy.deepcopy(config_lines) init_config_lines['components'].insert(0, '# Config file for initializing neural network prior to') init_config_lines['components'].insert(0, '# preconditioning matrix computation') - nodes.AddOutputLayer(init_config_lines, prev_layer_output) + nodes.AddOutputLayer(init_config_lines, prev_layer_output, label_delay = label_delay, objective_type = objective_type) config_files[config_dir + '/init.config'] = init_config_lines - prev_layer_output = nodes.AddLdaLayer(config_lines, "L0", prev_layer_output, config_dir + '/lda.mat') + # add_lda needs to be set "false" when using dense targets, + # or if the task is not a simple classification task + # (e.g. regression, multi-task) + if add_lda: + prev_layer_output = nodes.AddLdaLayer(config_lines, "L0", prev_layer_output, args.config_dir + '/lda.mat') for i in range(num_lstm_layers): if len(lstm_delay[i]) == 2: # add a bi-directional LSTM layer @@ -284,7 +304,7 @@ def MakeConfigs(config_dir, feat_dim, ivector_dim, num_targets, max_change_per_component = max_change_per_component) # make the intermediate config file for layerwise discriminative # training - nodes.AddFinalLayer(config_lines, prev_layer_output, num_targets, ng_affine_options, max_change_per_component = max_change_per_component_final, label_delay = label_delay, include_log_softmax = include_log_softmax) + nodes.AddFinalLayer(config_lines, prev_layer_output, num_targets, ng_affine_options, max_change_per_component = max_change_per_component_final, label_delay = label_delay, include_log_softmax = include_log_softmax, add_final_sigmoid = add_final_sigmoid, objective_type = objective_type) if xent_regularize != 0.0: @@ -302,7 +322,7 @@ def MakeConfigs(config_dir, feat_dim, ivector_dim, num_targets, ng_affine_options, self_repair_scale = self_repair_scale_nonlinearity, max_change_per_component = max_change_per_component) # make the intermediate config file for layerwise discriminative # training - nodes.AddFinalLayer(config_lines, prev_layer_output, num_targets, ng_affine_options, max_change_per_component = max_change_per_component_final, label_delay = label_delay, include_log_softmax = include_log_softmax) + nodes.AddFinalLayer(config_lines, prev_layer_output, num_targets, ng_affine_options, max_change_per_component = max_change_per_component_final, label_delay = label_delay, include_log_softmax = include_log_softmax, add_final_sigmoid = add_final_sigmoid, objective_type = objective_type) if xent_regularize != 0.0: nodes.AddFinalLayer(config_lines, prev_layer_output, num_targets, @@ -331,24 +351,30 @@ def ProcessSpliceIndexes(config_dir, splice_indexes, label_delay, num_lstm_layer if (num_hidden_layers < num_lstm_layers): raise Exception("num-lstm-layers : number of lstm layers has to be greater than number of layers, decided based on splice-indexes") - # write the files used by other scripts like steps/nnet3/get_egs.sh - f = open(config_dir + "/vars", "w") - print('model_left_context=' + str(left_context), file=f) - print('model_right_context=' + str(right_context), file=f) - print('num_hidden_layers=' + str(num_hidden_layers), file=f) - # print('initial_right_context=' + str(splice_array[0][-1]), file=f) - f.close() - return [left_context, right_context, num_hidden_layers, splice_indexes] def Main(): args = GetArgs() - [left_context, right_context, num_hidden_layers, splice_indexes] = ProcessSpliceIndexes(args.config_dir, args.splice_indexes, args.label_delay, args.num_lstm_layers) + [left_context, right_context, + num_hidden_layers, splice_indexes] = ProcessSpliceIndexes(args.config_dir, args.splice_indexes, + args.label_delay, args.num_lstm_layers) + + # write the files used by other scripts like steps/nnet3/get_egs.sh + f = open(args.config_dir + "/vars", "w") + print('model_left_context=' + str(left_context), file=f) + print('model_right_context=' + str(right_context), file=f) + print('num_hidden_layers=' + str(num_hidden_layers), file=f) + print('num_targets=' + str(args.num_targets), file=f) + print('objective_type=' + str(args.objective_type), file=f) + print('add_lda=' + ("true" if args.add_lda else "false"), file=f) + print('include_log_softmax=' + ("true" if args.include_log_softmax else "false"), file=f) + f.close() MakeConfigs(config_dir = args.config_dir, feat_dim = args.feat_dim, ivector_dim = args.ivector_dim, num_targets = args.num_targets, + add_lda = args.add_lda, splice_indexes = splice_indexes, lstm_delay = args.lstm_delay, cell_dim = args.cell_dim, hidden_dim = args.hidden_dim, @@ -364,6 +390,8 @@ def Main(): ng_affine_options = args.ng_affine_options, label_delay = args.label_delay, include_log_softmax = args.include_log_softmax, + add_final_sigmoid = args.add_final_sigmoid, + objective_type = args.objective_type, xent_regularize = args.xent_regularize, self_repair_scale_nonlinearity = args.self_repair_scale_nonlinearity, self_repair_scale_clipgradient = args.self_repair_scale_clipgradient, From 7806dd6bd4a8986d9c876f010a63b61bc9a71251 Mon Sep 17 00:00:00 2001 From: Vimal Manohar Date: Wed, 23 Nov 2016 15:02:30 -0500 Subject: [PATCH 030/213] asr_diarization: Create prepare_unsad_data.sh --- .../local/segmentation/prepare_unsad_data.sh | 537 ++++++++++++++++++ 1 file changed, 537 insertions(+) create mode 100755 egs/aspire/s5/local/segmentation/prepare_unsad_data.sh diff --git a/egs/aspire/s5/local/segmentation/prepare_unsad_data.sh b/egs/aspire/s5/local/segmentation/prepare_unsad_data.sh new file mode 100755 index 00000000000..12097811ec9 --- /dev/null +++ b/egs/aspire/s5/local/segmentation/prepare_unsad_data.sh @@ -0,0 +1,537 @@ +#!/bin/bash + +# This script prepares speech labels and deriv weights for +# training unsad network for speech activity detection and music detection. + +set -u +set -o pipefail +set -e + +. path.sh + +stage=-2 +cmd=queue.pl +reco_nj=40 +nj=100 + +# Options to be passed to get_sad_map.py +map_noise_to_sil=true # Map noise phones to silence label (0) +map_unk_to_speech=true # Map unk phones to speech label (1) +sad_map= # Initial mapping from phones to speech/non-speech labels. + # Overrides the default mapping using phones/silence.txt + # and phones/nonsilence.txt + +# Options for feature extraction +feat_type=mfcc # mfcc or plp +add_pitch=false # Add pitch features + +config_dir=conf +feat_config= +pitch_config= + +mfccdir=mfcc +plpdir=plp + +speed_perturb=true + +sat_model_dir= # Model directory used for getting alignments +lang_test= # Language directory used to build graph. + # If its not provided, $lang will be used instead. + +. utils/parse_options.sh + +if [ $# -ne 5 ]; then + echo "This script takes a data directory and creates a new data directory " + echo "and speech activity labels" + echo "for the purpose of training a Universal Speech Activity Detector." + echo "Usage: $0 [options] " + echo " e.g.: $0 data/train_100k data/lang exp/tri4a exp/vad_data_prep" + echo "" + echo "Main options (for others, see top of script file)" + echo " --config # config file containing options" + echo " --cmd (run.pl|/queue.pl ) # how to run jobs." + echo " --reco-nj <#njobs|4> # Split a whole data directory into these many pieces" + echo " --nj <#njobs|4> # Split a segmented data directory into these many pieces" + exit 1 +fi + +data_dir=$1 +lang=$2 +model_dir=$3 +dir=$4 + +if [ $feat_type != "plp" ] && [ $feat_type != "mfcc" ]; then + echo "$0: --feat-type must be plp or mfcc. Must match the model_dir used." + exit 1 +fi + +[ -z "$feat_config" ] && feat_config=$config_dir/$feat_type.conf +[ -z "$pitch_config" ] && pitch_config=$config_dir/pitch.conf + +extra_files= + +if $add_pitch; then + extra_files="$extra_files $pitch_config" +fi + +for f in $feat_config $extra_files; do + if [ ! -f $f ]; then + echo "$f could not be found" + exit 1 + fi +done + +mkdir -p $dir + +function make_mfcc { + local nj=$nj + local mfcc_config=$feat_config + local add_pitch=$add_pitch + local cmd=$cmd + local pitch_config=$pitch_config + + while [ $# -gt 0 ]; do + if [ $1 == "--nj" ]; then + nj=$2 + shift; shift; + elif [ $1 == "--mfcc-config" ]; then + mfcc_config=$2 + shift; shift; + elif [ $1 == "--add-pitch" ]; then + add_pitch=$2 + shift; shift; + elif [ $1 == "--cmd" ]; then + cmd=$2 + shift; shift; + elif [ $1 == "--pitch-config" ]; then + pitch_config=$2 + shift; shift; + else + break + fi + done + + if [ $# -ne 3 ]; then + echo "Usage: make_mfcc " + exit 1 + fi + + if $add_pitch; then + steps/make_mfcc_pitch.sh --cmd "$cmd" --nj $nj \ + --mfcc-config $mfcc_config --pitch-config $pitch_config $1 $2 $3 || exit 1 + else + steps/make_mfcc.sh --cmd "$cmd" --nj $nj \ + --mfcc-config $mfcc_config $1 $2 $3 || exit 1 + fi + +} + +function make_plp { + local nj=$nj + local mfcc_config=$feat_config + local add_pitch=$add_pitch + local cmd=$cmd + local pitch_config=$pitch_config + + while [ $# -gt 0 ]; do + if [ $1 == "--nj" ]; then + nj=$2 + shift; shift; + elif [ $1 == "--plp-config" ]; then + plp_config=$2 + shift; shift; + elif [ $1 == "--add-pitch" ]; then + add_pitch=$2 + shift; shift; + elif [ $1 == "--cmd" ]; then + cmd=$2 + shift; shift; + elif [ $1 == "--pitch-config" ]; then + pitch_config=$2 + shift; shift; + else + break + fi + done + + if [ $# -ne 3 ]; then + echo "Usage: make_plp " + exit 1 + fi + + if $add_pitch; then + steps/make_plp_pitch.sh --cmd "$cmd" --nj $nj \ + --plp-config $plp_config --pitch-config $pitch_config $1 $2 $3 || exit 1 + else + steps/make_plp.sh --cmd "$cmd" --nj $nj \ + --plp-config $plp_config $1 $2 $3 || exit 1 + fi +} + +frame_shift_info=`cat $feat_config | steps/segmentation/get_frame_shift_info_from_config.pl` || exit 1 + +frame_shift=`echo $frame_shift_info | awk '{print $1}'` +frame_overlap=`echo $frame_shift_info | awk '{print $2}'` + +data_id=$(basename $data_dir) +whole_data_dir=${data_dir}_whole +whole_data_id=${data_id}_whole + +if [ $stage -le -2 ]; then + steps/segmentation/get_sad_map.py \ + --init-sad-map="$sad_map" \ + --map-noise-to-sil=$map_noise_to_sil \ + --map-unk-to-speech=$map_unk_to_speech \ + $lang | utils/sym2int.pl -f 1 $lang/phones.txt > $dir/sad_map + + utils/data/convert_data_dir_to_whole.sh ${data_dir} ${whole_data_dir} + utils/data/get_utt2dur.sh ${whole_data_dir} +fi + +if $speed_perturb; then + plpdir=${plpdir}_sp + mfccdir=${mfccdir}_sp + + + if [ $stage -le -1 ]; then + utils/data/perturb_data_dir_speed_3way.sh ${whole_data_dir} ${whole_data_dir}_sp + utils/data/perturb_data_dir_speed_3way.sh ${data_dir} ${data_dir}_sp + + if [ $feat_type == "mfcc" ]; then + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $mfccdir/storage ]; then + utils/create_split_dir.pl \ + /export/b0{3,4,5,6}/$USER/kaldi-data/egs/aspire-$(date +'%m_%d_%H_%M')/s5/$mfccdir/storage $mfccdir/storage + fi + make_mfcc --cmd "$cmd --max-jobs-run 40" --nj $nj \ + --mfcc-config $feat_config \ + --add-pitch $add_pitch --pitch-config $pitch_config \ + ${whole_data_dir}_sp exp/make_mfcc $mfccdir || exit 1 + steps/compute_cmvn_stats.sh \ + ${whole_data_dir}_sp exp/make_mfcc $mfccdir || exit 1 + elif [ $feat_type == "plp" ]; then + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $plpdir/storage ]; then + utils/create_split_dir.pl \ + /export/b0{3,4,5,6}/$USER/kaldi-data/egs/aspire-$(date +'%m_%d_%H_%M')/s5/$plpdir/storage $plpdir/storage + fi + + make_plp --cmd "$cmd --max-jobs-run 40" --nj $nj \ + --plp-config $feat_config \ + --add-pitch $add_pitch --pitch-config $pitch_config \ + ${whole_data_dir}_sp exp/make_plp $plpdir || exit 1 + steps/compute_cmvn_stats.sh \ + ${whole_data_dir}_sp exp/make_plp $plpdir || exit 1 + else + echo "$0: Unknown feat-type $feat_type. Must be mfcc or plp." + exit 1 + fi + + utils/fix_data_dir.sh ${whole_data_dir}_sp + fi + + data_dir=${data_dir}_sp + whole_data_dir=${whole_data_dir}_sp + data_id=${data_id}_sp +fi + + +############################################################################### +# Compute length of recording +############################################################################### + +utils/data/get_reco2utt.sh $data_dir + +if [ $stage -le 0 ]; then + steps/segmentation/get_utt2num_frames.sh \ + --frame-shift $frame_shift --frame-overlap $frame_overlap \ + --cmd "$cmd" --nj $reco_nj $whole_data_dir + + awk '{print $1" "$2}' ${data_dir}/segments | utils/apply_map.pl -f 2 ${whole_data_dir}/utt2num_frames > $data_dir/utt2max_frames + utils/data/subsegment_feats.sh ${whole_data_dir}/feats.scp \ + $frame_shift $frame_overlap ${data_dir}/segments | \ + utils/data/fix_subsegmented_feats.pl $data_dir/utt2max_frames \ + > ${data_dir}/feats.scp + + if [ $feat_type == mfcc ]; then + steps/compute_cmvn_stats.sh ${data_dir} exp/make_mfcc/${data_id} $mfccdir + else + steps/compute_cmvn_stats.sh ${data_dir} exp/make_plp/${data_id} $plpdir + fi + + utils/fix_data_dir.sh $data_dir +fi + +if [ -z "$sat_model_dir" ]; then + ali_dir=${model_dir}_ali_${data_id} + if [ $stage -le 2 ]; then + steps/align_si.sh --nj $nj --cmd "$cmd" \ + ${data_dir} ${lang} ${model_dir} ${model_dir}_ali_${data_id} || exit 1 + fi +else + ali_dir=${sat_model_dir}_ali_${data_id} + #obtain the alignment of the perturbed data + if [ $stage -le 2 ]; then + steps/align_fmllr.sh --nj $nj --cmd "$cmd" \ + ${data_dir} ${lang} ${sat_model_dir} ${sat_model_dir}_ali_${data_id} || exit 1 + fi +fi + + +# All the data from this point is speed perturbed. + +data_id=$(basename $data_dir) +utils/split_data.sh $data_dir $nj + +############################################################################### +# Convert alignment for the provided segments into +# initial SAD labels at utterance-level in segmentation format +############################################################################### + +vad_dir=$dir/`basename ${ali_dir}`_vad_${data_id} +if [ $stage -le 3 ]; then + steps/segmentation/internal/convert_ali_to_vad.sh --cmd "$cmd" \ + $data_dir $ali_dir \ + $dir/sad_map $vad_dir +fi + +[ ! -s $vad_dir/sad_seg.scp ] && echo "$0: $vad_dir/vad.scp is empty" && exit 1 + +if [ $stage -le 4 ]; then + utils/copy_data_dir.sh $data_dir $dir/${data_id}_manual_segments + + awk '{print $1" "$2}' $dir/${data_id}_manual_segments/segments | sort -k1,1 > $dir/${data_id}_manual_segments/utt2spk + utils/utt2spk_to_spk2utt.pl $dir/${data_id}_manual_segments/utt2spk | sort -k1,1 > $dir/${data_id}_manual_segments/spk2utt + + if [ $feat_type == mfcc ]; then + steps/compute_cmvn_stats.sh $dir/${data_id}_manual_segments exp/make_mfcc/${data_id}_manual_segments $mfccdir + else + steps/compute_cmvn_stats.sh $dir/${data_id}_manual_segments exp/make_plp/${data_id}_manual_segments $plpdir + fi + + utils/fix_data_dir.sh $dir/${data_id}_manual_segments || true # Might fail because utt2spk will be not sorted on both utts and spks +fi + + +#utils/split_data.sh --per-reco $data_dir $reco_nj +#segmentation-combine-segments ark,s:$vad_dir/sad_seg.scp +# "ark,s:segmentation-init-from-segments --shift-to-zero=false --frame-shift=$ali_frame_shift --frame-overlap=$ali_frame_overlap ${data}/split${reco_nj}reco/JOB/segments ark:- |" \ +# "ark:cat ${data}/split${reco_nj}reco/JOB/segments | cut -d ' ' -f 1,2 | utils/utt2spk_to_spk2utt.pl | sort -k1,1 |" ark:- + +############################################################################### + + +# Create extended data directory that consists of the provided +# segments along with the segments outside it. +# This is basically dividing the whole recording into pieces +# consisting of pieces corresponding to the provided segments +# and outside the provided segments. + +############################################################################### +# Create segments outside of the manual segments +############################################################################### + +outside_data_dir=$dir/${data_id}_outside +if [ $stage -le 5 ]; then + rm -rf $outside_data_dir + mkdir -p $outside_data_dir/split${reco_nj}reco + + for f in wav.scp reco2file_and_channel stm glm; do + [ -f ${data_dir}/$f ] && cp ${data_dir}/$f $outside_data_dir + done + + steps/segmentation/split_data_on_reco.sh $data_dir $whole_data_dir $reco_nj + + for n in `seq $reco_nj`; do + dsn=$whole_data_dir/split${reco_nj}reco/$n + awk '{print $2}' $dsn/segments | \ + utils/filter_scp.pl /dev/stdin $whole_data_dir/utt2num_frames > \ + $dsn/utt2num_frames + mkdir -p $outside_data_dir/split${reco_nj}reco/$n + done + + $cmd JOB=1:$reco_nj $outside_data_dir/log/get_empty_segments.JOB.log \ + segmentation-init-from-segments --frame-shift=$frame_shift \ + --frame-overlap=$frame_overlap --shift-to-zero=false \ + ${data_dir}/split${reco_nj}reco/JOB/segments ark:- \| \ + segmentation-combine-segments-to-recordings ark:- \ + "ark,t:cut -d ' ' -f 1,2 ${data_dir}/split${reco_nj}reco/JOB/segments | utils/utt2spk_to_spk2utt.pl |" ark:- \| \ + segmentation-create-subsegments --filter-label=1 --subsegment-label=0 \ + "ark:segmentation-init-from-lengths --label=1 ark,t:${whole_data_dir}/split${reco_nj}reco/JOB/utt2num_frames ark:- |" \ + ark:- ark:- \| \ + segmentation-post-process --remove-labels=0 --max-segment-length=1000 \ + --post-process-label=1 --overlap-length=50 \ + ark:- ark:- \| segmentation-to-segments --single-speaker=true \ + --frame-shift=$frame_shift --frame-overlap=$frame_overlap \ + ark:- ark,t:$outside_data_dir/split${reco_nj}reco/JOB/utt2spk \ + $outside_data_dir/split${reco_nj}reco/JOB/segments || exit 1 + + for n in `seq $reco_nj`; do + cat $outside_data_dir/split${reco_nj}reco/$n/utt2spk + done | sort -k1,1 > $outside_data_dir/utt2spk + + for n in `seq $reco_nj`; do + cat $outside_data_dir/split${reco_nj}reco/$n/segments + done | sort -k1,1 > $outside_data_dir/segments + + utils/fix_data_dir.sh $outside_data_dir + +fi + + +if [ $stage -le 6 ]; then + utils/data/get_reco2utt.sh $outside_data_dir + awk '{print $1" "$2}' $outside_data_dir/segments | utils/apply_map.pl -f 2 $whole_data_dir/utt2num_frames > $outside_data_dir/utt2max_frames + + utils/data/subsegment_feats.sh ${whole_data_dir}/feats.scp \ + $frame_shift $frame_overlap ${outside_data_dir}/segments | \ + utils/data/fix_subsegmented_feats.pl $outside_data_dir/utt2max_framres \ + > ${outside_data_dir}/feats.scp + +fi + +extended_data_dir=$dir/${data_id}_extended +if [ $stage -le 7 ]; then + cp $dir/${data_id}_manual_segments/cmvn.scp ${outside_data_dir} || exit 1 + utils/fix_data_dir.sh $outside_data_dir + + utils/combine_data.sh $extended_data_dir $data_dir $outside_data_dir + + steps/segmentation/split_data_on_reco.sh $data_dir $extended_data_dir $reco_nj +fi + +############################################################################### +# Create graph for decoding +############################################################################### + +# TODO: By default, we use word LM. If required, we can think +# consider phone LM. +graph_dir=$model_dir/graph +if [ $stage -le 8 ]; then + if [ ! -d $graph_dir ]; then + utils/mkgraph.sh ${lang_test} $model_dir $graph_dir || exit 1 + fi +fi + +############################################################################### +# Decode extended data directory +############################################################################### + + +# Decode without lattice (get only best path) +if [ $stage -le 8 ]; then + steps/decode_nolats.sh --cmd "$cmd --mem 2G" --nj $nj \ + --max-active 1000 --beam 10.0 --write-words false \ + --write-alignments true \ + $graph_dir ${extended_data_dir} \ + ${model_dir}/decode_${data_id}_extended || exit 1 + cp ${model_dir}/final.mdl ${model_dir}/decode_${data_id}_extended +fi + +model_id=`basename $model_dir` + +# Get VAD based on the decoded best path +decode_vad_dir=$dir/${model_id}_decode_vad_${data_id} +if [ $stage -le 9 ]; then + steps/segmentation/internal/convert_ali_to_vad.sh --cmd "$cmd" \ + $extended_data_dir ${model_dir}/decode_${data_id}_extended \ + $dir/sad_map $decode_vad_dir +fi + +[ ! -s $decode_vad_dir/sad_seg.scp ] && echo "$0: $decode_vad_dir/vad.scp is empty" && exit 1 + +vad_dir=`perl -e '($dir,$pwd)= @ARGV; if($dir!~m:^/:) { $dir = "$pwd/$dir"; } print $dir; ' $vad_dir ${PWD}` + +if [ $stage -le 10 ]; then + segmentation-init-from-segments --frame-shift=$frame_shift \ + --frame-overlap=$frame_overlap --segment-label=0 \ + $outside_data_dir/segments \ + ark,scp:$vad_dir/outside_sad_seg.ark,$vad_dir/outside_sad_seg.scp +fi + +reco_vad_dir=$dir/${model_id}_reco_vad_${data_id} +mkdir -p $reco_vad_dir +if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $reco_vad_dir/storage ]; then + utils/create_split_dir.pl \ + /export/b0{3,4,5,6}/$USER/kaldi-data/egs/aspire-$(date +'%m_%d_%H_%M')/s5/$reco_vad_dir/storage $reco_vad_dir/storage +fi + +reco_vad_dir=`perl -e '($dir,$pwd)= @ARGV; if($dir!~m:^/:) { $dir = "$pwd/$dir"; } print $dir; ' $reco_vad_dir ${PWD}` + +echo $reco_nj > $reco_vad_dir/num_jobs + +if [ $stage -le 11 ]; then + $cmd JOB=1:$reco_nj $reco_vad_dir/log/intersect_vad.JOB.log \ + segmentation-intersect-segments --mismatch-label=10 \ + "scp:cat $vad_dir/sad_seg.scp $vad_dir/outside_sad_seg.scp | sort -k1,1 | utils/filter_scp.pl $extended_data_dir/split${reco_nj}reco/JOB/utt2spk |" \ + "scp:utils/filter_scp.pl $extended_data_dir/split${reco_nj}reco/JOB/utt2spk $decode_vad_dir/sad_seg.scp |" \ + ark:- \| segmentation-post-process --remove-labels=10 \ + --merge-adjacent-segments --max-intersegment-length=10 ark:- ark:- \| \ + segmentation-combine-segments ark:- "ark:segmentation-init-from-segments --shift-to-zero=false $extended_data_dir/split${reco_nj}reco/JOB/segments ark:- |" \ + ark,t:$extended_data_dir/split${reco_nj}reco/JOB/reco2utt \ + ark,scp:$reco_vad_dir/sad_seg.JOB.ark,$reco_vad_dir/sad_seg.JOB.scp + for n in `seq $reco_nj`; do + cat $reco_vad_dir/sad_seg.$n.scp + done > $reco_vad_dir/sad_seg.scp +fi + +set +e +for n in `seq $reco_nj`; do + utils/create_data_link.pl $reco_vad_dir/deriv_weights.$n.ark + utils/create_data_link.pl $reco_vad_dir/deriv_weights_for_uncorrupted.$n.ark + utils/create_data_link.pl $reco_vad_dir/speech_feat.$n.ark +done +set -e + +if [ $stage -le 12 ]; then + $cmd JOB=1:$reco_nj $reco_vad_dir/log/get_deriv_weights.JOB.log \ + segmentation-post-process --merge-labels=0:1:2:3 --merge-dst-label=1 \ + scp:$reco_vad_dir/sad_seg.JOB.scp ark:- \| \ + segmentation-to-ali --lengths-rspecifier=ark,t:${whole_data_dir}/utt2num_frames ark:- ark,t:- \| \ + steps/segmentation/convert_ali_to_vec.pl \| copy-vector ark,t:- \ + ark,scp:$reco_vad_dir/deriv_weights.JOB.ark,$reco_vad_dir/deriv_weights.JOB.scp + + for n in `seq $reco_nj`; do + cat $reco_vad_dir/deriv_weights.$n.scp + done > $reco_vad_dir/deriv_weights.scp +fi + +if [ $stage -le 13 ]; then + $cmd JOB=1:$reco_nj $reco_vad_dir/log/get_deriv_weights_for_uncorrupted.JOB.log \ + segmentation-post-process --remove-labels=1:2:3 scp:$reco_vad_dir/sad_seg.JOB.scp \ + ark:- \| segmentation-post-process --merge-labels=0 --merge-dst-label=1 ark:- ark:- \| \ + segmentation-to-ali --lengths-rspecifier=ark,t:${whole_data_dir}/utt2num_frames ark:- ark,t:- \| \ + steps/segmentation/convert_ali_to_vec.pl \| copy-vector ark,t:- \ + ark,scp:$reco_vad_dir/deriv_weights_for_uncorrupted.JOB.ark,$reco_vad_dir/deriv_weights_for_uncorrupted.JOB.scp + for n in `seq $reco_nj`; do + cat $reco_vad_dir/deriv_weights_for_uncorrupted.$n.scp + done > $reco_vad_dir/deriv_weights_for_uncorrupted.scp +fi + +if [ $stage -le 14 ]; then + $cmd JOB=1:$reco_nj $reco_vad_dir/log/get_speech_labels.JOB.log \ + segmentation-post-process --keep-label=1 scp:$reco_vad_dir/sad_seg.JOB.scp ark:- \| \ + segmentation-to-ali --lengths-rspecifier=ark,t:${whole_data_dir}/utt2num_frames \ + ark:- ark,t:- \| \ + steps/segmentation/convert_ali_to_vec.pl \| vector-to-feat ark:- ark:- \| copy-feats --compress \ + ark:- ark,scp:$reco_vad_dir/speech_feat.JOB.ark,$reco_vad_dir/speech_feat.JOB.scp + for n in `seq $reco_nj`; do + cat $reco_vad_dir/speech_feat.$n.scp + done > $reco_vad_dir/speech_feat.scp +fi + +if [ $stage -le 15 ]; then + $cmd JOB=1:$reco_nj $reco_vad_dir/log/convert_manual_segments_to_deriv_weights.JOB.log \ + segmentation-init-from-segments --shift-to-zero=false \ + $data_dir/split${reco_nj}reco/JOB/segments ark:- \| \ + segmentation-combine-segments-to-recordings ark:- \ + ark:$data_dir/split${reco_nj}reco/JOB/reco2utt ark:- \| \ + segmentation-to-ali --lengths-rspecifier=ark,t:${whole_data_dir}/utt2num_frames \ + ark:- ark,t:- \| \ + steps/segmentation/convert_ali_to_vec.pl \| copy-vector ark,t:- \ + ark,scp:$reco_vad_dir/deriv_weights_manual_seg.JOB.ark,$reco_vad_dir/deriv_weights_manual_seg.JOB.scp + + for n in `seq $reco_nj`; do + cat $reco_vad_dir/deriv_weights_manual_seg.$n.scp + done > $reco_vad_dir/deriv_weights_manual_seg.scp +fi + +echo "$0: Finished creating corpus for training Universal SAD with data in $whole_data_dir and labels in $reco_vad_dir" From b281cea71e83f00d581830dd75b14ced64cc0cae Mon Sep 17 00:00:00 2001 From: Vimal Manohar Date: Thu, 24 Nov 2016 01:18:00 -0500 Subject: [PATCH 031/213] asr_diarization: Temporary changes to mfcc_hires_bp.conf and path.sh in aspire --- egs/aspire/s5/conf/mfcc_hires_bp.conf | 13 +++++++++++++ egs/aspire/s5/path.sh | 4 ++++ 2 files changed, 17 insertions(+) create mode 100644 egs/aspire/s5/conf/mfcc_hires_bp.conf diff --git a/egs/aspire/s5/conf/mfcc_hires_bp.conf b/egs/aspire/s5/conf/mfcc_hires_bp.conf new file mode 100644 index 00000000000..64292e8b489 --- /dev/null +++ b/egs/aspire/s5/conf/mfcc_hires_bp.conf @@ -0,0 +1,13 @@ +# config for high-resolution MFCC features, intended for neural network training. +# Note: we keep all cepstra, so it has the same info as filterbank features, +# but MFCC is more easily compressible (because less correlated) which is why +# we prefer this method. +--use-energy=false # use average of log energy, not energy. +--sample-frequency=8000 # Switchboard is sampled at 8kHz +--num-mel-bins=28 +--num-ceps=28 +--cepstral-lifter=0 +--low-freq=330 # low cutoff frequency for mel bins +--high-freq=-1000 # high cutoff frequently, relative to Nyquist of 4000 (=3000) + + diff --git a/egs/aspire/s5/path.sh b/egs/aspire/s5/path.sh index 1a6fb5f891b..5c0d3a92f19 100755 --- a/egs/aspire/s5/path.sh +++ b/egs/aspire/s5/path.sh @@ -2,4 +2,8 @@ export KALDI_ROOT=`pwd`/../../.. export PATH=$PWD/utils/:$KALDI_ROOT/tools/openfst/bin:$PWD:$PATH [ ! -f $KALDI_ROOT/tools/config/common_path.sh ] && echo >&2 "The standard file $KALDI_ROOT/tools/config/common_path.sh is not present -> Exit!" && exit 1 . $KALDI_ROOT/tools/config/common_path.sh +export PATH=/home/vmanoha1/kaldi-raw-signal/src/segmenterbin:$PATH +export PATH=$KALDI_ROOT/tools/sph2pipe_v2.5:$PATH +export PATH=$KALDI_ROOT/tools/sctk/bin:$PATH +export PYTHONPATH=steps:${PYTHONPATH} export LC_ALL=C From 30bb9645ccd14f3c10b0b120eb3e1579b046a7d7 Mon Sep 17 00:00:00 2001 From: Vimal Manohar Date: Thu, 24 Nov 2016 01:17:37 -0500 Subject: [PATCH 032/213] asr_diarization: Modified reverberation script by moving some functions to library and adding extra options --- .../steps/data/data_dir_manipulation_lib.py | 420 ++++++++++++- egs/wsj/s5/steps/data/reverberate_data_dir.py | 553 ++++-------------- src/featbin/wav-reverberate.cc | 110 +++- 3 files changed, 632 insertions(+), 451 deletions(-) diff --git a/egs/wsj/s5/steps/data/data_dir_manipulation_lib.py b/egs/wsj/s5/steps/data/data_dir_manipulation_lib.py index 1f7253d4891..7f1a5f74fe2 100644 --- a/egs/wsj/s5/steps/data/data_dir_manipulation_lib.py +++ b/egs/wsj/s5/steps/data/data_dir_manipulation_lib.py @@ -1,4 +1,10 @@ -import subprocess +#!/usr/bin/env python +# Copyright 2016 Tom Ko +# 2016 Vimal Manohar +# Apache 2.0 + +from __future__ import print_function +import subprocess, random, argparse, os, shlex, warnings def RunKaldiCommand(command, wait = True): """ Runs commands frequently seen in Kaldi scripts. These are usually a @@ -16,3 +22,415 @@ def RunKaldiCommand(command, wait = True): else: return p +class list_cyclic_iterator: + def __init__(self, list): + self.list_index = 0 + self.list = list + random.shuffle(self.list) + + def next(self): + item = self.list[self.list_index] + self.list_index = (self.list_index + 1) % len(self.list) + return item + +# This functions picks an item from the collection according to the associated probability distribution. +# The probability estimate of each item in the collection is stored in the "probability" field of +# the particular item. x : a collection (list or dictionary) where the values contain a field called probability +def PickItemWithProbability(x): + if isinstance(x, dict): + plist = list(set(x.values())) + else: + plist = x + total_p = sum(item.probability for item in plist) + p = random.uniform(0, total_p) + accumulate_p = 0 + for item in plist: + if accumulate_p + item.probability >= p: + return item + accumulate_p += item.probability + assert False, "Shouldn't get here as the accumulated probability should always equal to 1" + +# This function smooths the probability distribution in the list +def SmoothProbabilityDistribution(list, smoothing_weight=0.0, target_sum=1.0): + if len(list) > 0: + num_unspecified = 0 + accumulated_prob = 0 + for item in list: + if item.probability is None: + num_unspecified += 1 + else: + accumulated_prob += item.probability + + # Compute the probability for the items without specifying their probability + uniform_probability = 0 + if num_unspecified > 0 and accumulated_prob < 1: + uniform_probability = (1 - accumulated_prob) / float(num_unspecified) + elif num_unspecified > 0 and accumulate_prob >= 1: + warnings.warn("The sum of probabilities specified by user is larger than or equal to 1. " + "The items without probabilities specified will be given zero to their probabilities.") + + for item in list: + if item.probability is None: + item.probability = uniform_probability + else: + # smooth the probability + item.probability = (1 - smoothing_weight) * item.probability + smoothing_weight * uniform_probability + + # Normalize the probability + sum_p = sum(item.probability for item in list) + for item in list: + item.probability = item.probability / sum_p * target_sum + + return list + +# This function parses a file and pack the data into a dictionary +# It is useful for parsing file like wav.scp, utt2spk, text...etc +def ParseFileToDict(file, assert2fields = False, value_processor = None): + if value_processor is None: + value_processor = lambda x: x[0] + + dict = {} + for line in open(file, 'r'): + parts = line.split() + if assert2fields: + assert(len(parts) == 2) + + dict[parts[0]] = value_processor(parts[1:]) + return dict + +# This function creates a file and write the content of a dictionary into it +def WriteDictToFile(dict, file_name): + file = open(file_name, 'w') + keys = dict.keys() + keys.sort() + for key in keys: + value = dict[key] + if type(value) in [list, tuple] : + if type(value) is tuple: + value = list(value) + value.sort() + value = ' '.join([ str(x) for x in value ]) + file.write('{0} {1}\n'.format(key, value)) + file.close() + + +# This function creates the utt2uniq file from the utterance id in utt2spk file +def CreateCorruptedUtt2uniq(input_dir, output_dir, num_replicas, include_original, prefix): + corrupted_utt2uniq = {} + # Parse the utt2spk to get the utterance id + utt2spk = ParseFileToDict(input_dir + "/utt2spk", value_processor = lambda x: " ".join(x)) + keys = utt2spk.keys() + keys.sort() + if include_original: + start_index = 0 + else: + start_index = 1 + + for i in range(start_index, num_replicas+1): + for utt_id in keys: + new_utt_id = GetNewId(utt_id, prefix, i) + corrupted_utt2uniq[new_utt_id] = utt_id + + WriteDictToFile(corrupted_utt2uniq, output_dir + "/utt2uniq") + +# This function generates a new id from the input id +# This is needed when we have to create multiple copies of the original data +# E.g. GetNewId("swb0035", prefix="rvb", copy=1) returns a string "rvb1_swb0035" +def GetNewId(id, prefix=None, copy=0): + if prefix is not None: + new_id = prefix + str(copy) + "_" + id + else: + new_id = id + + return new_id + +# This function replicate the entries in files like segments, utt2spk, text +def AddPrefixToFields(input_file, output_file, num_replicas, include_original, prefix, field = [0]): + list = map(lambda x: x.strip(), open(input_file)) + f = open(output_file, "w") + if include_original: + start_index = 0 + else: + start_index = 1 + + for i in range(start_index, num_replicas+1): + for line in list: + if len(line) > 0 and line[0] != ';': + split1 = line.split() + for j in field: + split1[j] = GetNewId(split1[j], prefix, i) + print(" ".join(split1), file=f) + else: + print(line, file=f) + f.close() + +def CopyDataDirFiles(input_dir, output_dir, num_replicas, include_original, prefix): + if not os.path.isfile(output_dir + "/wav.scp"): + raise Exception("CopyDataDirFiles function expects output_dir to contain wav.scp already") + + AddPrefixToFields(input_dir + "/utt2spk", output_dir + "/utt2spk", num_replicas, include_original, prefix, field = [0,1]) + RunKaldiCommand("utils/utt2spk_to_spk2utt.pl <{output_dir}/utt2spk >{output_dir}/spk2utt" + .format(output_dir = output_dir)) + + if os.path.isfile(input_dir + "/utt2uniq"): + AddPrefixToFields(input_dir + "/utt2uniq", output_dir + "/utt2uniq", num_replicas, include_original, prefix, field =[0]) + else: + # Create the utt2uniq file + CreateCorruptedUtt2uniq(input_dir, output_dir, num_replicas, include_original, prefix) + + if os.path.isfile(input_dir + "/text"): + AddPrefixToFields(input_dir + "/text", output_dir + "/text", num_replicas, prefix, include_original, field =[0]) + if os.path.isfile(input_dir + "/segments"): + AddPrefixToFields(input_dir + "/segments", output_dir + "/segments", num_replicas, prefix, include_original, field = [0,1]) + if os.path.isfile(input_dir + "/reco2file_and_channel"): + AddPrefixToFields(input_dir + "/reco2file_and_channel", output_dir + "/reco2file_and_channel", num_replicas, include_original, prefix, field = [0,1]) + + AddPrefixToFields(input_dir + "/reco2dur", output_dir + "/reco2dur", num_replicas, include_original, prefix, field = [0]) + + RunKaldiCommand("utils/validate_data_dir.sh --no-feats {output_dir}" + .format(output_dir = output_dir)) + + +# This function parse the array of rir set parameter strings. +# It will assign probabilities to those rir sets which don't have a probability +# It will also check the existence of the rir list files. +def ParseSetParameterStrings(set_para_array): + set_list = [] + for set_para in set_para_array: + set = lambda: None + setattr(set, "filename", None) + setattr(set, "probability", None) + parts = set_para.split(',') + if len(parts) == 2: + set.probability = float(parts[0]) + set.filename = parts[1].strip() + else: + set.filename = parts[0].strip() + if not os.path.isfile(set.filename): + raise Exception(set.filename + " not found") + set_list.append(set) + + return SmoothProbabilityDistribution(set_list) + + +# This function creates the RIR list +# Each rir object in the list contains the following attributes: +# rir_id, room_id, receiver_position_id, source_position_id, rt60, drr, probability +# Please refer to the help messages in the parser for the meaning of these attributes +def ParseRirList(rir_set_para_array, smoothing_weight, sampling_rate = None): + rir_parser = argparse.ArgumentParser() + rir_parser.add_argument('--rir-id', type=str, required=True, help='This id is unique for each RIR and the noise may associate with a particular RIR by refering to this id') + rir_parser.add_argument('--room-id', type=str, required=True, help='This is the room that where the RIR is generated') + rir_parser.add_argument('--receiver-position-id', type=str, default=None, help='receiver position id') + rir_parser.add_argument('--source-position-id', type=str, default=None, help='source position id') + rir_parser.add_argument('--rt60', type=float, default=None, help='RT60 is the time required for reflections of a direct sound to decay 60 dB.') + rir_parser.add_argument('--drr', type=float, default=None, help='Direct-to-reverberant-ratio of the impulse response.') + rir_parser.add_argument('--cte', type=float, default=None, help='Early-to-late index of the impulse response.') + rir_parser.add_argument('--probability', type=float, default=None, help='probability of the impulse response.') + rir_parser.add_argument('rir_rspecifier', type=str, help="""rir rspecifier, it can be either a filename or a piped command. + E.g. data/impulses/Room001-00001.wav or "sox data/impulses/Room001-00001.wav -t wav - |" """) + + set_list = ParseSetParameterStrings(rir_set_para_array) + + rir_list = [] + for rir_set in set_list: + current_rir_list = map(lambda x: rir_parser.parse_args(shlex.split(x.strip())),open(rir_set.filename)) + for rir in current_rir_list: + if sampling_rate is not None: + # check if the rspecifier is a pipe or not + if len(rir.rir_rspecifier.split()) == 1: + rir.rir_rspecifier = "sox {0} -r {1} -t wav - |".format(rir.rir_rspecifier, sampling_rate) + else: + rir.rir_rspecifier = "{0} sox -t wav - -r {1} -t wav - |".format(rir.rir_rspecifier, sampling_rate) + + rir_list += SmoothProbabilityDistribution(current_rir_list, smoothing_weight, rir_set.probability) + + return rir_list + + +# This dunction checks if the inputs are approximately equal assuming they are floats. +def almost_equal(value_1, value_2, accuracy = 10**-8): + return abs(value_1 - value_2) < accuracy + +# This function converts a list of RIRs into a dictionary of RIRs indexed by the room-id. +# Its values are objects with two attributes: a local RIR list +# and the probability of the corresponding room +# Please look at the comments at ParseRirList() for the attributes that a RIR object contains +def MakeRoomDict(rir_list): + room_dict = {} + for rir in rir_list: + if rir.room_id not in room_dict: + # add new room + room_dict[rir.room_id] = lambda: None + setattr(room_dict[rir.room_id], "rir_list", []) + setattr(room_dict[rir.room_id], "probability", 0) + room_dict[rir.room_id].rir_list.append(rir) + + # the probability of the room is the sum of probabilities of its RIR + for key in room_dict.keys(): + room_dict[key].probability = sum(rir.probability for rir in room_dict[key].rir_list) + + assert almost_equal(sum(room_dict[key].probability for key in room_dict.keys()), 1.0) + + return room_dict + + +# This function creates the point-source noise list +# and the isotropic noise dictionary from the noise information file +# The isotropic noise dictionary is indexed by the room +# and its value is the corrresponding isotropic noise list +# Each noise object in the list contains the following attributes: +# noise_id, noise_type, bg_fg_type, room_linkage, probability, noise_rspecifier +# Please refer to the help messages in the parser for the meaning of these attributes +def ParseNoiseList(noise_set_para_array, smoothing_weight, sampling_rate = None): + noise_parser = argparse.ArgumentParser() + noise_parser.add_argument('--noise-id', type=str, required=True, help='noise id') + noise_parser.add_argument('--noise-type', type=str, required=True, help='the type of noise; i.e. isotropic or point-source', choices = ["isotropic", "point-source"]) + noise_parser.add_argument('--bg-fg-type', type=str, default="background", help='background or foreground noise, for background noises, ' + 'they will be extended before addition to cover the whole speech; for foreground noise, they will be kept ' + 'to their original duration and added at a random point of the speech.', choices = ["background", "foreground"]) + noise_parser.add_argument('--room-linkage', type=str, default=None, help='required if isotropic, should not be specified if point-source.') + noise_parser.add_argument('--probability', type=float, default=None, help='probability of the noise.') + noise_parser.add_argument('noise_rspecifier', type=str, help="""noise rspecifier, it can be either a filename or a piped command. + E.g. type5_noise_cirline_ofc_ambient1.wav or "sox type5_noise_cirline_ofc_ambient1.wav -t wav - |" """) + + set_list = ParseSetParameterStrings(noise_set_para_array) + + pointsource_noise_list = [] + iso_noise_dict = {} + for noise_set in set_list: + current_noise_list = map(lambda x: noise_parser.parse_args(shlex.split(x.strip())),open(noise_set.filename)) + current_pointsource_noise_list = [] + for noise in current_noise_list: + if sampling_rate is not None: + # check if the rspecifier is a pipe or not + if len(noise.noise_rspecifier.split()) == 1: + noise.noise_rspecifier = "sox {0} -r {1} -t wav - |".format(noise.noise_rspecifier, sampling_rate) + else: + noise.noise_rspecifier = "{0} sox -t wav - -r {1} -t wav - |".format(noise.noise_rspecifier, sampling_rate) + + if noise.noise_type == "isotropic": + if noise.room_linkage is None: + raise Exception("--room-linkage must be specified if --noise-type is isotropic") + else: + if noise.room_linkage not in iso_noise_dict: + iso_noise_dict[noise.room_linkage] = [] + iso_noise_dict[noise.room_linkage].append(noise) + else: + current_pointsource_noise_list.append(noise) + + pointsource_noise_list += SmoothProbabilityDistribution(current_pointsource_noise_list, smoothing_weight, noise_set.probability) + + # ensure the point-source noise probabilities sum to 1 + pointsource_noise_list = SmoothProbabilityDistribution(pointsource_noise_list, smoothing_weight, 1.0) + if len(pointsource_noise_list) > 0: + assert almost_equal(sum(noise.probability for noise in pointsource_noise_list), 1.0) + + # ensure the isotropic noise source probabilities for a given room sum to 1 + for key in iso_noise_dict.keys(): + iso_noise_dict[key] = SmoothProbabilityDistribution(iso_noise_dict[key]) + assert almost_equal(sum(noise.probability for noise in iso_noise_dict[key]), 1.0) + + return (pointsource_noise_list, iso_noise_dict) + +def AddPointSourceNoise(room, # the room selected + pointsource_noise_list, # the point source noise list + pointsource_noise_addition_probability, # Probability of adding point-source noises + foreground_snrs, # the SNR for adding the foreground noises + background_snrs, # the SNR for adding the background noises + speech_dur, # duration of the recording + max_noises_recording, # Maximum number of point-source noises that can be added + noise_addition_descriptor # descriptor to store the information of the noise added + ): + num_noises_added = 0 + if len(pointsource_noise_list) > 0 and random.random() < pointsource_noise_addition_probability and max_noises_recording >= 1: + for k in range(random.randint(1, max_noises_recording)): + num_noises_added = num_noises_added + 1 + # pick the RIR to reverberate the point-source noise + noise = PickItemWithProbability(pointsource_noise_list) + noise_rir = PickItemWithProbability(room.rir_list) + # If it is a background noise, the noise will be extended and be added to the whole speech + # if it is a foreground noise, the noise will not extended and be added at a random time of the speech + if noise.bg_fg_type == "background": + noise_rvb_command = """wav-reverberate --impulse-response="{0}" --duration={1}""".format(noise_rir.rir_rspecifier, speech_dur) + noise_addition_descriptor['start_times'].append(0) + noise_addition_descriptor['snrs'].append(background_snrs.next()) + noise_addition_descriptor['durations'].append(speech_dur) + noise_addition_descriptor['noise_ids'].append(noise.noise_id) + else: + noise_rvb_command = """wav-reverberate --impulse-response="{0}" """.format(noise_rir.rir_rspecifier) + noise_addition_descriptor['start_times'].append(round(random.random() * speech_dur, 2)) + noise_addition_descriptor['snrs'].append(foreground_snrs.next()) + noise_addition_descriptor['durations'].append(-1) + noise_addition_descriptor['noise_ids'].append(noise.noise_id) + + # check if the rspecifier is a pipe or not + if len(noise.noise_rspecifier.split()) == 1: + noise_addition_descriptor['noise_io'].append("{1} {0} - |".format(noise.noise_rspecifier, noise_rvb_command)) + else: + noise_addition_descriptor['noise_io'].append("{0} {1} - - |".format(noise.noise_rspecifier, noise_rvb_command)) + +# This function randomly decides whether to reverberate, and sample a RIR if it does +# It also decides whether to add the appropriate noises +# This function return the string of options to the binary wav-reverberate +def GenerateReverberationOpts(room_dict, # the room dictionary, please refer to MakeRoomDict() for the format + pointsource_noise_list, # the point source noise list + iso_noise_dict, # the isotropic noise dictionary + foreground_snrs, # the SNR for adding the foreground noises + background_snrs, # the SNR for adding the background noises + speech_rvb_probability, # Probability of reverberating a speech signal + isotropic_noise_addition_probability, # Probability of adding isotropic noises + pointsource_noise_addition_probability, # Probability of adding point-source noises + speech_dur, # duration of the recording + max_noises_recording # Maximum number of point-source noises that can be added + ): + impulse_response_opts = "" + additive_noise_opts = "" + + noise_addition_descriptor = {'noise_io': [], + 'start_times': [], + 'snrs': [], + 'noise_ids': [], + 'durations': [] + } + # Randomly select the room + # Here the room probability is a sum of the probabilities of the RIRs recorded in the room. + room = PickItemWithProbability(room_dict) + # Randomly select the RIR in the room + speech_rir = PickItemWithProbability(room.rir_list) + if random.random() < speech_rvb_probability: + # pick the RIR to reverberate the speech + impulse_response_opts = """--impulse-response="{0}" """.format(speech_rir.rir_rspecifier) + + rir_iso_noise_list = [] + if speech_rir.room_id in iso_noise_dict: + rir_iso_noise_list = iso_noise_dict[speech_rir.room_id] + # Add the corresponding isotropic noise associated with the selected RIR + if len(rir_iso_noise_list) > 0 and random.random() < isotropic_noise_addition_probability: + isotropic_noise = PickItemWithProbability(rir_iso_noise_list) + # extend the isotropic noise to the length of the speech waveform + # check if the rspecifier is really a pipe + if len(isotropic_noise.noise_rspecifier.split()) == 1: + noise_addition_descriptor['noise_io'].append("wav-reverberate --duration={1} {0} - |".format(isotropic_noise.noise_rspecifier, speech_dur)) + else: + noise_addition_descriptor['noise_io'].append("{0} wav-reverberate --duration={1} - - |".format(isotropic_noise.noise_rspecifier, speech_dur)) + noise_addition_descriptor['start_times'].append(0) + noise_addition_descriptor['snrs'].append(background_snrs.next()) + noise_addition_descriptor['noise_ids'].append(isotropic_noise.noise_id) + noise_addition_descriptor['durations'].append(speech_dur) + + AddPointSourceNoise(room, # the room selected + pointsource_noise_list, # the point source noise list + pointsource_noise_addition_probability, # Probability of adding point-source noises + foreground_snrs, # the SNR for adding the foreground noises + background_snrs, # the SNR for adding the background noises + speech_dur, # duration of the recording + max_noises_recording, # Maximum number of point-source noises that can be added + noise_addition_descriptor # descriptor to store the information of the noise added + ) + + assert len(noise_addition_descriptor['noise_io']) == len(noise_addition_descriptor['start_times']) + assert len(noise_addition_descriptor['noise_io']) == len(noise_addition_descriptor['snrs']) + + return [impulse_response_opts, noise_addition_descriptor] + diff --git a/egs/wsj/s5/steps/data/reverberate_data_dir.py b/egs/wsj/s5/steps/data/reverberate_data_dir.py index 0083efa4939..69bc5e08b3b 100755 --- a/egs/wsj/s5/steps/data/reverberate_data_dir.py +++ b/egs/wsj/s5/steps/data/reverberate_data_dir.py @@ -5,7 +5,7 @@ # we're using python 3.x style print but want it to work in python 2.x, from __future__ import print_function -import argparse, shlex, glob, math, os, random, sys, warnings, copy, imp, ast +import argparse, glob, math, os, random, sys, warnings, copy, imp, ast data_lib = imp.load_source('dml', 'steps/data/data_dir_manipulation_lib.py') @@ -20,7 +20,7 @@ def GetArgs(): "--random-seed 1 data/train data/train_rvb", formatter_class=argparse.ArgumentDefaultsHelpFormatter) - parser.add_argument("--rir-set-parameters", type=str, action='append', required = True, dest = "rir_set_para_array", + parser.add_argument("--rir-set-parameters", type=str, action='append', required = True, dest = "rir_set_para_array", help="Specifies the parameters of an RIR set. " "Supports the specification of mixture_weight and rir_list_file_name. The mixture weight is optional. " "The default mixture weight is the probability mass remaining after adding the mixture weights " @@ -71,6 +71,9 @@ def GetArgs(): "the RIRs/noises will be resampled to the rate of the source data.") parser.add_argument("--include-original-data", type=str, help="If true, the output data includes one copy of the original data", choices=['true', 'false'], default = "false") + parser.add_argument("--output-additive-noise-dir", type=str, help="Output directory corresponding to the additive noise part of the data corruption") + parser.add_argument("--output-reverb-dir", type=str, help="Output directory corresponding to the reverberated signal part of the data corruption") + parser.add_argument("input_dir", help="Input data directory") parser.add_argument("output_dir", @@ -87,12 +90,29 @@ def CheckArgs(args): if not os.path.exists(args.output_dir): os.makedirs(args.output_dir) - ## Check arguments + ## Check arguments. + if args.prefix is None: if args.num_replicas > 1 or args.include_original_data == "true": args.prefix = "rvb" warnings.warn("--prefix is set to 'rvb' as more than one copy of data is generated") + if args.output_reverb_dir is not None: + if args.output_reverb_dir == "": + args.output_reverb_dir = None + + if args.output_reverb_dir is not None: + if not os.path.exists(args.output_reverb_dir): + os.makedirs(args.output_reverb_dir) + + if args.output_additive_noise_dir is not None: + if args.output_additive_noise_dir == "": + args.output_additive_noise_dir = None + + if args.output_additive_noise_dir is not None: + if not os.path.exists(args.output_additive_noise_dir): + os.makedirs(args.output_additive_noise_dir) + if not args.num_replicas > 0: raise Exception("--num-replications cannot be non-positive") @@ -104,7 +124,7 @@ def CheckArgs(args): if args.isotropic_noise_addition_probability < 0 or args.isotropic_noise_addition_probability > 1: raise Exception("--isotropic-noise-addition-probability must be between 0 and 1") - + if args.rir_smoothing_weight < 0 or args.rir_smoothing_weight > 1: raise Exception("--rir-smoothing-weight must be between 0 and 1") @@ -113,208 +133,20 @@ def CheckArgs(args): if args.max_noises_per_minute < 0: raise Exception("--max-noises-per-minute cannot be negative") - + if args.source_sampling_rate is not None and args.source_sampling_rate <= 0: raise Exception("--source-sampling-rate cannot be non-positive") return args -class list_cyclic_iterator: - def __init__(self, list): - self.list_index = 0 - self.list = list - random.shuffle(self.list) - - def next(self): - item = self.list[self.list_index] - self.list_index = (self.list_index + 1) % len(self.list) - return item - - -# This functions picks an item from the collection according to the associated probability distribution. -# The probability estimate of each item in the collection is stored in the "probability" field of -# the particular item. x : a collection (list or dictionary) where the values contain a field called probability -def PickItemWithProbability(x): - if isinstance(x, dict): - plist = list(set(x.values())) - else: - plist = x - total_p = sum(item.probability for item in plist) - p = random.uniform(0, total_p) - accumulate_p = 0 - for item in plist: - if accumulate_p + item.probability >= p: - return item - accumulate_p += item.probability - assert False, "Shouldn't get here as the accumulated probability should always equal to 1" - - -# This function parses a file and pack the data into a dictionary -# It is useful for parsing file like wav.scp, utt2spk, text...etc -def ParseFileToDict(file, assert2fields = False, value_processor = None): - if value_processor is None: - value_processor = lambda x: x[0] - - dict = {} - for line in open(file, 'r'): - parts = line.split() - if assert2fields: - assert(len(parts) == 2) - - dict[parts[0]] = value_processor(parts[1:]) - return dict - -# This function creates a file and write the content of a dictionary into it -def WriteDictToFile(dict, file_name): - file = open(file_name, 'w') - keys = dict.keys() - keys.sort() - for key in keys: - value = dict[key] - if type(value) in [list, tuple] : - if type(value) is tuple: - value = list(value) - value.sort() - value = ' '.join(str(value)) - file.write('{0} {1}\n'.format(key, value)) - file.close() - - -# This function creates the utt2uniq file from the utterance id in utt2spk file -def CreateCorruptedUtt2uniq(input_dir, output_dir, num_replicas, include_original, prefix): - corrupted_utt2uniq = {} - # Parse the utt2spk to get the utterance id - utt2spk = ParseFileToDict(input_dir + "/utt2spk", value_processor = lambda x: " ".join(x)) - keys = utt2spk.keys() - keys.sort() - if include_original: - start_index = 0 - else: - start_index = 1 - - for i in range(start_index, num_replicas+1): - for utt_id in keys: - new_utt_id = GetNewId(utt_id, prefix, i) - corrupted_utt2uniq[new_utt_id] = utt_id - - WriteDictToFile(corrupted_utt2uniq, output_dir + "/utt2uniq") - - -def AddPointSourceNoise(noise_addition_descriptor, # descriptor to store the information of the noise added - room, # the room selected - pointsource_noise_list, # the point source noise list - pointsource_noise_addition_probability, # Probability of adding point-source noises - foreground_snrs, # the SNR for adding the foreground noises - background_snrs, # the SNR for adding the background noises - speech_dur, # duration of the recording - max_noises_recording # Maximum number of point-source noises that can be added - ): - if len(pointsource_noise_list) > 0 and random.random() < pointsource_noise_addition_probability and max_noises_recording >= 1: - for k in range(random.randint(1, max_noises_recording)): - # pick the RIR to reverberate the point-source noise - noise = PickItemWithProbability(pointsource_noise_list) - noise_rir = PickItemWithProbability(room.rir_list) - # If it is a background noise, the noise will be extended and be added to the whole speech - # if it is a foreground noise, the noise will not extended and be added at a random time of the speech - if noise.bg_fg_type == "background": - noise_rvb_command = """wav-reverberate --impulse-response="{0}" --duration={1}""".format(noise_rir.rir_rspecifier, speech_dur) - noise_addition_descriptor['start_times'].append(0) - noise_addition_descriptor['snrs'].append(background_snrs.next()) - else: - noise_rvb_command = """wav-reverberate --impulse-response="{0}" """.format(noise_rir.rir_rspecifier) - noise_addition_descriptor['start_times'].append(round(random.random() * speech_dur, 2)) - noise_addition_descriptor['snrs'].append(foreground_snrs.next()) - - # check if the rspecifier is a pipe or not - if len(noise.noise_rspecifier.split()) == 1: - noise_addition_descriptor['noise_io'].append("{1} {0} - |".format(noise.noise_rspecifier, noise_rvb_command)) - else: - noise_addition_descriptor['noise_io'].append("{0} {1} - - |".format(noise.noise_rspecifier, noise_rvb_command)) - - return noise_addition_descriptor - - -# This function randomly decides whether to reverberate, and sample a RIR if it does -# It also decides whether to add the appropriate noises -# This function return the string of options to the binary wav-reverberate -def GenerateReverberationOpts(room_dict, # the room dictionary, please refer to MakeRoomDict() for the format - pointsource_noise_list, # the point source noise list - iso_noise_dict, # the isotropic noise dictionary - foreground_snrs, # the SNR for adding the foreground noises - background_snrs, # the SNR for adding the background noises - speech_rvb_probability, # Probability of reverberating a speech signal - isotropic_noise_addition_probability, # Probability of adding isotropic noises - pointsource_noise_addition_probability, # Probability of adding point-source noises - speech_dur, # duration of the recording - max_noises_recording # Maximum number of point-source noises that can be added - ): - reverberate_opts = "" - noise_addition_descriptor = {'noise_io': [], - 'start_times': [], - 'snrs': []} - # Randomly select the room - # Here the room probability is a sum of the probabilities of the RIRs recorded in the room. - room = PickItemWithProbability(room_dict) - # Randomly select the RIR in the room - speech_rir = PickItemWithProbability(room.rir_list) - if random.random() < speech_rvb_probability: - # pick the RIR to reverberate the speech - reverberate_opts += """--impulse-response="{0}" """.format(speech_rir.rir_rspecifier) - - rir_iso_noise_list = [] - if speech_rir.room_id in iso_noise_dict: - rir_iso_noise_list = iso_noise_dict[speech_rir.room_id] - # Add the corresponding isotropic noise associated with the selected RIR - if len(rir_iso_noise_list) > 0 and random.random() < isotropic_noise_addition_probability: - isotropic_noise = PickItemWithProbability(rir_iso_noise_list) - # extend the isotropic noise to the length of the speech waveform - # check if the rspecifier is a pipe or not - if len(isotropic_noise.noise_rspecifier.split()) == 1: - noise_addition_descriptor['noise_io'].append("wav-reverberate --duration={1} {0} - |".format(isotropic_noise.noise_rspecifier, speech_dur)) - else: - noise_addition_descriptor['noise_io'].append("{0} wav-reverberate --duration={1} - - |".format(isotropic_noise.noise_rspecifier, speech_dur)) - noise_addition_descriptor['start_times'].append(0) - noise_addition_descriptor['snrs'].append(background_snrs.next()) - - noise_addition_descriptor = AddPointSourceNoise(noise_addition_descriptor, # descriptor to store the information of the noise added - room, # the room selected - pointsource_noise_list, # the point source noise list - pointsource_noise_addition_probability, # Probability of adding point-source noises - foreground_snrs, # the SNR for adding the foreground noises - background_snrs, # the SNR for adding the background noises - speech_dur, # duration of the recording - max_noises_recording # Maximum number of point-source noises that can be added - ) - - assert len(noise_addition_descriptor['noise_io']) == len(noise_addition_descriptor['start_times']) - assert len(noise_addition_descriptor['noise_io']) == len(noise_addition_descriptor['snrs']) - if len(noise_addition_descriptor['noise_io']) > 0: - reverberate_opts += "--additive-signals='{0}' ".format(','.join(noise_addition_descriptor['noise_io'])) - reverberate_opts += "--start-times='{0}' ".format(','.join(map(lambda x:str(x), noise_addition_descriptor['start_times']))) - reverberate_opts += "--snrs='{0}' ".format(','.join(map(lambda x:str(x), noise_addition_descriptor['snrs']))) - - return reverberate_opts - -# This function generates a new id from the input id -# This is needed when we have to create multiple copies of the original data -# E.g. GetNewId("swb0035", prefix="rvb", copy=1) returns a string "rvb1_swb0035" -def GetNewId(id, prefix=None, copy=0): - if prefix is not None: - new_id = prefix + str(copy) + "_" + id - else: - new_id = id - - return new_id - - # This is the main function to generate pipeline command for the corruption # The generic command of wav-reverberate will be like: -# wav-reverberate --duration=t --impulse-response=rir.wav +# wav-reverberate --duration=t --impulse-response=rir.wav # --additive-signals='noise1.wav,noise2.wav' --snrs='snr1,snr2' --start-times='s1,s2' input.wav output.wav def GenerateReverberatedWavScp(wav_scp, # a dictionary whose values are the Kaldi-IO strings of the speech recordings durations, # a dictionary whose values are the duration (in sec) of the speech recordings - output_dir, # output directory to write the corrupted wav.scp + output_dir, # output directory to write the corrupted wav.scp room_dict, # the room dictionary, please refer to MakeRoomDict() for the format pointsource_noise_list, # the point source noise list iso_noise_dict, # the isotropic noise dictionary @@ -327,13 +159,20 @@ def GenerateReverberatedWavScp(wav_scp, # a dictionary whose values are the Kal shift_output, # option whether to shift the output waveform isotropic_noise_addition_probability, # Probability of adding isotropic noises pointsource_noise_addition_probability, # Probability of adding point-source noises - max_noises_per_minute # maximum number of point-source noises that can be added to a recording according to its duration + max_noises_per_minute, # maximum number of point-source noises that can be added to a recording according to its duration + output_reverb_dir = None, + output_additive_noise_dir = None ): - foreground_snrs = list_cyclic_iterator(foreground_snr_array) - background_snrs = list_cyclic_iterator(background_snr_array) + foreground_snrs = data_lib.list_cyclic_iterator(foreground_snr_array) + background_snrs = data_lib.list_cyclic_iterator(background_snr_array) corrupted_wav_scp = {} + reverb_wav_scp = {} + additive_noise_wav_scp = {} keys = wav_scp.keys() keys.sort() + + additive_signals_info = {} + if include_original: start_index = 0 else: @@ -346,51 +185,71 @@ def GenerateReverberatedWavScp(wav_scp, # a dictionary whose values are the Kal if len(wav_original_pipe.split()) == 1: wav_original_pipe = "cat {0} |".format(wav_original_pipe) speech_dur = durations[recording_id] - max_noises_recording = math.floor(max_noises_per_minute * speech_dur / 60) - - reverberate_opts = GenerateReverberationOpts(room_dict, # the room dictionary, please refer to MakeRoomDict() for the format - pointsource_noise_list, # the point source noise list - iso_noise_dict, # the isotropic noise dictionary - foreground_snrs, # the SNR for adding the foreground noises - background_snrs, # the SNR for adding the background noises - speech_rvb_probability, # Probability of reverberating a speech signal - isotropic_noise_addition_probability, # Probability of adding isotropic noises - pointsource_noise_addition_probability, # Probability of adding point-source noises - speech_dur, # duration of the recording - max_noises_recording # Maximum number of point-source noises that can be added - ) + max_noises_recording = math.ceil(max_noises_per_minute * speech_dur / 60) + + [impulse_response_opts, noise_addition_descriptor] = data_lib.GenerateReverberationOpts(room_dict, # the room dictionary, please refer to MakeRoomDict() for the format + pointsource_noise_list, # the point source noise list + iso_noise_dict, # the isotropic noise dictionary + foreground_snrs, # the SNR for adding the foreground noises + background_snrs, # the SNR for adding the background noises + speech_rvb_probability, # Probability of reverberating a speech signal + isotropic_noise_addition_probability, # Probability of adding isotropic noises + pointsource_noise_addition_probability, # Probability of adding point-source noises + speech_dur, # duration of the recording + max_noises_recording # Maximum number of point-source noises that can be added + ) + additive_noise_opts = "" + + if len(noise_addition_descriptor['noise_io']) > 0: + additive_noise_opts += "--additive-signals='{0}' ".format(','.join(noise_addition_descriptor['noise_io'])) + additive_noise_opts += "--start-times='{0}' ".format(','.join(map(lambda x:str(x), noise_addition_descriptor['start_times']))) + additive_noise_opts += "--snrs='{0}' ".format(','.join(map(lambda x:str(x), noise_addition_descriptor['snrs']))) + + reverberate_opts = impulse_response_opts + additive_noise_opts + + new_recording_id = data_lib.GetNewId(recording_id, prefix, i) # prefix using index 0 is reserved for original data e.g. rvb0_swb0035 corresponds to the swb0035 recording in original data if reverberate_opts == "" or i == 0: - wav_corrupted_pipe = "{0}".format(wav_original_pipe) + wav_corrupted_pipe = "{0}".format(wav_original_pipe) else: wav_corrupted_pipe = "{0} wav-reverberate --shift-output={1} {2} - - |".format(wav_original_pipe, shift_output, reverberate_opts) - new_recording_id = GetNewId(recording_id, prefix, i) corrupted_wav_scp[new_recording_id] = wav_corrupted_pipe - WriteDictToFile(corrupted_wav_scp, output_dir + "/wav.scp") + if output_reverb_dir is not None: + if impulse_response_opts == "": + wav_reverb_pipe = "{0}".format(wav_original_pipe) + else: + wav_reverb_pipe = "{0} wav-reverberate --shift-output={1} --reverb-out-wxfilename=- {2} - /dev/null |".format(wav_original_pipe, shift_output, reverberate_opts) + reverb_wav_scp[new_recording_id] = wav_reverb_pipe + if output_additive_noise_dir is not None: + if additive_noise_opts != "": + wav_additive_noise_pipe = "{0} wav-reverberate --shift-output={1} --additive-noise-out-wxfilename=- {2} - /dev/null |".format(wav_original_pipe, shift_output, reverberate_opts) + additive_noise_wav_scp[new_recording_id] = wav_additive_noise_pipe -# This function replicate the entries in files like segments, utt2spk, text -def AddPrefixToFields(input_file, output_file, num_replicas, include_original, prefix, field = [0]): - list = map(lambda x: x.strip(), open(input_file)) - f = open(output_file, "w") - if include_original: - start_index = 0 - else: - start_index = 1 - - for i in range(start_index, num_replicas+1): - for line in list: - if len(line) > 0 and line[0] != ';': - split1 = line.split() - for j in field: - split1[j] = GetNewId(split1[j], prefix, i) - print(" ".join(split1), file=f) - else: - print(line, file=f) - f.close() + if additive_noise_opts != "": + additive_signals_info[new_recording_id] = [ + ':'.join(x) + for x in zip(noise_addition_descriptor['noise_ids'], + [ str(x) for x in noise_addition_descriptor['start_times'] ], + [ str(x) for x in noise_addition_descriptor['durations'] ]) + ] + + # Write for each new recording, the id, start time and durations + # of the signals. Duration is -1 for the foreground noise and needs to + # be extracted separately if required by determining the durations + # using the wav file + data_lib.WriteDictToFile(additive_signals_info, output_dir + "/additive_signals_info.txt") + + data_lib.WriteDictToFile(corrupted_wav_scp, output_dir + "/wav.scp") + + if output_reverb_dir is not None: + data_lib.WriteDictToFile(reverb_wav_scp, output_reverb_dir + "/wav.scp") + + if output_additive_noise_dir is not None: + data_lib.WriteDictToFile(additive_noise_wav_scp, output_additive_noise_dir + "/wav.scp") # This function creates multiple copies of the necessary files, e.g. utt2spk, wav.scp ... @@ -408,10 +267,12 @@ def CreateReverberatedCopy(input_dir, shift_output, # option whether to shift the output waveform isotropic_noise_addition_probability, # Probability of adding isotropic noises pointsource_noise_addition_probability, # Probability of adding point-source noises - max_noises_per_minute # maximum number of point-source noises that can be added to a recording according to its duration + max_noises_per_minute, # maximum number of point-source noises that can be added to a recording according to its duration + output_reverb_dir = None, + output_additive_noise_dir = None ): - - wav_scp = ParseFileToDict(input_dir + "/wav.scp", value_processor = lambda x: " ".join(x)) + + wav_scp = data_lib.ParseFileToDict(input_dir + "/wav.scp", value_processor = lambda x: " ".join(x)) if not os.path.isfile(input_dir + "/reco2dur"): print("Getting the duration of the recordings..."); read_entire_file="false" @@ -421,225 +282,38 @@ def CreateReverberatedCopy(input_dir, read_entire_file="true" break data_lib.RunKaldiCommand("wav-to-duration --read-entire-file={1} scp:{0}/wav.scp ark,t:{0}/reco2dur".format(input_dir, read_entire_file)) - durations = ParseFileToDict(input_dir + "/reco2dur", value_processor = lambda x: float(x[0])) + durations = data_lib.ParseFileToDict(input_dir + "/reco2dur", value_processor = lambda x: float(x[0])) foreground_snr_array = map(lambda x: float(x), foreground_snr_string.split(':')) background_snr_array = map(lambda x: float(x), background_snr_string.split(':')) GenerateReverberatedWavScp(wav_scp, durations, output_dir, room_dict, pointsource_noise_list, iso_noise_dict, - foreground_snr_array, background_snr_array, num_replicas, include_original, prefix, - speech_rvb_probability, shift_output, isotropic_noise_addition_probability, - pointsource_noise_addition_probability, max_noises_per_minute) + foreground_snr_array, background_snr_array, num_replicas, include_original, prefix, + speech_rvb_probability, shift_output, isotropic_noise_addition_probability, + pointsource_noise_addition_probability, max_noises_per_minute, + output_reverb_dir = output_reverb_dir, + output_additive_noise_dir = output_additive_noise_dir) - AddPrefixToFields(input_dir + "/utt2spk", output_dir + "/utt2spk", num_replicas, include_original, prefix, field = [0,1]) - data_lib.RunKaldiCommand("utils/utt2spk_to_spk2utt.pl <{output_dir}/utt2spk >{output_dir}/spk2utt" - .format(output_dir = output_dir)) + data_lib.CopyDataDirFiles(input_dir, output_dir, num_replicas, include_original, prefix) - if os.path.isfile(input_dir + "/utt2uniq"): - AddPrefixToFields(input_dir + "/utt2uniq", output_dir + "/utt2uniq", num_replicas, include_original, prefix, field =[0]) - else: - # Create the utt2uniq file - CreateCorruptedUtt2uniq(input_dir, output_dir, num_replicas, include_original, prefix) - - if os.path.isfile(input_dir + "/text"): - AddPrefixToFields(input_dir + "/text", output_dir + "/text", num_replicas, include_original, prefix, field =[0]) - if os.path.isfile(input_dir + "/segments"): - AddPrefixToFields(input_dir + "/segments", output_dir + "/segments", num_replicas, include_original, prefix, field = [0,1]) - if os.path.isfile(input_dir + "/reco2file_and_channel"): - AddPrefixToFields(input_dir + "/reco2file_and_channel", output_dir + "/reco2file_and_channel", num_replicas, include_original, prefix, field = [0,1]) - - data_lib.RunKaldiCommand("utils/validate_data_dir.sh --no-feats {output_dir}" - .format(output_dir = output_dir)) - - -# This function smooths the probability distribution in the list -def SmoothProbabilityDistribution(list, smoothing_weight=0.0, target_sum=1.0): - if len(list) > 0: - num_unspecified = 0 - accumulated_prob = 0 - for item in list: - if item.probability is None: - num_unspecified += 1 - else: - accumulated_prob += item.probability - - # Compute the probability for the items without specifying their probability - uniform_probability = 0 - if num_unspecified > 0 and accumulated_prob < 1: - uniform_probability = (1 - accumulated_prob) / float(num_unspecified) - elif num_unspecified > 0 and accumulate_prob >= 1: - warnings.warn("The sum of probabilities specified by user is larger than or equal to 1. " - "The items without probabilities specified will be given zero to their probabilities.") - - for item in list: - if item.probability is None: - item.probability = uniform_probability - else: - # smooth the probability - item.probability = (1 - smoothing_weight) * item.probability + smoothing_weight * uniform_probability - - # Normalize the probability - sum_p = sum(item.probability for item in list) - for item in list: - item.probability = item.probability / sum_p * target_sum - - return list - - -# This function parse the array of rir set parameter strings. -# It will assign probabilities to those rir sets which don't have a probability -# It will also check the existence of the rir list files. -def ParseSetParameterStrings(set_para_array): - set_list = [] - for set_para in set_para_array: - set = lambda: None - setattr(set, "filename", None) - setattr(set, "probability", None) - parts = set_para.split(',') - if len(parts) == 2: - set.probability = float(parts[0]) - set.filename = parts[1].strip() - else: - set.filename = parts[0].strip() - if not os.path.isfile(set.filename): - raise Exception(set.filename + " not found") - set_list.append(set) - - return SmoothProbabilityDistribution(set_list) - - -# This function creates the RIR list -# Each rir object in the list contains the following attributes: -# rir_id, room_id, receiver_position_id, source_position_id, rt60, drr, probability -# Please refer to the help messages in the parser for the meaning of these attributes -def ParseRirList(rir_set_para_array, smoothing_weight, sampling_rate = None): - rir_parser = argparse.ArgumentParser() - rir_parser.add_argument('--rir-id', type=str, required=True, help='This id is unique for each RIR and the noise may associate with a particular RIR by refering to this id') - rir_parser.add_argument('--room-id', type=str, required=True, help='This is the room that where the RIR is generated') - rir_parser.add_argument('--receiver-position-id', type=str, default=None, help='receiver position id') - rir_parser.add_argument('--source-position-id', type=str, default=None, help='source position id') - rir_parser.add_argument('--rt60', type=float, default=None, help='RT60 is the time required for reflections of a direct sound to decay 60 dB.') - rir_parser.add_argument('--drr', type=float, default=None, help='Direct-to-reverberant-ratio of the impulse response.') - rir_parser.add_argument('--cte', type=float, default=None, help='Early-to-late index of the impulse response.') - rir_parser.add_argument('--probability', type=float, default=None, help='probability of the impulse response.') - rir_parser.add_argument('rir_rspecifier', type=str, help="""rir rspecifier, it can be either a filename or a piped command. - E.g. data/impulses/Room001-00001.wav or "sox data/impulses/Room001-00001.wav -t wav - |" """) - - set_list = ParseSetParameterStrings(rir_set_para_array) - - rir_list = [] - for rir_set in set_list: - current_rir_list = map(lambda x: rir_parser.parse_args(shlex.split(x.strip())),open(rir_set.filename)) - for rir in current_rir_list: - if sampling_rate is not None: - # check if the rspecifier is a pipe or not - if len(rir.rir_rspecifier.split()) == 1: - rir.rir_rspecifier = "sox {0} -r {1} -t wav - |".format(rir.rir_rspecifier, sampling_rate) - else: - rir.rir_rspecifier = "{0} sox -t wav - -r {1} -t wav - |".format(rir.rir_rspecifier, sampling_rate) - - rir_list += SmoothProbabilityDistribution(current_rir_list, smoothing_weight, rir_set.probability) - - return rir_list - - -# This dunction checks if the inputs are approximately equal assuming they are floats. -def almost_equal(value_1, value_2, accuracy = 10**-8): - return abs(value_1 - value_2) < accuracy - -# This function converts a list of RIRs into a dictionary of RIRs indexed by the room-id. -# Its values are objects with two attributes: a local RIR list -# and the probability of the corresponding room -# Please look at the comments at ParseRirList() for the attributes that a RIR object contains -def MakeRoomDict(rir_list): - room_dict = {} - for rir in rir_list: - if rir.room_id not in room_dict: - # add new room - room_dict[rir.room_id] = lambda: None - setattr(room_dict[rir.room_id], "rir_list", []) - setattr(room_dict[rir.room_id], "probability", 0) - room_dict[rir.room_id].rir_list.append(rir) - - # the probability of the room is the sum of probabilities of its RIR - for key in room_dict.keys(): - room_dict[key].probability = sum(rir.probability for rir in room_dict[key].rir_list) - - assert almost_equal(sum(room_dict[key].probability for key in room_dict.keys()), 1.0) - - return room_dict - - -# This function creates the point-source noise list -# and the isotropic noise dictionary from the noise information file -# The isotropic noise dictionary is indexed by the room -# and its value is the corrresponding isotropic noise list -# Each noise object in the list contains the following attributes: -# noise_id, noise_type, bg_fg_type, room_linkage, probability, noise_rspecifier -# Please refer to the help messages in the parser for the meaning of these attributes -def ParseNoiseList(noise_set_para_array, smoothing_weight, sampling_rate = None): - noise_parser = argparse.ArgumentParser() - noise_parser.add_argument('--noise-id', type=str, required=True, help='noise id') - noise_parser.add_argument('--noise-type', type=str, required=True, help='the type of noise; i.e. isotropic or point-source', choices = ["isotropic", "point-source"]) - noise_parser.add_argument('--bg-fg-type', type=str, default="background", help='background or foreground noise, for background noises, ' - 'they will be extended before addition to cover the whole speech; for foreground noise, they will be kept ' - 'to their original duration and added at a random point of the speech.', choices = ["background", "foreground"]) - noise_parser.add_argument('--room-linkage', type=str, default=None, help='required if isotropic, should not be specified if point-source.') - noise_parser.add_argument('--probability', type=float, default=None, help='probability of the noise.') - noise_parser.add_argument('noise_rspecifier', type=str, help="""noise rspecifier, it can be either a filename or a piped command. - E.g. type5_noise_cirline_ofc_ambient1.wav or "sox type5_noise_cirline_ofc_ambient1.wav -t wav - |" """) - - set_list = ParseSetParameterStrings(noise_set_para_array) - - pointsource_noise_list = [] - iso_noise_dict = {} - for noise_set in set_list: - current_noise_list = map(lambda x: noise_parser.parse_args(shlex.split(x.strip())),open(noise_set.filename)) - current_pointsource_noise_list = [] - for noise in current_noise_list: - if sampling_rate is not None: - # check if the rspecifier is a pipe or not - if len(noise.noise_rspecifier.split()) == 1: - noise.noise_rspecifier = "sox {0} -r {1} -t wav - |".format(noise.noise_rspecifier, sampling_rate) - else: - noise.noise_rspecifier = "{0} sox -t wav - -r {1} -t wav - |".format(noise.noise_rspecifier, sampling_rate) + if output_reverb_dir is not None: + data_lib.CopyDataDirFiles(input_dir, output_reverb_dir, num_replicas, include_original, prefix) - if noise.noise_type == "isotropic": - if noise.room_linkage is None: - raise Exception("--room-linkage must be specified if --noise-type is isotropic") - else: - if noise.room_linkage not in iso_noise_dict: - iso_noise_dict[noise.room_linkage] = [] - iso_noise_dict[noise.room_linkage].append(noise) - else: - current_pointsource_noise_list.append(noise) - - pointsource_noise_list += SmoothProbabilityDistribution(current_pointsource_noise_list, smoothing_weight, noise_set.probability) - - # ensure the point-source noise probabilities sum to 1 - pointsource_noise_list = SmoothProbabilityDistribution(pointsource_noise_list, smoothing_weight, 1.0) - if len(pointsource_noise_list) > 0: - assert almost_equal(sum(noise.probability for noise in pointsource_noise_list), 1.0) - - # ensure the isotropic noise source probabilities for a given room sum to 1 - for key in iso_noise_dict.keys(): - iso_noise_dict[key] = SmoothProbabilityDistribution(iso_noise_dict[key]) - assert almost_equal(sum(noise.probability for noise in iso_noise_dict[key]), 1.0) - - return (pointsource_noise_list, iso_noise_dict) + if output_additive_noise_dir is not None: + data_lib.CopyDataDirFiles(input_dir, output_additive_noise_dir, num_replicas, include_original, prefix) def Main(): args = GetArgs() random.seed(args.random_seed) - rir_list = ParseRirList(args.rir_set_para_array, args.rir_smoothing_weight, args.source_sampling_rate) + rir_list = data_lib.ParseRirList(args.rir_set_para_array, args.rir_smoothing_weight, args.source_sampling_rate) print("Number of RIRs is {0}".format(len(rir_list))) pointsource_noise_list = [] iso_noise_dict = {} if args.noise_set_para_array is not None: - pointsource_noise_list, iso_noise_dict = ParseNoiseList(args.noise_set_para_array, args.noise_smoothing_weight, args.source_sampling_rate) + pointsource_noise_list, iso_noise_dict = data_lib.ParseNoiseList(args.noise_set_para_array, args.noise_smoothing_weight, args.source_sampling_rate) print("Number of point-source noises is {0}".format(len(pointsource_noise_list))) print("Number of isotropic noises is {0}".format(sum(len(iso_noise_dict[key]) for key in iso_noise_dict.keys()))) - room_dict = MakeRoomDict(rir_list) + room_dict = data_lib.MakeRoomDict(rir_list) if args.include_original_data == "true": include_original = True @@ -660,8 +334,11 @@ def Main(): shift_output = args.shift_output, isotropic_noise_addition_probability = args.isotropic_noise_addition_probability, pointsource_noise_addition_probability = args.pointsource_noise_addition_probability, - max_noises_per_minute = args.max_noises_per_minute) + max_noises_per_minute = args.max_noises_per_minute, + output_reverb_dir = args.output_reverb_dir, + output_additive_noise_dir = args.output_additive_noise_dir) if __name__ == "__main__": Main() + diff --git a/src/featbin/wav-reverberate.cc b/src/featbin/wav-reverberate.cc index a9e6d3509c1..3b92f6e0b3e 100644 --- a/src/featbin/wav-reverberate.cc +++ b/src/featbin/wav-reverberate.cc @@ -156,6 +156,8 @@ int main(int argc, char *argv[]) { bool normalize_output = true; BaseFloat volume = 0; BaseFloat duration = 0; + std::string reverb_wxfilename; + std::string additive_noise_wxfilename; po.Register("multi-channel-output", &multi_channel_output, "Specifies if the output should be multi-channel or not"); @@ -212,6 +214,14 @@ int main(int argc, char *argv[]) { "after reverberating and possibly adding noise. " "If you set this option to a nonzero value, it will be as " "if you had also specified --normalize-output=false."); + po.Register("reverb-out-wxfilename", &reverb_wxfilename, + "Output the reverberated wave file, i.e. before adding the " + "additive noise. " + "Useful for computing SNR features or for debugging"); + po.Register("additive-noise-out-wxfilename", + &additive_noise_wxfilename, + "Output the additive noise file used to corrupt the input wave." + "Useful for computing SNR features or for debugging"); po.Read(argc, argv); if (po.NumArgs() != 2) { @@ -314,10 +324,23 @@ int main(int argc, char *argv[]) { int32 num_samp_output = (duration > 0 ? samp_freq_input * duration : (shift_output ? num_samp_input : num_samp_input + num_samp_rir - 1)); + Matrix out_matrix(num_output_channels, num_samp_output); + Matrix out_reverb_matrix; + if (!reverb_wxfilename.empty()) + out_reverb_matrix.Resize(num_output_channels, num_samp_output); + + Matrix out_noise_matrix; + if (!additive_noise_wxfilename.empty()) + out_noise_matrix.Resize(num_output_channels, num_samp_output); + for (int32 output_channel = 0; output_channel < num_output_channels; output_channel++) { Vector input(num_samp_input); + + Vector out_reverb(0); + Vector out_noise(0); + input.CopyRowFromMat(input_matrix, input_channel); float power_before_reverb = VecVec(input, input) / input.Dim(); @@ -337,6 +360,16 @@ int main(int argc, char *argv[]) { } } + if (!reverb_wxfilename.empty()) { + out_reverb.Resize(input.Dim()); + out_reverb.CopyFromVec(input); + } + + if (!additive_noise_wxfilename.empty()) { + out_noise.Resize(input.Dim()); + out_noise.SetZero(); + } + if (additive_signal_matrices.size() > 0) { Vector noise(0); int32 this_noise_channel = (multi_channel_output ? output_channel : noise_channel); @@ -345,33 +378,86 @@ int main(int argc, char *argv[]) { for (int32 i = 0; i < additive_signal_matrices.size(); i++) { noise.Resize(additive_signal_matrices[i].NumCols()); noise.CopyRowFromMat(additive_signal_matrices[i], this_noise_channel); - AddNoise(&noise, snr_vector[i], start_time_vector[i], - samp_freq_input, early_energy, &input); + + if (!additive_noise_wxfilename.empty()) { + AddNoise(&noise, snr_vector[i], start_time_vector[i], + samp_freq_input, early_energy, &out_noise); + } else { + AddNoise(&noise, snr_vector[i], start_time_vector[i], + samp_freq_input, early_energy, &input); + } + } + + if (!additive_noise_wxfilename.empty()) { + input.AddVec(1.0, out_noise); } } float power_after_reverb = VecVec(input, input) / input.Dim(); - if (volume > 0) + if (volume > 0) { input.Scale(volume); - else if (normalize_output) + out_reverb.Scale(volume); + out_noise.Scale(volume); + } else if (normalize_output) { input.Scale(sqrt(power_before_reverb / power_after_reverb)); + out_reverb.Scale(sqrt(power_before_reverb / power_after_reverb)); + out_noise.Scale(sqrt(power_before_reverb / power_after_reverb)); + } if (num_samp_output <= num_samp_input) { // trim the signal from the start out_matrix.CopyRowFromVec(input.Range(shift_index, num_samp_output), output_channel); + + if (!reverb_wxfilename.empty()) { + out_reverb_matrix.CopyRowFromVec(out_reverb.Range(shift_index, num_samp_output), output_channel); + } + + if (!additive_noise_wxfilename.empty()) { + out_noise_matrix.CopyRowFromVec(out_noise.Range(shift_index, num_samp_output), output_channel); + } } else { - // repeat the signal to fill up the duration - Vector extended_input(num_samp_output); - extended_input.SetZero(); - AddVectorsOfUnequalLength(input.Range(shift_index, num_samp_input), &extended_input); - out_matrix.CopyRowFromVec(extended_input, output_channel); + { + // repeat the signal to fill up the duration + Vector extended_input(num_samp_output); + extended_input.SetZero(); + AddVectorsOfUnequalLength(input.Range(shift_index, num_samp_input), &extended_input); + out_matrix.CopyRowFromVec(extended_input, output_channel); + } + if (!reverb_wxfilename.empty()) { + // repeat the signal to fill up the duration + Vector extended_input(num_samp_output); + extended_input.SetZero(); + AddVectorsOfUnequalLength(out_reverb.Range(shift_index, num_samp_input), &extended_input); + out_reverb_matrix.CopyRowFromVec(extended_input, output_channel); + } + if (!additive_noise_wxfilename.empty()) { + // repeat the signal to fill up the duration + Vector extended_input(num_samp_output); + extended_input.SetZero(); + AddVectorsOfUnequalLength(out_noise.Range(shift_index, num_samp_input), &extended_input); + out_noise_matrix.CopyRowFromVec(extended_input, output_channel); + } } } + + { + WaveData out_wave(samp_freq_input, out_matrix); + Output ko(output_wave_file, false); + out_wave.Write(ko.Stream()); + } + + if (!reverb_wxfilename.empty()) { + WaveData out_wave(samp_freq_input, out_reverb_matrix); + Output ko(reverb_wxfilename, false); + out_wave.Write(ko.Stream()); + } - WaveData out_wave(samp_freq_input, out_matrix); - Output ko(output_wave_file, false); - out_wave.Write(ko.Stream()); + if (!additive_noise_wxfilename.empty()) { + WaveData out_wave(samp_freq_input, out_noise_matrix); + Output ko(additive_noise_wxfilename, false); + out_wave.Write(ko.Stream()); + } return 0; } catch(const std::exception &e) { From 9ca5aa09507d1826bf36341663ca812b8fa0de8a Mon Sep 17 00:00:00 2001 From: Vimal Manohar Date: Wed, 23 Nov 2016 22:33:53 -0500 Subject: [PATCH 033/213] asr_diarization: Add extra_egs_copy_cmd --- .../nnet3/train/frame_level_objf/common.py | 64 ++++++++++++------- egs/wsj/s5/steps/nnet3/train_raw_dnn.py | 12 +++- egs/wsj/s5/steps/nnet3/train_raw_rnn.py | 12 +++- 3 files changed, 59 insertions(+), 29 deletions(-) diff --git a/egs/wsj/s5/steps/libs/nnet3/train/frame_level_objf/common.py b/egs/wsj/s5/steps/libs/nnet3/train/frame_level_objf/common.py index 87cae801e90..d0cb2a52758 100644 --- a/egs/wsj/s5/steps/libs/nnet3/train/frame_level_objf/common.py +++ b/egs/wsj/s5/steps/libs/nnet3/train/frame_level_objf/common.py @@ -30,7 +30,8 @@ def train_new_models(dir, iter, srand, num_jobs, shuffle_buffer_size, minibatch_size, cache_read_opt, run_opts, frames_per_eg=-1, - min_deriv_time=None, max_deriv_time=None): + min_deriv_time=None, max_deriv_time=None, + extra_egs_copy_cmd=""): """ Called from train_one_iteration(), this model does one iteration of training with 'num_jobs' jobs, and writes files like exp/tdnn_a/24.{1,2,3,..}.raw @@ -92,7 +93,7 @@ def train_new_models(dir, iter, srand, num_jobs, --max-param-change={max_param_change} \ {deriv_time_opts} "{raw_model}" \ "ark,bg:nnet3-copy-egs {frame_opts} {context_opts} """ - """ark:{egs_dir}/egs.{archive_index}.ark ark:- |""" + """ark:{egs_dir}/egs.{archive_index}.ark ark:- |{extra_egs_copy_cmd}""" """nnet3-shuffle-egs --buffer-size={shuffle_buffer_size} """ """--srand={srand} ark:- ark:- | """ """nnet3-merge-egs --minibatch-size={minibatch_size} """ @@ -115,7 +116,9 @@ def train_new_models(dir, iter, srand, num_jobs, raw_model=raw_model_string, context_opts=context_opts, egs_dir=egs_dir, archive_index=archive_index, shuffle_buffer_size=shuffle_buffer_size, - minibatch_size=minibatch_size), wait=False) + minibatch_size=minibatch_size, + extra_egs_copy_cmd=extra_egs_copy_cmd), + wait=False) processes.append(process_handle) @@ -143,7 +146,8 @@ def train_one_iteration(dir, iter, srand, egs_dir, min_deriv_time=None, max_deriv_time=None, shrinkage_value=1.0, get_raw_nnet_from_am=True, - background_process_handler=None): + background_process_handler=None, + extra_egs_copy_cmd=""): """ Called from steps/nnet3/train_*.py scripts for one iteration of neural network training @@ -192,7 +196,8 @@ def train_one_iteration(dir, iter, srand, egs_dir, run_opts=run_opts, mb_size=cv_minibatch_size, get_raw_nnet_from_am=get_raw_nnet_from_am, wait=False, - background_process_handler=background_process_handler) + background_process_handler=background_process_handler, + extra_egs_copy_cmd=extra_egs_copy_cmd) if iter > 0: # Runs in the background @@ -202,7 +207,8 @@ def train_one_iteration(dir, iter, srand, egs_dir, run_opts=run_opts, mb_size=cv_minibatch_size, wait=False, get_raw_nnet_from_am=get_raw_nnet_from_am, - background_process_handler=background_process_handler) + background_process_handler=background_process_handler, + extra_egs_copy_cmd=extra_egs_copy_cmd) # an option for writing cache (storing pairs of nnet-computations # and computation-requests) during training. @@ -276,7 +282,8 @@ def train_one_iteration(dir, iter, srand, egs_dir, cache_read_opt=cache_read_opt, run_opts=run_opts, frames_per_eg=frames_per_eg, min_deriv_time=min_deriv_time, - max_deriv_time=max_deriv_time) + max_deriv_time=max_deriv_time, + extra_egs_copy_cmd=extra_egs_copy_cmd) [models_to_average, best_model] = common_train_lib.get_successful_models( num_jobs, '{0}/log/train.{1}.%.log'.format(dir, iter)) @@ -375,7 +382,8 @@ def compute_preconditioning_matrix(dir, egs_dir, num_lda_jobs, run_opts, def compute_train_cv_probabilities(dir, iter, egs_dir, left_context, right_context, run_opts, mb_size=256, wait=False, background_process_handler=None, - get_raw_nnet_from_am=True): + get_raw_nnet_from_am=True, + extra_egs_copy_cmd=""): if get_raw_nnet_from_am: model = "nnet3-am-copy --raw=true {dir}/{iter}.mdl - |".format( dir=dir, iter=iter) @@ -389,7 +397,7 @@ def compute_train_cv_probabilities(dir, iter, egs_dir, left_context, """ {command} {dir}/log/compute_prob_valid.{iter}.log \ nnet3-compute-prob "{model}" \ "ark,bg:nnet3-copy-egs {context_opts} \ - ark:{egs_dir}/valid_diagnostic.egs ark:- | \ + ark:{egs_dir}/valid_diagnostic.egs ark:- |{extra_egs_copy_cmd} \ nnet3-merge-egs --minibatch-size={mb_size} ark:- \ ark:- |" """.format(command=run_opts.command, dir=dir, @@ -397,14 +405,15 @@ def compute_train_cv_probabilities(dir, iter, egs_dir, left_context, context_opts=context_opts, mb_size=mb_size, model=model, - egs_dir=egs_dir), + egs_dir=egs_dir, + extra_egs_copy_cmd=extra_egs_copy_cmd), wait=wait, background_process_handler=background_process_handler) common_lib.run_job( """{command} {dir}/log/compute_prob_train.{iter}.log \ nnet3-compute-prob "{model}" \ "ark,bg:nnet3-copy-egs {context_opts} \ - ark:{egs_dir}/train_diagnostic.egs ark:- | \ + ark:{egs_dir}/train_diagnostic.egs ark:- |{extra_egs_copy_cmd} \ nnet3-merge-egs --minibatch-size={mb_size} ark:- \ ark:- |" """.format(command=run_opts.command, dir=dir, @@ -412,14 +421,16 @@ def compute_train_cv_probabilities(dir, iter, egs_dir, left_context, context_opts=context_opts, mb_size=mb_size, model=model, - egs_dir=egs_dir), + egs_dir=egs_dir, + extra_egs_copy_cmd=extra_egs_copy_cmd), wait=wait, background_process_handler=background_process_handler) def compute_progress(dir, iter, egs_dir, left_context, right_context, run_opts, mb_size=256, background_process_handler=None, wait=False, - get_raw_nnet_from_am=True): + get_raw_nnet_from_am=True, + extra_egs_copy_cmd=""): if get_raw_nnet_from_am: prev_model = "nnet3-am-copy --raw=true {0}/{1}.mdl - |".format( dir, iter - 1) @@ -436,7 +447,7 @@ def compute_progress(dir, iter, egs_dir, left_context, right_context, nnet3-info "{model}" '&&' \ nnet3-show-progress --use-gpu=no "{prev_model}" "{model}" \ "ark,bg:nnet3-copy-egs {context_opts} \ - ark:{egs_dir}/train_diagnostic.egs ark:- | \ + ark:{egs_dir}/train_diagnostic.egs ark:- |{extra_egs_copy_cmd} \ nnet3-merge-egs --minibatch-size={mb_size} ark:- \ ark:- |" """.format(command=run_opts.command, dir=dir, @@ -445,14 +456,16 @@ def compute_progress(dir, iter, egs_dir, left_context, right_context, context_opts=context_opts, mb_size=mb_size, prev_model=prev_model, - egs_dir=egs_dir), + egs_dir=egs_dir, + extra_egs_copy_cmd=extra_egs_copy_cmd), wait=wait, background_process_handler=background_process_handler) def combine_models(dir, num_iters, models_to_combine, egs_dir, left_context, right_context, run_opts, background_process_handler=None, - chunk_width=None, get_raw_nnet_from_am=True): + chunk_width=None, get_raw_nnet_from_am=True, + extra_egs_copy_cmd=""): """ Function to do model combination In the nnet3 setup, the logic @@ -499,7 +512,7 @@ def combine_models(dir, num_iters, models_to_combine, egs_dir, --enforce-sum-to-one=true --enforce-positive-weights=true \ --verbose=3 {raw_models} \ "ark,bg:nnet3-copy-egs {context_opts} \ - ark:{egs_dir}/combine.egs ark:- | \ + ark:{egs_dir}/combine.egs ark:- |{extra_egs_copy_cmd} \ nnet3-merge-egs --measure-output-frames=false \ --minibatch-size={mbsize} ark:- ark:- |" \ "{out_model}" @@ -509,7 +522,8 @@ def combine_models(dir, num_iters, models_to_combine, egs_dir, context_opts=context_opts, mbsize=mbsize, out_model=out_model, - egs_dir=egs_dir)) + egs_dir=egs_dir, + extra_egs_copy_cmd=extra_egs_copy_cmd)) # Compute the probability of the final, combined model with # the same subset we used for the previous compute_probs, as the @@ -519,14 +533,16 @@ def combine_models(dir, num_iters, models_to_combine, egs_dir, dir=dir, iter='combined', egs_dir=egs_dir, left_context=left_context, right_context=right_context, run_opts=run_opts, wait=False, - background_process_handler=background_process_handler) + background_process_handler=background_process_handler, + extra_egs_copy_cmd=extra_egs_copy_cmd) else: compute_train_cv_probabilities( dir=dir, iter='final', egs_dir=egs_dir, left_context=left_context, right_context=right_context, run_opts=run_opts, wait=False, background_process_handler=background_process_handler, - get_raw_nnet_from_am=False) + get_raw_nnet_from_am=False, + extra_egs_copy_cmd=extra_egs_copy_cmd) def get_realign_iters(realign_times, num_iters, @@ -639,7 +655,8 @@ def adjust_am_priors(dir, input_model, avg_posterior_vector, output_model, def compute_average_posterior(dir, iter, egs_dir, num_archives, prior_subset_size, left_context, right_context, - run_opts, get_raw_nnet_from_am=True): + run_opts, get_raw_nnet_from_am=True, + extra_egs_copy_cmd=""): """ Computes the average posterior of the network Note: this just uses CPUs, using a smallish subset of data. """ @@ -663,7 +680,7 @@ def compute_average_posterior(dir, iter, egs_dir, num_archives, """{command} JOB=1:{num_jobs_compute_prior} {prior_queue_opt} \ {dir}/log/get_post.{iter}.JOB.log \ nnet3-copy-egs {context_opts} \ - ark:{egs_dir}/egs.{egs_part}.ark ark:- \| \ + ark:{egs_dir}/egs.{egs_part}.ark ark:- \| {extra_egs_copy_cmd}\ nnet3-subset-egs --srand=JOB --n={prior_subset_size} \ ark:- ark:- \| \ nnet3-merge-egs --measure-output-frames=true \ @@ -679,7 +696,8 @@ def compute_average_posterior(dir, iter, egs_dir, num_archives, iter=iter, prior_subset_size=prior_subset_size, egs_dir=egs_dir, egs_part=egs_part, context_opts=context_opts, - prior_gpu_opt=run_opts.prior_gpu_opt)) + prior_gpu_opt=run_opts.prior_gpu_opt, + extra_egs_copy_cmd=extra_egs_copy_cmd)) # make sure there is time for $dir/post.{iter}.*.vec to appear. time.sleep(5) diff --git a/egs/wsj/s5/steps/nnet3/train_raw_dnn.py b/egs/wsj/s5/steps/nnet3/train_raw_dnn.py index b67ba8792a8..d7651889d83 100755 --- a/egs/wsj/s5/steps/nnet3/train_raw_dnn.py +++ b/egs/wsj/s5/steps/nnet3/train_raw_dnn.py @@ -53,6 +53,9 @@ def get_args(): parser.add_argument("--egs.frames-per-eg", type=int, dest='frames_per_eg', default=8, help="Number of output labels per example") + parser.add_argument("--egs.extra-copy-cmd", type=str, + dest='extra_egs_copy_cmd', default = "", + help="""Modify egs before passing it to training"""); # trainer options parser.add_argument("--trainer.prior-subset-size", type=int, @@ -322,7 +325,8 @@ def learning_rate(iter, current_num_jobs, num_archives_processed): shuffle_buffer_size=args.shuffle_buffer_size, run_opts=run_opts, get_raw_nnet_from_am=False, - background_process_handler=background_process_handler) + background_process_handler=background_process_handler, + extra_egs_copy_cmd=args.extra_egs_copy_cmd) if args.cleanup: # do a clean up everythin but the last 2 models, under certain @@ -353,7 +357,8 @@ def learning_rate(iter, current_num_jobs, num_archives_processed): left_context=left_context, right_context=right_context, run_opts=run_opts, background_process_handler=background_process_handler, - get_raw_nnet_from_am=False) + get_raw_nnet_from_am=False, + extra_egs_copy_cmd=args.extra_egs_copy_cmd) if include_log_softmax and args.stage <= num_iters + 1: logger.info("Getting average posterior for purposes of " @@ -363,7 +368,8 @@ def learning_rate(iter, current_num_jobs, num_archives_processed): num_archives=num_archives, left_context=left_context, right_context=right_context, prior_subset_size=args.prior_subset_size, run_opts=run_opts, - get_raw_nnet_from_am=False) + get_raw_nnet_from_am=False, + extra_egs_copy_cmd=args.extra_egs_copy_cmd) if args.cleanup: logger.info("Cleaning up the experiment directory " diff --git a/egs/wsj/s5/steps/nnet3/train_raw_rnn.py b/egs/wsj/s5/steps/nnet3/train_raw_rnn.py index 29df61ab546..ae038445fc0 100755 --- a/egs/wsj/s5/steps/nnet3/train_raw_rnn.py +++ b/egs/wsj/s5/steps/nnet3/train_raw_rnn.py @@ -69,6 +69,9 @@ def get_args(): help="""Number of left steps used in the estimation of LSTM state before prediction of the first label. Overrides the default value in CommonParser""") + parser.add_argument("--egs.extra-copy-cmd", type=str, + dest='extra_egs_copy_cmd', default = "", + help="""Modify egs before passing it to training"""); # trainer options parser.add_argument("--trainer.samples-per-iter", type=int, @@ -424,7 +427,8 @@ def learning_rate(iter, current_num_jobs, num_archives_processed): cv_minibatch_size=args.cv_minibatch_size, run_opts=run_opts, get_raw_nnet_from_am=False, - background_process_handler=background_process_handler) + background_process_handler=background_process_handler, + extra_egs_copy_cmd=args.extra_egs_copy_cmd) if args.cleanup: # do a clean up everythin but the last 2 models, under certain @@ -455,7 +459,8 @@ def learning_rate(iter, current_num_jobs, num_archives_processed): left_context=left_context, right_context=right_context, run_opts=run_opts, chunk_width=args.chunk_width, background_process_handler=background_process_handler, - get_raw_nnet_from_am=False) + get_raw_nnet_from_am=False, + extra_egs_copy_cmd=args.extra_egs_copy_cmd) if include_log_softmax and args.stage <= num_iters + 1: logger.info("Getting average posterior for purposes of " @@ -465,7 +470,8 @@ def learning_rate(iter, current_num_jobs, num_archives_processed): num_archives=num_archives, left_context=left_context, right_context=right_context, prior_subset_size=args.prior_subset_size, run_opts=run_opts, - get_raw_nnet_from_am=False) + get_raw_nnet_from_am=False, + extra_egs_copy_cmd=args.extra_egs_copy_cmd) if args.cleanup: logger.info("Cleaning up the experiment directory " From c5796c3c206795f6ce5ecbe154e7c1be592981b1 Mon Sep 17 00:00:00 2001 From: Vimal Manohar Date: Wed, 23 Nov 2016 21:49:35 -0500 Subject: [PATCH 034/213] asr_diarization: Create get_egs.py supporting multiple targets and get-egs-multiple-targets --- egs/wsj/s5/steps/libs/data.py | 57 ++ .../steps/nnet3/get_egs_multiple_targets.py | 910 ++++++++++++++++++ src/nnet3bin/Makefile | 2 +- .../nnet3-get-egs-multiple-targets.cc | 538 +++++++++++ 4 files changed, 1506 insertions(+), 1 deletion(-) create mode 100644 egs/wsj/s5/steps/libs/data.py create mode 100755 egs/wsj/s5/steps/nnet3/get_egs_multiple_targets.py create mode 100644 src/nnet3bin/nnet3-get-egs-multiple-targets.cc diff --git a/egs/wsj/s5/steps/libs/data.py b/egs/wsj/s5/steps/libs/data.py new file mode 100644 index 00000000000..44895cae1a4 --- /dev/null +++ b/egs/wsj/s5/steps/libs/data.py @@ -0,0 +1,57 @@ +import os + +import libs.common as common_lib + +def get_frame_shift(data_dir): + frame_shift = common_lib.run_kaldi_command("utils/data/get_frame_shift.sh {0}".format(data_dir))[0] + return float(frame_shift.strip()) + +def generate_utt2dur(data_dir): + common_lib.run_kaldi_command("utils/data/get_utt2dur.sh {0}".format(data_dir)) + +def get_utt2dur(data_dir): + generate_utt2dur(data_dir) + utt2dur = {} + for line in open('{0}/utt2dur'.format(data_dir), 'r').readlines(): + parts = line.split() + utt2dur[parts[0]] = float(parts[1]) + return utt2dur + +def get_utt2uniq(data_dir): + utt2uniq_file = '{0}/utt2uniq'.format(data_dir) + if not os.path.exists(utt2uniq_file): + return None, None + utt2uniq = {} + uniq2utt = {} + for line in open(utt2uniq_file, 'r').readlines(): + parts = line.split() + utt2uniq[parts[0]] = parts[1] + if uniq2utt.has_key(parts[1]): + uniq2utt[parts[1]].append(parts[0]) + else: + uniq2utt[parts[1]] = [parts[0]] + return utt2uniq, uniq2utt + +def get_num_frames(data_dir, utts = None): + generate_utt2dur(data_dir) + frame_shift = get_frame_shift(data_dir) + total_duration = 0 + utt2dur = get_utt2dur(data_dir) + if utts is None: + utts = utt2dur.keys() + for utt in utts: + total_duration = total_duration + utt2dur[utt] + return int(float(total_duration)/frame_shift) + +def create_data_links(file_names): + # if file_names already exist create_data_link.pl returns with code 1 + # so we just delete them before calling create_data_link.pl + for file_name in file_names: + try_to_delete(file_name) + common_lib.run_kaldi_command(" utils/create_data_link.pl {0}".format(" ".join(file_names))) + +def try_to_delete(file_name): + try: + os.remove(file_name) + except OSError: + pass diff --git a/egs/wsj/s5/steps/nnet3/get_egs_multiple_targets.py b/egs/wsj/s5/steps/nnet3/get_egs_multiple_targets.py new file mode 100755 index 00000000000..16e1f98a019 --- /dev/null +++ b/egs/wsj/s5/steps/nnet3/get_egs_multiple_targets.py @@ -0,0 +1,910 @@ +#!/usr/bin/env python + +# Copyright 2016 Vijayaditya Peddinti +# 2016 Vimal Manohar +# Apache 2.0. + +from __future__ import print_function +import os +import argparse +import sys +import logging +import shlex +import random +import math +import glob + +import libs.data as data_lib +import libs.common as common_lib + +logger = logging.getLogger('libs') +logger.setLevel(logging.INFO) +handler = logging.StreamHandler() +handler.setLevel(logging.INFO) +formatter = logging.Formatter("%(asctime)s [%(filename)s:%(lineno)s - " + "%(funcName)s - %(levelname)s ] %(message)s") +handler.setFormatter(formatter) +logger.addHandler(handler) +logger.info('Getting egs for training') + + +def get_args(): + # we add compulsary arguments as named arguments for readability + parser = argparse.ArgumentParser( + description="""Generates training examples used to train the 'nnet3' + network (and also the validation examples used for diagnostics), + and puts them in separate archives.""", + formatter_class=argparse.ArgumentDefaultsHelpFormatter) + + parser.add_argument("--cmd", type=str, default="run.pl", + help="Specifies the script to launch jobs." + " e.g. queue.pl for launching on SGE cluster run.pl" + " for launching on local machine") + # feat options + parser.add_argument("--feat.dir", type=str, dest='feat_dir', required=True, + help="Directory with features used for training " + "the neural network.") + parser.add_argument("--feat.online-ivector-dir", type=str, + dest='online_ivector_dir', + default=None, action=common_lib.NullstrToNoneAction, + help="directory with the ivectors extracted in an " + "online fashion.") + parser.add_argument("--feat.cmvn-opts", type=str, dest='cmvn_opts', + default=None, action=common_lib.NullstrToNoneAction, + help="A string specifying '--norm-means' and " + "'--norm-vars' values") + parser.add_argument("--feat.apply-cmvn-sliding", type=str, + dest='apply_cmvn_sliding', + default=False, action=common_lib.StrToBoolAction, + help="Apply CMVN sliding, instead of per-utteance " + "or speakers") + + # egs extraction options + parser.add_argument("--frames-per-eg", type=int, default=8, + help="""Number of frames of labels per example. + more->less disk space and less time preparing egs, but + more I/O during training. + note: the script may reduce this if + reduce-frames-per-eg is true.""") + parser.add_argument("--left-context", type=int, default=4, + help="""Amount of left-context per eg (i.e. extra + frames of input features not present in the output + supervision).""") + parser.add_argument("--right-context", type=int, default=4, + help="Amount of right-context per eg") + parser.add_argument("--valid-left-context", type=int, default=None, + help="""Amount of left-context for validation egs, + typically used in recurrent architectures to ensure + matched condition with training egs""") + parser.add_argument("--valid-right-context", type=int, default=None, + help="""Amount of right-context for validation egs, + typically used in recurrent architectures to ensure + matched condition with training egs""") + parser.add_argument("--compress-input", type=str, default=True, + action=common_lib.StrToBoolAction, + choices=["true", "false"], + help="If false, disables compression. Might be " + "necessary to check if results will be affected.") + parser.add_argument("--input-compress-format", type=int, default=0, + help="Format used for compressing the input features") + + parser.add_argument("--reduce-frames-per-eg", type=str, default=True, + action=common_lib.StrToBoolAction, + choices=["true", "false"], + help="""If true, this script may reduce the + frames-per-eg if there is only one archive and even + with the reduced frames-per-eg, the number of + samples-per-iter that would result is less than or + equal to the user-specified value.""") + + parser.add_argument("--num-utts-subset", type=int, default=300, + help="Number of utterances in validation and training" + " subsets used for shrinkage and diagnostics") + parser.add_argument("--num-utts-subset-valid", type=int, + help="Number of utterances in validation" + " subset used for diagnostics") + parser.add_argument("--num-utts-subset-train", type=int, + help="Number of utterances in training" + " subset used for shrinkage and diagnostics") + parser.add_argument("--num-train-egs-combine", type=int, default=10000, + help="Training examples for combination weights at the" + " very end.") + parser.add_argument("--num-valid-egs-combine", type=int, default=0, + help="Validation examples for combination weights at " + "the very end.") + parser.add_argument("--num-egs-diagnostic", type=int, default=4000, + help="Numer of frames for 'compute-probs' jobs") + + parser.add_argument("--samples-per-iter", type=int, default=400000, + help="""This is the target number of egs in each + archive of egs (prior to merging egs). We probably + should have called it egs_per_iter. This is just a + guideline; it will pick a number that divides the + number of samples in the entire data.""") + + parser.add_argument("--stage", type=int, default=0, + help="Stage to start running script from") + parser.add_argument("--num-jobs", type=int, default=6, + help="""This should be set to the maximum number of + jobs you are comfortable to run in parallel; you can + increase it if your disk speed is greater and you have + more machines.""") + parser.add_argument("--srand", type=int, default=0, + help="Rand seed for nnet3-copy-egs and " + "nnet3-shuffle-egs") + + parser.add_argument("--targets-parameters", type=str, action='append', + required=True, dest='targets_para_array', + help="""Parameters for targets. Each set of parameters + corresponds to a separate output node of the neural + network. The targets can be sparse or dense. + The parameters used are: + --targets-rspecifier= + # rspecifier for the targets, can be alignment or + # matrix. + --num-targets= + # targets dimension. required for sparse feats. + --target-type=""") + + parser.add_argument("--dir", type=str, required=True, + help="Directory to store the examples") + + print(' '.join(sys.argv)) + print(sys.argv) + + args = parser.parse_args() + + args = process_args(args) + + return args + + +def process_args(args): + # process the options + if args.num_utts_subset_valid is None: + args.num_utts_subset_valid = args.num_utts_subset + + if args.num_utts_subset_train is None: + args.num_utts_subset_train = args.num_utts_subset + + if args.valid_left_context is None: + args.valid_left_context = args.left_context + if args.valid_right_context is None: + args.valid_right_context = args.right_context + + if (args.left_context < 0 or args.right_context < 0 + or args.valid_left_context < 0 or args.valid_right_context < 0): + raise Exception( + "--{,valid-}{left,right}-context should be non-negative") + + return args + + +def check_for_required_files(feat_dir, targets_scps, online_ivector_dir=None): + required_files = ['{0}/feats.scp'.format(feat_dir), + '{0}/cmvn.scp'.format(feat_dir)] + if online_ivector_dir is not None: + required_files.append('{0}/ivector_online.scp'.format( + online_ivector_dir)) + required_files.append('{0}/ivector_period'.format( + online_ivector_dir)) + + for file in required_files: + if not os.path.isfile(file): + raise Exception('Expected {0} to exist.'.format(file)) + + +def parse_targets_parameters_array(para_array): + targets_parser = argparse.ArgumentParser() + targets_parser.add_argument("--output-name", type=str, required=True, + help="Name of the output. e.g. output-xent") + targets_parser.add_argument("--dim", type=int, default=-1, + help="Target dimension (required for sparse " + "targets") + targets_parser.add_argument("--target-type", type=str, default="dense", + choices=["dense", "sparse"], + help="Dense for matrix format") + targets_parser.add_argument("--targets-scp", type=str, required=True, + help="Scp file of targets; can be posteriors " + "or matrices") + targets_parser.add_argument("--compress", type=str, default=True, + action=common_lib.StrToBoolAction, + help="Specifies whether the output must be " + "compressed") + targets_parser.add_argument("--compress-format", type=int, default=0, + help="Format for compressing target") + targets_parser.add_argument("--deriv-weights-scp", type=str, default="", + help="Per-frame deriv weights for this output") + targets_parser.add_argument("--scp2ark-cmd", type=str, default="", + help="""The command that is used to convert + targets scp to archive. e.g. An scp of + alignments can be converted to posteriors using + ali-to-post""") + + targets_parameters = [targets_parser.parse_args(shlex.split(x)) + for x in para_array] + + for t in targets_parameters: + if not os.path.isfile(t.targets_scp): + raise Exception("Expected {0} to exist.".format(t.targets_scp)) + + if (t.target_type == "dense"): + dim = common_lib.get_feat_dim_from_scp(t.targets_scp) + if (t.dim != -1 and t.dim != dim): + raise Exception('Mismatch in --dim provided and feat dim for ' + 'file {0}; {1} vs {2}'.format(t.targets_scp, + t.dim, dim)) + t.dim = -dim + + return targets_parameters + + +def sample_utts(feat_dir, num_utts_subset, min_duration, exclude_list=None): + utt2durs_dict = data_lib.get_utt2dur(feat_dir) + utt2durs = utt2durs_dict.items() + utt2uniq, uniq2utt = data_lib.get_utt2uniq(feat_dir) + if num_utts_subset is None: + num_utts_subset = len(utt2durs) + if exclude_list is not None: + num_utts_subset = num_utts_subset - len(exclude_list) + + random.shuffle(utt2durs) + sampled_utts = [] + + index = 0 + num_trials = 0 + while (len(sampled_utts) < num_utts_subset + and num_trials <= len(utt2durs)): + if utt2durs[index][-1] >= min_duration: + if utt2uniq is not None: + uniq_id = utt2uniq[utt2durs[index][0]] + utts2add = uniq2utt[uniq_id] + else: + utts2add = [utt2durs[index][0]] + exclude_utt = False + if exclude_list is not None: + for utt in utts2add: + if utt in exclude_list: + exclude_utt = True + break + if not exclude_utt: + for utt in utts2add: + sampled_utts.append(utt) + + index = index + 1 + num_trials = num_trials + 1 + if exclude_list is not None: + assert(len(set(exclude_list).intersection(sampled_utts)) == 0) + if len(sampled_utts) < num_utts_subset: + raise Exception( + """Number of utterances which have duration of at least {md} + seconds is really low (required={rl}, available={al}). Please + check your data.""".format( + md=min_duration, al=len(sampled_utts), rl=num_utts_subset)) + + sampled_utts_durs = [] + for utt in sampled_utts: + sampled_utts_durs.append([utt, utt2durs_dict[utt]]) + return sampled_utts, sampled_utts_durs + + +def write_list(listd, file_name): + file_handle = open(file_name, 'w') + assert(type(listd) == list) + for item in listd: + file_handle.write(str(item)+"\n") + file_handle.close() + + +def get_max_open_files(): + stdout, stderr = common_lib.run_kaldi_command("ulimit -n") + return int(stdout) + + +def get_feat_ivector_strings(dir, feat_dir, split_feat_dir, + cmvn_opt_string, ivector_dir=None, + apply_cmvn_sliding=False): + + if not apply_cmvn_sliding: + train_feats = ("ark,s,cs:utils/filter_scp.pl --exclude " + "{dir}/valid_uttlist {sdir}/JOB/feats.scp | " + "apply-cmvn {cmvn} --utt2spk=ark:{sdir}/JOB/utt2spk " + "scp:{sdir}/JOB/cmvn.scp scp:- ark:- |".format( + dir=dir, sdir=split_feat_dir, + cmvn=cmvn_opt_string)) + valid_feats = ("ark,s,cs:utils/filter_scp.pl {dir}/valid_uttlist " + "{fdir}/feats.scp | " + "apply-cmvn {cmvn} --utt2spk=ark:{fdir}/utt2spk " + "scp:{fdir}/cmvn.scp scp:- ark:- |".format( + dir=dir, fdir=feat_dir, cmvn=cmvn_opt_string)) + train_subset_feats = ("ark,s,cs:utils/filter_scp.pl " + "{dir}/train_subset_uttlist {fdir}/feats.scp | " + "apply-cmvn {cmvn} --utt2spk=ark:{fdir}/utt2spk " + "scp:{fdir}/cmvn.scp scp:- ark:- |".format( + dir=dir, fdir=feat_dir, + cmvn=cmvn_opt_string)) + + def feats_subset_func(subset_list): + return ("ark,s,cs:utils/filter_scp.pl {subset_list} " + "{fdir}/feats.scp | " + "apply-cmvn {cmvn} --utt2spk=ark:{fdir}/utt2spk " + "scp:{fdir}/cmvn.scp scp:- ark:- |".format( + dir=dir, subset_list=subset_list, + fdir=feat_dir, cmvn=cmvn_opt_string)) + + else: + train_feats = ("ark,s,cs:utils/filter_scp.pl --exclude " + "{dir}/valid_uttlist {sdir}/JOB/feats.scp | " + "apply-cmvn-sliding scp:{sdir}/JOB/cmvn.scp scp:- " + "ark:- |".format(dir=dir, sdir=split_feat_dir, + cmvn=cmvn_opt_string)) + + def feats_subset_func(subset_list): + return ("ark,s,cs:utils/filter_scp.pl {subset_list} " + "{fdir}/feats.scp | " + "apply-cmvn-sliding {cmvn} scp:{fdir}/cmvn.scp scp:- " + "ark:- |".format(dir=dir, subset_list=subset_list, + fdir=feat_dir, cmvn=cmvn_opt_string)) + + train_subset_feats = feats_subset_func( + "{0}/train_subset_uttlist".format(dir)) + valid_feats = feats_subset_func("{0}/valid_uttlist".format(dir)) + + if ivector_dir is not None: + ivector_period = common_lib.GetIvectorPeriod(ivector_dir) + ivector_opt = ("--ivectors='ark,s,cs:utils/filter_scp.pl " + "{sdir}/JOB/utt2spk {idir}/ivector_online.scp | " + "subsample-feats --n=-{period} scp:- ark:- |'".format( + sdir=split_feat_dir, idir=ivector_dir, + period=ivector_period)) + valid_ivector_opt = ("--ivectors='ark,s,cs:utils/filter_scp.pl " + "{dir}/valid_uttlist {idir}/ivector_online.scp | " + "subsample-feats --n=-{period} " + "scp:- ark:- |'".format( + dir=dir, idir=ivector_dir, + period=ivector_period)) + train_subset_ivector_opt = ( + "--ivectors='ark,s,cs:utils/filter_scp.pl " + "{dir}/train_subset_uttlist {idir}/ivector_online.scp | " + "subsample-feats --n=-{period} scp:- ark:- |'".format( + dir=dir, idir=ivector_dir, period=ivector_period)) + else: + ivector_opt = '' + valid_ivector_opt = '' + train_subset_ivector_opt = '' + + return {'train_feats': train_feats, + 'valid_feats': valid_feats, + 'train_subset_feats': train_subset_feats, + 'feats_subset_func': feats_subset_func, + 'ivector_opts': ivector_opt, + 'valid_ivector_opts': valid_ivector_opt, + 'train_subset_ivector_opts': train_subset_ivector_opt, + 'feat_dim': common_lib.get_feat_dim(feat_dir), + 'ivector_dim': common_lib.get_ivector_dim(ivector_dir)} + + +def get_egs_options(targets_parameters, frames_per_eg, + left_context, right_context, + valid_left_context, valid_right_context, + compress_input, + input_compress_format=0, length_tolerance=0): + + train_egs_opts = [] + train_egs_opts.append("--left-context={0}".format(left_context)) + train_egs_opts.append("--right-context={0}".format(right_context)) + train_egs_opts.append("--num-frames={0}".format(frames_per_eg)) + train_egs_opts.append("--compress-input={0}".format(compress_input)) + train_egs_opts.append("--input-compress-format={0}".format( + input_compress_format)) + train_egs_opts.append("--compress-targets={0}".format( + ':'.join(["true" if t.compress else "false" + for t in targets_parameters]))) + train_egs_opts.append("--targets-compress-formats={0}".format( + ':'.join([str(t.compress_format) + for t in targets_parameters]))) + train_egs_opts.append("--length-tolerance={0}".format(length_tolerance)) + train_egs_opts.append("--output-names={0}".format( + ':'.join([t.output_name + for t in targets_parameters]))) + train_egs_opts.append("--output-dims={0}".format( + ':'.join([str(t.dim) + for t in targets_parameters]))) + + valid_egs_opts = ( + "--left-context={vlc} --right-context={vrc} " + "--num-frames={n} --compress-input={comp} " + "--input-compress-format={icf} --compress-targets={ct} " + "--targets-compress-formats={tcf} --length-tolerance={tol} " + "--output-names={names} --output-dims={dims}".format( + vlc=valid_left_context, vrc=valid_right_context, n=frames_per_eg, + comp=compress_input, icf=input_compress_format, + ct=':'.join(["true" if t.compress else "false" + for t in targets_parameters]), + tcf=':'.join([str(t.compress_format) + for t in targets_parameters]), + tol=length_tolerance, + names=':'.join([t.output_name + for t in targets_parameters]), + dims=':'.join([str(t.dim) for t in targets_parameters]))) + + return {'train_egs_opts': " ".join(train_egs_opts), + 'valid_egs_opts': valid_egs_opts} + + +def get_targets_list(targets_parameters, subset_list): + targets_list = [] + for t in targets_parameters: + rspecifier = "ark,s,cs:" if t.scp2ark_cmd != "" else "scp,s,cs:" + rspecifier += get_subset_rspecifier(t.targets_scp, subset_list) + rspecifier += t.scp2ark_cmd + deriv_weights_rspecifier = "" + if t.deriv_weights_scp != "": + deriv_weights_rspecifier = "scp,s,cs:{0}".format( + get_subset_rspecifier(t.deriv_weights_scp, subset_list)) + this_targets = '''"{rspecifier}" "{dw}"'''.format( + rspecifier=rspecifier, dw=deriv_weights_rspecifier) + + targets_list.append(this_targets) + return " ".join(targets_list) + + +def get_subset_rspecifier(scp_file, subset_list): + if scp_file == "": + return "" + return "utils/filter_scp.pl {subset} {scp} |".format(subset=subset_list, + scp=scp_file) + + +def split_scp(scp_file, num_jobs): + out_scps = ["{0}.{1}".format(scp_file, n) for n in range(1, num_jobs + 1)] + common_lib.run_kaldi_command("utils/split_scp.pl {scp} {oscps}".format( + scp=scp_file, + oscps=' '.join(out_scps))) + return out_scps + + +def generate_valid_train_subset_egs(dir, targets_parameters, + feat_ivector_strings, egs_opts, + num_train_egs_combine, + num_valid_egs_combine, + num_egs_diagnostic, cmd, + num_jobs=1): + wait_pids = [] + + logger.info("Creating validation and train subset examples.") + + split_scp('{0}/valid_uttlist'.format(dir), num_jobs) + split_scp('{0}/train_subset_uttlist'.format(dir), num_jobs) + + valid_pid = common_lib.run_kaldi_command( + """{cmd} JOB=1:{nj} {dir}/log/create_valid_subset.JOB.log \ + nnet3-get-egs-multiple-targets {v_iv_opt} {v_egs_opt} "{v_feats}" \ + {targets} ark:{dir}/valid_all.JOB.egs""".format( + cmd=cmd, nj=num_jobs, dir=dir, + v_egs_opt=egs_opts['valid_egs_opts'], + v_iv_opt=feat_ivector_strings['valid_ivector_opts'], + v_feats=feat_ivector_strings['feats_subset_func']( + '{dir}/valid_uttlist.JOB'.format(dir=dir)), + targets=get_targets_list( + targets_parameters, + '{dir}/valid_uttlist.JOB'.format(dir=dir))), + wait=False) + + train_pid = common_lib.run_kaldi_command( + """{cmd} JOB=1:{nj} {dir}/log/create_train_subset.JOB.log \ + nnet3-get-egs-multiple-targets {t_iv_opt} {v_egs_opt} "{t_feats}" \ + {targets} ark:{dir}/train_subset_all.JOB.egs""".format( + cmd=cmd, nj=num_jobs, dir=dir, + v_egs_opt=egs_opts['valid_egs_opts'], + t_iv_opt=feat_ivector_strings['train_subset_ivector_opts'], + t_feats=feat_ivector_strings['feats_subset_func']( + '{dir}/train_subset_uttlist.JOB'.format(dir=dir)), + targets=get_targets_list( + targets_parameters, + '{dir}/train_subset_uttlist.JOB'.format(dir=dir))), + wait=False) + + wait_pids.append(valid_pid) + wait_pids.append(train_pid) + + for pid in wait_pids: + stdout, stderr = pid.communicate() + if pid.returncode != 0: + raise Exception(stderr) + + valid_egs_all = ' '.join(['{dir}/valid_all.{n}.egs'.format(dir=dir, n=n) + for n in range(1, num_jobs + 1)]) + train_subset_egs_all = ' '.join(['{dir}/train_subset_all.{n}.egs'.format( + dir=dir, n=n) + for n in range(1, num_jobs + 1)]) + + wait_pids = [] + logger.info("... Getting subsets of validation examples for diagnostics " + " and combination.") + pid = common_lib.run_kaldi_command( + """{cmd} {dir}/log/create_valid_subset_combine.log \ + cat {valid_egs_all} \| nnet3-subset-egs --n={nve_combine} ark:- \ + ark:{dir}/valid_combine.egs""".format( + cmd=cmd, dir=dir, valid_egs_all=valid_egs_all, + nve_combine=num_valid_egs_combine), + wait=False) + wait_pids.append(pid) + + pid = common_lib.run_kaldi_command( + """{cmd} {dir}/log/create_valid_subset_diagnostic.log \ + cat {valid_egs_all} \| nnet3-subset-egs --n={ne_diagnostic} ark:- \ + ark:{dir}/valid_diagnostic.egs""".format( + cmd=cmd, dir=dir, valid_egs_all=valid_egs_all, + ne_diagnostic=num_egs_diagnostic), + wait=False) + wait_pids.append(pid) + + pid = common_lib.run_kaldi_command( + """{cmd} {dir}/log/create_train_subset_combine.log \ + cat {train_subset_egs_all} \| \ + nnet3-subset-egs --n={nte_combine} ark:- \ + ark:{dir}/train_combine.egs""".format( + cmd=cmd, dir=dir, train_subset_egs_all=train_subset_egs_all, + nte_combine=num_train_egs_combine), + wait=False) + wait_pids.append(pid) + + pid = common_lib.run_kaldi_command( + """{cmd} {dir}/log/create_train_subset_diagnostic.log \ + cat {train_subset_egs_all} \| \ + nnet3-subset-egs --n={ne_diagnostic} ark:- \ + ark:{dir}/train_diagnostic.egs""".format( + cmd=cmd, dir=dir, train_subset_egs_all=train_subset_egs_all, + ne_diagnostic=num_egs_diagnostic), wait=False) + wait_pids.append(pid) + + for pid in wait_pids: + stdout, stderr = pid.communicate() + if pid.returncode != 0: + raise Exception(stderr) + + common_lib.run_kaldi_command( + """cat {dir}/valid_combine.egs {dir}/train_combine.egs > \ + {dir}/combine.egs""".format(dir=dir)) + + # perform checks + for file_name in ('{0}/combine.egs {0}/train_diagnostic.egs ' + '{0}/valid_diagnostic.egs'.format(dir).split()): + if os.path.getsize(file_name) == 0: + raise Exception("No examples in {0}".format(file_name)) + + # clean-up + for x in ('{0}/valid_all.*.egs {0}/train_subset_all.*.egs ' + '{0}/train_combine.egs ' + '{0}/valid_combine.egs'.format(dir).split()): + for file_name in glob.glob(x): + os.remove(file_name) + + +def generate_training_examples_internal(dir, targets_parameters, feat_dir, + train_feats_string, + train_egs_opts_string, + ivector_opts, + num_jobs, frames_per_eg, + samples_per_iter, cmd, srand=0, + reduce_frames_per_eg=True, + only_shuffle=False, + dry_run=False): + + # The examples will go round-robin to egs_list. Note: we omit the + # 'normalization.fst' argument while creating temporary egs: the phase of + # egs preparation that involves the normalization FST is quite + # CPU-intensive and it's more convenient to do it later, in the 'shuffle' + # stage. Otherwise to make it efficient we need to use a large 'nj', like + # 40, and in that case there can be too many small files to deal with, + # because the total number of files is the product of 'nj' by + # 'num_archives_intermediate', which might be quite large. + num_frames = data_lib.get_num_frames(feat_dir) + num_archives = (num_frames) / (frames_per_eg * samples_per_iter) + 1 + + reduced = False + while (reduce_frames_per_eg and frames_per_eg > 1 and + num_frames / ((frames_per_eg-1)*samples_per_iter) == 0): + frames_per_eg -= 1 + num_archives = 1 + reduced = True + + if reduced: + logger.info("Reduced frames-per-eg to {0} " + "because amount of data is small".format(frames_per_eg)) + + max_open_files = get_max_open_files() + num_archives_intermediate = num_archives + archives_multiple = 1 + while (num_archives_intermediate+4) > max_open_files: + archives_multiple = archives_multiple + 1 + num_archives_intermediate = int(math.ceil(float(num_archives) + / archives_multiple)) + num_archives = num_archives_intermediate * archives_multiple + egs_per_archive = num_frames/(frames_per_eg * num_archives) + + if egs_per_archive > samples_per_iter: + raise Exception( + """egs_per_archive({epa}) > samples_per_iter({fpi}). + This is an error in the logic for determining + egs_per_archive""".format(epa=egs_per_archive, + fpi=samples_per_iter)) + + if dry_run: + cleanup(dir, archives_multiple) + return {'num_frames': num_frames, + 'num_archives': num_archives, + 'egs_per_archive': egs_per_archive} + + logger.info("Splitting a total of {nf} frames into {na} archives, " + "each with {epa} egs.".format(nf=num_frames, na=num_archives, + epa=egs_per_archive)) + + if os.path.isdir('{0}/storage'.format(dir)): + # this is a striped directory, so create the softlinks + data_lib.create_data_links(["{dir}/egs.{x}.ark".format(dir=dir, x=x) + for x in range(1, num_archives + 1)]) + for x in range(1, num_archives_intermediate + 1): + data_lib.create_data_links( + ["{dir}/egs_orig.{y}.{x}.ark".format(dir=dir, x=x, y=y) + for y in range(1, num_jobs + 1)]) + + split_feat_dir = "{0}/split{1}".format(feat_dir, num_jobs) + egs_list = ' '.join(['ark:{dir}/egs_orig.JOB.{ark_num}.ark'.format( + dir=dir, ark_num=x) + for x in range(1, num_archives_intermediate + 1)]) + + if not only_shuffle: + common_lib.run_kaldi_command( + """{cmd} JOB=1:{nj} {dir}/log/get_egs.JOB.log \ + nnet3-get-egs-multiple-targets {iv_opts} {egs_opts} \ + "{feats}" {targets} ark:- \| \ + nnet3-copy-egs --random=true --srand=$[JOB+{srand}] \ + ark:- {egs_list}""".format( + cmd=cmd, nj=num_jobs, dir=dir, srand=srand, + iv_opts=ivector_opts, egs_opts=train_egs_opts_string, + feats=train_feats_string, + targets=get_targets_list(targets_parameters, + '{sdir}/JOB/utt2spk'.format( + sdir=split_feat_dir)), + egs_list=egs_list)) + + logger.info("Recombining and shuffling order of archives on disk") + egs_list = ' '.join(['{dir}/egs_orig.{n}.JOB.ark'.format(dir=dir, n=x) + for x in range(1, num_jobs + 1)]) + + if archives_multiple == 1: + # there are no intermediate archives so just shuffle egs across + # jobs and dump them into a single output + common_lib.run_kaldi_command( + """{cmd} --max-jobs-run {msjr} JOB=1:{nai} \ + {dir}/log/shuffle.JOB.log \ + nnet3-shuffle-egs --srand=$[JOB+{srand}] \ + "ark:cat {egs_list}|" ark:{dir}/egs.JOB.ark""".format( + cmd=cmd, msjr=num_jobs, + nai=num_archives_intermediate, srand=srand, + dir=dir, egs_list=egs_list)) + else: + # there are intermediate archives so we shuffle egs across jobs + # and split them into archives_multiple output archives + output_archives = ' '.join(["ark:{dir}/egs.JOB.{ark_num}.ark".format( + dir=dir, ark_num=x) + for x in range(1, archives_multiple + 1)]) + # archives were created as egs.x.y.ark + # linking them to egs.i.ark format which is expected by the training + # scripts + for i in range(1, num_archives_intermediate + 1): + for j in range(1, archives_multiple + 1): + archive_index = (i-1) * archives_multiple + j + common_lib.force_sym_link( + "egs.{0}.ark".format(archive_index), + "{dir}/egs.{i}.{j}.ark".format(dir=dir, i=i, j=j)) + + common_lib.run_kaldi_command( + """{cmd} --max-jobs-run {msjr} JOB=1:{nai} \ + {dir}/log/shuffle.JOB.log \ + nnet3-shuffle-egs --srand=$[JOB+{srand}] \ + "ark:cat {egs_list}|" ark:- \| \ + nnet3-copy-egs ark:- {oarks}""".format( + cmd=cmd, msjr=num_jobs, + nai=num_archives_intermediate, srand=srand, + dir=dir, egs_list=egs_list, oarks=output_archives)) + + cleanup(dir, archives_multiple) + return {'num_frames': num_frames, + 'num_archives': num_archives, + 'egs_per_archive': egs_per_archive} + + +def cleanup(dir, archives_multiple): + logger.info("Removing temporary archives in {0}.".format(dir)) + for file_name in glob.glob("{0}/egs_orig*".format(dir)): + real_path = os.path.realpath(file_name) + data_lib.try_to_delete(real_path) + data_lib.try_to_delete(file_name) + + if archives_multiple > 1: + # there will be some extra soft links we want to delete + for file_name in glob.glob('{0}/egs.*.*.ark'.format(dir)): + os.remove(file_name) + + +def create_directory(dir): + import errno + try: + os.makedirs(dir) + except OSError, e: + if e.errno == errno.EEXIST: + pass + + +def generate_training_examples(dir, targets_parameters, feat_dir, + feat_ivector_strings, egs_opts, + frame_shift, frames_per_eg, samples_per_iter, + cmd, num_jobs, srand=0, + only_shuffle=False, dry_run=False): + + # generate the training options string with the given chunk_width + train_egs_opts = egs_opts['train_egs_opts'] + # generate the feature vector string with the utt list for the + # current chunk width + train_feats = feat_ivector_strings['train_feats'] + + if os.path.isdir('{0}/storage'.format(dir)): + real_paths = [os.path.realpath(x).strip("/") + for x in glob.glob('{0}/storage/*'.format(dir))] + common_lib.run_kaldi_command( + """utils/create_split_dir.pl {target_dirs} \ + {dir}/storage""".format( + target_dirs=" ".join(real_paths), dir=dir)) + + info = generate_training_examples_internal( + dir=dir, targets_parameters=targets_parameters, + feat_dir=feat_dir, train_feats_string=train_feats, + train_egs_opts_string=train_egs_opts, + ivector_opts=feat_ivector_strings['ivector_opts'], + num_jobs=num_jobs, frames_per_eg=frames_per_eg, + samples_per_iter=samples_per_iter, cmd=cmd, + srand=srand, + only_shuffle=only_shuffle, + dry_run=dry_run) + + return info + + +def write_egs_info(info, info_dir): + for x in ['num_frames', 'num_archives', 'egs_per_archive', + 'feat_dim', 'ivector_dim', + 'left_context', 'right_context', 'frames_per_eg']: + write_list([info['{0}'.format(x)]], '{0}/{1}'.format(info_dir, x)) + + +def generate_egs(egs_dir, feat_dir, targets_para_array, + online_ivector_dir=None, + frames_per_eg=8, + left_context=4, + right_context=4, + valid_left_context=None, + valid_right_context=None, + cmd="run.pl", stage=0, + cmvn_opts=None, apply_cmvn_sliding=False, + compress_input=True, + input_compress_format=0, + num_utts_subset=300, + num_train_egs_combine=1000, + num_valid_egs_combine=0, + num_egs_diagnostic=4000, + samples_per_iter=400000, + num_jobs=6, + srand=0): + + for directory in '{0}/log {0}/info'.format(egs_dir).split(): + create_directory(directory) + + print (cmvn_opts if cmvn_opts is not None else '', + file=open('{0}/cmvn_opts'.format(egs_dir), 'w')) + print ("true" if apply_cmvn_sliding else "false", + file=open('{0}/apply_cmvn_sliding'.format(egs_dir), 'w')) + + targets_parameters = parse_targets_parameters_array(targets_para_array) + + # Check files + check_for_required_files(feat_dir, + [t.targets_scp for t in targets_parameters], + online_ivector_dir) + + frame_shift = data_lib.get_frame_shift(feat_dir) + min_duration = frames_per_eg * frame_shift + valid_utts = sample_utts(feat_dir, num_utts_subset, min_duration)[0] + train_subset_utts = sample_utts(feat_dir, num_utts_subset, min_duration, + exclude_list=valid_utts)[0] + train_utts, train_utts_durs = sample_utts(feat_dir, None, -1, + exclude_list=valid_utts) + + write_list(valid_utts, '{0}/valid_uttlist'.format(egs_dir)) + write_list(train_subset_utts, '{0}/train_subset_uttlist'.format(egs_dir)) + write_list(train_utts, '{0}/train_uttlist'.format(egs_dir)) + + # split the training data into parts for individual jobs + # we will use the same number of jobs as that used for alignment + split_feat_dir = common_lib.split_data(feat_dir, num_jobs) + feat_ivector_strings = get_feat_ivector_strings( + dir=egs_dir, feat_dir=feat_dir, split_feat_dir=split_feat_dir, + cmvn_opt_string=cmvn_opts, + ivector_dir=online_ivector_dir, + apply_cmvn_sliding=apply_cmvn_sliding) + + egs_opts = get_egs_options(targets_parameters=targets_parameters, + frames_per_eg=frames_per_eg, + left_context=left_context, + right_context=right_context, + valid_left_context=valid_left_context, + valid_right_context=valid_right_context, + compress_input=compress_input, + input_compress_format=input_compress_format) + + if stage <= 2: + logger.info("Generating validation and training subset examples") + + generate_valid_train_subset_egs( + dir=egs_dir, + targets_parameters=targets_parameters, + feat_ivector_strings=feat_ivector_strings, + egs_opts=egs_opts, + num_train_egs_combine=num_train_egs_combine, + num_valid_egs_combine=num_valid_egs_combine, + num_egs_diagnostic=num_egs_diagnostic, + cmd=cmd, + num_jobs=num_jobs) + + logger.info("Generating training examples on disk.") + info = generate_training_examples( + dir=egs_dir, + targets_parameters=targets_parameters, + feat_dir=feat_dir, + feat_ivector_strings=feat_ivector_strings, + egs_opts=egs_opts, + frame_shift=frame_shift, + frames_per_eg=frames_per_eg, + samples_per_iter=samples_per_iter, + cmd=cmd, + num_jobs=num_jobs, + srand=srand, + only_shuffle=True if stage > 3 else False, + dry_run=True if stage > 4 else False) + + info['feat_dim'] = feat_ivector_strings['feat_dim'] + info['ivector_dim'] = feat_ivector_strings['ivector_dim'] + info['left_context'] = left_context + info['right_context'] = right_context + info['frames_per_eg'] = frames_per_eg + + write_egs_info(info, '{dir}/info'.format(dir=egs_dir)) + + +def main(): + args = get_args() + generate_egs(args.dir, args.feat_dir, args.targets_para_array, + online_ivector_dir=args.online_ivector_dir, + frames_per_eg=args.frames_per_eg, + left_context=args.left_context, + right_context=args.right_context, + valid_left_context=args.valid_left_context, + valid_right_context=args.valid_right_context, + cmd=args.cmd, stage=args.stage, + cmvn_opts=args.cmvn_opts, + apply_cmvn_sliding=args.apply_cmvn_sliding, + compress_input=args.compress_input, + input_compress_format=args.input_compress_format, + num_utts_subset=args.num_utts_subset, + num_train_egs_combine=args.num_train_egs_combine, + num_valid_egs_combine=args.num_valid_egs_combine, + num_egs_diagnostic=args.num_egs_diagnostic, + samples_per_iter=args.samples_per_iter, + num_jobs=args.num_jobs, + srand=args.srand) + + +if __name__ == "__main__": + main() diff --git a/src/nnet3bin/Makefile b/src/nnet3bin/Makefile index d46c56a1044..aeb3dc1dc03 100644 --- a/src/nnet3bin/Makefile +++ b/src/nnet3bin/Makefile @@ -17,7 +17,7 @@ BINFILES = nnet3-init nnet3-info nnet3-get-egs nnet3-copy-egs nnet3-subset-egs \ nnet3-discriminative-merge-egs nnet3-discriminative-shuffle-egs \ nnet3-discriminative-compute-objf nnet3-discriminative-train \ discriminative-get-supervision nnet3-discriminative-subset-egs \ - nnet3-discriminative-compute-from-egs + nnet3-discriminative-compute-from-egs nnet3-get-egs-multiple-targets OBJFILES = diff --git a/src/nnet3bin/nnet3-get-egs-multiple-targets.cc b/src/nnet3bin/nnet3-get-egs-multiple-targets.cc new file mode 100644 index 00000000000..49f0dde4af7 --- /dev/null +++ b/src/nnet3bin/nnet3-get-egs-multiple-targets.cc @@ -0,0 +1,538 @@ +// nnet3bin/nnet3-get-egs-multiple-targets.cc + +// Copyright 2012-2015 Johns Hopkins University (author: Daniel Povey) +// 2014-2016 Vimal Manohar + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "base/kaldi-common.h" +#include "util/common-utils.h" +#include "hmm/transition-model.h" +#include "hmm/posterior.h" +#include "nnet3/nnet-example.h" +#include "nnet3/nnet-example-utils.h" + +namespace kaldi { +namespace nnet3 { + +bool ToBool(std::string str) { + std::transform(str.begin(), str.end(), str.begin(), ::tolower); + + if ((str.compare("true") == 0) || (str.compare("t") == 0) + || (str.compare("1") == 0)) + return true; + if ((str.compare("false") == 0) || (str.compare("f") == 0) + || (str.compare("0") == 0)) + return false; + KALDI_ERR << "Invalid format for boolean argument [expected true or false]: " + << str; + return false; // never reached +} + +static void ProcessFile(const MatrixBase &feats, + const MatrixBase *ivector_feats, + const std::vector &output_names, + const std::vector &output_dims, + const std::vector* > &dense_target_matrices, + const std::vector &posteriors, + const std::vector* > &deriv_weights, + const std::string &utt_id, + bool compress_input, + int32 input_compress_format, + const std::vector &compress_targets, + const std::vector &targets_compress_formats, + int32 left_context, + int32 right_context, + int32 frames_per_eg, + std::vector *num_frames_written, + std::vector *num_egs_written, + NnetExampleWriter *example_writer) { + KALDI_ASSERT(output_names.size() > 0); + //KALDI_ASSERT(feats.NumRows() == static_cast(targets.NumRows())); + for (int32 t = 0; t < feats.NumRows(); t += frames_per_eg) { + + int32 tot_frames = left_context + frames_per_eg + right_context; + + Matrix input_frames(tot_frames, feats.NumCols(), kUndefined); + + // Set up "input_frames". + for (int32 j = -left_context; j < frames_per_eg + right_context; j++) { + int32 t2 = j + t; + if (t2 < 0) t2 = 0; + if (t2 >= feats.NumRows()) t2 = feats.NumRows() - 1; + SubVector src(feats, t2), + dest(input_frames, j + left_context); + dest.CopyFromVec(src); + } + + NnetExample eg; + + // call the regular input "input". + eg.io.push_back(NnetIo("input", - left_context, + input_frames)); + + if (compress_input) + eg.io.back().Compress(input_compress_format); + + // if applicable, add the iVector feature. + if (ivector_feats) { + int32 actual_frames_per_eg = std::min(frames_per_eg, + feats.NumRows() - t); + // try to get closest frame to middle of window to get + // a representative iVector. + int32 closest_frame = t + (actual_frames_per_eg / 2); + KALDI_ASSERT(ivector_feats->NumRows() > 0); + if (closest_frame >= ivector_feats->NumRows()) + closest_frame = ivector_feats->NumRows() - 1; + Matrix ivector(1, ivector_feats->NumCols()); + ivector.Row(0).CopyFromVec(ivector_feats->Row(closest_frame)); + eg.io.push_back(NnetIo("ivector", 0, ivector)); + } + + int32 num_outputs_added = 0; + + for (int32 n = 0; n < output_names.size(); n++) { + Vector this_deriv_weights(0); + if (deriv_weights[n]) { + // actual_frames_per_eg is the number of frames with actual targets. + // At the end of the file, we pad with the last frame repeated + // so that all examples have the same structure (prevents the need + // for recompilations). + int32 actual_frames_per_eg = std::min(std::min(frames_per_eg, + feats.NumRows() - t), deriv_weights[n]->Dim() - t); + + this_deriv_weights.Resize(frames_per_eg); + int32 frames_to_copy = std::min(t + actual_frames_per_eg, + deriv_weights[n]->Dim()) - t; + this_deriv_weights.Range(0, frames_to_copy).CopyFromVec(deriv_weights[n]->Range(t, frames_to_copy)); + if (this_deriv_weights.Sum() == 0) { + continue; // Ignore frames that have frame weights 0 + } + } + + if (dense_target_matrices[n]) { + const MatrixBase &targets = *dense_target_matrices[n]; + Matrix targets_dest(frames_per_eg, targets.NumCols()); + + // actual_frames_per_eg is the number of frames with actual targets. + // At the end of the file, we pad with the last frame repeated + // so that all examples have the same structure (prevents the need + // for recompilations). + int32 actual_frames_per_eg = std::min(std::min(frames_per_eg, + feats.NumRows() - t), targets.NumRows() - t); + + for (int32 i = 0; i < actual_frames_per_eg; i++) { + // Copy the i^th row of the target matrix from the (t+i)^th row of the + // input targets matrix + SubVector this_target_dest(targets_dest, i); + SubVector this_target_src(targets, t+i); + this_target_dest.CopyFromVec(this_target_src); + } + + // Copy the last frame's target to the padded frames + for (int32 i = actual_frames_per_eg; i < frames_per_eg; i++) { + // Copy the i^th row of the target matrix from the last row of the + // input targets matrix + KALDI_ASSERT(t + actual_frames_per_eg - 1 == targets.NumRows() - 1); + SubVector this_target_dest(targets_dest, i); + SubVector this_target_src(targets, t+actual_frames_per_eg-1); + this_target_dest.CopyFromVec(this_target_src); + } + + if (deriv_weights[n]) { + eg.io.push_back(NnetIo(output_names[n], this_deriv_weights, 0, targets_dest)); + } else { + eg.io.push_back(NnetIo(output_names[n], 0, targets_dest)); + } + } else if (posteriors[n]) { + const Posterior &pdf_post = *(posteriors[n]); + + // actual_frames_per_eg is the number of frames with actual targets. + // At the end of the file, we pad with the last frame repeated + // so that all examples have the same structure (prevents the need + // for recompilations). + int32 actual_frames_per_eg = std::min(std::min(frames_per_eg, + feats.NumRows() - t), static_cast(pdf_post.size()) - t); + + Posterior labels(frames_per_eg); + for (int32 i = 0; i < actual_frames_per_eg; i++) + labels[i] = pdf_post[t + i]; + // remaining posteriors for frames are empty. + + if (deriv_weights[n]) { + eg.io.push_back(NnetIo(output_names[n], this_deriv_weights, output_dims[n], 0, labels)); + } else { + eg.io.push_back(NnetIo(output_names[n], output_dims[n], 0, labels)); + } + } else + continue; + if (compress_targets[n]) + eg.io.back().Compress(targets_compress_formats[n]); + + num_outputs_added++; + (*num_frames_written)[n] += frames_per_eg; // Actually actual_frames_per_eg, but that depends on the different output. For simplification, frames_per_eg is used. + (*num_egs_written)[n] += 1; + } + + if (num_outputs_added == 0) continue; + + std::ostringstream os; + os << utt_id << "-" << t; + + std::string key = os.str(); // key is - + + KALDI_ASSERT(NumOutputs(eg) == num_outputs_added); + + example_writer->Write(key, eg); + } +} + + +} // namespace nnet2 +} // namespace kaldi + +int main(int argc, char *argv[]) { + try { + using namespace kaldi; + using namespace kaldi::nnet3; + typedef kaldi::int32 int32; + typedef kaldi::int64 int64; + + const char *usage = + "Get frame-by-frame examples of data for nnet3 neural network training.\n" + "This program is similar to nnet3-get-egs, but the targets here are " + "dense matrices instead of posteriors (sparse matrices).\n" + "This is useful when you want the targets to be continuous real-valued " + "with the neural network possibly trained with a quadratic objective\n" + "\n" + "Usage: nnet3-get-egs-multiple-targets [options] " + " ::[:] " + "[ :: ... ] \n" + "\n" + "Here is any random string for output node name, \n" + " is the rspecifier for either dense targets in matrix format or sparse targets in posterior format,\n" + "and is the target dimension of output node for sparse targets or -1 for dense targets\n" + "\n" + "An example [where $feats expands to the actual features]:\n" + "nnet-get-egs-multiple-targets --left-context=12 \\\n" + "--right-context=9 --num-frames=8 \"$feats\" \\\n" + "output-snr:\"ark:copy-matrix ark:exp/snrs/snr.1.ark ark:- |\":-1 \n" + " ark:- \n"; + + + bool compress_input = true; + int32 input_compress_format = 0; + int32 left_context = 0, right_context = 0, + num_frames = 1, length_tolerance = 2; + + std::string ivector_rspecifier, + targets_compress_formats_str, + compress_targets_str; + std::string output_dims_str; + std::string output_names_str; + + ParseOptions po(usage); + po.Register("compress-input", &compress_input, "If true, write egs in " + "compressed format."); + po.Register("input-compress-format", &input_compress_format, "Format for " + "compressing input feats e.g. Use 2 for compressing wave"); + po.Register("compress-targets", &compress_targets_str, "CSL of whether " + "targets must be compressed for each of the outputs"); + po.Register("targets-compress-formats", &targets_compress_formats_str, "Format for " + "compressing all feats in general"); + po.Register("left-context", &left_context, "Number of frames of left " + "context the neural net requires."); + po.Register("right-context", &right_context, "Number of frames of right " + "context the neural net requires."); + po.Register("num-frames", &num_frames, "Number of frames with labels " + "that each example contains."); + po.Register("ivectors", &ivector_rspecifier, "Rspecifier of ivector " + "features, as matrix."); + po.Register("length-tolerance", &length_tolerance, "Tolerance for " + "difference in num-frames between feat and ivector matrices"); + po.Register("output-dims", &output_dims_str, "CSL of output node dims"); + po.Register("output-names", &output_names_str, "CSL of output node names"); + //po.Register("deriv-weights-rspecifiers", &deriv_weights_rspecifiers_str, + // "CSL of per-frame weights (only binary - 0 or 1) that specifies " + // "whether a frame's gradient must be backpropagated or not. " + // "Not specifying this is equivalent to specifying a vector of " + // "all 1s."); + + po.Read(argc, argv); + + if (po.NumArgs() < 3) { + po.PrintUsage(); + exit(1); + } + + std::string feature_rspecifier = po.GetArg(1), + examples_wspecifier = po.GetArg(po.NumArgs()); + + // Read in all the training files. + SequentialBaseFloatMatrixReader feat_reader(feature_rspecifier); + RandomAccessBaseFloatMatrixReader ivector_reader(ivector_rspecifier); + NnetExampleWriter example_writer(examples_wspecifier); + + int32 num_outputs = (po.NumArgs() - 2) / 2; + KALDI_ASSERT(num_outputs > 0); + + std::vector deriv_weights_readers(num_outputs, + static_cast(NULL)); + std::vector dense_targets_readers(num_outputs, + static_cast(NULL)); + std::vector sparse_targets_readers(num_outputs, + static_cast(NULL)); + + + std::vector compress_targets(1, true); + std::vector compress_targets_vector; + + if (!compress_targets_str.empty()) { + SplitStringToVector(compress_targets_str, ":,", + true, &compress_targets_vector); + } + + if (compress_targets_vector.size() == 1 && num_outputs != 1) { + KALDI_WARN << "compress-targets is of size 1. " + << "Extending it to size num-outputs=" << num_outputs; + compress_targets[0] = ToBool(compress_targets_vector[0]); + compress_targets.resize(num_outputs, ToBool(compress_targets_vector[0])); + } else { + if (compress_targets_vector.size() != num_outputs) { + KALDI_ERR << "Mismatch in length of compress-targets and num-outputs; " + << compress_targets_vector.size() << " vs " << num_outputs; + } + for (int32 n = 0; n < num_outputs; n++) { + compress_targets[n] = ToBool(compress_targets_vector[n]); + } + } + + std::vector targets_compress_formats(1, 1); + if (!targets_compress_formats_str.empty()) { + SplitStringToIntegers(targets_compress_formats_str, ":,", + true, &targets_compress_formats); + } + + if (targets_compress_formats.size() == 1 && num_outputs != 1) { + KALDI_WARN << "targets-compress-formats is of size 1. " + << "Extending it to size num-outputs=" << num_outputs; + targets_compress_formats.resize(num_outputs, targets_compress_formats[0]); + } + + if (targets_compress_formats.size() != num_outputs) { + KALDI_ERR << "Mismatch in length of targets-compress-formats and num-outputs; " + << targets_compress_formats.size() << " vs " << num_outputs; + } + + std::vector output_dims(num_outputs); + SplitStringToIntegers(output_dims_str, ":,", + true, &output_dims); + + std::vector output_names(num_outputs); + SplitStringToVector(output_names_str, ":,", true, &output_names); + + //std::vector deriv_weights_rspecifiers; + //if (!deriv_weights_rspecifiers_str.empty()) { + // std::vector parts; + // SplitStringToVector(deriv_weights_rspecifiers_str, ":,", + // false, &deriv_weights_rspecifiers); + + // if (deriv_weights_rspecifiers.size() != num_outputs) { + // KALDI_ERR << "Expecting the number of deriv-weights-rspecifiers to " + // << "be equal to the number of outputs"; + // } + //} + + std::vector targets_rspecifiers(num_outputs); + std::vector deriv_weights_rspecifiers(num_outputs); + + for (int32 n = 0; n < num_outputs; n++) { + const std::string &targets_rspecifier = po.GetArg(2*n + 2); + const std::string &deriv_weights_rspecifier = po.GetArg(2*n + 3); + + targets_rspecifiers[n] = targets_rspecifier; + deriv_weights_rspecifiers[n] = deriv_weights_rspecifier; + + if (output_dims[n] >= 0) { + sparse_targets_readers[n] = new RandomAccessPosteriorReader(targets_rspecifier); + } else { + dense_targets_readers[n] = new RandomAccessBaseFloatMatrixReader(targets_rspecifier); + } + + if (!deriv_weights_rspecifier.empty()) + deriv_weights_readers[n] = new RandomAccessBaseFloatVectorReader(deriv_weights_rspecifier); + + KALDI_LOG << "output-name=" << output_names[n] + << " target-dim=" << output_dims[n] + << " targets-rspecifier=\"" << targets_rspecifiers[n] << "\"" + << " deriv-weights-rspecifier=\"" << deriv_weights_rspecifiers[n] << "\"" + << " compress-target=" << (compress_targets[n] ? "true" : "false") + << " target-compress-format=" << targets_compress_formats[n]; + } + + int32 num_done = 0, num_err = 0; + + std::vector num_frames_written(num_outputs, 0); + std::vector num_egs_written(num_outputs, 0); + + for (; !feat_reader.Done(); feat_reader.Next()) { + std::string key = feat_reader.Key(); + const Matrix &feats = feat_reader.Value(); + + const Matrix *ivector_feats = NULL; + if (!ivector_rspecifier.empty()) { + if (!ivector_reader.HasKey(key)) { + KALDI_WARN << "No iVectors for utterance " << key; + num_err++; + continue; + } else { + // this address will be valid until we call HasKey() or Value() + // again. + ivector_feats = &(ivector_reader.Value(key)); + } + } + + if (ivector_feats && + (abs(feats.NumRows() - ivector_feats->NumRows()) > length_tolerance + || ivector_feats->NumRows() == 0)) { + KALDI_WARN << "Length difference between feats " << feats.NumRows() + << " and iVectors " << ivector_feats->NumRows() + << "exceeds tolerance " << length_tolerance; + num_err++; + continue; + } + + std::vector* > dense_targets(num_outputs, static_cast* >(NULL)); + std::vector sparse_targets(num_outputs, static_cast(NULL)); + std::vector* > deriv_weights(num_outputs, static_cast* >(NULL)); + + int32 num_outputs_found = 0; + for (int32 n = 0; n < num_outputs; n++) { + if (dense_targets_readers[n]) { + if (!dense_targets_readers[n]->HasKey(key)) { + KALDI_WARN << "No dense targets matrix for key " << key << " in " + << "rspecifier " << targets_rspecifiers[n] + << " for output " << output_names[n]; + continue; + } + const MatrixBase *target_matrix = &(dense_targets_readers[n]->Value(key)); + + if ((target_matrix->NumRows() - feats.NumRows()) > length_tolerance) { + KALDI_WARN << "Length difference between feats " << feats.NumRows() + << " and target matrix " << target_matrix->NumRows() + << "exceeds tolerance " << length_tolerance; + num_err++; + continue; + } + + dense_targets[n] = target_matrix; + } else { + if (!sparse_targets_readers[n]->HasKey(key)) { + KALDI_WARN << "No sparse target matrix for key " << key << " in " + << "rspecifier " << targets_rspecifiers[n] + << " for output " << output_names[n]; + continue; + } + const Posterior *posterior = &(sparse_targets_readers[n]->Value(key)); + + if (abs(static_cast(posterior->size()) - feats.NumRows()) > length_tolerance + || posterior->size() < feats.NumRows()) { + KALDI_WARN << "Posterior has wrong size " << posterior->size() + << " versus " << feats.NumRows(); + num_err++; + continue; + } + + sparse_targets[n] = posterior; + } + + if (deriv_weights_readers[n]) { + if (!deriv_weights_readers[n]->HasKey(key)) { + KALDI_WARN << "No deriv weights for key " << key << " in " + << "rspecifier " << deriv_weights_rspecifiers[n] + << " for output " << output_names[n]; + num_err++; + sparse_targets[n] = NULL; + dense_targets[n] = NULL; + continue; + } else { + // this address will be valid until we call HasKey() or Value() + // again. + deriv_weights[n] = &(deriv_weights_readers[n]->Value(key)); + } + } + + if (deriv_weights[n] && + (abs(feats.NumRows() - deriv_weights[n]->Dim()) > length_tolerance + || deriv_weights[n]->Dim() == 0)) { + KALDI_WARN << "Length difference between feats " << feats.NumRows() + << " and deriv weights " << deriv_weights[n]->Dim() + << " exceeds tolerance " << length_tolerance; + num_err++; + sparse_targets[n] = NULL; + dense_targets[n] = NULL; + deriv_weights[n] = NULL; + continue; + } + + num_outputs_found++; + } + + if (num_outputs_found == 0) { + KALDI_WARN << "No output found for key " << key; + num_err++; + continue; + } + + ProcessFile(feats, ivector_feats, output_names, output_dims, + dense_targets, sparse_targets, + deriv_weights, key, + compress_input, input_compress_format, + compress_targets, targets_compress_formats, + left_context, right_context, num_frames, + &num_frames_written, &num_egs_written, + &example_writer); + num_done++; + } + + int64 max_num_egs_written = 0, max_num_frames_written = 0; + for (int32 n = 0; n < num_outputs; n++) { + delete dense_targets_readers[n]; + delete sparse_targets_readers[n]; + delete deriv_weights_readers[n]; + if (num_egs_written[n] == 0) return false; + if (num_egs_written[n] > max_num_egs_written) { + max_num_egs_written = num_egs_written[n]; + max_num_frames_written = num_frames_written[n]; + } + } + + KALDI_LOG << "Finished generating examples, " + << "successfully processed " << num_done + << " feature files, wrote at most " << max_num_egs_written << " examples, " + << " with at most " << max_num_frames_written << " egs in total; " + << num_err << " files had errors."; + + return (num_err > num_done ? 1 : 0); + } catch(const std::exception &e) { + std::cerr << e.what() << '\n'; + return -1; + } +} + + From 4baeb72a22f53cceba36afb55d6683b872443d34 Mon Sep 17 00:00:00 2001 From: Vimal Manohar Date: Wed, 23 Nov 2016 19:02:00 -0500 Subject: [PATCH 035/213] asr_diarization: Modify the egs binaries and utilities to support multiple outputs in egs --- src/nnet3/nnet-example-utils.cc | 9 +++- src/nnet3/nnet-example-utils.h | 2 + src/nnet3/nnet-example.cc | 3 ++ src/nnet3bin/nnet3-compute-from-egs.cc | 10 +++-- src/nnet3bin/nnet3-copy-egs.cc | 59 +++++++++++++++++++++++++- src/nnet3bin/nnet3-merge-egs.cc | 6 ++- 6 files changed, 81 insertions(+), 8 deletions(-) diff --git a/src/nnet3/nnet-example-utils.cc b/src/nnet3/nnet-example-utils.cc index 39922153db4..548fb842385 100644 --- a/src/nnet3/nnet-example-utils.cc +++ b/src/nnet3/nnet-example-utils.cc @@ -111,7 +111,7 @@ static void MergeIo(const std::vector &src, io.indexes.resize(size); } - std::vector::const_iterator names_begin = names.begin(); + std::vector::const_iterator names_begin = names.begin(), names_end = names.end(); std::vector::const_iterator eg_iter = src.begin(), eg_end = src.end(); @@ -318,6 +318,13 @@ void RoundUpNumFrames(int32 frame_subsampling_factor, } } +int32 NumOutputs(const NnetExample &eg) { + int32 num_outputs = 0; + for (size_t i = 0; i < eg.io.size(); i++) + if (eg.io[i].name.find("output") != std::string::npos) + num_outputs++; + return num_outputs; +} } // namespace nnet3 } // namespace kaldi diff --git a/src/nnet3/nnet-example-utils.h b/src/nnet3/nnet-example-utils.h index 3e309e18915..d223c5eb5d1 100644 --- a/src/nnet3/nnet-example-utils.h +++ b/src/nnet3/nnet-example-utils.h @@ -80,6 +80,8 @@ void RoundUpNumFrames(int32 frame_subsampling_factor, int32 *num_frames, int32 *num_frames_overlap); +// Returns the number of outputs in an eg +int32 NumOutputs(const NnetExample &eg); } // namespace nnet3 } // namespace kaldi diff --git a/src/nnet3/nnet-example.cc b/src/nnet3/nnet-example.cc index 89d40b9ef89..2ad90c0f11d 100644 --- a/src/nnet3/nnet-example.cc +++ b/src/nnet3/nnet-example.cc @@ -128,6 +128,9 @@ NnetIo::NnetIo(const std::string &name, } void NnetExample::Write(std::ostream &os, bool binary) const { +#ifdef KALDI_PARANOID + KALDI_ASSERT(NumOutputs(eg) > 0); +#endif // Note: weight, label, input_frames and spk_info are members. This is a // struct. WriteToken(os, binary, ""); diff --git a/src/nnet3bin/nnet3-compute-from-egs.cc b/src/nnet3bin/nnet3-compute-from-egs.cc index 66eace0dab5..e35e67bbeb5 100644 --- a/src/nnet3bin/nnet3-compute-from-egs.cc +++ b/src/nnet3bin/nnet3-compute-from-egs.cc @@ -36,7 +36,8 @@ class NnetComputerFromEg { // Compute the output (which will have the same number of rows as the number // of Indexes in the output of the eg), and put it in "output". - void Compute(const NnetExample &eg, Matrix *output) { + void Compute(const NnetExample &eg, const std::string &output_name, + Matrix *output) { ComputationRequest request; bool need_backprop = false, store_stats = false; GetComputationRequest(nnet_, eg, need_backprop, store_stats, &request); @@ -47,7 +48,7 @@ class NnetComputerFromEg { NnetComputer computer(options, computation, nnet_, NULL); computer.AcceptInputs(nnet_, eg.io); computer.Forward(); - const CuMatrixBase &nnet_output = computer.GetOutput("output"); + const CuMatrixBase &nnet_output = computer.GetOutput(output_name); output->Resize(nnet_output.NumRows(), nnet_output.NumCols()); nnet_output.CopyToMat(output); } @@ -80,11 +81,14 @@ int main(int argc, char *argv[]) { bool binary_write = true, apply_exp = false; std::string use_gpu = "yes"; + std::string output_name = "output"; ParseOptions po(usage); po.Register("binary", &binary_write, "Write output in binary mode"); po.Register("apply-exp", &apply_exp, "If true, apply exp function to " "output"); + po.Register("output-name", &output_name, "Do computation for " + "specified output"); po.Register("use-gpu", &use_gpu, "yes|no|optional|wait, only has effect if compiled with CUDA"); @@ -115,7 +119,7 @@ int main(int argc, char *argv[]) { for (; !example_reader.Done(); example_reader.Next(), num_egs++) { Matrix output; - computer.Compute(example_reader.Value(), &output); + computer.Compute(example_reader.Value(), output_name, &output); KALDI_ASSERT(output.NumRows() != 0); if (apply_exp) output.ApplyExp(); diff --git a/src/nnet3bin/nnet3-copy-egs.cc b/src/nnet3bin/nnet3-copy-egs.cc index ceb415ffe87..2702ae5fae9 100644 --- a/src/nnet3bin/nnet3-copy-egs.cc +++ b/src/nnet3bin/nnet3-copy-egs.cc @@ -23,10 +23,29 @@ #include "hmm/transition-model.h" #include "nnet3/nnet-example.h" #include "nnet3/nnet-example-utils.h" +#include namespace kaldi { namespace nnet3 { +bool KeepOutputs(const std::vector &keep_outputs, + NnetExample *eg) { + std::vector io_new; + int32 num_outputs = 0; + for (std::vector::iterator it = eg->io.begin(); + it != eg->io.end(); ++it) { + if (it->name.find("output") != std::string::npos) { + if (!std::binary_search(keep_outputs.begin(), keep_outputs.end(), it->name)) + continue; + num_outputs++; + } + io_new.push_back(*it); + } + eg->io.swap(io_new); + + return num_outputs; +} + // returns an integer randomly drawn with expected value "expected_count" // (will be either floor(expected_count) or ceil(expected_count)). int32 GetCount(double expected_count) { @@ -257,6 +276,22 @@ bool SelectFromExample(const NnetExample &eg, return true; } +bool RemoveZeroDerivOutputs(NnetExample *eg) { + std::vector io_new; + int32 num_outputs = 0; + for (std::vector::iterator it = eg->io.begin(); + it != eg->io.end(); ++it) { + if (it->name.find("output") != std::string::npos) { + if (it->deriv_weights.Dim() > 0 && it->deriv_weights.Sum() == 0) + continue; + num_outputs++; + } + io_new.push_back(*it); + } + eg->io.swap(io_new); + + return (num_outputs > 0); +} } // namespace nnet3 } // namespace kaldi @@ -284,6 +319,8 @@ int main(int argc, char *argv[]) { int32 srand_seed = 0; int32 frame_shift = 0; BaseFloat keep_proportion = 1.0; + std::string keep_outputs_str; + bool remove_zero_deriv_outputs = false; // The following config variables, if set, can be used to extract a single // frame of labels from a multi-frame example, and/or to reduce the amount @@ -315,7 +352,11 @@ int main(int argc, char *argv[]) { "feature left-context that we output."); po.Register("right-context", &right_context, "Can be used to truncate the " "feature right-context that we output."); - + po.Register("keep-outputs", &keep_outputs_str, "Comma separated list of " + "output nodes to keep"); + po.Register("remove-zero-deriv-outputs", &remove_zero_deriv_outputs, + "Remove outputs that do not contribute to the objective " + "because of zero deriv-weights"); po.Read(argc, argv); @@ -335,17 +376,29 @@ int main(int argc, char *argv[]) { for (int32 i = 0; i < num_outputs; i++) example_writers[i] = new NnetExampleWriter(po.GetArg(i+2)); + std::vector keep_outputs; + if (!keep_outputs_str.empty()) { + SplitStringToVector(keep_outputs_str, ",:", true, &keep_outputs); + std::sort(keep_outputs.begin(), keep_outputs.end()); + } int64 num_read = 0, num_written = 0; for (; !example_reader.Done(); example_reader.Next(), num_read++) { // count is normally 1; could be 0, or possibly >1. int32 count = GetCount(keep_proportion); std::string key = example_reader.Key(); - const NnetExample &eg = example_reader.Value(); + NnetExample eg(example_reader.Value()); + + if (!keep_outputs_str.empty()) { + if (!KeepOutputs(keep_outputs, &eg)) continue; + } + for (int32 c = 0; c < count; c++) { int32 index = (random ? Rand() : num_written) % num_outputs; if (frame_str == "" && left_context == -1 && right_context == -1 && frame_shift == 0) { + if (remove_zero_deriv_outputs) + if (!RemoveZeroDerivOutputs(&eg)) continue; example_writers[index]->Write(key, eg); num_written++; } else { // the --frame option or context options were set. @@ -354,6 +407,8 @@ int main(int argc, char *argv[]) { frame_shift, &eg_modified)) { // this branch of the if statement will almost always be taken (should only // not be taken for shorter-than-normal egs from the end of a file. + if (remove_zero_deriv_outputs) + if (!RemoveZeroDerivOutputs(&eg_modified)) continue; example_writers[index]->Write(key, eg_modified); num_written++; } diff --git a/src/nnet3bin/nnet3-merge-egs.cc b/src/nnet3bin/nnet3-merge-egs.cc index 7415db8d12a..30096ab9988 100644 --- a/src/nnet3bin/nnet3-merge-egs.cc +++ b/src/nnet3bin/nnet3-merge-egs.cc @@ -26,8 +26,10 @@ namespace kaldi { namespace nnet3 { -// returns the number of indexes/frames in the NnetIo named "output" in the eg, -// or crashes if it is not there. +// returns the number of indexes/frames in the output NnetIo +// assumes the output name starts with "output" and only looks at the +// first such output to get the indexes size. +// crashes if it there is no such output int32 NumOutputIndexes(const NnetExample &eg) { for (size_t i = 0; i < eg.io.size(); i++) if (eg.io[i].name.find("output") != std::string::npos) From 687b0f19864c3a706c6e4657b51d78809159435a Mon Sep 17 00:00:00 2001 From: Vimal Manohar Date: Thu, 24 Nov 2016 00:28:58 -0500 Subject: [PATCH 036/213] asr_diarization: Adding local/snr/make_sad_tdnn_configs.py and stats component --- .../segmentation/make_sad_tdnn_configs.py | 616 ++++++++++++++++++ egs/wsj/s5/steps/nnet3/components.py | 112 +++- 2 files changed, 720 insertions(+), 8 deletions(-) create mode 100755 egs/aspire/s5/local/segmentation/make_sad_tdnn_configs.py diff --git a/egs/aspire/s5/local/segmentation/make_sad_tdnn_configs.py b/egs/aspire/s5/local/segmentation/make_sad_tdnn_configs.py new file mode 100755 index 00000000000..e859a3593ce --- /dev/null +++ b/egs/aspire/s5/local/segmentation/make_sad_tdnn_configs.py @@ -0,0 +1,616 @@ +#!/usr/bin/env python + +# we're using python 3.x style print but want it to work in python 2.x, +from __future__ import print_function +import os +import argparse +import shlex +import sys +import warnings +import copy +import imp +import ast + +nodes = imp.load_source('', 'steps/nnet3/components.py') +import libs.common as common_lib + +def GetArgs(): + # we add compulsary arguments as named arguments for readability + parser = argparse.ArgumentParser(description="Writes config files and variables " + "for TDNNs creation and training", + epilog="See steps/nnet3/tdnn/train.sh for example.") + + # Only one of these arguments can be specified, and one of them has to + # be compulsarily specified + feat_group = parser.add_mutually_exclusive_group(required = True) + feat_group.add_argument("--feat-dim", type=int, + help="Raw feature dimension, e.g. 13") + feat_group.add_argument("--feat-dir", type=str, + help="Feature directory, from which we derive the feat-dim") + + # only one of these arguments can be specified + ivector_group = parser.add_mutually_exclusive_group(required = False) + ivector_group.add_argument("--ivector-dim", type=int, + help="iVector dimension, e.g. 100", default=0) + ivector_group.add_argument("--ivector-dir", type=str, + help="iVector dir, which will be used to derive the ivector-dim ", default=None) + + num_target_group = parser.add_mutually_exclusive_group(required = True) + num_target_group.add_argument("--num-targets", type=int, + help="number of network targets (e.g. num-pdf-ids/num-leaves)") + num_target_group.add_argument("--ali-dir", type=str, + help="alignment directory, from which we derive the num-targets") + num_target_group.add_argument("--tree-dir", type=str, + help="directory with final.mdl, from which we derive the num-targets") + num_target_group.add_argument("--output-node-parameters", type=str, action='append', + dest='output_node_para_array', + help = "Define output nodes' and their parameters like output-suffix, dim, objective-type etc") + # CNN options + parser.add_argument('--cnn.layer', type=str, action='append', dest = "cnn_layer", + help="CNN parameters at each CNN layer, e.g. --filt-x-dim=3 --filt-y-dim=8 " + "--filt-x-step=1 --filt-y-step=1 --num-filters=256 --pool-x-size=1 --pool-y-size=3 " + "--pool-z-size=1 --pool-x-step=1 --pool-y-step=3 --pool-z-step=1, " + "when CNN layers are used, no LDA will be added", default = None) + parser.add_argument("--cnn.bottleneck-dim", type=int, dest = "cnn_bottleneck_dim", + help="Output dimension of the linear layer at the CNN output " + "for dimension reduction, e.g. 256." + "The default zero means this layer is not needed.", default=0) + + # General neural network options + parser.add_argument("--splice-indexes", type=str, required = True, + help="Splice indexes at each layer, e.g. '-3,-2,-1,0,1,2,3' " + "If CNN layers are used the first set of splice indexes will be used as input " + "to the first CNN layer and later splice indexes will be interpreted as indexes " + "for the TDNNs.") + parser.add_argument("--add-lda", type=str, action=common_lib.StrToBoolAction, + help="If \"true\" an LDA matrix computed from the input features " + "(spliced according to the first set of splice-indexes) will be used as " + "the first Affine layer. This affine layer's parameters are fixed during training. " + "This variable needs to be set to \"false\" when using dense-targets.\n" + "If --cnn.layer is specified this option will be forced to \"false\".", + default=True, choices = ["false", "true"]) + + parser.add_argument("--include-log-softmax", type=str, action=common_lib.StrToBoolAction, + help="add the final softmax layer ", default=True, choices = ["false", "true"]) + parser.add_argument("--add-final-sigmoid", type=str, action=common_lib.StrToBoolAction, + help="add a final sigmoid layer as alternate to log-softmax-layer. " + "Can only be used if include-log-softmax is false. " + "This is useful in cases where you want the output to be " + "like probabilities between 0 and 1. Typically the nnet " + "is trained with an objective such as quadratic", + default=False, choices = ["false", "true"]) + + parser.add_argument("--objective-type", type=str, + help = "the type of objective; i.e. quadratic or linear", + default="linear", choices = ["linear", "quadratic"]) + parser.add_argument("--xent-regularize", type=float, + help="For chain models, if nonzero, add a separate output for cross-entropy " + "regularization (with learning-rate-factor equal to the inverse of this)", + default=0.0) + parser.add_argument("--final-layer-normalize-target", type=float, + help="RMS target for final layer (set to <1 if final layer learns too fast", + default=1.0) + parser.add_argument("--subset-dim", type=int, default=0, + help="dimension of the subset of units to be sent to the central frame") + parser.add_argument("--pnorm-input-dim", type=int, + help="input dimension to p-norm nonlinearities") + parser.add_argument("--pnorm-output-dim", type=int, + help="output dimension of p-norm nonlinearities") + relu_dim_group = parser.add_mutually_exclusive_group(required = False) + relu_dim_group.add_argument("--relu-dim", type=int, + help="dimension of all ReLU nonlinearity layers") + relu_dim_group.add_argument("--relu-dim-final", type=int, + help="dimension of the last ReLU nonlinearity layer. Dimensions increase geometrically from the first through the last ReLU layer.", default=None) + parser.add_argument("--relu-dim-init", type=int, + help="dimension of the first ReLU nonlinearity layer. Dimensions increase geometrically from the first through the last ReLU layer.", default=None) + + parser.add_argument("--self-repair-scale-nonlinearity", type=float, + help="A non-zero value activates the self-repair mechanism in the sigmoid and tanh non-linearities of the LSTM", default=None) + + + parser.add_argument("--use-presoftmax-prior-scale", type=str, action=common_lib.StrToBoolAction, + help="if true, a presoftmax-prior-scale is added", + choices=['true', 'false'], default = True) + + # Options to convert input MFCC into Fbank features. This is useful when a + # LDA layer is not added (such as when using dense targets) + parser.add_argument("--cnn.cepstral-lifter", type=float, dest = "cepstral_lifter", + help="The factor used for determining the liftering vector in the production of MFCC. " + "User has to ensure that it matches the lifter used in MFCC generation, " + "e.g. 22.0", default=22.0) + + parser.add_argument("config_dir", + help="Directory to write config files and variables") + + print(' '.join(sys.argv)) + + args = parser.parse_args() + args = CheckArgs(args) + + return args + +def CheckArgs(args): + if not os.path.exists(args.config_dir): + os.makedirs(args.config_dir) + + ## Check arguments. + if args.feat_dir is not None: + args.feat_dim = common_lib.get_feat_dim(args.feat_dir) + + if args.ivector_dir is not None: + args.ivector_dim = common_lib.get_ivector_dim(args.ivector_dir) + + if not args.feat_dim > 0: + raise Exception("feat-dim has to be postive") + + if len(args.output_node_para_array) == 0: + if args.ali_dir is not None: + args.num_targets = common_lib.get_number_of_leaves_from_tree(args.ali_dir) + elif args.tree_dir is not None: + args.num_targets = common_lib.get_number_of_leaves_from_tree(args.tree_dir) + if not args.num_targets > 0: + print(args.num_targets) + raise Exception("num_targets has to be positive") + args.output_node_para_array.append( + "--dim={0} --objective-type={1} --include-log-softmax={2} --add-final-sigmoid={3} --xent-regularize={4}".format( + args.num_targets, args.objective_type, + "true" if args.include_log_softmax else "false", + "true" if args.add_final_sigmoid else "false", + args.xent_regularize)) + + if not args.ivector_dim >= 0: + raise Exception("ivector-dim has to be non-negative") + + if (args.subset_dim < 0): + raise Exception("--subset-dim has to be non-negative") + + if not args.relu_dim is None: + if not args.pnorm_input_dim is None or not args.pnorm_output_dim is None or not args.relu_dim_init is None: + raise Exception("--relu-dim argument not compatible with " + "--pnorm-input-dim or --pnorm-output-dim or --relu-dim-init options"); + args.nonlin_input_dim = args.relu_dim + args.nonlin_output_dim = args.relu_dim + args.nonlin_output_dim_final = None + args.nonlin_output_dim_init = None + args.nonlin_type = 'relu' + + elif not args.relu_dim_final is None: + if not args.pnorm_input_dim is None or not args.pnorm_output_dim is None: + raise Exception("--relu-dim-final argument not compatible with " + "--pnorm-input-dim or --pnorm-output-dim options") + if args.relu_dim_init is None: + raise Exception("--relu-dim-init argument should also be provided with --relu-dim-final") + if args.relu_dim_init > args.relu_dim_final: + raise Exception("--relu-dim-init has to be no larger than --relu-dim-final") + args.nonlin_input_dim = None + args.nonlin_output_dim = None + args.nonlin_output_dim_final = args.relu_dim_final + args.nonlin_output_dim_init = args.relu_dim_init + args.nonlin_type = 'relu' + + else: + if not args.relu_dim_init is None: + raise Exception("--relu-dim-final argument not compatible with " + "--pnorm-input-dim or --pnorm-output-dim options") + if not args.pnorm_input_dim > 0 or not args.pnorm_output_dim > 0: + raise Exception("--relu-dim not set, so expected --pnorm-input-dim and " + "--pnorm-output-dim to be provided."); + args.nonlin_input_dim = args.pnorm_input_dim + args.nonlin_output_dim = args.pnorm_output_dim + if (args.nonlin_input_dim < args.nonlin_output_dim) or (args.nonlin_input_dim % args.nonlin_output_dim != 0): + raise Exception("Invalid --pnorm-input-dim {0} and --pnorm-output-dim {1}".format(args.nonlin_input_dim, args.nonlin_output_dim)) + args.nonlin_output_dim_final = None + args.nonlin_output_dim_init = None + args.nonlin_type = 'pnorm' + + if args.add_lda and args.cnn_layer is not None: + args.add_lda = False + warnings.warn("--add-lda is set to false as CNN layers are used.") + + return args + +def AddConvMaxpLayer(config_lines, name, input, args): + if '3d-dim' not in input: + raise Exception("The input to AddConvMaxpLayer() needs '3d-dim' parameters.") + + input = nodes.AddConvolutionLayer(config_lines, name, input, + input['3d-dim'][0], input['3d-dim'][1], input['3d-dim'][2], + args.filt_x_dim, args.filt_y_dim, + args.filt_x_step, args.filt_y_step, + args.num_filters, input['vectorization']) + + if args.pool_x_size > 1 or args.pool_y_size > 1 or args.pool_z_size > 1: + input = nodes.AddMaxpoolingLayer(config_lines, name, input, + input['3d-dim'][0], input['3d-dim'][1], input['3d-dim'][2], + args.pool_x_size, args.pool_y_size, args.pool_z_size, + args.pool_x_step, args.pool_y_step, args.pool_z_step) + + return input + +# The ivectors are processed through an affine layer parallel to the CNN layers, +# then concatenated with the CNN output and passed to the deeper part of the network. +def AddCnnLayers(config_lines, cnn_layer, cnn_bottleneck_dim, cepstral_lifter, config_dir, feat_dim, splice_indexes=[0], ivector_dim=0): + cnn_args = ParseCnnString(cnn_layer) + num_cnn_layers = len(cnn_args) + # We use an Idct layer here to convert MFCC to FBANK features + common_lib.write_idct_matrix(feat_dim, cepstral_lifter, config_dir.strip() + "/idct.mat") + prev_layer_output = {'descriptor': "input", + 'dimension': feat_dim} + prev_layer_output = nodes.AddFixedAffineLayer(config_lines, "Idct", prev_layer_output, config_dir.strip() + '/idct.mat') + + list = [('Offset({0}, {1})'.format(prev_layer_output['descriptor'],n) if n != 0 else prev_layer_output['descriptor']) for n in splice_indexes] + splice_descriptor = "Append({0})".format(", ".join(list)) + cnn_input_dim = len(splice_indexes) * feat_dim + prev_layer_output = {'descriptor': splice_descriptor, + 'dimension': cnn_input_dim, + '3d-dim': [len(splice_indexes), feat_dim, 1], + 'vectorization': 'yzx'} + + for cl in range(0, num_cnn_layers): + prev_layer_output = AddConvMaxpLayer(config_lines, "L{0}".format(cl), prev_layer_output, cnn_args[cl]) + + if cnn_bottleneck_dim > 0: + prev_layer_output = nodes.AddAffineLayer(config_lines, "cnn-bottleneck", prev_layer_output, cnn_bottleneck_dim, "") + + if ivector_dim > 0: + iv_layer_output = {'descriptor': 'ReplaceIndex(ivector, t, 0)', + 'dimension': ivector_dim} + iv_layer_output = nodes.AddAffineLayer(config_lines, "ivector", iv_layer_output, ivector_dim, "") + prev_layer_output['descriptor'] = 'Append({0}, {1})'.format(prev_layer_output['descriptor'], iv_layer_output['descriptor']) + prev_layer_output['dimension'] = prev_layer_output['dimension'] + iv_layer_output['dimension'] + + return prev_layer_output + +def PrintConfig(file_name, config_lines): + f = open(file_name, 'w') + f.write("\n".join(config_lines['components'])+"\n") + f.write("\n#Component nodes\n") + f.write("\n".join(config_lines['component-nodes'])+"\n") + f.close() + +def ParseCnnString(cnn_param_string_list): + cnn_parser = argparse.ArgumentParser(description="cnn argument parser") + + cnn_parser.add_argument("--filt-x-dim", required=True, type=int) + cnn_parser.add_argument("--filt-y-dim", required=True, type=int) + cnn_parser.add_argument("--filt-x-step", type=int, default = 1) + cnn_parser.add_argument("--filt-y-step", type=int, default = 1) + cnn_parser.add_argument("--num-filters", required=True, type=int) + cnn_parser.add_argument("--pool-x-size", type=int, default = 1) + cnn_parser.add_argument("--pool-y-size", type=int, default = 1) + cnn_parser.add_argument("--pool-z-size", type=int, default = 1) + cnn_parser.add_argument("--pool-x-step", type=int, default = 1) + cnn_parser.add_argument("--pool-y-step", type=int, default = 1) + cnn_parser.add_argument("--pool-z-step", type=int, default = 1) + + cnn_args = [] + for cl in range(0, len(cnn_param_string_list)): + cnn_args.append(cnn_parser.parse_args(shlex.split(cnn_param_string_list[cl]))) + + return cnn_args + +def ParseSpliceString(splice_indexes): + splice_array = [] + left_context = 0 + right_context = 0 + split_on_spaces = splice_indexes.split(); # we already checked the string is nonempty. + if len(split_on_spaces) < 1: + raise Exception("invalid splice-indexes argument, too short: " + + splice_indexes) + try: + for string in split_on_spaces: + this_splices = string.split(",") + if len(this_splices) < 1: + raise Exception("invalid splice-indexes argument, too-short element: " + + splice_indexes) + # the rest of this block updates left_context and right_context, and + # does some checking. + leftmost_splice = 10000 + rightmost_splice = -10000 + + int_list = [] + for s in this_splices: + try: + n = int(s) + if n < leftmost_splice: + leftmost_splice = n + if n > rightmost_splice: + rightmost_splice = n + int_list.append(n) + except ValueError: + #if len(splice_array) == 0: + # raise Exception("First dimension of splicing array must not have averaging [yet]") + try: + x = nodes.StatisticsConfig(s, { 'dimension':100, + 'descriptor': 'foo'} ) + int_list.append(s) + except Exception as e: + raise Exception("The following element of the splicing array is not a valid specifier " + "of statistics: {0}\nGot {1}".format(s, str(e))) + splice_array.append(int_list) + + if leftmost_splice == 10000 or rightmost_splice == -10000: + raise Exception("invalid element of --splice-indexes: " + string) + left_context += -leftmost_splice + right_context += rightmost_splice + except ValueError as e: + raise Exception("invalid --splice-indexes argument " + args.splice_indexes + " " + str(e)) + + left_context = max(0, left_context) + right_context = max(0, right_context) + + return {'left_context':left_context, + 'right_context':right_context, + 'splice_indexes':splice_array, + 'num_hidden_layers':len(splice_array) + } + +def AddPriorsAccumulator(config_lines, name, input): + components = config_lines['components'] + component_nodes = config_lines['component-nodes'] + + components.append("component name={0}_softmax type=SoftmaxComponent dim={1}".format(name, input['dimension'])) + component_nodes.append("component-node name={0}_softmax component={0}_softmax input={1}".format(name, input['descriptor'])) + + return {'descriptor': '{0}_softmax'.format(name), + 'dimension': input['dimension']} + +def AddFinalLayer(config_lines, input, output_dim, + ng_affine_options = " param-stddev=0 bias-stddev=0 ", + label_delay=None, + use_presoftmax_prior_scale = False, + prior_scale_file = None, + include_log_softmax = True, + add_final_sigmoid = False, + name_affix = None, + objective_type = "linear", + objective_scale = 1.0, + objective_scales_vec = None): + components = config_lines['components'] + component_nodes = config_lines['component-nodes'] + + if name_affix is not None: + final_node_prefix = 'Final-' + str(name_affix) + else: + final_node_prefix = 'Final' + + prev_layer_output = nodes.AddAffineLayer(config_lines, + final_node_prefix , input, output_dim, + ng_affine_options) + if include_log_softmax: + if use_presoftmax_prior_scale : + components.append('component name={0}-fixed-scale type=FixedScaleComponent scales={1}'.format(final_node_prefix, prior_scale_file)) + component_nodes.append('component-node name={0}-fixed-scale component={0}-fixed-scale input={1}'.format(final_node_prefix, + prev_layer_output['descriptor'])) + prev_layer_output['descriptor'] = "{0}-fixed-scale".format(final_node_prefix) + prev_layer_output = nodes.AddSoftmaxLayer(config_lines, final_node_prefix, prev_layer_output) + + elif add_final_sigmoid: + # Useful when you need the final outputs to be probabilities + # between 0 and 1. + # Usually used with an objective-type such as "quadratic" + prev_layer_output = nodes.AddSigmoidLayer(config_lines, final_node_prefix, prev_layer_output) + + # we use the same name_affix as a prefix in for affine/scale nodes but as a + # suffix for output node + if (objective_scale != 1.0 or objective_scales_vec is not None): + prev_layer_output = nodes.AddGradientScaleLayer(config_lines, final_node_prefix, prev_layer_output, objective_scale, objective_scales_vec) + + nodes.AddOutputLayer(config_lines, prev_layer_output, label_delay, suffix = name_affix, objective_type = objective_type) + +def AddOutputLayers(config_lines, prev_layer_output, output_nodes, + ng_affine_options = "", label_delay = 0): + + for o in output_nodes: + # make the intermediate config file for layerwise discriminative + # training + AddFinalLayer(config_lines, prev_layer_output, o.dim, + ng_affine_options, label_delay = label_delay, + include_log_softmax = o.include_log_softmax, + add_final_sigmoid = o.add_final_sigmoid, + objective_type = o.objective_type, + name_affix = o.output_suffix) + + if o.xent_regularize != 0.0: + nodes.AddFinalLayer(config_lines, prev_layer_output, o.dim, + include_log_softmax = True, + label_delay = label_delay, + name_affix = o.output_suffix + '_xent') + +# The function signature of MakeConfigs is changed frequently as it is intended for local use in this script. +def MakeConfigs(config_dir, splice_indexes_string, + cnn_layer, cnn_bottleneck_dim, cepstral_lifter, + feat_dim, ivector_dim, add_lda, + nonlin_type, nonlin_input_dim, nonlin_output_dim, subset_dim, + nonlin_output_dim_init, nonlin_output_dim_final, + use_presoftmax_prior_scale, final_layer_normalize_target, + output_nodes, self_repair_scale): + + parsed_splice_output = ParseSpliceString(splice_indexes_string.strip()) + + left_context = parsed_splice_output['left_context'] + right_context = parsed_splice_output['right_context'] + num_hidden_layers = parsed_splice_output['num_hidden_layers'] + splice_indexes = parsed_splice_output['splice_indexes'] + input_dim = len(parsed_splice_output['splice_indexes'][0]) + feat_dim + ivector_dim + + prior_scale_file = '{0}/presoftmax_prior_scale.vec'.format(config_dir) + + config_lines = {'components':[], 'component-nodes':[]} + + config_files={} + prev_layer_output = nodes.AddInputLayer(config_lines, feat_dim, splice_indexes[0], + ivector_dim) + + # Add the init config lines for estimating the preconditioning matrices + init_config_lines = copy.deepcopy(config_lines) + init_config_lines['components'].insert(0, '# Config file for initializing neural network prior to') + init_config_lines['components'].insert(0, '# preconditioning matrix computation') + + for o in output_nodes: + nodes.AddOutputLayer(init_config_lines, prev_layer_output, + objective_type = o.objective_type, suffix = o.output_suffix) + + config_files[config_dir + '/init.config'] = init_config_lines + + if cnn_layer is not None: + prev_layer_output = AddCnnLayers(config_lines, cnn_layer, cnn_bottleneck_dim, cepstral_lifter, config_dir, + feat_dim, splice_indexes[0], ivector_dim) + + # add_lda needs to be set "false" when using dense targets, + # or if the task is not a simple classification task + # (e.g. regression, multi-task) + if add_lda: + prev_layer_output = nodes.AddLdaLayer(config_lines, "L0", prev_layer_output, config_dir + '/lda.mat') + + left_context = 0 + right_context = 0 + # we moved the first splice layer to before the LDA.. + # so the input to the first affine layer is going to [0] index + splice_indexes[0] = [0] + + if not nonlin_output_dim is None: + nonlin_output_dims = [nonlin_output_dim] * num_hidden_layers + elif nonlin_output_dim_init < nonlin_output_dim_final and num_hidden_layers == 1: + raise Exception("num-hidden-layers has to be greater than 1 if relu-dim-init and relu-dim-final is different.") + else: + # computes relu-dim for each hidden layer. They increase geometrically across layers + factor = pow(float(nonlin_output_dim_final) / nonlin_output_dim_init, 1.0 / (num_hidden_layers - 1)) if num_hidden_layers > 1 else 1 + nonlin_output_dims = [int(round(nonlin_output_dim_init * pow(factor, i))) for i in range(0, num_hidden_layers)] + assert(nonlin_output_dims[-1] >= nonlin_output_dim_final - 1 and nonlin_output_dims[-1] <= nonlin_output_dim_final + 1) # due to rounding error + nonlin_output_dims[-1] = nonlin_output_dim_final # It ensures that the dim of the last hidden layer is exactly the same as what is specified + + for i in range(0, num_hidden_layers): + # make the intermediate config file for layerwise discriminative training + + # prepare the spliced input + if not (len(splice_indexes[i]) == 1 and splice_indexes[i][0] == 0): + try: + zero_index = splice_indexes[i].index(0) + except ValueError: + zero_index = None + # I just assume the prev_layer_output_descriptor is a simple forwarding descriptor + prev_layer_output_descriptor = prev_layer_output['descriptor'] + subset_output = prev_layer_output + if subset_dim > 0: + # if subset_dim is specified the script expects a zero in the splice indexes + assert(zero_index is not None) + subset_node_config = ("dim-range-node name=Tdnn_input_{0} " + "input-node={1} dim-offset={2} dim={3}".format( + i, prev_layer_output_descriptor, 0, subset_dim)) + subset_output = {'descriptor' : 'Tdnn_input_{0}'.format(i), + 'dimension' : subset_dim} + config_lines['component-nodes'].append(subset_node_config) + appended_descriptors = [] + appended_dimension = 0 + for j in range(len(splice_indexes[i])): + if j == zero_index: + appended_descriptors.append(prev_layer_output['descriptor']) + appended_dimension += prev_layer_output['dimension'] + continue + try: + offset = int(splice_indexes[i][j]) + # it's an integer offset. + appended_descriptors.append('Offset({0}, {1})'.format( + subset_output['descriptor'], splice_indexes[i][j])) + appended_dimension += subset_output['dimension'] + except ValueError: + # it's not an integer offset, so assume it specifies the + # statistics-extraction. + stats = nodes.StatisticsConfig(splice_indexes[i][j], prev_layer_output) + stats_layer = stats.AddLayer(config_lines, "Tdnn_stats_{0}".format(i)) + appended_descriptors.append(stats_layer['descriptor']) + appended_dimension += stats_layer['dimension'] + + prev_layer_output = {'descriptor' : "Append({0})".format(" , ".join(appended_descriptors)), + 'dimension' : appended_dimension} + else: + # this is a normal affine node + pass + + if nonlin_type == "relu": + prev_layer_output = nodes.AddAffRelNormLayer(config_lines, "Tdnn_{0}".format(i), + prev_layer_output, nonlin_output_dims[i], + self_repair_scale=self_repair_scale, + norm_target_rms=1.0 if i < num_hidden_layers -1 else final_layer_normalize_target) + elif nonlin_type == "pnorm": + prev_layer_output = nodes.AddAffPnormLayer(config_lines, "Tdnn_{0}".format(i), + prev_layer_output, nonlin_input_dim, nonlin_output_dim, + norm_target_rms=1.0 if i < num_hidden_layers -1 else final_layer_normalize_target) + else: + raise Exception("Unknown nonlinearity type") + # a final layer is added after each new layer as we are generating + # configs for layer-wise discriminative training + + AddOutputLayers(config_lines, prev_layer_output, output_nodes) + + config_files['{0}/layer{1}.config'.format(config_dir, i + 1)] = config_lines + config_lines = {'components':[], 'component-nodes':[]} + + left_context += int(parsed_splice_output['left_context']) + right_context += int(parsed_splice_output['right_context']) + + # write the files used by other scripts like steps/nnet3/get_egs.sh + f = open(config_dir + "/vars", "w") + print('model_left_context=' + str(left_context), file=f) + print('model_right_context=' + str(right_context), file=f) + print('num_hidden_layers=' + str(num_hidden_layers), file=f) + print('add_lda=' + ('true' if add_lda else 'false'), file=f) + f.close() + + # printing out the configs + # init.config used to train lda-mllt train + for key in config_files.keys(): + PrintConfig(key, config_files[key]) + +def ParseOutputNodesParameters(para_array): + output_parser = argparse.ArgumentParser() + output_parser.add_argument('--output-suffix', type=str, action=common_lib.NullstrToNoneAction, + help = "Name of the output node. e.g. output-xent") + output_parser.add_argument('--dim', type=int, required=True, + help = "Dimension of the output node") + output_parser.add_argument("--include-log-softmax", type=str, action=common_lib.StrToBoolAction, + help="add the final softmax layer ", + default=True, choices = ["false", "true"]) + output_parser.add_argument("--add-final-sigmoid", type=str, action=common_lib.StrToBoolAction, + help="add a sigmoid layer as the final layer. Applicable only if skip-final-softmax is true.", + choices=['true', 'false'], default = False) + output_parser.add_argument("--objective-type", type=str, default="linear", + choices = ["linear", "quadratic","xent-per-dim"], + help = "the type of objective; i.e. quadratic or linear") + output_parser.add_argument("--xent-regularize", type=float, + help="For chain models, if nonzero, add a separate output for cross-entropy " + "regularization (with learning-rate-factor equal to the inverse of this)", + default=0.0) + + output_nodes = [ output_parser.parse_args(shlex.split(x)) for x in para_array ] + + return output_nodes + +def Main(): + args = GetArgs() + + output_nodes = ParseOutputNodesParameters(args.output_node_para_array) + + MakeConfigs(config_dir = args.config_dir, + feat_dim = args.feat_dim, ivector_dim = args.ivector_dim, + add_lda = args.add_lda, + cepstral_lifter = args.cepstral_lifter, + splice_indexes_string = args.splice_indexes, + cnn_layer = args.cnn_layer, + cnn_bottleneck_dim = args.cnn_bottleneck_dim, + nonlin_type = args.nonlin_type, + nonlin_input_dim = args.nonlin_input_dim, + nonlin_output_dim = args.nonlin_output_dim, + subset_dim = args.subset_dim, + nonlin_output_dim_init = args.nonlin_output_dim_init, + nonlin_output_dim_final = args.nonlin_output_dim_final, + use_presoftmax_prior_scale = args.use_presoftmax_prior_scale, + final_layer_normalize_target = args.final_layer_normalize_target, + output_nodes = output_nodes, + self_repair_scale = args.self_repair_scale_nonlinearity) + +if __name__ == "__main__": + Main() + + diff --git a/egs/wsj/s5/steps/nnet3/components.py b/egs/wsj/s5/steps/nnet3/components.py index 82566d2e37d..c811297cda8 100644 --- a/egs/wsj/s5/steps/nnet3/components.py +++ b/egs/wsj/s5/steps/nnet3/components.py @@ -6,6 +6,7 @@ import sys import warnings import copy +import re from operator import itemgetter def GetSumDescriptor(inputs): @@ -30,17 +31,33 @@ def AddInputLayer(config_lines, feat_dim, splice_indexes=[0], ivector_dim=0): components = config_lines['components'] component_nodes = config_lines['component-nodes'] output_dim = 0 - components.append('input-node name=input dim=' + str(feat_dim)) - list = [('Offset(input, {0})'.format(n) if n != 0 else 'input') for n in splice_indexes] - output_dim += len(splice_indexes) * feat_dim + components.append('input-node name=input dim={0}'.format(feat_dim)) + prev_layer_output = {'descriptor': "input", + 'dimension': feat_dim} + inputs = [] + for n in splice_indexes: + try: + offset = int(n) + if offset == 0: + inputs.append(prev_layer_output['descriptor']) + else: + inputs.append('Offset({0}, {1})'.format( + prev_layer_output['descriptor'], offset)) + output_dim += prev_layer_output['dimension'] + except ValueError: + stats = StatisticsConfig(n, prev_layer_output) + stats_layer = stats.AddLayer(config_lines, "Tdnn_stats_{0}".format(0)) + inputs.append(stats_layer['descriptor']) + output_dim += stats_layer['dimension'] + if ivector_dim > 0: - components.append('input-node name=ivector dim=' + str(ivector_dim)) - list.append('ReplaceIndex(ivector, t, 0)') + components.append('input-node name=ivector dim={0}'.format(ivector_dim)) + inputs.append('ReplaceIndex(ivector, t, 0)') output_dim += ivector_dim - if len(list) > 1: - splice_descriptor = "Append({0})".format(", ".join(list)) + if len(inputs) > 1: + splice_descriptor = "Append({0})".format(", ".join(inputs)) else: - splice_descriptor = list[0] + splice_descriptor = inputs[0] print(splice_descriptor) return {'descriptor': splice_descriptor, 'dimension': output_dim} @@ -519,3 +536,82 @@ def AddBLstmLayer(config_lines, 'dimension':output_dim } +# this is a bit like a struct, initialized from a string, which describes how to +# set up the statistics-pooling and statistics-extraction components. +# An example string is 'mean(-99:3:9::99)', which means, compute the mean of +# data within a window of -99 to +99, with distinct means computed every 9 frames +# (we round to get the appropriate one), and with the input extracted on multiples +# of 3 frames (so this will force the input to this layer to be evaluated +# every 3 frames). Another example string is 'mean+stddev(-99:3:9:99)', +# which will also cause the standard deviation to be computed. +class StatisticsConfig: + # e.g. c = StatisticsConfig('mean+stddev(-99:3:9:99)', 400, 'jesus1-forward-output-affine') + def __init__(self, config_string, input): + + self.input_dim = input['dimension'] + self.input_descriptor = input['descriptor'] + + m = re.search("(mean|mean\+stddev|mean\+count|mean\+stddev\+count)\((-?\d+):(-?\d+):(-?\d+):(-?\d+)\)", + config_string) + if m == None: + raise Exception("Invalid splice-index or statistics-config string: " + config_string) + self.output_stddev = (m.group(1) in ['mean+stddev', 'mean+stddev+count']) + self.output_log_counts = (m.group(1) in ['mean+count', 'mean+stddev+count']) + self.left_context = -int(m.group(2)) + self.input_period = int(m.group(3)) + self.stats_period = int(m.group(4)) + self.right_context = int(m.group(5)) + if not (self.left_context > 0 and self.right_context > 0 and + self.input_period > 0 and self.stats_period > 0 and + self.left_context % self.stats_period == 0 and + self.right_context % self.stats_period == 0 and + self.stats_period % self.input_period == 0): + raise Exception("Invalid configuration of statistics-extraction: " + config_string) + + # OutputDim() returns the output dimension of the node that this produces. + def OutputDim(self): + return (self.input_dim * (2 if self.output_stddev else 1) + + 1 if self.output_log_counts else 0) + + # OutputDims() returns an array of output dimensions, consisting of + # [ input-dim ] if just "mean" was specified, otherwise + # [ input-dim input-dim ] + def OutputDims(self): + output_dims = [ self.input_dim ] + if self.output_stddev: + output_dims.append(self.input_dim) + if self.output_log_counts: + output_dims.append(1) + return output_dims + + # Descriptor() returns the textual form of the descriptor by which the + # output of this node is to be accessed. + def Descriptor(self, name): + return 'Round({0}-pooling-{1}-{2}, {3})'.format(name, self.left_context, self.right_context, + self.stats_period) + + def AddLayer(self, config_lines, name): + components = config_lines['components'] + component_nodes = config_lines['component-nodes'] + + components.append('component name={name}-extraction-{lc}-{rc} type=StatisticsExtractionComponent input-dim={dim} ' + 'input-period={input_period} output-period={output_period} include-variance={var} '.format( + name = name, lc = self.left_context, rc = self.right_context, + dim = self.input_dim, input_period = self.input_period, output_period = self.stats_period, + var = ('true' if self.output_stddev else 'false'))) + component_nodes.append('component-node name={name}-extraction-{lc}-{rc} component={name}-extraction-{lc}-{rc} input={input} '.format( + name = name, lc = self.left_context, rc = self.right_context, input = self.input_descriptor)) + stats_dim = 1 + self.input_dim * (2 if self.output_stddev else 1) + components.append('component name={name}-pooling-{lc}-{rc} type=StatisticsPoolingComponent input-dim={dim} ' + 'input-period={input_period} left-context={lc} right-context={rc} num-log-count-features={count} ' + 'output-stddevs={var} '.format(name = name, lc = self.left_context, rc = self.right_context, + dim = stats_dim, input_period = self.stats_period, + count = 1 if self.output_log_counts else 0, + var = ('true' if self.output_stddev else 'false'))) + component_nodes.append('component-node name={name}-pooling-{lc}-{rc} component={name}-pooling-{lc}-{rc} input={name}-extraction-{lc}-{rc} '.format( + name = name, lc = self.left_context, rc = self.right_context)) + + return { 'dimension': self.OutputDim(), + 'descriptor': self.Descriptor(name), + 'dimensions': self.OutputDims() + } From fbc0333e6b79b6e4ffab7c3be2fa15a396ea4f6e Mon Sep 17 00:00:00 2001 From: Vimal Manohar Date: Thu, 24 Nov 2016 00:09:46 -0500 Subject: [PATCH 037/213] asr_diarization: compute_output.sh, SAD decoding scripts and do_segmentation_data_dir.sh --- egs/wsj/s5/steps/nnet3/compute_output.sh | 179 ++++++++++++++++++ egs/wsj/s5/steps/segmentation/decode_sad.sh | 42 ++++ .../segmentation/decode_sad_to_segments.sh | 97 ++++++++++ .../segmentation/do_segmentation_data_dir.sh | 134 +++++++++++++ .../steps/segmentation/internal/make_G_fst.py | 52 +++++ .../segmentation/internal/make_sad_graph.sh | 83 ++++++++ .../internal/post_process_segments.sh | 41 +--- .../segmentation/internal/prepare_sad_lang.py | 94 +++++++++ .../post_process_sad_to_segments.sh | 69 +++++-- .../post_process_sad_to_subsegments.sh | 69 +++++++ 10 files changed, 804 insertions(+), 56 deletions(-) create mode 100755 egs/wsj/s5/steps/nnet3/compute_output.sh create mode 100755 egs/wsj/s5/steps/segmentation/decode_sad.sh create mode 100755 egs/wsj/s5/steps/segmentation/decode_sad_to_segments.sh create mode 100755 egs/wsj/s5/steps/segmentation/do_segmentation_data_dir.sh create mode 100755 egs/wsj/s5/steps/segmentation/internal/make_G_fst.py create mode 100755 egs/wsj/s5/steps/segmentation/internal/make_sad_graph.sh create mode 100755 egs/wsj/s5/steps/segmentation/internal/prepare_sad_lang.py create mode 100644 egs/wsj/s5/steps/segmentation/post_process_sad_to_subsegments.sh diff --git a/egs/wsj/s5/steps/nnet3/compute_output.sh b/egs/wsj/s5/steps/nnet3/compute_output.sh new file mode 100755 index 00000000000..f49790bc578 --- /dev/null +++ b/egs/wsj/s5/steps/nnet3/compute_output.sh @@ -0,0 +1,179 @@ +#!/bin/bash + +# Copyright 2012-2015 Johns Hopkins University (Author: Daniel Povey). +# 2016 Vimal Manohar +# Apache 2.0. + +# This script does decoding with a neural-net. If the neural net was built on +# top of fMLLR transforms from a conventional system, you should provide the +# --transform-dir option. + +# Begin configuration section. +stage=1 +transform_dir= # dir to find fMLLR transforms. +nj=4 # number of jobs. If --transform-dir set, must match that number! +cmd=run.pl +use_gpu=false +frames_per_chunk=50 +ivector_scale=1.0 +iter=final +extra_left_context=0 +extra_right_context=0 +extra_left_context_initial=-1 +extra_right_context_final=-1 +frame_subsampling_factor=1 +feat_type= +compress=false +online_ivector_dir= +post_vec= +output_name= +get_raw_nnet_from_am=true +# End configuration section. + +echo "$0 $@" # Print the command line for logging + +[ -f ./path.sh ] && . ./path.sh; # source the path. +. parse_options.sh || exit 1; + +if [ $# -ne 3 ]; then + echo "Usage: $0 [options] " + echo "e.g.: steps/nnet3/compute_output.sh --nj 8 \\" + echo "--online-ivector-dir exp/nnet3/ivectors_test_eval92 \\" + echo " data/test_eval92_hires exp/nnet3/tdnn exp/nnet3/tdnn/output" + echo "main options (for others, see top of script file)" + echo " --transform-dir # directory of previous decoding" + echo " # where we can find transforms for SAT systems." + echo " --config # config containing options" + echo " --nj # number of parallel jobs" + echo " --cmd # Command to run in parallel with" + echo " --iter # Iteration of model to decode; default is final." + exit 1; +fi + +data=$1 +srcdir=$2 +dir=$3 + +if $get_raw_nnet_from_am; then + [ ! -f $srcdir/$iter.mdl ] && echo "$0: no such file $srcdir/$iter.mdl" && exit 1 + model="nnet3-am-copy --raw=true $srcdir/$iter.mdl - |" +else + [ ! -f $srcdir/$iter.raw ] && echo "$0: no such file $srcdir/$iter.raw" && exit 1 + model="nnet3-copy $srcdir/$iter.raw - |" +fi + +mkdir -p $dir/log +echo "rename-node old-name=$output_name new-name=output" > $dir/edits.config + +if [ ! -z "$output_name" ]; then + model="$model nnet3-copy --edits-config=$dir/edits.config - - |" +else + output_name=output +fi + +[ ! -z "$online_ivector_dir" ] && \ + extra_files="$online_ivector_dir/ivector_online.scp $online_ivector_dir/ivector_period" + +for f in $data/feats.scp $extra_files; do + [ ! -f $f ] && echo "$0: no such file $f" && exit 1; +done + +sdata=$data/split$nj; +cmvn_opts=`cat $srcdir/cmvn_opts` || exit 1; + +[[ -d $sdata && $data/feats.scp -ot $sdata ]] || split_data.sh $data $nj || exit 1; +echo $nj > $dir/num_jobs + + +## Set up features. +if [ -z "$feat_type" ]; then + if [ -f $srcdir/final.mat ]; then feat_type=lda; else feat_type=raw; fi + echo "$0: feature type is $feat_type" +fi + +splice_opts=`cat $srcdir/splice_opts 2>/dev/null` + +case $feat_type in + raw) feats="ark,s,cs:apply-cmvn $cmvn_opts --utt2spk=ark:$sdata/JOB/utt2spk scp:$sdata/JOB/cmvn.scp scp:$sdata/JOB/feats.scp ark:- |";; + lda) feats="ark,s,cs:apply-cmvn $cmvn_opts --utt2spk=ark:$sdata/JOB/utt2spk scp:$sdata/JOB/cmvn.scp scp:$sdata/JOB/feats.scp ark:- | splice-feats $splice_opts ark:- ark:- | transform-feats $srcdir/final.mat ark:- ark:- |" + ;; + *) echo "$0: invalid feature type $feat_type" && exit 1; +esac +if [ ! -z "$transform_dir" ]; then + echo "$0: using transforms from $transform_dir" + [ ! -s $transform_dir/num_jobs ] && \ + echo "$0: expected $transform_dir/num_jobs to contain the number of jobs." && exit 1; + nj_orig=$(cat $transform_dir/num_jobs) + + if [ $feat_type == "raw" ]; then trans=raw_trans; + else trans=trans; fi + if [ $feat_type == "lda" ] && \ + ! cmp $transform_dir/../final.mat $srcdir/final.mat && \ + ! cmp $transform_dir/final.mat $srcdir/final.mat; then + echo "$0: LDA transforms differ between $srcdir and $transform_dir" + exit 1; + fi + if [ ! -f $transform_dir/$trans.1 ]; then + echo "$0: expected $transform_dir/$trans.1 to exist (--transform-dir option)" + exit 1; + fi + if [ $nj -ne $nj_orig ]; then + # Copy the transforms into an archive with an index. + for n in $(seq $nj_orig); do cat $transform_dir/$trans.$n; done | \ + copy-feats ark:- ark,scp:$dir/$trans.ark,$dir/$trans.scp || exit 1; + feats="$feats transform-feats --utt2spk=ark:$sdata/JOB/utt2spk scp:$dir/$trans.scp ark:- ark:- |" + else + feats="$feats transform-feats --utt2spk=ark:$sdata/JOB/utt2spk ark:$transform_dir/$trans.JOB ark:- ark:- |" + fi +elif grep 'transform-feats --utt2spk' $srcdir/log/train.1.log >&/dev/null; then + echo "$0: **WARNING**: you seem to be using a neural net system trained with transforms," + echo " but you are not providing the --transform-dir option in test time." +fi +## + +if [ ! -z "$online_ivector_dir" ]; then + ivector_period=$(cat $online_ivector_dir/ivector_period) || exit 1; + ivector_opts="--online-ivectors=scp:$online_ivector_dir/ivector_online.scp --online-ivector-period=$ivector_period" +fi + +frame_subsampling_opt= +if [ $frame_subsampling_factor -ne 1 ]; then + # e.g. for 'chain' systems + frame_subsampling_opt="--frame-subsampling-factor=$frame_subsampling_factor" +fi + +output_wspecifier="ark:| copy-feats --compress=$compress ark:- ark:- | gzip -c > $dir/nnet_output.JOB.gz" + +if [ ! -z $post_vec ]; then + if [ $stage -le 1 ]; then + copy-vector --binary=false $post_vec - | \ + awk '{for (i = 2; i < NF; i++) { sum += i; }; + printf ("["); + for (i = 2; i < NF; i++) { printf " "log(i/sum); }; + print (" ]");}' > $dir/log_priors.vec + fi + + output_wspecifier="ark:| matrix-add-offset ark:- 'vector-scale --scale=-1.0 $dir/log_priors.vec - |' ark:- | copy-feats --compress=$compress ark:- ark:- | gzip -c > $dir/log_likes.JOB.gz" +fi + +gpu_opt="--use-gpu=no" +gpu_queue_opt= + +if $use_gpu; then + gpu_queue_opt="--gpu 1" + gpu_opt="--use-gpu=yes" +fi + +if [ $stage -le 2 ]; then + $cmd $gpu_queue_opt JOB=1:$nj $dir/log/compute_output.JOB.log \ + nnet3-compute $gpu_opt $ivector_opts $frame_subsampling_opt \ + --frames-per-chunk=$frames_per_chunk \ + --extra-left-context=$extra_left_context \ + --extra-right-context=$extra_right_context \ + --extra-left-context-initial=$extra_left_context_initial \ + --extra-right-context-final=$extra_right_context_final \ + "$model" "$feats" "$output_wspecifier" || exit 1; +fi + +exit 0; + diff --git a/egs/wsj/s5/steps/segmentation/decode_sad.sh b/egs/wsj/s5/steps/segmentation/decode_sad.sh new file mode 100755 index 00000000000..9758d36e24e --- /dev/null +++ b/egs/wsj/s5/steps/segmentation/decode_sad.sh @@ -0,0 +1,42 @@ +#! /bin/bash + +set -e +set -o pipefail + +cmd=run.pl +acwt=0.1 +beam=8 +max_active=1000 + +. path.sh + +. utils/parse_options.sh + +if [ $# -ne 3 ]; then + echo "Usage: $0 " + echo " e.g.: $0 " + exit 1 +fi + +graph_dir=$1 +log_likes_dir=$2 +dir=$3 + +nj=`cat $log_likes_dir/num_jobs` +echo $nj > $dir/num_jobs + +for f in $dir/trans.mdl $log_likes_dir/log_likes.1.gz $graph_dir/HCLG.fst; do + if [ ! -f $f ]; then + echo "$0: Could not find file $f" + fi +done + +decoder_opts+=(--acoustic-scale=$acwt --beam=$beam --max-active=$max_active) + +$cmd JOB=1:$nj $dir/log/decode.JOB.log \ + decode-faster-mapped ${decoder_opts[@]} \ + $dir/trans.mdl \ + $graph_dir/HCLG.fst "ark:gunzip -c $log_likes_dir/log_likes.JOB.gz |" \ + ark:/dev/null ark:- \| \ + ali-to-phones --per-frame $dir/trans.mdl ark:- \ + "ark:|gzip -c > $dir/ali.JOB.gz" diff --git a/egs/wsj/s5/steps/segmentation/decode_sad_to_segments.sh b/egs/wsj/s5/steps/segmentation/decode_sad_to_segments.sh new file mode 100755 index 00000000000..8f4ed60dfda --- /dev/null +++ b/egs/wsj/s5/steps/segmentation/decode_sad_to_segments.sh @@ -0,0 +1,97 @@ +#! /bin/bash + +set -e +set -o pipefail +set -u + +stage=-1 +segmentation_config=conf/segmentation.conf +cmd=run.pl + +# Viterbi options +min_silence_duration=30 # minimum number of frames for silence +min_speech_duration=30 # minimum number of frames for speech +frame_subsampling_factor=1 +nonsil_transition_probability=0.1 +sil_transition_probability=0.1 +sil_prior=0.5 +speech_prior=0.5 + +# Decoding options +acwt=1 +beam=10 +max_active=7000 + +. utils/parse_options.sh + +if [ $# -ne 4 ]; then + echo "Usage: $0 " + echo " e.g.: $0 data/babel_bengali_dev10h exp/nnet3_sad_snr/tdnn_b_n4/sad_babel_bengali_dev10h exp/nnet3_sad_snr/tdnn_b_n4/segmentation_babel_bengali_dev10h exp/nnet3_sad_snr/tdnn_b_n4/segmentation_babel_bengali_dev10h/babel_bengali_dev10h.seg" + exit 1 +fi + +data=$1 +sad_likes_dir=$2 +dir=$3 +out_data=$4 + +t=sil${sil_prior}_sp${speech_prior} +lang=$dir/lang_test_${t} + +min_silence_duration=`perl -e "print (int($min_silence_duration / $frame_subsampling_factor))"` +min_speech_duration=`perl -e "print (int($min_speech_duration / $frame_subsampling_factor))"` + +if [ $stage -le 1 ]; then + mkdir -p $lang + + steps/segmentation/internal/prepare_sad_lang.py \ + --phone-transition-parameters="--phone-list=1 --min-duration=$min_silence_duration --end-transition-probability=$sil_transition_probability" \ + --phone-transition-parameters="--phone-list=2 --min-duration=$min_speech_duration --end-transition-probability=$nonsil_transition_probability" $lang + + cp $lang/phones.txt $lang/words.txt +fi + +feat_dim=2 # dummy. We don't need this. +if [ $stage -le 2 ]; then + $cmd $dir/log/create_transition_model.log gmm-init-mono \ + $lang/topo $feat_dim - $dir/tree \| \ + copy-transition-model --binary=false - $dir/trans.mdl || exit 1 +fi + +if [ $stage -le 3 ]; then + cat > $lang/word2prior < $lang/G.fst +fi + +graph_dir=$dir/graph_test_${t} + +if [ $stage -le 4 ]; then + $cmd $dir/log/make_vad_graph.log \ + steps/segmentation/internal/make_sad_graph.sh --iter trans \ + $lang $dir $dir/graph_test_${t} || exit 1 +fi + +if [ $stage -le 5 ]; then + steps/segmentation/decode_sad.sh \ + --acwt $acwt --beam $beam --max-active $max_active \ + $graph_dir $sad_likes_dir $dir +fi + +if [ $stage -le 6 ]; then + cat > $lang/phone2sad_map < 8kHz sampling frequency. +do_downsampling=false + +# Segmentation configs +min_silence_duration=30 +min_speech_duration=30 +segmentation_config=conf/segmentation_speech.conf + +echo $* + +. utils/parse_options.sh + +if [ $# -ne 3 ]; then + echo "Usage: $0 " + echo " e.g.: $0 ~/workspace/egs/ami/s5b/data/sdm1/dev data/ami_sdm1_dev exp/nnet3_sad_snr/nnet_tdnn_j_n4" + exit 1 +fi + +src_data_dir=$1 +data_dir=$2 +sad_nnet_dir=$3 + +affix=${affix:+_$affix} +feat_affix=${feat_affix:+_$feat_affix} + +data_id=`basename $data_dir` +sad_dir=${sad_nnet_dir}/${sad_name}${affix}_${data_id}_whole${feat_affix} +seg_dir=${sad_nnet_dir}/${segmentation_name}${affix}_${data_id}_whole${feat_affix} + +export PATH="$KALDI_ROOT/tools/sph2pipe_v2.5/:$PATH" +[ ! -z `which sph2pipe` ] + +if [ $stage -le 0 ]; then + utils/data/convert_data_dir_to_whole.sh $src_data_dir ${data_dir}_whole + + if $do_downsampling; then + freq=`cat $mfcc_config | perl -pe 's/\s*#.*//g' | grep "sample-frequency=" | awk -F'=' '{if (NF == 0) print 16000; else print $2}'` + sox=`which sox` + + cat $src_data_dir/wav.scp | python -c "import sys +for line in sys.stdin.readlines(): + splits = line.strip().split() + if splits[-1] == '|': + out_line = line.strip() + ' $sox -t wav - -r $freq -c 1 -b 16 -t wav - downsample |' + else: + out_line = 'cat {0} {1} | $sox -t wav - -r $freq -c 1 -b 16 -t wav - downsample |'.format(splits[0], ' '.join(splits[1:])) + print (out_line)" > ${data_dir}_whole/wav.scp + fi + + utils/copy_data_dir.sh ${data_dir}_whole ${data_dir}_whole${feat_affix}_hires +fi + +test_data_dir=${data_dir}_whole${feat_affix}_hires + +if [ $stage -le 1 ]; then + steps/make_mfcc.sh --mfcc-config $mfcc_config --nj $reco_nj --cmd "$train_cmd" \ + ${data_dir}_whole${feat_affix}_hires exp/make_hires/${data_id}_whole${feat_affix} mfcc_hires + steps/compute_cmvn_stats.sh ${data_dir}_whole${feat_affix}_hires exp/make_hires/${data_id}_whole${feat_affix} mfcc_hires +fi + +post_vec=$sad_nnet_dir/post_${output_name}.vec +if [ ! -f $sad_nnet_dir/post_${output_name}.vec ]; then + echo "$0: Could not find $sad_nnet_dir/post_${output_name}.vec. See the last stage of local/segmentation/run_train_sad.sh" + exit 1 +fi + +if [ $stage -le 2 ]; then + steps/nnet3/compute_output.sh --nj $reco_nj --cmd "$train_cmd" \ + --post-vec "$post_vec" \ + --iter $iter \ + --extra-left-context $extra_left_context \ + --extra-right-context $extra_right_context \ + --frames-per-chunk 150 \ + --stage $sad_stage --output-name $output_name \ + --frame-subsampling-factor $frame_subsampling_factor \ + --get-raw-nnet-from-am false ${test_data_dir} $sad_nnet_dir $sad_dir +fi + +if [ $stage -le 3 ]; then + steps/segmentation/decode_sad_to_segments.sh \ + --frame-subsampling-factor $frame_subsampling_factor \ + --min-silence-duration $min_silence_duration \ + --min-speech-duration $min_speech_duration \ + --segmentation-config $segmentation_config --cmd "$train_cmd" \ + ${test_data_dir} $sad_dir $seg_dir $seg_dir/${data_id}_seg +fi + +# Subsegment data directory +if [ $stage -le 4 ]; then + rm $seg_dir/${data_id}_seg/feats.scp || true + utils/data/get_reco2num_frames.sh ${test_data_dir} + awk '{print $1" "$2}' ${seg_dir}/${data_id}_seg/segments | \ + utils/apply_map.pl -f 2 ${test_data_dir}/reco2num_frames > \ + $seg_dir/${data_id}_seg/utt2max_frames + + frame_shift_info=`cat $mfcc_config | steps/segmentation/get_frame_shift_info_from_config.pl` + utils/data/get_subsegment_feats.sh ${test_data_dir}/feats.scp \ + $frame_shift_info $seg_dir/${data_id}_seg/segments | \ + utils/data/fix_subsegmented_feats.pl ${seg_dir}/${data_id}_seg/utt2max_frames > \ + $seg_dir/${data_id}_seg/feats.scp + steps/compute_cmvn_stats.sh --fake $seg_dir/${data_id}_seg +fi diff --git a/egs/wsj/s5/steps/segmentation/internal/make_G_fst.py b/egs/wsj/s5/steps/segmentation/internal/make_G_fst.py new file mode 100755 index 00000000000..5ad7e867d10 --- /dev/null +++ b/egs/wsj/s5/steps/segmentation/internal/make_G_fst.py @@ -0,0 +1,52 @@ +#! /usr/bin/env python + +from __future__ import print_function +import argparse, math + +def ParseArgs(): + parser = argparse.ArgumentParser("""Make a simple unigram FST for +decoding for segmentation purpose.""") + + parser.add_argument("--word2prior-map", type=str, required=True, + help = "A file with priors for different words") + parser.add_argument("--end-probability", type=float, default=0.01, + help = "Ending probability") + + args = parser.parse_args() + + return args + +def ReadMap(map_file): + out_map = {} + sum_prob = 0 + for line in open(map_file): + parts = line.strip().split() + if len(parts) == 0: + continue + if len(parts) != 2: + raise Exception("Invalid line {0} in {1}".format(line.strip(), map_file)) + + if parts[0] in out_map: + raise Exception("Duplicate entry of {0} in {1}".format(parts[0], map_file)) + + prob = float(parts[1]) + out_map[parts[0]] = prob + + sum_prob += prob + + return (out_map, sum_prob) + +def Main(): + args = ParseArgs() + + word2prior, sum_prob = ReadMap(args.word2prior_map) + sum_prob += args.end_probability + + for w,p in word2prior.iteritems(): + print ("0 0 {word} {word} {log_p}".format(word = w, + log_p = -math.log(p / sum_prob))) + print ("0 {log_p}".format(word = w, + log_p = -math.log(args.end_probability / sum_prob))) + +if __name__ == '__main__': + Main() diff --git a/egs/wsj/s5/steps/segmentation/internal/make_sad_graph.sh b/egs/wsj/s5/steps/segmentation/internal/make_sad_graph.sh new file mode 100755 index 00000000000..5edb3eb2bb6 --- /dev/null +++ b/egs/wsj/s5/steps/segmentation/internal/make_sad_graph.sh @@ -0,0 +1,83 @@ +#!/bin/bash + +# Copyright 2016 Vimal Manohar + +# Begin configuration section. +stage=0 +cmd=run.pl +iter=final # use $iter.mdl from $model_dir +tree=tree +tscale=1.0 # transition scale. +loopscale=0.1 # scale for self-loops. +# End configuration section. + +echo "$0 $@" # Print the command line for logging + +[ -f ./path.sh ] && . ./path.sh; # source the path. +. parse_options.sh || exit 1; + +if [ $# -ne 3 ]; then + echo "Usage: $0 [options] " + echo " e.g.: $0 exp/vad_dev/lang exp/vad_dev exp/vad_dev/graph" + echo "Makes the graph in \$dir, corresponding to the model in \$model_dir" + exit 1; +fi + +lang=$1 +model=$2/$iter.mdl +tree=$2/$tree +dir=$3 + +for f in $lang/G.fst $model $tree; do + if [ ! -f $f ]; then + echo "$0: expected $f to exist" + exit 1; + fi +done + +mkdir -p $dir $lang/tmp + +clg=$lang/tmp/CLG.fst + +if [[ ! -s $clg || $clg -ot $lang/G.fst ]]; then + echo "$0: creating CLG." + + fstcomposecontext --context-size=1 --central-position=0 \ + $lang/tmp/ilabels < $lang/G.fst | \ + fstarcsort --sort_type=ilabel > $clg + fstisstochastic $clg || echo "[info]: CLG not stochastic." +fi + +if [[ ! -s $dir/Ha.fst || $dir/Ha.fst -ot $model || $dir/Ha.fst -ot $lang/tmp/ilabels ]]; then + make-h-transducer --disambig-syms-out=$dir/disambig_tid.int \ + --transition-scale=$tscale $lang/tmp/ilabels $tree $model \ + > $dir/Ha.fst || exit 1; +fi + +if [[ ! -s $dir/HCLGa.fst || $dir/HCLGa.fst -ot $dir/Ha.fst || $dir/HCLGa.fst -ot $clg ]]; then + fsttablecompose $dir/Ha.fst $clg | fstdeterminizestar --use-log=true \ + | fstrmsymbols $dir/disambig_tid.int | fstrmepslocal | \ + fstminimizeencoded > $dir/HCLGa.fst || exit 1; + fstisstochastic $dir/HCLGa.fst || echo "HCLGa is not stochastic" +fi + +if [[ ! -s $dir/HCLG.fst || $dir/HCLG.fst -ot $dir/HCLGa.fst ]]; then + add-self-loops --self-loop-scale=$loopscale --reorder=true \ + $model < $dir/HCLGa.fst > $dir/HCLG.fst || exit 1; + + if [ $tscale == 1.0 -a $loopscale == 1.0 ]; then + # No point doing this test if transition-scale not 1, as it is bound to fail. + fstisstochastic $dir/HCLG.fst || echo "[info]: final HCLG is not stochastic." + fi +fi + +# keep a copy of the lexicon and a list of silence phones with HCLG... +# this means we can decode without reference to the $lang directory. + +cp $lang/words.txt $dir/ || exit 1; +cp $lang/phones.txt $dir/ 2> /dev/null # ignore the error if it's not there. + +# to make const fst: +# fstconvert --fst_type=const $dir/HCLG.fst $dir/HCLG_c.fst +am-info --print-args=false $model | grep pdfs | awk '{print $NF}' > $dir/num_pdfs + diff --git a/egs/wsj/s5/steps/segmentation/internal/post_process_segments.sh b/egs/wsj/s5/steps/segmentation/internal/post_process_segments.sh index c2750b4a895..e37d5dc2f62 100755 --- a/egs/wsj/s5/steps/segmentation/internal/post_process_segments.sh +++ b/egs/wsj/s5/steps/segmentation/internal/post_process_segments.sh @@ -51,31 +51,11 @@ for f in $dir/orig_segmentation.1.gz $data_dir/segments; do fi done -cat < $dir/segmentation.conf -pad_length=$pad_length # Pad speech segments by this many frames on either side -max_blend_length=$max_blend_length # Maximum duration of speech that will be removed as part - # of smoothing process. This is only if there are no other - # speech segments nearby. -max_intersegment_length=$max_intersegment_length # Merge nearby speech segments if the silence - # between them is less than this many frames. -post_pad_length=$post_pad_length # Pad speech segments by this many frames on either side - # after the merging process using max_intersegment_length -max_segment_length=$max_segment_length # Segments that are longer than this are split into - # overlapping frames. -overlap_length=$overlap_length # Overlapping frames when segments are split. - # See the above option. -min_silence_length=$min_silence_length # Min silence length at which to split very long segments - -frame_shift=$frame_shift -EOF - nj=`cat $dir/num_jobs` || exit 1 -if [ $stage -le 1 ]; then - rm -r $segmented_data_dir || true - utils/data/convert_data_dir_to_whole.sh $data_dir $segmented_data_dir || exit 1 - rm $segmented_data_dir/text -fi +[ $pad_length -eq -1 ] && pad_length= +[ $post_pad_length -eq -1 ] && post_pad_length= +[ $max_blend_length -eq -1 ] && max_blend_length= if [ $stage -le 2 ]; then # Post-process the orignal SAD segmentation using the following steps: @@ -94,10 +74,10 @@ if [ $stage -le 2 ]; then $cmd JOB=1:$nj $dir/log/post_process_segmentation.JOB.log \ gunzip -c $dir/orig_segmentation.JOB.gz \| \ segmentation-post-process --merge-adjacent-segments --max-intersegment-length=0 ark:- ark:- \| \ - segmentation-post-process --max-blend-length=$max_blend_length --blend-short-segments-class=1 ark:- ark:- \| \ - segmentation-post-process --remove-labels=0 --pad-label=1 --pad-length=$pad_length ark:- ark:- \| \ + segmentation-post-process ${max_blend_length:+--max-blend-length=$max_blend_length --blend-short-segments-class=1} ark:- ark:- \| \ + segmentation-post-process --remove-labels=0 ${pad_length:+--pad-label=1 --pad-length=$pad_length} ark:- ark:- \| \ segmentation-post-process --merge-adjacent-segments --max-intersegment-length=$max_intersegment_length ark:- ark:- \| \ - segmentation-post-process --pad-label=1 --pad-length=$post_pad_length ark:- ark:- \| \ + segmentation-post-process ${post_pad_length:+--pad-label=1 --pad-length=$post_pad_length} ark:- ark:- \| \ segmentation-split-segments --alignments="ark,s,cs:gunzip -c $dir/orig_segmentation.JOB.gz | segmentation-to-ali ark:- ark:- |" \ --max-segment-length=$max_segment_length --min-alignment-chunk-length=$min_silence_length --ali-label=0 ark:- ark:- \| \ segmentation-split-segments \ @@ -118,12 +98,3 @@ if [ ! -s $segmented_data_dir/utt2spk ] || [ ! -s $segmented_data_dir/segments ] echo "$0: Segmentation failed to generate segments or utt2spk!" exit 1 fi - -utils/utt2spk_to_spk2utt.pl $segmented_data_dir/utt2spk > $segmented_data_dir/spk2utt || exit 1 -utils/fix_data_dir.sh $segmented_data_dir - -if [ ! -s $segmented_data_dir/utt2spk ] || [ ! -s $segmented_data_dir/segments ]; then - echo "$0: Segmentation failed to generate segments or utt2spk!" - exit 1 -fi - diff --git a/egs/wsj/s5/steps/segmentation/internal/prepare_sad_lang.py b/egs/wsj/s5/steps/segmentation/internal/prepare_sad_lang.py new file mode 100755 index 00000000000..17b039015d2 --- /dev/null +++ b/egs/wsj/s5/steps/segmentation/internal/prepare_sad_lang.py @@ -0,0 +1,94 @@ +#! /usr/bin/env python + +from __future__ import print_function +import argparse, shlex + +def GetArgs(): + parser = argparse.ArgumentParser(description="""This script generates a lang +directory for purpose of segmentation. It takes as arguments the list of phones, +the corresponding min durations and end transition probability.""") + + parser.add_argument("--phone-transition-parameters", dest='phone_transition_para_array', + type=str, action='append', required = True, + help = "Options to build topology. \n" + "--phone-list= # Colon-separated list of phones\n" + "--min-duration= # Min duration for the phones\n" + "--end-transition-probability= # Probability of the end transition after the minimum duration\n") + parser.add_argument("dir", type=str, + help = "Output lang directory") + args = parser.parse_args() + return args + + +def ParsePhoneTransitionParameters(para_array): + parser = argparse.ArgumentParser() + + parser.add_argument("--phone-list", type=str, required=True, + help="Colon-separated list of phones") + parser.add_argument("--min-duration", type=int, default=3, + help="Minimum number of states for the phone") + parser.add_argument("--end-transition-probability", type=float, default=0.1, + help="Probability of the end transition after the minimum duration") + + phone_transition_parameters = [ parser.parse_args(shlex.split(x)) for x in para_array ] + + for t in phone_transition_parameters: + if (t.end_transition_probability > 1.0 or + t.end_transition_probability < 0.0): + raise ValueError("Expected --end-transition-probability to be " + "between 0 and 1, got {0} for phones {1}".format( + t.end_transition_probability, t.phone_list)) + if t.min_duration > 100 or t.min_duration < 1: + raise ValueError("Expected --min-duration to be " + "between 1 and 100, got {0} for phones {1}".format( + t.min_duration, t.phone_list)) + + t.phone_list = t.phone_list.split(":") + + return phone_transition_parameters + +def GetPhoneMap(phone_transition_parameters): + phone2int = {} + n = 1 + for t in phone_transition_parameters: + for p in t.phone_list: + if p in phone2int: + raise Exception("Phone {0} found in multiple topologies".format(p)) + phone2int[p] = n + n += 1 + + return phone2int + +def Main(): + args = GetArgs() + phone_transition_parameters = ParsePhoneTransitionParameters(args.phone_transition_para_array) + + phone2int = GetPhoneMap(phone_transition_parameters) + + topo = open("{0}/topo".format(args.dir), 'w') + + print ("", file = topo) + + for t in phone_transition_parameters: + print ("", file = topo) + print ("", file = topo) + print ("{0}".format(" ".join([str(phone2int[p]) for p in t.phone_list])), file = topo) + print ("", file = topo) + + for state in range(0, t.min_duration-1): + print(" {0} 0 {1} 1.0 ".format(state, state + 1), file = topo) + print(" {state} 0 {state} {self_prob} {next_state} {next_prob} ".format( + state = t.min_duration - 1, next_state = t.min_duration, + self_prob = 1 - t.end_transition_probability, + next_prob = t.end_transition_probability), file = topo) + print(" {state} ".format(state = t.min_duration), file = topo) # Final state + print ("", file = topo) + print ("", file = topo) + + phones_file = open("{0}/phones.txt".format(args.dir), 'w') + + for p,n in sorted(list(phone2int.items()), key = lambda x:x[1]): + print ("{0} {1}".format(p, n), file = phones_file) + +if __name__ == '__main__': + Main() diff --git a/egs/wsj/s5/steps/segmentation/post_process_sad_to_segments.sh b/egs/wsj/s5/steps/segmentation/post_process_sad_to_segments.sh index f4011f20a03..c1006d09678 100755 --- a/egs/wsj/s5/steps/segmentation/post_process_sad_to_segments.sh +++ b/egs/wsj/s5/steps/segmentation/post_process_sad_to_segments.sh @@ -16,7 +16,9 @@ frame_shift=0.01 weight_threshold=0.5 ali_suffix=_acwt0.1 -phone_map= +frame_subsampling_factor=1 + +phone2sad_map= . utils/parse_options.sh @@ -48,56 +50,81 @@ fi dir=$1 segmented_data_dir=$2 -cat $data_dir/segments | awk '{print $1" "$2}' | \ - utils/utt2spk_to_spk2utt.pl > $data_dir/reco2utt - -utils/split_data.sh $data_dir $nj - -for n in `seq $nj`; do - cat $data_dir/split$nj/$n/segments | awk '{print $1" "$2}' | \ - utils/utt2spk_to_spk2utt.pl > $data_dir/split$nj/$n/reco2utt -done - +utils/data/get_reco2utt.sh $data_dir mkdir -p $dir if [ ! -z "$vad_dir" ]; then nj=`cat $vad_dir/num_jobs` || exit 1 + + utils/split_data.sh $data_dir $nj - if [ -z "$phone_map" ]; then - phone_map=$dir/phone_map + for n in `seq $nj`; do + cat $data_dir/split$nj/$n/segments | awk '{print $1" "$2}' | \ + utils/utt2spk_to_spk2utt.pl > $data_dir/split$nj/$n/reco2utt + done + + if [ -z "$phone2sad_map" ]; then + phone2sad_map=$dir/phone2sad_map { cat $lang/phones/silence.int | awk '{print $1" 0"}'; cat $lang/phones/nonsilence.int | awk '{print $1" 1"}'; - } | sort -k1,1 -n > $dir/phone_map + } | sort -k1,1 -n > $dir/phone2sad_map fi + frame_shift_subsampled=`perl -e "print ($frame_subsampling_factor * $frame_shift)"` + if [ $stage -le 0 ]; then # Convert the original SAD into segmentation $cmd JOB=1:$nj $dir/log/segmentation.JOB.log \ - segmentation-init-from-ali --reco2utt-rspecifier="ark,t:$data_dir/split$nj/JOB/reco2utt" \ - --segmentation-rspecifier="ark:segmentation-init-from-segments --shift-to-zero=false --frame-shift=$frame_shift $data_dir/split$nj/JOB/segments ark:- |" \ + segmentation-init-from-ali \ "ark:gunzip -c $vad_dir/ali${ali_suffix}.JOB.gz |" ark:- \| \ - segmentation-copy --label-map=$phone_map ark:- \ + segmentation-combine-segments ark:- \ + "ark:segmentation-init-from-segments --shift-to-zero=false --frame-shift=$frame_shift_subsampled $data_dir/split$nj/JOB/segments ark:- |" \ + "ark,t:$data_dir/split$nj/JOB/reco2utt" ark:- \| \ + segmentation-copy --label-map=$phone2sad_map \ + --frame-subsampling-factor=$frame_subsampling_factor ark:- \ "ark:| gzip -c > $dir/orig_segmentation.JOB.gz" fi else + utils/split_data.sh $data_dir $nj + for n in `seq $nj`; do - utils/filter_scp.pl $data_dir/split$nj/$n/reco2utt $weights_scp > $dir/weights.$n.scp + utils/data/get_reco2utt.sh $data_dir/split$nj/$n + utils/filter_scp.pl $data_dir/split$nj/$n/reco2utt $weights_scp > \ + $dir/weights.$n.scp done $cmd JOB=1:$nj $dir/log/weights_to_segments.JOB.log \ copy-vector scp:$dir/weights.JOB.scp ark,t:- \| \ awk -v t=$weight_threshold '{printf $1; for (i=3; i < NF; i++) { if ($i >= t) printf (" 1"); else printf (" 0"); }; print "";}' \| \ - segmentation-init-from-ali --reco2utt-rspecifier="ark,t:$data_dir/split$nj/JOB/reco2utt" \ - --segmentation-rspecifier="ark:segmentation-init-from-segments --shift-to-zero=false --frame-shift=$frame_shift $data_dir/split$nj/JOB/segments ark:- |" \ - ark,t:- "ark:| gzip -c > $dir/orig_segmentation.JOB.gz" + segmentation-init-from-ali \ + ark,t:- ark:- \| segmentation-combine-segments ark:- \ + "ark:segmentation-init-from-segments --shift-to-zero=false --frame-shift=$frame_shift_subsampled $data_dir/split$nj/JOB/segments ark:- |" \ + "ark,t:$data_dir/split$nj/JOB/reco2utt" ark:- \| \ + segmentation-copy --frame-subsampling-factor=$frame_subsampling_factor \ + ark:- "ark:| gzip -c > $dir/orig_segmentation.JOB.gz" fi echo $nj > $dir/num_jobs +if [ $stage -le 1 ]; then + rm -r $segmented_data_dir || true + utils/data/convert_data_dir_to_whole.sh $data_dir $segmented_data_dir || exit 1 + rm $segmented_data_dir/text || true +fi + steps/segmentation/internal/post_process_segments.sh \ --stage $stage --cmd "$cmd" \ --config $segmentation_config --frame-shift $frame_shift \ $data_dir $dir $segmented_data_dir + +utils/utt2spk_to_spk2utt.pl $segmented_data_dir/utt2spk > $segmented_data_dir/spk2utt || exit 1 +utils/fix_data_dir.sh $segmented_data_dir + +if [ ! -s $segmented_data_dir/utt2spk ] || [ ! -s $segmented_data_dir/segments ]; then + echo "$0: Segmentation failed to generate segments or utt2spk!" + exit 1 +fi + diff --git a/egs/wsj/s5/steps/segmentation/post_process_sad_to_subsegments.sh b/egs/wsj/s5/steps/segmentation/post_process_sad_to_subsegments.sh new file mode 100644 index 00000000000..8cfcaa40cda --- /dev/null +++ b/egs/wsj/s5/steps/segmentation/post_process_sad_to_subsegments.sh @@ -0,0 +1,69 @@ +#! /bin/bash + +# Copyright 2015 Vimal Manohar +# Apache 2.0. + +set -e -o pipefail -u +. path.sh + +cmd=run.pl +stage=-10 + +segmentation_config=conf/segmentation.conf +nj=18 + +frame_shift=0.01 + +. utils/parse_options.sh + +if [ $# -ne 5 ]; then + echo "Usage: $0 " + echo " e.g.: $0 data/dev_aspire_whole exp/vad_dev_aspire data/dev_aspire_seg" + exit 1 +fi + +data_dir=$1 +phone2sad_map=$2 +vad_dir=$3 +dir=$4 +segmented_data_dir=$5 + +mkdir -p $dir + +nj=`cat $vad_dir/num_jobs` || exit 1 + +utils/split_data.sh $data_dir $nj + +if [ $stage -le 0 ]; then + # Convert the original SAD into segmentation + $cmd JOB=1:$nj $dir/log/segmentation.JOB.log \ + segmentation-init-from-ali \ + "ark:gunzip -c $vad_dir/ali.JOB.gz |" ark:- \| \ + segmentation-copy --label-map=$phone2sad_map ark:- \ + "ark:| gzip -c > $dir/orig_segmentation.JOB.gz" +fi + +echo $nj > $dir/num_jobs + +if [ $stage -le 1 ]; then + rm -r $segmented_data_dir || true + utils/data/convert_data_dir_to_whole.sh $data_dir $segmented_data_dir || exit 1 + rm $segmented_data_dir/text || true +fi + +steps/segmentation/internal/post_process_segments.sh \ + --stage $stage --cmd "$cmd" \ + --config $segmentation_config --frame-shift $frame_shift \ + $data_dir $dir $segmented_data_dir + +mv $segmented_data_dir/segments $segmented_data_dir/sub_segments +utils/data/subsegment_data_dir.sh $data_dir $segmented_data_dir/sub_segments $segmented_data_dir + +utils/utt2spk_to_spk2utt.pl $segmented_data_dir/utt2spk > $segmented_data_dir/spk2utt || exit 1 +utils/fix_data_dir.sh $segmented_data_dir + +if [ ! -s $segmented_data_dir/utt2spk ] || [ ! -s $segmented_data_dir/segments ]; then + echo "$0: Segmentation failed to generate segments or utt2spk!" + exit 1 +fi + From e80e4a99d52f855e5c15fbc23d5c96593e65e410 Mon Sep 17 00:00:00 2001 From: Vimal Manohar Date: Fri, 18 Nov 2016 22:07:53 -0500 Subject: [PATCH 038/213] asr_diarization: Adding min-extra-left-context --- egs/wsj/s5/steps/libs/common.py | 1 + .../nnet3/train/frame_level_objf/common.py | 11 +++++++ egs/wsj/s5/steps/nnet3/train_raw_rnn.py | 32 +++++++++++++++++-- 3 files changed, 42 insertions(+), 2 deletions(-) diff --git a/egs/wsj/s5/steps/libs/common.py b/egs/wsj/s5/steps/libs/common.py index 1e0608525ba..f2a336cd640 100644 --- a/egs/wsj/s5/steps/libs/common.py +++ b/egs/wsj/s5/steps/libs/common.py @@ -315,6 +315,7 @@ def split_data(data, num_jobs): run_kaldi_command("utils/split_data.sh {data} {num_jobs}".format( data=data, num_jobs=num_jobs)) + return "{0}/split{1}".format(data, num_jobs) def read_kaldi_matrix(matrix_file): diff --git a/egs/wsj/s5/steps/libs/nnet3/train/frame_level_objf/common.py b/egs/wsj/s5/steps/libs/nnet3/train/frame_level_objf/common.py index d0cb2a52758..55508daf02c 100644 --- a/egs/wsj/s5/steps/libs/nnet3/train/frame_level_objf/common.py +++ b/egs/wsj/s5/steps/libs/nnet3/train/frame_level_objf/common.py @@ -31,6 +31,7 @@ def train_new_models(dir, iter, srand, num_jobs, cache_read_opt, run_opts, frames_per_eg=-1, min_deriv_time=None, max_deriv_time=None, + min_left_context=None, min_right_context=None, extra_egs_copy_cmd=""): """ Called from train_one_iteration(), this model does one iteration of training with 'num_jobs' jobs, and writes files like @@ -64,6 +65,13 @@ def train_new_models(dir, iter, srand, num_jobs, deriv_time_opts.append("--optimization.max-deriv-time={0}".format( max_deriv_time)) + this_random = random.Random(srand) + if min_left_context is not None: + left_context = this_random.randint(min_left_context, left_context) + + if min_right_context is not None: + right_context = this_random.randint(min_right_context, right_context) + context_opts = "--left-context={0} --right-context={1}".format( left_context, right_context) @@ -144,6 +152,7 @@ def train_one_iteration(dir, iter, srand, egs_dir, run_opts, cv_minibatch_size=256, frames_per_eg=-1, min_deriv_time=None, max_deriv_time=None, + min_left_context=None, min_right_context=None, shrinkage_value=1.0, get_raw_nnet_from_am=True, background_process_handler=None, @@ -283,6 +292,8 @@ def train_one_iteration(dir, iter, srand, egs_dir, frames_per_eg=frames_per_eg, min_deriv_time=min_deriv_time, max_deriv_time=max_deriv_time, + min_left_context=min_left_context, + min_right_context=min_right_context, extra_egs_copy_cmd=extra_egs_copy_cmd) [models_to_average, best_model] = common_train_lib.get_successful_models( diff --git a/egs/wsj/s5/steps/nnet3/train_raw_rnn.py b/egs/wsj/s5/steps/nnet3/train_raw_rnn.py index ae038445fc0..e4af318fb57 100755 --- a/egs/wsj/s5/steps/nnet3/train_raw_rnn.py +++ b/egs/wsj/s5/steps/nnet3/train_raw_rnn.py @@ -72,6 +72,18 @@ def get_args(): parser.add_argument("--egs.extra-copy-cmd", type=str, dest='extra_egs_copy_cmd', default = "", help="""Modify egs before passing it to training"""); + parser.add_argument("--trainer.min-chunk-left-context", type=int, + dest='min_chunk_left_context', default=None, + help="""If provided and is less than + --egs.chunk-left-context, then the chunk left context + is randomized between egs.chunk-left-context and + this value.""") + parser.add_argument("--trainer.min-chunk-right-context", type=int, + dest='min_chunk_right_context', default=None, + help="""If provided and is less than + --egs.chunk-right-context, then the chunk right context + is randomized between egs.chunk-right-context and + this value.""") # trainer options parser.add_argument("--trainer.samples-per-iter", type=int, @@ -184,6 +196,12 @@ def process_args(args): "--trainer.deriv-truncate-margin.".format( args.deriv_truncate_margin)) + if args.min_chunk_left_context is None: + args.min_chunk_left_context = args.chunk_left_context + + if args.min_chunk_right_context is None: + args.min_chunk_right_context = args.chunk_right_context + if (not os.path.exists(args.dir) or not os.path.exists(args.dir+"/configs")): raise Exception("This scripts expects {0} to exist and have a configs " @@ -254,12 +272,18 @@ def train(args, run_opts, background_process_handler): # discriminative pretraining num_hidden_layers = variables['num_hidden_layers'] add_lda = common_lib.str_to_bool(variables['add_lda']) - include_log_softmax = common_lib.str_to_bool( - variables['include_log_softmax']) except KeyError as e: raise Exception("KeyError {0}: Variables need to be defined in " "{1}".format(str(e), '{0}/configs'.format(args.dir))) + try: + include_log_softmax = common_lib.str_to_bool( + variables['include_log_softmax']) + except KeyError as e: + logger.warning("KeyError {0}: Using default include-log-softmax value " + "as False.".format(str(e))) + include_log_softmax = False + left_context = args.chunk_left_context + model_left_context right_context = args.chunk_right_context + model_right_context @@ -419,6 +443,10 @@ def learning_rate(iter, current_num_jobs, num_archives_processed): add_layers_period=args.add_layers_period, left_context=left_context, right_context=right_context, + min_left_context=args.min_chunk_left_context + + model_left_context, + min_right_context=args.min_chunk_right_context + + model_right_context, min_deriv_time=min_deriv_time, max_deriv_time=max_deriv_time, momentum=args.momentum, From b79f0faa9f04522d0aa4589108da84a833f7102b Mon Sep 17 00:00:00 2001 From: Vimal Manohar Date: Sun, 27 Nov 2016 01:47:22 -0500 Subject: [PATCH 039/213] asr_diarization: Segmentation tools --- src/Makefile | 11 +- src/segmenter/Makefile | 16 + src/segmenter/segment.cc | 35 + src/segmenter/segment.h | 78 ++ src/segmenter/segmentation-io-test.cc | 63 ++ src/segmenter/segmentation-post-processor.cc | 198 +++++ src/segmenter/segmentation-post-processor.h | 168 ++++ src/segmenter/segmentation-test.cc | 226 ++++++ src/segmenter/segmentation-utils.cc | 743 ++++++++++++++++++ src/segmenter/segmentation-utils.h | 337 ++++++++ src/segmenter/segmentation.cc | 201 +++++ src/segmenter/segmentation.h | 144 ++++ src/segmenterbin/Makefile | 36 + ...entation-combine-segments-to-recordings.cc | 114 +++ .../segmentation-combine-segments.cc | 128 +++ src/segmenterbin/segmentation-copy.cc | 232 ++++++ .../segmentation-create-subsegments.cc | 175 +++++ src/segmenterbin/segmentation-get-stats.cc | 125 +++ .../segmentation-init-from-ali.cc | 91 +++ .../segmentation-init-from-lengths.cc | 82 ++ .../segmentation-init-from-segments.cc | 179 +++++ .../segmentation-intersect-ali.cc | 99 +++ .../segmentation-intersect-segments.cc | 145 ++++ .../segmentation-merge-recordings.cc | 101 +++ src/segmenterbin/segmentation-merge.cc | 146 ++++ src/segmenterbin/segmentation-post-process.cc | 142 ++++ .../segmentation-remove-segments.cc | 155 ++++ .../segmentation-split-segments.cc | 194 +++++ src/segmenterbin/segmentation-to-ali.cc | 99 +++ src/segmenterbin/segmentation-to-rttm.cc | 255 ++++++ src/segmenterbin/segmentation-to-segments.cc | 133 ++++ tools/config/common_path.sh | 1 + 32 files changed, 4847 insertions(+), 5 deletions(-) create mode 100644 src/segmenter/Makefile create mode 100644 src/segmenter/segment.cc create mode 100644 src/segmenter/segment.h create mode 100644 src/segmenter/segmentation-io-test.cc create mode 100644 src/segmenter/segmentation-post-processor.cc create mode 100644 src/segmenter/segmentation-post-processor.h create mode 100644 src/segmenter/segmentation-test.cc create mode 100644 src/segmenter/segmentation-utils.cc create mode 100644 src/segmenter/segmentation-utils.h create mode 100644 src/segmenter/segmentation.cc create mode 100644 src/segmenter/segmentation.h create mode 100644 src/segmenterbin/Makefile create mode 100644 src/segmenterbin/segmentation-combine-segments-to-recordings.cc create mode 100644 src/segmenterbin/segmentation-combine-segments.cc create mode 100644 src/segmenterbin/segmentation-copy.cc create mode 100644 src/segmenterbin/segmentation-create-subsegments.cc create mode 100644 src/segmenterbin/segmentation-get-stats.cc create mode 100644 src/segmenterbin/segmentation-init-from-ali.cc create mode 100644 src/segmenterbin/segmentation-init-from-lengths.cc create mode 100644 src/segmenterbin/segmentation-init-from-segments.cc create mode 100644 src/segmenterbin/segmentation-intersect-ali.cc create mode 100644 src/segmenterbin/segmentation-intersect-segments.cc create mode 100644 src/segmenterbin/segmentation-merge-recordings.cc create mode 100644 src/segmenterbin/segmentation-merge.cc create mode 100644 src/segmenterbin/segmentation-post-process.cc create mode 100644 src/segmenterbin/segmentation-remove-segments.cc create mode 100644 src/segmenterbin/segmentation-split-segments.cc create mode 100644 src/segmenterbin/segmentation-to-ali.cc create mode 100644 src/segmenterbin/segmentation-to-rttm.cc create mode 100644 src/segmenterbin/segmentation-to-segments.cc diff --git a/src/Makefile b/src/Makefile index 9905be869a0..a42f78f4742 100644 --- a/src/Makefile +++ b/src/Makefile @@ -6,16 +6,16 @@ SHELL := /bin/bash SUBDIRS = base matrix util feat tree thread gmm transform sgmm \ - fstext hmm lm decoder lat kws cudamatrix nnet \ + fstext hmm lm decoder lat kws cudamatrix nnet segmenter \ bin fstbin gmmbin fgmmbin sgmmbin featbin \ nnetbin latbin sgmm2 sgmm2bin nnet2 nnet3 chain nnet3bin nnet2bin kwsbin \ - ivector ivectorbin online2 online2bin lmbin chainbin + ivector ivectorbin online2 online2bin lmbin chainbin segmenterbin MEMTESTDIRS = base matrix util feat tree thread gmm transform sgmm \ - fstext hmm lm decoder lat nnet kws chain \ + fstext hmm lm decoder lat nnet kws chain segmenter \ bin fstbin gmmbin fgmmbin sgmmbin featbin \ nnetbin latbin sgmm2 nnet2 nnet3 nnet2bin nnet3bin sgmm2bin kwsbin \ - ivector ivectorbin online2 online2bin lmbin + ivector ivectorbin online2 online2bin lmbin segmenterbin CUDAMEMTESTDIR = cudamatrix @@ -155,7 +155,7 @@ $(EXT_SUBDIRS) : mklibdir bin fstbin gmmbin fgmmbin sgmmbin sgmm2bin featbin nnetbin nnet2bin nnet3bin chainbin latbin ivectorbin lmbin kwsbin online2bin: \ base matrix util feat tree thread gmm transform sgmm sgmm2 fstext hmm \ - lm decoder lat cudamatrix nnet nnet2 nnet3 ivector chain kws online2 + lm decoder lat cudamatrix nnet nnet2 nnet3 ivector chain kws online2 segmenter #2)The libraries have inter-dependencies base: base/.depend.mk @@ -179,6 +179,7 @@ nnet2: base util matrix thread lat gmm hmm tree transform cudamatrix nnet3: base util matrix thread lat gmm hmm tree transform cudamatrix chain fstext chain: lat hmm tree fstext matrix cudamatrix util thread base ivector: base util matrix thread transform tree gmm +segmenter: base matrix util gmm thread #3)Dependencies for optional parts of Kaldi onlinebin: base matrix util feat tree gmm transform sgmm sgmm2 fstext hmm lm decoder lat cudamatrix nnet nnet2 online thread # python-kaldi-decoding: base matrix util feat tree thread gmm transform sgmm sgmm2 fstext hmm decoder lat online diff --git a/src/segmenter/Makefile b/src/segmenter/Makefile new file mode 100644 index 00000000000..03df6132050 --- /dev/null +++ b/src/segmenter/Makefile @@ -0,0 +1,16 @@ +all: + +include ../kaldi.mk + +TESTFILES = segmentation-io-test + +OBJFILES = segment.o segmentation.o segmentation-utils.o \ + segmentation-post-processor.o + +LIBNAME = kaldi-segmenter + +ADDLIBS = ../gmm/kaldi-gmm.a \ + ../util/kaldi-util.a ../matrix/kaldi-matrix.a ../base/kaldi-base.a ../thread/kaldi-thread.a + +include ../makefiles/default_rules.mk + diff --git a/src/segmenter/segment.cc b/src/segmenter/segment.cc new file mode 100644 index 00000000000..b4f485c26bc --- /dev/null +++ b/src/segmenter/segment.cc @@ -0,0 +1,35 @@ +#include "segmenter/segment.h" + +namespace kaldi { +namespace segmenter { + +void Segment::Write(std::ostream &os, bool binary) const { + if (binary) { + os.write(reinterpret_cast(&start_frame), sizeof(start_frame)); + os.write(reinterpret_cast(&end_frame), sizeof(start_frame)); + os.write(reinterpret_cast(&class_id), sizeof(class_id)); + } else { + WriteBasicType(os, binary, start_frame); + WriteBasicType(os, binary, end_frame); + WriteBasicType(os, binary, Label()); + } +} + +void Segment::Read(std::istream &is, bool binary) { + if (binary) { + is.read(reinterpret_cast(&start_frame), sizeof(start_frame)); + is.read(reinterpret_cast(&end_frame), sizeof(end_frame)); + is.read(reinterpret_cast(&class_id), sizeof(class_id)); + } else { + ReadBasicType(is, binary, &start_frame); + ReadBasicType(is, binary, &end_frame); + int32 label; + ReadBasicType(is, binary, &label); + SetLabel(label); + } + + KALDI_ASSERT(end_frame >= start_frame && start_frame >= 0); +} + +} // end namespace segmenter +} // end namespace kaldi diff --git a/src/segmenter/segment.h b/src/segmenter/segment.h new file mode 100644 index 00000000000..1657affc875 --- /dev/null +++ b/src/segmenter/segment.h @@ -0,0 +1,78 @@ +#ifndef KALDI_SEGMENTER_SEGMENT_H_ +#define KALDI_SEGMENTER_SEGMENT_H_ + +#include "base/kaldi-common.h" +#include "matrix/kaldi-matrix.h" + +namespace kaldi { +namespace segmenter { + +/** + * This structure defines a single segment. It consists of the following basic + * properties: + * 1) start_frame : This is the frame index of the first frame in the + * segment. + * 2) end_frame : This is the frame index of the last frame in the segment. + * Note that the end_frame is included in the segment. + * 3) class_id : This is the class corresponding to the segments. For e.g., + * could be 0, 1 or 2 depending on whether the segment is + * silence, speech or noise. In general, it can be any + * integer class label. +**/ + +struct Segment { + int32 start_frame; + int32 end_frame; + int32 class_id; + + // Accessors for labels or class id. This is useful in the future when + // we might change the type of label. + inline int32 Label() const { return class_id; } + inline void SetLabel(int32 label) { class_id = label; } + inline int32 Length() const { return end_frame - start_frame + 1; } + + // This is the default constructor that sets everything to undefined values. + Segment() : start_frame(-1), end_frame(-1), class_id(-1) { } + + // This constructor initializes the segmented with the provided start and end + // frames and the segment label. This is the main constructor. + Segment(int32 start, int32 end, int32 label) : + start_frame(start), end_frame(end), class_id(label) { } + + void Write(std::ostream &os, bool binary) const; + void Read(std::istream &is, bool binary); + + // This is a function that returns the size of the elements in the structure. + // It is used during I/O in binary mode, which checks for the total size + // required to store the segment. + static size_t SizeInBytes() { + return (sizeof(int32) + sizeof(int32) + sizeof(int32)); + } +}; + +/** + * Comparator to order segments based on start frame +**/ + +class SegmentComparator { + public: + bool operator() (const Segment &lhs, const Segment &rhs) const { + return lhs.start_frame < rhs.start_frame; + } +}; + +/** + * Comparator to order segments based on length +**/ + +class SegmentLengthComparator { + public: + bool operator() (const Segment &lhs, const Segment &rhs) const { + return lhs.Length() < rhs.Length(); + } +}; + +} // end namespace segmenter +} // end namespace kaldi + +#endif // KALDI_SEGMENTER_SEGMENT_H_ diff --git a/src/segmenter/segmentation-io-test.cc b/src/segmenter/segmentation-io-test.cc new file mode 100644 index 00000000000..f019a653a4a --- /dev/null +++ b/src/segmenter/segmentation-io-test.cc @@ -0,0 +1,63 @@ +// segmenter/segmentation-io-test.cc + +// Copyright 2015 Vimal Manohar (Johns Hopkins University) + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "segmenter/segmentation.h" + +namespace kaldi { +namespace segmenter { + +void UnitTestSegmentationIo() { + Segmentation seg; + int32 max_length = RandInt(0, 1000), + max_segment_length = max_length / 10, + num_classes = RandInt(0, 3); + + if (max_segment_length == 0) + max_segment_length = 1; + + seg.GenRandomSegmentation(max_length, max_segment_length, num_classes); + + bool binary = ( RandInt(0,1) == 0 ); + std::ostringstream os; + + seg.Write(os, binary); + + Segmentation seg2; + std::istringstream is(os.str()); + seg2.Read(is, binary); + + std::ostringstream os2; + seg2.Write(os2, binary); + + KALDI_ASSERT(os2.str() == os.str()); +} + +} // namespace segmenter +} // namespace kaldi + +int main() { + using namespace kaldi; + using namespace kaldi::segmenter; + + for (int32 i = 0; i < 100; i++) + UnitTestSegmentationIo(); + return 0; +} + + diff --git a/src/segmenter/segmentation-post-processor.cc b/src/segmenter/segmentation-post-processor.cc new file mode 100644 index 00000000000..2c97e31db56 --- /dev/null +++ b/src/segmenter/segmentation-post-processor.cc @@ -0,0 +1,198 @@ +// segmenter/segmentation-post-processor.h + +// Copyright 2016 Vimal Manohar + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "segmenter/segmentation-utils.h" +#include "segmenter/segmentation-post-processor.h" + +namespace kaldi { +namespace segmenter { + +static inline bool IsMergingLabelsToBeDone( + const SegmentationPostProcessingOptions &opts) { + return (!opts.merge_labels_csl.empty() || opts.merge_dst_label != -1); +} + +static inline bool IsPaddingSegmentsToBeDone( + const SegmentationPostProcessingOptions &opts) { + return (opts.pad_label != -1 || opts.pad_length != -1); +} + +static inline bool IsShrinkingSegmentsToBeDone( + const SegmentationPostProcessingOptions &opts) { + return (opts.shrink_label != -1 || opts.shrink_length != -1); +} + +static inline bool IsBlendingShortSegmentsToBeDone( + const SegmentationPostProcessingOptions &opts) { + return (opts.blend_short_segments_class != -1 || opts.max_blend_length != -1); +} + +static inline bool IsRemovingSegmentsToBeDone( + const SegmentationPostProcessingOptions &opts) { + return (!opts.remove_labels_csl.empty()); +} + +static inline bool IsMergingAdjacentSegmentsToBeDone( + const SegmentationPostProcessingOptions &opts) { + return (opts.merge_adjacent_segments); +} + +static inline bool IsSplittingSegmentsToBeDone( + const SegmentationPostProcessingOptions &opts) { + return (opts.max_segment_length != -1); +} + + +SegmentationPostProcessor::SegmentationPostProcessor( + const SegmentationPostProcessingOptions &opts) : opts_(opts) { + if (!opts_.remove_labels_csl.empty()) { + if (!SplitStringToIntegers(opts_.remove_labels_csl, ":", + false, &remove_labels_)) { + KALDI_ERR << "Bad value for --remove-labels option: " + << opts_.remove_labels_csl; + } + std::sort(remove_labels_.begin(), remove_labels_.end()); + } + + if (!opts_.merge_labels_csl.empty()) { + if (!SplitStringToIntegers(opts_.merge_labels_csl, ":", + false, &merge_labels_)) { + KALDI_ERR << "Bad value for --merge-labels option: " + << opts_.merge_labels_csl; + } + std::sort(merge_labels_.begin(), merge_labels_.end()); + } + + Check(); +} + +void SegmentationPostProcessor::Check() const { + if (IsPaddingSegmentsToBeDone(opts_) && opts_.pad_label < 0) { + KALDI_ERR << "Invalid value " << opts_.pad_label << " for option " + << "--pad-label. It must be non-negative."; + } + + if (IsPaddingSegmentsToBeDone(opts_) && opts_.pad_length <= 0) { + KALDI_ERR << "Invalid value " << opts_.pad_length << " for option " + << "--pad-length. It must be positive."; + } + + if (IsShrinkingSegmentsToBeDone(opts_) && opts_.shrink_label < 0) { + KALDI_ERR << "Invalid value " << opts_.shrink_label << " for option " + << "--shrink-label. It must be non-negative."; + } + + if (IsShrinkingSegmentsToBeDone(opts_) && opts_.shrink_length <= 0) { + KALDI_ERR << "Invalid value " << opts_.shrink_length << " for option " + << "--shrink-length. It must be positive."; + } + + if (IsBlendingShortSegmentsToBeDone(opts_) && + opts_.blend_short_segments_class < 0) { + KALDI_ERR << "Invalid value " << opts_.blend_short_segments_class + << " for option " << "--blend-short-segments-class. " + << "It must be non-negative."; + } + + if (IsBlendingShortSegmentsToBeDone(opts_) && opts_.max_blend_length <= 0) { + KALDI_ERR << "Invalid value " << opts_.max_blend_length << " for option " + << "--max-blend-length. It must be positive."; + } + + if (IsRemovingSegmentsToBeDone(opts_) && remove_labels_[0] < 0) { + KALDI_ERR << "Invalid value " << opts_.remove_labels_csl + << " for option " << "--remove-labels. " + << "The labels must be non-negative."; + } + + if (IsMergingAdjacentSegmentsToBeDone(opts_) && + opts_.max_intersegment_length < 0) { + KALDI_ERR << "Invalid value " << opts_.max_intersegment_length + << " for option " + << "--max-intersegment-length. It must be non-negative."; + } + + if (IsSplittingSegmentsToBeDone(opts_) && opts_.max_segment_length <= 0) { + KALDI_ERR << "Invalid value " << opts_.max_segment_length + << " for option " + << "--max-segment-length. It must be positive."; + } + + if (opts_.post_process_label != -1 && opts_.post_process_label < 0) { + KALDI_ERR << "Invalid value " << opts_.post_process_label << " for option " + << "--post-process-label. It must be non-negative."; + } +} + +bool SegmentationPostProcessor::PostProcess(Segmentation *seg) const { + DoMergingLabels(seg); + DoPaddingSegments(seg); + DoShrinkingSegments(seg); + DoBlendingShortSegments(seg); + DoRemovingSegments(seg); + DoMergingAdjacentSegments(seg); + DoSplittingSegments(seg); + + return true; +} + +void SegmentationPostProcessor::DoMergingLabels(Segmentation *seg) const { + if (!IsMergingLabelsToBeDone(opts_)) return; + MergeLabels(merge_labels_, opts_.merge_dst_label, seg); +} + +void SegmentationPostProcessor::DoPaddingSegments(Segmentation *seg) const { + if (!IsPaddingSegmentsToBeDone(opts_)) return; + PadSegments(opts_.pad_label, opts_.pad_length, seg); +} + +void SegmentationPostProcessor::DoShrinkingSegments(Segmentation *seg) const { + if (!IsShrinkingSegmentsToBeDone(opts_)) return; + ShrinkSegments(opts_.shrink_label, opts_.shrink_length, seg); +} + +void SegmentationPostProcessor::DoBlendingShortSegments( + Segmentation *seg) const { + if (!IsBlendingShortSegmentsToBeDone(opts_)) return; + BlendShortSegmentsWithNeighbors(opts_.blend_short_segments_class, + opts_.max_blend_length, + opts_.max_intersegment_length, seg); +} + +void SegmentationPostProcessor::DoRemovingSegments(Segmentation *seg) const { + if (!IsRemovingSegmentsToBeDone(opts_)) return; + RemoveSegments(remove_labels_, seg); +} + +void SegmentationPostProcessor::DoMergingAdjacentSegments( + Segmentation *seg) const { + if (!IsMergingAdjacentSegmentsToBeDone(opts_)) return; + MergeAdjacentSegments(opts_.max_intersegment_length, seg); +} + +void SegmentationPostProcessor::DoSplittingSegments(Segmentation *seg) const { + if (!IsSplittingSegmentsToBeDone(opts_)) return; + SplitSegments(opts_.max_segment_length, + opts_.max_segment_length / 2, + opts_.overlap_length, + opts_.post_process_label, seg); +} + +} // end namespace segmenter +} // end namespace kaldi diff --git a/src/segmenter/segmentation-post-processor.h b/src/segmenter/segmentation-post-processor.h new file mode 100644 index 00000000000..01a23b93b1b --- /dev/null +++ b/src/segmenter/segmentation-post-processor.h @@ -0,0 +1,168 @@ +// segmenter/segmentation-post-processor.h + +// Copyright 2016 Vimal Manohar + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#ifndef KALDI_SEGMENTER_SEGMENTATION_POST_PROCESSOR_H_ +#define KALDI_SEGMENTER_SEGMENTATION_POST_PROCESSOR_H_ + +#include "base/kaldi-common.h" +#include "itf/options-itf.h" +#include "segmenter/segmentation.h" + +namespace kaldi { +namespace segmenter { + +/** + * Structure for some common options related to segmentation that would be used + * in multiple segmentation programs. Some of the operations include merging, + * filtering etc. +**/ + +struct SegmentationPostProcessingOptions { + std::string merge_labels_csl; + int32 merge_dst_label; + + int32 pad_label; + int32 pad_length; + + int32 shrink_label; + int32 shrink_length; + + int32 blend_short_segments_class; + int32 max_blend_length; + + std::string remove_labels_csl; + + bool merge_adjacent_segments; + int32 max_intersegment_length; + + int32 max_segment_length; + int32 overlap_length; + + int32 post_process_label; + + SegmentationPostProcessingOptions() : + merge_dst_label(-1), + pad_label(-1), pad_length(-1), + shrink_label(-1), shrink_length(-1), + blend_short_segments_class(-1), max_blend_length(-1), + merge_adjacent_segments(false), max_intersegment_length(0), + max_segment_length(-1), overlap_length(0), + post_process_label(-1) { } + + void Register(OptionsItf *opts) { + opts->Register("merge-labels", &merge_labels_csl, "Merge labels into a " + "single label defined by merge-dst-label. " + "The labels are specified as a colon-separated list. " + "Refer to the MergeLabels() code for details. " + "Used in conjunction with the option --merge-dst-label"); + opts->Register("merge-dst-label", &merge_dst_label, + "Merge labels specified by merge-labels into this label. " + "Refer to the MergeLabels() code for details. " + "Used in conjunction with the option --merge-labels."); + opts->Register("pad-label", &pad_label, + "Pad segments of this label by pad_length frames." + "Refer to the PadSegments() code for details. " + "Used in conjunction with the option --pad-length."); + opts->Register("pad-length", &pad_length, "Pad segments by this many " + "frames on either side. " + "Refer to the PadSegments() code for details. " + "Used in conjunction with the option --pad-label."); + opts->Register("shrink-label", &shrink_label, + "Shrink segments of this label by shrink_length frames. " + "Refer to the ShrinkSegments() code for details. " + "Used in conjunction with the option --shrink-length."); + opts->Register("shrink-length", &shrink_length, "Shrink segments by this " + "many frames on either side. " + "Refer to the ShrinkSegments() code for details. " + "Used in conjunction with the option --shrink-label."); + opts->Register("blend-short-segments-class", &blend_short_segments_class, + "The label for which the short segments are to be " + "blended with the neighboring segments that are less than " + "max_intersegment_length frames away. " + "Refer to BlendShortSegments() code for details. " + "Used in conjunction with the option --max-blend-length " + "and --max-intersegment-length."); + opts->Register("max-blend-length", &max_blend_length, + "The maximum length of segment in number of frames that " + "will be blended with the neighboring segments provided " + "they both have the same label. " + "Refer to BlendShortSegments() code for details. " + "Used in conjunction with the option " + "--blend-short-segments-class"); + opts->Register("remove-labels", &remove_labels_csl, + "Remove any segment whose label is contained in " + "remove_labels_csl. " + "Refer to the RemoveLabels() code for details."); + opts->Register("merge-adjacent-segments", &merge_adjacent_segments, + "Merge adjacent segments of the same label if they are " + "within max-intersegment-length distance. " + "Refer to the MergeAdjacentSegments() code for details. " + "Used in conjunction with the option " + "--max-intersegment-length\n"); + opts->Register("max-intersegment-length", &max_intersegment_length, + "The maximum intersegment length that is allowed for " + "two adjacent segments to be merged. " + "Refer to the MergeAdjacentSegments() code for details. " + "Used in conjunction with the option " + "--merge-adjacent-segments or " + "--blend-short-segments-class\n"); + opts->Register("max-segment-length", &max_segment_length, + "If segment is longer than this length, split it into " + "pieces with less than these many frames. " + "Refer to the SplitSegments() code for details. " + "Used in conjunction with the option --overlap-length."); + opts->Register("overlap-length", &overlap_length, + "When splitting segments longer than max-segment-length, " + "have the pieces overlap by these many frames. " + "Refer to the SplitSegments() code for details. " + "Used in conjunction with the option --max-segment-length."); + opts->Register("post-process-label", &post_process_label, + "Do post processing only on this label. This option is " + "applicable to only a few operations including " + "SplitSegments"); + } +}; + +class SegmentationPostProcessor { + public: + explicit SegmentationPostProcessor( + const SegmentationPostProcessingOptions &opts); + + bool PostProcess(Segmentation *seg) const; + + void DoMergingLabels(Segmentation *seg) const; + void DoPaddingSegments(Segmentation *seg) const; + void DoShrinkingSegments(Segmentation *seg) const; + void DoBlendingShortSegments(Segmentation *seg) const; + void DoRemovingSegments(Segmentation *seg) const; + void DoMergingAdjacentSegments(Segmentation *seg) const; + void DoSplittingSegments(Segmentation *seg) const; + + private: + const SegmentationPostProcessingOptions &opts_; + std::vector merge_labels_; + std::vector remove_labels_; + + void Check() const; +}; + +} // end namespace segmenter +} // end namespace kaldi + +#endif // KALDI_SEGMENTER_SEGMENTATION_POST_PROCESSOR_H_ diff --git a/src/segmenter/segmentation-test.cc b/src/segmenter/segmentation-test.cc new file mode 100644 index 00000000000..7654b23b119 --- /dev/null +++ b/src/segmenter/segmentation-test.cc @@ -0,0 +1,226 @@ +// segmenter/segmentation-test.cc + +// Copyright 2015 Vimal Manohar (Johns Hopkins University) + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "segmenter/segmentation.h" + +namespace kaldi { +namespace segmenter { + +void GenerateRandomSegmentation(int32 max_length, int32 num_classes, + Segmentation *segmentation) { + Clear(); + int32 s = max_length; + int32 e = max_length; + + while (s >= 0) { + int32 chunk_size = rand() % (max_length / 10); + s = e - chunk_size + 1; + int32 k = rand() % num_classes; + + if (k != 0) { + segmentation.Emplace(s, e, k); + } + e = s - 1; + } + Check(); +} + + +int32 GenerateRandomAlignment(int32 max_length, int32 num_classes, + std::vector *ali) { + int32 N = RandInt(1, max_length); + int32 C = RandInt(1, num_classes); + + ali->clear(); + + int32 len = 0; + while (len < N) { + int32 c = RandInt(0, C-1); + int32 n = std::min(RandInt(1, N), N - len); + ali->insert(ali->begin() + len, n, c); + len += n; + } + KALDI_ASSERT(ali->size() == N && len == N); + + int32 state = -1, num_segments = 0; + for (std::vector::const_iterator it = ali->begin(); + it != ali->end(); ++it) { + if (*it != state) num_segments++; + state = *it; + } + + return num_segments; +} + +void TestConversionToAlignment() { + std::vector ali; + int32 max_length = 1000, num_classes = 3; + int32 num_segments = GenerateRandomAlignment(max_length, num_classes, &ali); + + Segmentation seg; + KALDI_ASSERT(num_segments == seg.InsertFromAlignment(ali, 0)); + + std::vector out_ali; + { + seg.ConvertToAlignment(&out_ali); + KALDI_ASSERT(ali == out_ali); + } + + { + seg.ConvertToAlignment(&out_ali, num_classes, max_length * 2); + std::vector tmp_ali(out_ali.begin(), out_ali.begin() + ali.size()); + KALDI_ASSERT(ali == tmp_ali); + for (std::vector::const_iterator it = out_ali.begin() + ali.size(); + it != out_ali.end(); ++it) { + KALDI_ASSERT(*it == num_classes); + } + } + + seg.Clear(); + KALDI_ASSERT(num_segments == seg.InsertFromAlignment(ali, max_length)); + { + seg.ConvertToAlignment(&out_ali, num_classes, max_length * 2); + + for (std::vector::const_iterator it = out_ali.begin(); + it != out_ali.begin() + max_length; ++it) { + KALDI_ASSERT(*it == num_classes); + } + std::vector tmp_ali(out_ali.begin() + max_length, out_ali.begin() + max_length + ali.size()); + KALDI_ASSERT(tmp_ali == ali); + + for (std::vector::const_iterator it = out_ali.begin() + max_length + ali.size(); + it != out_ali.end(); ++it) { + KALDI_ASSERT(*it == num_classes); + } + } +} + +void TestRemoveSegments() { + std::vector ali; + int32 max_length = 1000, num_classes = 10; + int32 num_segments = GenerateRandomAlignment(max_length, num_classes, &ali); + + Segmentation seg; + KALDI_ASSERT(num_segments == seg.InsertFromAlignment(ali, 0)); + + for (int32 i = 0; i < num_classes; i++) { + Segmentation out_seg(seg); + out_seg.RemoveSegments(i); + std::vector out_ali; + out_seg.ConvertToAlignment(&out_ali, i, ali.size()); + KALDI_ASSERT(ali == out_ali); + } + + { + std::vector classes; + for (int32 i = 0; i < 3; i++) + classes.push_back(RandInt(0, num_classes - 1)); + std::sort(classes.begin(), classes.end()); + + Segmentation out_seg1(seg); + out_seg1.RemoveSegments(classes); + + Segmentation out_seg2(seg); + for (std::vector::const_iterator it = classes.begin(); + it != classes.end(); ++it) + out_seg2.RemoveSegments(*it); + + std::vector out_ali1, out_ali2; + out_seg1.ConvertToAlignment(&out_ali1); + out_seg2.ConvertToAlignment(&out_ali2); + + KALDI_ASSERT(out_ali1 == out_ali2); + } +} + +void TestIntersectSegments() { + int32 max_length = 100, num_classes = 3; + + std::vector primary_ali; + GenerateRandomAlignment(max_length, num_classes, &primary_ali); + + std::vector secondary_ali; + GenerateRandomAlignment(max_length, num_classes, &secondary_ali); + + Segmentation primary_seg; + primary_seg.InsertFromAlignment(primary_ali); + + Segmentation secondary_seg; + secondary_seg.InsertFromAlignment(secondary_ali); + + { + Segmentation out_seg; + primary_seg.IntersectSegments(secondary_seg, &out_seg, num_classes); + + std::vector out_ali; + out_seg.ConvertToAlignment(&out_ali); + + std::vector oracle_ali(primary_ali.size()); + + for (size_t i = 0; i < oracle_ali.size(); i++) { + int32 p = (i < primary_ali.size()) ? primary_ali[i] : -1; + int32 s = (i < secondary_ali.size()) ? secondary_ali[i] : -2; + + oracle_ali[i] = (p == s) ? p : num_classes; + } + + KALDI_ASSERT(oracle_ali == out_ali); + } + + { + Segmentation out_seg; + primary_seg.IntersectSegments(secondary_seg, &out_seg); + + std::vector out_ali; + out_seg.ConvertToAlignment(&out_ali, num_classes); + + std::vector oracle_ali(out_ali.size()); + + for (size_t i = 0; i < oracle_ali.size(); i++) { + int32 p = (i < primary_ali.size()) ? primary_ali[i] : -1; + int32 s = (i < secondary_ali.size()) ? secondary_ali[i] : -2; + + oracle_ali[i] = (p == s) ? p : num_classes; + } + + KALDI_ASSERT(oracle_ali == out_ali); + } + +} + +void UnitTestSegmentation() { + TestConversionToAlignment(); + TestRemoveSegments(); + TestIntersectSegments(); +} + +} // namespace segmenter +} // namespace kaldi + +int main() { + using namespace kaldi; + using namespace kaldi::segmenter; + + for (int32 i = 0; i < 10; i++) + UnitTestSegmentation(); + return 0; +} + + + diff --git a/src/segmenter/segmentation-utils.cc b/src/segmenter/segmentation-utils.cc new file mode 100644 index 00000000000..3adc178d66d --- /dev/null +++ b/src/segmenter/segmentation-utils.cc @@ -0,0 +1,743 @@ +// segmenter/segmentation-utils.cc + +// Copyright 2016 Vimal Manohar + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "segmenter/segmentation-utils.h" + +namespace kaldi { +namespace segmenter { + +void MergeLabels(const std::vector &merge_labels, + int32 dest_label, + Segmentation *segmentation) { + KALDI_ASSERT(segmentation); + + // Check if sorted and unique + KALDI_ASSERT(std::adjacent_find(merge_labels.begin(), + merge_labels.end(), std::greater()) + == merge_labels.end()); + + for (SegmentList::iterator it = segmentation->Begin(); + it != segmentation->End(); ++it) { + if (std::binary_search(merge_labels.begin(), merge_labels.end(), + it->Label())) { + it->SetLabel(dest_label); + } + } +#ifdef KALDI_PARANOID + segmentation->Check(); +#endif +} + +void RelabelSegmentsUsingMap(const unordered_map &label_map, + Segmentation *segmentation) { + for (SegmentList::iterator it = segmentation->Begin(); + it != segmentation->End(); ++it) { + unordered_map::const_iterator map_it = label_map.find( + it->Label()); + if (map_it == label_map.end()) + KALDI_ERR << "Could not find label " << it->Label() << " in label map."; + + it->SetLabel(map_it->second); + } +} + +void RelabelAllSegments(int32 label, Segmentation *segmentation) { + for (SegmentList::iterator it = segmentation->Begin(); + it != segmentation->End(); ++it) + it->SetLabel(label); +} + +void ScaleFrameShift(BaseFloat factor, Segmentation *segmentation) { + for (SegmentList::iterator it = segmentation->Begin(); + it != segmentation->End(); ++it) { + it->start_frame *= factor; + it->end_frame *= factor; + } +} + +void RemoveSegments(int32 label, Segmentation *segmentation) { + for (SegmentList::iterator it = segmentation->Begin(); + it != segmentation->End(); ) { + if (it->Label() == label) { + it = segmentation->Erase(it); + } else { + ++it; + } + } +#ifdef KALDI_PARANOID + segmentation->Check(); +#endif +} + +void RemoveSegments(const std::vector &labels, + Segmentation *segmentation) { + // Check if sorted and unique + KALDI_ASSERT(std::adjacent_find(labels.begin(), + labels.end(), std::greater()) == labels.end()); + + for (SegmentList::iterator it = segmentation->Begin(); + it != segmentation->End(); ) { + if (std::binary_search(labels.begin(), labels.end(), it->Label())) { + it = segmentation->Erase(it); + } else { + ++it; + } + } +#ifdef KALDI_PARANOID + segmentation->Check(); +#endif +} + +// Opposite of RemoveSegments() +void KeepSegments(int32 label, Segmentation *segmentation) { + for (SegmentList::iterator it = segmentation->Begin(); + it != segmentation->End(); ) { + if (it->Label() != label) { + it = segmentation->Erase(it); + } else { + ++it; + } + } +#ifdef KALDI_PARANOID + segmentation->Check(); +#endif +} + +// TODO(Vimal): Write test function for this. +void SplitInputSegmentation(const Segmentation &in_segmentation, + int32 segment_length, + Segmentation *out_segmentation) { + out_segmentation->Clear(); + for (SegmentList::const_iterator it = in_segmentation.Begin(); + it != in_segmentation.End(); ++it) { + int32 length = it->Length(); + + // Since ceil is used, this results in all pieces to be smaller than + // segment_length rather than being larger. + int32 num_chunks = std::ceil(static_cast(length) + / segment_length); + int32 actual_segment_length = static_cast(length) / num_chunks; + + int32 start_frame = it->start_frame; + for (int32 j = 0; j < num_chunks; j++) { + int32 end_frame = std::min(start_frame + actual_segment_length - 1, + it->end_frame); + out_segmentation->EmplaceBack(start_frame, end_frame, it->Label()); + start_frame = end_frame + 1; + } + } +#ifdef KALDI_PARANOID + out_segmentation->Check(); +#endif +} + +// TODO(Vimal): Write test function for this. +void SplitSegments(int32 segment_length, int32 min_remainder, + int32 overlap_length, int32 segment_label, + Segmentation *segmentation) { + KALDI_ASSERT(segmentation); + KALDI_ASSERT(segment_length > 0 && min_remainder > 0); + KALDI_ASSERT(overlap_length >= 0); + + KALDI_ASSERT(overlap_length < segment_length); + for (SegmentList::iterator it = segmentation->Begin(); + it != segmentation->End(); ++it) { + if (segment_label != -1 && it->Label() != segment_label) continue; + + int32 start_frame = it->start_frame; + int32 length = it->Length(); + + if (length > segment_length + min_remainder) { + // Split segment + // To show what this is doing, consider the following example, where it is + // currently pointing to B. + // A <--> B <--> C + + // Modify the start_frame of the current frame. This prepares the current + // segment to be used as the "next segment" when we move the iterator in + // the next statement. + // In the example, the start_frame for B has just been modified. + it->start_frame = start_frame + segment_length - overlap_length; + + // Create a new segment and add it to the where the current iterator is. + // The statement below results in this: + // A <--> B1 <--> B <--> C + // with the iterator it pointing at B1. So when the iterator is + // incremented in the for loop, it will point to B again, but whose + // start_frame had been modified. + it = segmentation->Emplace(it, start_frame, + start_frame + segment_length - 1, + it->Label()); + } + } +#ifdef KALDI_PARANOID + segmentation->Check(); +#endif +} + +// TODO(Vimal): Write test code for this +void SplitSegmentsUsingAlignment(int32 segment_length, + int32 segment_label, + const std::vector &ali, + int32 ali_label, + int32 min_silence_length, + Segmentation *segmentation) { + KALDI_ASSERT(segmentation); + KALDI_ASSERT(segment_length > 0); + + for (SegmentList::iterator it = segmentation->Begin(); + it != segmentation->End();) { + // Safety check. In practice, should never fail. + KALDI_ASSERT(segmentation->Dim() <= ali.size()); + + if (segment_label != -1 && it->Label() != segment_label) { + ++it; + continue; + } + + int32 start_frame = it->start_frame; + int32 length = it->Length(); + int32 label = it->Label(); + + if (length <= segment_length) { + ++it; + continue; + } + + // Split segment + // To show what this is doing, consider the following example, where it is + // currently pointing to B. + // A <--> B <--> C + + Segmentation ali_segmentation; + InsertFromAlignment(ali, start_frame, + start_frame + length, + 0, &ali_segmentation, NULL); + KeepSegments(ali_label, &ali_segmentation); + MergeAdjacentSegments(0, &ali_segmentation); + + // Get largest alignment chunk where label == ali_label + SegmentList::iterator s_it = ali_segmentation.MaxElement(); + + if (s_it == ali_segmentation.End() || s_it->Length() < min_silence_length) { + ++it; + continue; + } + + KALDI_ASSERT(s_it->start_frame >= start_frame); + KALDI_ASSERT(s_it->end_frame <= start_frame + length); + + // Modify the start_frame of the current frame. This prepares the current + // segment to be used as the "next segment" when we move the iterator in + // the next statement. + // In the example, the start_frame for B has just been modified. + int32 end_frame; + if (s_it->Length() > 1) { + end_frame = s_it->start_frame + s_it->Length() / 2 - 2; + it->start_frame = end_frame + 2; + } else { + end_frame = s_it->start_frame - 1; + it->start_frame = s_it->end_frame + 1; + } + + // end_frame is within this current segment + KALDI_ASSERT(end_frame < start_frame + length); + // The first new segment length is smaller than the old segment length + KALDI_ASSERT(end_frame - start_frame + 1 < length); + + // The second new segment length is smaller than the old segment length + KALDI_ASSERT(it->end_frame - end_frame - 1 < length); + + if (it->Length() < 0) { + // This is possible when the beginning of the segment is silence + it = segmentation->Erase(it); + } + + // Create a new segment and add it to the where the current iterator is. + // The statement below results in this: + // A <--> B1 <--> B <--> C + // with the iterator it pointing at B1. + if (end_frame >= start_frame) { + it = segmentation->Emplace(it, start_frame, end_frame, label); + } + } +#ifdef KALDI_PARANOID + segmentation->Check(); +#endif +} + +// TODO(Vimal): Write test code for this +void IntersectSegmentationAndAlignment(const Segmentation &in_segmentation, + const std::vector &alignment, + int32 ali_label, + int32 min_align_chunk_length, + Segmentation *out_segmentation) { + KALDI_ASSERT(out_segmentation); + + for (SegmentList::const_iterator it = in_segmentation.Begin(); + it != in_segmentation.End(); ++it) { + Segmentation filter_segmentation; + InsertFromAlignment(alignment, it->start_frame, + std::min(it->end_frame + 1, + static_cast(alignment.size())), + 0, &filter_segmentation, NULL); + + for (SegmentList::const_iterator f_it = filter_segmentation.Begin(); + f_it != filter_segmentation.End(); ++f_it) { + if (f_it->Length() < min_align_chunk_length) continue; + if (ali_label != -1 && f_it->Label() != ali_label) continue; + out_segmentation->EmplaceBack(f_it->start_frame, f_it->end_frame, + it->Label()); + } + } +} + +void SubSegmentUsingNonOverlappingSegments( + const Segmentation &primary_segmentation, + const Segmentation &secondary_segmentation, int32 secondary_label, + int32 subsegment_label, int32 unmatched_label, + Segmentation *out_segmentation) { + KALDI_ASSERT(out_segmentation); + KALDI_ASSERT(secondary_segmentation.Dim() > 0); + + std::vector alignment; + ConvertToAlignment(secondary_segmentation, -1, -1, 0, &alignment); + + for (SegmentList::const_iterator it = primary_segmentation.Begin(); + it != primary_segmentation.End(); ++it) { + if (it->end_frame >= alignment.size()) { + alignment.resize(it->end_frame + 1, -1); + } + Segmentation filter_segmentation; + InsertFromAlignment(alignment, it->start_frame, it->end_frame + 1, + 0, &filter_segmentation, NULL); + + for (SegmentList::const_iterator f_it = filter_segmentation.Begin(); + f_it != filter_segmentation.End(); ++f_it) { + int32 label = (unmatched_label > 0 ? unmatched_label : it->Label()); + if (f_it->Label() == secondary_label) { + if (subsegment_label >= 0) { + label = subsegment_label; + } else { + label = f_it->Label(); + } + } + out_segmentation->EmplaceBack(f_it->start_frame, f_it->end_frame, + label); + } + } +} + +// TODO(Vimal): Write test code for this +void MergeAdjacentSegments(int32 max_intersegment_length, + Segmentation *segmentation) { + SegmentList::iterator it = segmentation->Begin(), + prev_it = segmentation->Begin(); + + while (it != segmentation->End()) { + KALDI_ASSERT(it->start_frame >= prev_it->start_frame); + + if (it != segmentation->Begin() && + it->Label() == prev_it->Label() && + prev_it->end_frame + max_intersegment_length + 1 >= it->start_frame) { + // merge segments + if (prev_it->end_frame < it->end_frame) { + // If the previous segment end before the current segment, then + // extend the previous segment to the end_frame of the current + // segment and remove the current segment. + prev_it->end_frame = it->end_frame; + } // else simply remove the current segment. + it = segmentation->Erase(it); + } else { + // no merging of segments + prev_it = it; + ++it; + } + } + +#ifdef KALDI_PARANOID + segmentation->Check(); +#endif +} + +void PadSegments(int32 label, int32 length, Segmentation *segmentation) { + KALDI_ASSERT(segmentation); + for (SegmentList::iterator it = segmentation->Begin(); + it != segmentation->End(); ++it) { + if (it->Label() != label) continue; + + it->start_frame -= length; + it->end_frame += length; + + if (it->start_frame < 0) it->start_frame = 0; + } +} + +void WidenSegments(int32 label, int32 length, Segmentation *segmentation) { + for (SegmentList::iterator it = segmentation->Begin(); + it != segmentation->End(); ++it) { + if (it->Label() == label) { + if (it != segmentation->Begin()) { + // it is not the beginning of the segmentation, so we can widen it on + // the start_frame side + SegmentList::iterator prev_it = it; + --prev_it; + it->start_frame -= length; + if (prev_it->Label() == label && it->start_frame < prev_it->end_frame) { + // After widening this segment, it overlaps the previous segment that + // also has the same class_id. Then turn this segment into a composite + // one + it->start_frame = prev_it->start_frame; + // and remove the previous segment from the list. + segmentation->Erase(prev_it); + } else if (prev_it->Label() != label && + it->start_frame < prev_it->end_frame) { + // Previous segment is not the same class_id, so we cannot turn this + // into a composite segment. + if (it->start_frame <= prev_it->start_frame) { + // The extended segment absorbs the previous segment into it + // So remove the previous segment + segmentation->Erase(prev_it); + } else { + // The extended segment reduces the length of the previous + // segment. But does not completely overlap it. + prev_it->end_frame -= length; + if (prev_it->end_frame < prev_it->start_frame) + segmentation->Erase(prev_it); + } + } + if (it->start_frame < 0) it->start_frame = 0; + } else { + it->start_frame -= length; + if (it->start_frame < 0) it->start_frame = 0; + } + + SegmentList::iterator next_it = it; + ++next_it; + + if (next_it != segmentation->End()) + // We do not know the length of the file. + // So we don't want to extend the last one. + it->end_frame += length; // Line (1) + } else { // if (it->Label() != label) + if (it != segmentation->Begin()) { + SegmentList::iterator prev_it = it; + --prev_it; + if (prev_it->end_frame >= it->end_frame) { + // The extended previous segment in Line (1) completely + // overlaps the current segment. So remove the current segment. + it = segmentation->Erase(it); + // So that we can increment in the for loop + --it; // TODO(Vimal): This is buggy. + } else if (prev_it->end_frame >= it->start_frame) { + // The extended previous segment in Line (1) reduces the length of + // this segment. + it->start_frame = prev_it->end_frame + 1; + } + } + } + } +} + +void ShrinkSegments(int32 label, int32 length, Segmentation *segmentation) { + for (SegmentList::iterator it = segmentation->Begin(); + it != segmentation->End(); ) { + if (it->Label() == label) { + if (it->Length() <= 2 * length) { + it = segmentation->Erase(it); + } else { + it->start_frame += length; + it->end_frame -= length; + ++it; + } + } else { + ++it; + } + } + +#ifdef KALDI_PARANOID + segmentation->Check(); +#endif +} + +void BlendShortSegmentsWithNeighbors(int32 label, int32 max_length, + int32 max_intersegment_length, + Segmentation *segmentation) { + for (SegmentList::iterator it = segmentation->Begin(); + it != segmentation->End(); ) { + if (it == segmentation->Begin()) { + // Can't blend the first segment + ++it; + continue; + } + + SegmentList::iterator next_it = it; + ++next_it; + + if (next_it == segmentation->End()) // End of segmentation + break; + + SegmentList::iterator prev_it = it; + --prev_it; + + // If the previous and current segments have different labels, + // then ensure that they are not overlapping + KALDI_ASSERT(it->start_frame >= prev_it->start_frame && + (prev_it->Label() == it->Label() || + prev_it->end_frame < it->start_frame)); + + KALDI_ASSERT(next_it->start_frame >= it->start_frame && + (it->Label() == next_it->Label() || + it->end_frame < next_it->start_frame)); + + if (next_it->Label() != prev_it->Label() || it->Label() != label || + it->Length() >= max_length || + next_it->start_frame - it->end_frame - 1 > max_intersegment_length || + it->start_frame - prev_it->end_frame - 1 > max_intersegment_length) { + ++it; + continue; + } + + prev_it->end_frame = next_it->end_frame; + segmentation->Erase(it); + it = segmentation->Erase(next_it); + } +#ifdef KALDI_PARANOID + segmentation->Check(); +#endif +} + +bool ConvertToAlignment(const Segmentation &segmentation, + int32 default_label, int32 length, + int32 tolerance, + std::vector *alignment) { + KALDI_ASSERT(alignment); + alignment->clear(); + + if (length != -1) { + KALDI_ASSERT(length >= 0); + alignment->resize(length, default_label); + } + + SegmentList::const_iterator it = segmentation.Begin(); + for (; it != segmentation.End(); ++it) { + if (length != -1 && it->end_frame >= length + tolerance) { + KALDI_WARN << "End frame (" << it->end_frame << ") " + << ">= length (" << length + << ") + tolerance (" << tolerance << ")." + << "Conversion failed."; + return false; + } + + int32 end_frame = it->end_frame; + if (length == -1) { + alignment->resize(it->end_frame + 1, default_label); + } else { + if (it->end_frame >= length) + end_frame = length - 1; + } + + KALDI_ASSERT(end_frame < alignment->size()); + for (int32 i = it->start_frame; i <= end_frame; i++) { + (*alignment)[i] = it->Label(); + } + } + return true; +} + +int32 InsertFromAlignment(const std::vector &alignment, + int32 start, int32 end, + int32 start_time_offset, + Segmentation *segmentation, + std::vector *frame_counts_per_class) { + KALDI_ASSERT(segmentation); + + if (end <= start) return 0; // nothing to insert + + // Correct boundaries + if (end > alignment.size()) end = alignment.size(); + if (start < 0) start = 0; + + KALDI_ASSERT(end > start); // This is possible if end was originally + // greater than alignment.size(). + // The user must resize alignment appropriately + // before passing to this function. + + int32 num_segments = 0; + int32 state = -100, start_frame = -1; + for (int32 i = start; i < end; i++) { + KALDI_ASSERT(alignment[i] >= -1); + if (alignment[i] != state) { + // Change of state i.e. a different class id. + // So the previous segment has ended. + if (start_frame != -1) { + // start_frame == -1 in the beginning of the alignment. That is just + // initialization step and hence no creation of segment. + segmentation->EmplaceBack(start_frame + start_time_offset, + i-1 + start_time_offset, state); + num_segments++; + + if (frame_counts_per_class && state > 0) { + if (frame_counts_per_class->size() <= state) { + frame_counts_per_class->resize(state + 1, 0); + } + (*frame_counts_per_class)[state] += i - start_frame; + } + } + start_frame = i; + state = alignment[i]; + } + } + + KALDI_ASSERT(state >= -1 && start_frame >= 0 && start_frame < end); + segmentation->EmplaceBack(start_frame + start_time_offset, + end-1 + start_time_offset, state); + num_segments++; + if (frame_counts_per_class && state > 0) { + if (frame_counts_per_class->size() <= state) { + frame_counts_per_class->resize(state + 1, 0); + } + (*frame_counts_per_class)[state] += end - start_frame; + } + +#ifdef KALDI_PARANOID + segmentation->Check(); +#endif + + return num_segments; +} + +int32 InsertFromSegmentation( + const Segmentation &in_segmentation, int32 start_time_offset, + bool sort, + Segmentation *out_segmentation, + std::vector *frame_counts_per_class) { + KALDI_ASSERT(out_segmentation); + + if (in_segmentation.Dim() == 0) return 0; // nothing to insert + + int32 num_segments = 0; + + for (SegmentList::const_iterator it = in_segmentation.Begin(); + it != in_segmentation.End(); ++it) { + out_segmentation->EmplaceBack(it->start_frame + start_time_offset, + it->end_frame + start_time_offset, + it->Label()); + num_segments++; + if (frame_counts_per_class) { + if (frame_counts_per_class->size() <= it->Label()) { + frame_counts_per_class->resize(it->Label() + 1, 0); + } + (*frame_counts_per_class)[it->Label()] += it->Length(); + } + } + + if (sort) out_segmentation->Sort(); + +#ifdef KALDI_PARANOID + out_segmentation->Check(); +#endif + + return num_segments; +} + +void ExtendSegmentation(const Segmentation &in_segmentation, + bool sort, + Segmentation *segmentation) { + InsertFromSegmentation(in_segmentation, 0, sort, segmentation, NULL); +} + +bool GetClassCountsPerFrame( + const Segmentation &segmentation, + int32 length, int32 tolerance, + std::vector > *class_counts_per_frame) { + KALDI_ASSERT(class_counts_per_frame); + + if (length != -1) { + KALDI_ASSERT(length >= 0); + class_counts_per_frame->resize(length, std::map()); + } + + SegmentList::const_iterator it = segmentation.Begin(); + for (; it != segmentation.End(); ++it) { + if (length != -1 && it->end_frame >= length + tolerance) { + KALDI_WARN << "End frame (" << it->end_frame << ") " + << ">= length + tolerance (" << length + tolerance << ")." + << "Conversion failed."; + return false; + } + + int32 end_frame = it->end_frame; + if (length == -1) { + class_counts_per_frame->resize(it->end_frame + 1, + std::map()); + } else { + if (it->end_frame >= length) + end_frame = length - 1; + } + + KALDI_ASSERT(end_frame < class_counts_per_frame->size()); + for (int32 i = it->start_frame; i <= end_frame; i++) { + std::map &this_class_counts = (*class_counts_per_frame)[i]; + std::map::iterator c_it = this_class_counts.lower_bound( + it->Label()); + if (c_it == this_class_counts.end() || it->Label() < c_it->first) { + this_class_counts.insert(c_it, std::make_pair(it->Label(), 1)); + } else { + c_it->second++; + } + } + } + + return true; +} + +bool IsNonOverlapping(const Segmentation &segmentation) { + std::vector vec; + for (SegmentList::const_iterator it = segmentation.Begin(); + it != segmentation.End(); ++it) { + vec.resize(it->end_frame + 1, false); + for (int32 i = it->start_frame; i <= it->end_frame; i++) { + if (vec[i]) return false; + vec[i] = true; + } + } + return true; +} + +void Sort(Segmentation *segmentation) { + segmentation->Sort(); +} + +void TruncateToLength(int32 length, Segmentation *segmentation) { + for (SegmentList::iterator it = segmentation->Begin(); + it != segmentation->End(); ) { + if (it->start_frame >= length) { + it = segmentation->Erase(it); + continue; + } + + if (it->end_frame >= length) + it->end_frame = length - 1; + ++it; + } +} + +} // end namespace segmenter +} // end namespace kaldi diff --git a/src/segmenter/segmentation-utils.h b/src/segmenter/segmentation-utils.h new file mode 100644 index 00000000000..9401722ccb7 --- /dev/null +++ b/src/segmenter/segmentation-utils.h @@ -0,0 +1,337 @@ +// segmenter/segmentation-utils.h + +// Copyright 2016 Vimal Manohar + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#ifndef KALDI_SEGMENTER_SEGMENTATION_UTILS_H_ +#define KALDI_SEGMENTER_SEGMENTATION_UTILS_H_ + +#include "segmenter/segmentation.h" + +namespace kaldi { +namespace segmenter { + +/** + * This function is very straight forward. It just merges the labels in + * merge_labels to the class-id dest_label. This means any segment that + * originally had the class-id as any of the labels in merge_labels would end + * up having the class-id dest_label. + **/ +void MergeLabels(const std::vector &merge_labels, + int32 dest_label, Segmentation *segmentation); + +// Relabel segments using a map from old to new label. +// If segment label is not found in the map, the function exits with +// an error. +void RelabelSegmentsUsingMap(const unordered_map &label_map, + Segmentation *segmentation); + +// Relabel all segments to class-id label +void RelabelAllSegments(int32 label, Segmentation *segmentation); + +// Scale frame shift by this factor. +// Usually frame length is 0.01 and frame shift 0.015. But sometimes +// the alignments are obtained using a subsampling factor of 3. This +// function can be used to maintain consistency among different +// alignments and segmentations. +void ScaleFrameShift(BaseFloat factor, Segmentation *segmentation); + +/** + * This is very straight forward. It removes all segments of label "label" +**/ +void RemoveSegments(int32 label, Segmentation *segmentation); + +/** + * This is very straight forward. It removes any segment whose label is + * contained in the vector "labels" +**/ +void RemoveSegments(const std::vector &labels, + Segmentation *segmentation); + +// Keep only segments of label "label" +void KeepSegments(int32 label, Segmentation *segmentation); + +/** + * This function splits an input segmentation in_segmentation into pieces of + * approximately segment_length. Each piece is given the same class id as the + * original segment. + * + * The way this function is written is that it first figures out the number of + * pieces that the segment must be broken into. Then it creates that many pieces + * of equal size (actual_segment_length). This mimics some of the approaches + * used at script level +**/ +void SplitInputSegmentation(const Segmentation &in_segmentation, + int32 segment_length, + Segmentation *out_segmentation); + +/** + * This function splits the segments in the the segmentation + * into pieces of segment_length. + * But if the last remaining piece is smaller than min_remainder, then the last + * piece is merged to the piece before it, resulting in a piece that is of + * length < segment_length + min_remainder. + * If overlap_length > 0, then the created pieces overlap by these many frames. + * If segment_label == -1, then all segments are split. + * Otherwise, only the segments with this label are split. + * + * The way this function works it is it looks at the current segment length and + * checks if it is larger than segment_length + min_remainder. If it is larger, + * then it must be split. To do this, it first modifies the start_frame of + * the current frame to start_frame + segment_length - overlap. + * It then creates a new segment of length segment_length from the original + * start_frame to start_frame + segment_length - 1 and adds it just before the + * current segment. So in the next iteration, we would actually be back to the + * same segment, but whose start_frame had just been modified. +**/ +void SplitSegments(int32 segment_length, + int32 min_remainder, int32 overlap_length, + int32 segment_label, + Segmentation *segmentation); + +/** + * Split this segmentation into pieces of size segment_length, + * but only if possible by creating split points at the + * middle of the chunk where alignment == ali_label and + * the chunk is at least min_segment_length frames long + * + * min_remainder, segment_label serve the same purpose as in the + * above SplitSegments function. +**/ +void SplitSegmentsUsingAlignment(int32 segment_length, + int32 segment_label, + const std::vector &alignment, + int32 alignment_label, + int32 min_align_chunk_length, + Segmentation *segmentation); + +/** + * This function is a standard intersection of the set of times represented by + * the segmentation in_segmentation and the set of times of where + * alignment contains ali_label for at least min_align_chunk_length + * consecutive frames +**/ +void IntersectSegmentationAndAlignment(const Segmentation &in_segmentation, + const std::vector &alignment, + int32 ali_label, + int32 min_align_chunk_length, + Segmentation *out_segmentation); + +/** + * This function is a little complicated in what it does. But this is required + * for one of the applications. + * This function creates a new segmentation by sub-segmenting an arbitrary + * "primary_segmentation" and assign new label "subsegment_label" to regions + * where the "primary_segmentation" intersects the non-overlapping + * "secondary_segmentation" segments with label "secondary_label". + * This is similar to the function "IntersectSegments", but instead of keeping + * only the filtered subsegments, all the subsegments are kept, while only + * changing the class_id of the filtered sub-segments. + * The label for the newly created subsegments is determined as follows: + * if secondary segment's label == secondary_label: + * if subsegment_label > 0: + * label = subsegment_label + * else: + * label = secondary_label + * else: + * if unmatched_label > 0: + * label = unmatched_label + * else: + * label = primary_label +**/ +void SubSegmentUsingNonOverlappingSegments( + const Segmentation &primary_segmentation, + const Segmentation &secondary_segmentation, int32 secondary_label, + int32 subsegment_label, int32 unmatched_label, + Segmentation *out_segmentation); + +/** + * This function is used to merge segments next to each other in the SegmentList + * and within a distance of max_intersegment_length frames from each other, + * provided the segments are of the same label. + * This function requires the segmentation to be sorted before passing it. + **/ +void MergeAdjacentSegments(int32 max_intersegment_length, + Segmentation *segmentation); + +/** + * This function is used to pad segments of label "label" by "length" + * frames on either side of the segment. + * This is useful to pad segments of speech. +**/ +void PadSegments(int32 label, int32 length, Segmentation *segmentation); + +/** + * This function is used to widen segments of label "label" by "length" + * frames on either side of the segment. + * This is similar to PadSegments, but while widening, it also reduces the + * length of the segment adjacent to it. + * This may not be required in some applications, but it is ok for speech / + * silence. By this process, we are calling frames within a "length" number of + * frames near the speech segment as speech and hence we reduce the width of the + * silence segment before it. +**/ +void WidenSegments(int32 label, int32 length, Segmentation *segmentation); + +/** + * This function is used to shrink segments of class_id "label" by "length" + * frames on either side of the segment. + * If the whole segment is smaller than 2*length, then the segment is + * removed entirely. +**/ +void ShrinkSegments(int32 label, int32 length, Segmentation *segmentation); + +/** + * This function blends segments of label "label" that are shorter than + * "max_length" frames, provided the segments before and after it are of the + * same label "other_label" and the distance to the neighbor is less than + * "max_intersegment_distance". + * After blending, the three segments have the same label "other_label" and + * hence can be merged into a composite segment. + * An example where this is useful is when there is a short segment of silence + * with speech segments on either sides. Then the short segment of silence is + * removed and called speech instead. The three continguous segments of speech + * are merged into a single composite segment. +**/ +void BlendShortSegmentsWithNeighbors(int32 label, int32 max_length, + int32 max_intersegment_distance, + Segmentation *segmentation); + +/** + * This function is used to convert the segmentation into frame-level alignment + * with the label for each frame begin the class_id of segment the frame belongs + * to. + * The arguments are used to provided extended functionality that are required + * for most cases. + * default_label : the label that is used as filler in regions where the frame + * is not in any of the segments. In most applications, certain + * segments are removed, such as the ones that are silence. Then + * the segments would not span the entire duration of the file. + * e.g. + * 10 35 1 + * 41 190 2 + * ... + * Here there is no segment from 36-40. These frames are + * filled with default_label. + * length : the number of frames required in the alignment. + * If set to -1, then this length is ignored. + * In most applications, the length of the alignment required is + * known. Usually it must match the length of the features + * (obtained using feat-to-len). Then the alignment is resized + * to this length and filled with default_label. The segments + * are then read and the frames corresponding to the segments + * are relabeled with the class_id of the respective segments. + * tolerance : the tolerance in number of frames that we allow for the + * frame index corresponding to the end_frame of the last + * segment. Applicable when length != -1. + * Since, we use 25 ms widows with 10 ms frame shift, + * it is possible that the features length is 2 frames less than + * the end of the last segment. So the user can set the + * tolerance to 2 in order to avoid returning with error in this + * function. + * Function returns true is successful. +**/ +bool ConvertToAlignment(const Segmentation &segmentation, + int32 default_label, int32 length, + int32 tolerance, + std::vector *alignment); + +/** + * Insert segments created from alignment starting from frame index "start" + * until and excluding frame index "end". + * The inserted segments are shifted by "start_time_offset". + * "start_time_offset" is useful when the "alignment" is per-utterance, in which + * case the start time of the utterance can be provided as the + * "start_time_offset" + * The function returns the number of segments created. + * If "frame_counts_per_class" is provided, then the number of frames per class + * is accumulated there. +**/ +int32 InsertFromAlignment(const std::vector &alignment, + int32 start, int32 end, + int32 start_time_offset, + Segmentation *segmentation, + std::vector *frame_counts_per_class = NULL); + +/** + * Insert segments from in_segmentation, but shift them by + * start_time offset. + * If sort is true, then the final segmentation is sorted. + * It is useful in some applications to set sort to false. + * Returns number of segments inserted. +**/ +int32 InsertFromSegmentation(const Segmentation &in_segmentation, + int32 start_time_offset, bool sort, + Segmentation *segmentation, + std::vector *frame_counts_per_class = NULL); + +/** + * Extend a segmentation by adding another one. + * If "sort" is set to true, then resultant segmentation would be sorted. + * If its known that the other segmentation must all be after this segmentation, + * then the user may set "sort" false. +**/ +void ExtendSegmentation(const Segmentation &in_segmentation, bool sort, + Segmentation *segmentation); + +/** + * This function is used to get per-frame count of number of classes. + * The output is in the format of a vector of maps. + * class_counts_per_frame: A pointer to a vector of maps use to get the output. + * The size of the vector is the number of frames. + * For each frame, there is a map from the "class_id" + * to the number of segments where the label the + * corresponding "class_id". + * The size of the map gives the number of unique + * labels in this frame e.g. number of speakers. + * The count for each "class_id" is the number + * of segments with that "class_id" at that frame. + * length : the number of frames required in the output. + * In most applications, this length is known. + * Usually it must match the length of the features (obtained + * using feat-to-len). Then the output is resized to this + * length. The map is empty for frames where no segments are + * seen. + * tolerance : the tolerance in number of frames that we allow for the + * frame index corresponding to the end_frame of the last + * segment. Since, we use 25 ms widows with 10 ms frame shift, + * it is possible that the features length is 2 frames less than + * the end of the last segment. So the user can set the + * tolerance to 2 in order to avoid returning an error in this + * function. + * Function returns true is successful. +**/ +bool GetClassCountsPerFrame( + const Segmentation &segmentation, + int32 length, int32 tolerance, + std::vector > *class_counts_per_frame); + +// Checks if segmentation is non-overlapping +bool IsNonOverlapping(const Segmentation &segmentation); + +// Sorts segments on start frame. +void Sort(Segmentation *segmentation); + +// Truncate segmentation to "length". +// Removes any segments with "start_time" >= "length" +// and truncates any segments with "end_time" >= "length" +void TruncateToLength(int32 length, Segmentation *segmentation); + +} // end namespace segmenter +} // end namespace kaldi + +#endif // KALDI_SEGMENTER_SEGMENTATION_UTILS_H_ diff --git a/src/segmenter/segmentation.cc b/src/segmenter/segmentation.cc new file mode 100644 index 00000000000..fb83ed5476b --- /dev/null +++ b/src/segmenter/segmentation.cc @@ -0,0 +1,201 @@ +// segmenter/segmentation.cc + +// Copyright 2016 Vimal Manohar + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "segmenter/segmentation.h" +#include + +namespace kaldi { +namespace segmenter { + +void Segmentation::PushBack(const Segment &seg) { + dim_++; + segments_.push_back(seg); +} + +SegmentList::iterator Segmentation::Insert(SegmentList::iterator it, + const Segment &seg) { + dim_++; + return segments_.insert(it, seg); +} + +void Segmentation::EmplaceBack(int32 start_frame, int32 end_frame, + int32 class_id) { + dim_++; + Segment seg(start_frame, end_frame, class_id); + segments_.push_back(seg); +} + +SegmentList::iterator Segmentation::Emplace(SegmentList::iterator it, + int32 start_frame, int32 end_frame, + int32 class_id) { + dim_++; + Segment seg(start_frame, end_frame, class_id); + return segments_.insert(it, seg); +} + +SegmentList::iterator Segmentation::Erase(SegmentList::iterator it) { + dim_--; + return segments_.erase(it); +} + +void Segmentation::Clear() { + segments_.clear(); + dim_ = 0; +} + +void Segmentation::Read(std::istream &is, bool binary) { + Clear(); + + if (binary) { + int32 sz = is.peek(); + if (sz == Segment::SizeInBytes()) { + is.get(); + } else { + KALDI_ERR << "Segmentation::Read: expected to see Segment of size " + << Segment::SizeInBytes() << ", saw instead " << sz + << ", at file position " << is.tellg(); + } + + int32 segmentssz; + is.read(reinterpret_cast(&segmentssz), sizeof(segmentssz)); + if (is.fail() || segmentssz < 0) + KALDI_ERR << "Segmentation::Read: read failure at file position " + << is.tellg(); + + for (int32 i = 0; i < segmentssz; i++) { + Segment seg; + seg.Read(is, binary); + segments_.push_back(seg); + } + dim_ = segmentssz; + } else { + if (int c = is.peek() != static_cast('[')) { + KALDI_ERR << "Segmentation::Read: expected to see [, saw " + << static_cast(c) << ", at file position " << is.tellg(); + } + is.get(); // consume the '[' + is >> std::ws; + while (is.peek() != static_cast(']')) { + KALDI_ASSERT(!is.eof()); + Segment seg; + seg.Read(is, binary); + segments_.push_back(seg); + dim_++; + is >> std::ws; + } + is.get(); + KALDI_ASSERT(!is.eof()); + } +#ifdef KALDI_PARANOID + Check(); +#endif +} + +void Segmentation::Write(std::ostream &os, bool binary) const { +#ifdef KALDI_PARANOID + Check(); +#endif + + SegmentList::const_iterator it = Begin(); + if (binary) { + char sz = Segment::SizeInBytes(); + os.write(&sz, 1); + + int32 segmentssz = static_cast(Dim()); + KALDI_ASSERT((size_t)segmentssz == Dim()); + + os.write(reinterpret_cast(&segmentssz), sizeof(segmentssz)); + + for (; it != End(); ++it) { + it->Write(os, binary); + } + } else { + os << "[ "; + for (; it != End(); ++it) { + it->Write(os, binary); + os << std::endl; + } + os << "]" << std::endl; + } +} + +void Segmentation::Check() const { + int32 dim = 0; + for (SegmentList::const_iterator it = Begin(); it != End(); ++it, dim++) { + KALDI_ASSERT(it->start_frame >= 0); + KALDI_ASSERT(it->end_frame >= 0); + KALDI_ASSERT(it->Label() >= 0); + } + KALDI_ASSERT(dim == dim_); +} + +void Segmentation::Sort() { + segments_.sort(SegmentComparator()); +} + +void Segmentation::SortByLength() { + segments_.sort(SegmentLengthComparator()); +} + +SegmentList::iterator Segmentation::MinElement() { + return std::min_element(segments_.begin(), segments_.end(), + SegmentLengthComparator()); +} + +SegmentList::iterator Segmentation::MaxElement() { + return std::max_element(segments_.begin(), segments_.end(), + SegmentLengthComparator()); +} + +Segmentation::Segmentation() { + Clear(); +} + + +void Segmentation::GenRandomSegmentation(int32 max_length, + int32 max_segment_length, + int32 num_classes) { + Clear(); + int32 st = 0; + int32 end = 0; + + while (st > max_length) { + int32 segment_length = RandInt(0, max_segment_length); + + end = st + segment_length - 1; + + // Choose random class id + int32 k = RandInt(-1, num_classes - 1); + + if (k >= 0) { + Segment seg(st, end, k); + segments_.push_back(seg); + dim_++; + } + + // Choose random shift i.e. the distance between two adjacent segments + int32 shift = RandInt(0, max_segment_length); + st = end + shift; + } + + Check(); +} + +} // namespace segmenter +} // namespace kaldi diff --git a/src/segmenter/segmentation.h b/src/segmenter/segmentation.h new file mode 100644 index 00000000000..aa408374751 --- /dev/null +++ b/src/segmenter/segmentation.h @@ -0,0 +1,144 @@ +// segmenter/segmentation.h + +// Copyright 2016 Vimal Manohar + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#ifndef KALDI_SEGMENTER_SEGMENTATION_H_ +#define KALDI_SEGMENTER_SEGMENTATION_H_ + +#include +#include "base/kaldi-common.h" +#include "matrix/kaldi-matrix.h" +#include "util/kaldi-table.h" +#include "segmenter/segment.h" + +namespace kaldi { +namespace segmenter { + +// Segments are stored as a doubly-linked-list. This could be changed later +// if needed. Hence defining a typedef SegmentList. +typedef std::list SegmentList; + +// Declare class +class SegmentationPostProcessor; + +/** + * The main class to store segmentation and do operations on it. The segments + * are stored in the structure SegmentList, which is currently a doubly-linked + * list. + * See the .cc file for details of implementation of the different functions. + * This file gives only a small description of the functions. +**/ + +class Segmentation { + public: + // Inserts the segment at the back of the list. + void PushBack(const Segment &seg); + + // Inserts the segment before the segment at the position specified by the + // iterator "it". + SegmentList::iterator Insert(SegmentList::iterator it, + const Segment &seg); + + // The following function is a wrapper to the + // emplace_back functionality of a STL list of Segments + // and inserts a new segment to the back of the list. + void EmplaceBack(int32 start_frame, int32 end_frame, int32 class_id); + + // The following function is a wrapper to the + // emplace functionality of a STL list of segments + // and inserts a segment at the position specified by the iterator "it". + // Returns an iterator to the inserted segment. + SegmentList::iterator Emplace(SegmentList::iterator it, + int32 start_frame, int32 end_frame, + int32 class_id); + + // Call erase operation on the SegmentList and returns the iterator pointing + // to the next segment in the SegmentList and also decrements dim_. + SegmentList::iterator Erase(SegmentList::iterator it); + + // Reset segmentation i.e. clear all values + void Clear(); + + // Read segmentation object from input stream + void Read(std::istream &is, bool binary); + + // Write segmentation object to output stream + void Write(std::ostream &os, bool binary) const; + + // Check if all segments have class_id >=0 and if dim_ matches the number of + // segments. + void Check() const; + + // Sort the segments on the start_frame + void Sort(); + + // Sort the segments on the length + void SortByLength(); + + // Returns an iterator to the smallest segment akin to std::min_element + SegmentList::iterator MinElement(); + + // Returns an iterator to the largest segment akin to std::max_element + SegmentList::iterator MaxElement(); + + // Generate a random segmentation for debugging purposes. + // Arguments: + // max_length: The maximum length of the random segmentation to be + // generated. + // max_segment_length: Maximum length of a segment in the segmentation + // num_classes: Maximum number of classes in the generated segmentation + void GenRandomSegmentation(int32 max_length, int32 max_segment_length, + int32 num_classes); + + // Public accessors + inline int32 Dim() const { return dim_; } + SegmentList::iterator Begin() { return segments_.begin(); } + SegmentList::const_iterator Begin() const { return segments_.begin(); } + SegmentList::iterator End() { return segments_.end(); } + SegmentList::const_iterator End() const { return segments_.end(); } + + Segment& Back() { return segments_.back(); } + const Segment& Back() const { return segments_.back(); } + + const SegmentList* Data() const { return &segments_; } + + // Default constructor + Segmentation(); + + private: + // number of segments in the segmentation + int32 dim_; + + // list of segments in the segmentation + SegmentList segments_; + + friend class SegmentationPostProcessor; +}; + +typedef TableWriter > SegmentationWriter; +typedef SequentialTableReader > + SequentialSegmentationReader; +typedef RandomAccessTableReader > + RandomAccessSegmentationReader; +typedef RandomAccessTableReaderMapped > + RandomAccessSegmentationReaderMapped; + +} // end namespace segmenter +} // end namespace kaldi + +#endif // KALDI_SEGMENTER_SEGMENTATION_H_ diff --git a/src/segmenterbin/Makefile b/src/segmenterbin/Makefile new file mode 100644 index 00000000000..1f0efe71181 --- /dev/null +++ b/src/segmenterbin/Makefile @@ -0,0 +1,36 @@ + +all: + +EXTRA_CXXFLAGS = -Wno-sign-compare +include ../kaldi.mk + +BINFILES = segmentation-copy segmentation-get-stats \ + segmentation-init-from-ali segmentation-to-ali \ + segmentation-init-from-segments segmentation-to-segments \ + segmentation-combine-segments segmentation-merge-recordings \ + segmentation-create-subsegments segmentation-intersect-ali \ + segmentation-to-rttm segmentation-post-process \ + segmentation-merge segmentation-split-segments \ + segmentation-remove-segments \ + segmentation-init-from-lengths \ + segmentation-combine-segments-to-recordings \ + segmentation-create-overlapped-subsegments \ + segmentation-intersect-segments \ + segmentation-init-from-additive-signals-info #\ + gmm-acc-pdf-stats-segmentation \ + gmm-est-segmentation gmm-update-segmentation \ + segmentation-init-from-diarization \ + segmentation-compute-class-ctm-conf \ + combine-vector-segments + +OBJFILES = + + + +TESTFILES = + +ADDLIBS = ../hmm/kaldi-hmm.a ../gmm/kaldi-gmm.a ../segmenter/kaldi-segmenter.a ../tree/kaldi-tree.a \ + ../util/kaldi-util.a ../matrix/kaldi-matrix.a ../base/kaldi-base.a ../thread/kaldi-thread.a + +include ../makefiles/default_rules.mk + diff --git a/src/segmenterbin/segmentation-combine-segments-to-recordings.cc b/src/segmenterbin/segmentation-combine-segments-to-recordings.cc new file mode 100644 index 00000000000..acf71265577 --- /dev/null +++ b/src/segmenterbin/segmentation-combine-segments-to-recordings.cc @@ -0,0 +1,114 @@ +// segmenterbin/segmentation-combine-segments-to-recordings.cc + +// Copyright 2015-16 Vimal Manohar (Johns Hopkins University) + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "base/kaldi-common.h" +#include "util/common-utils.h" +#include "segmenter/segmentation-utils.h" + +int main(int argc, char *argv[]) { + try { + using namespace kaldi; + using namespace segmenter; + + const char *usage = + "Combine kaldi segments in segmentation format to " + "recording-level segmentation\n" + "A reco2utt file is used to specify which utterances are contained " + "in a recording.\n" + "This program expects the input segmentation to be a kaldi segment " + "converted to segmentation using segmentation-init-from-segments. " + "For other segmentations, the user can use the binary " + "segmentation-combine-segments instead.\n" + "\n" + "Usage: segmentation-combine-segments-to-recording [options] " + " " + "\n" + " e.g.: segmentation-combine-segments-to-recording \\\n" + "'ark:segmentation-init-from-segments --shift-to-zero=false " + "data/dev/segments ark:- |' ark,t:data/dev/reco2utt ark:file.seg\n" + "See also: segmentation-combine-segments, " + "segmentation-merge, segmentation-merge-recordings, " + "segmentation-post-process --merge-adjacent-segments\n"; + + ParseOptions po(usage); + + po.Read(argc, argv); + + if (po.NumArgs() != 3) { + po.PrintUsage(); + exit(1); + } + + std::string segmentation_rspecifier = po.GetArg(1), + reco2utt_rspecifier = po.GetArg(2), + segmentation_wspecifier = po.GetArg(3); + + SequentialTokenVectorReader reco2utt_reader(reco2utt_rspecifier); + RandomAccessSegmentationReader segmentation_reader( + segmentation_rspecifier); + SegmentationWriter segmentation_writer(segmentation_wspecifier); + + int32 num_done = 0, num_segmentations = 0, num_err = 0; + + for (; !reco2utt_reader.Done(); reco2utt_reader.Next()) { + const std::vector &utts = reco2utt_reader.Value(); + const std::string &reco_id = reco2utt_reader.Key(); + + Segmentation out_segmentation; + + for (std::vector::const_iterator it = utts.begin(); + it != utts.end(); ++it) { + if (!segmentation_reader.HasKey(*it)) { + KALDI_WARN << "Could not find utterance " << *it << " in " + << "segments segmentation " + << segmentation_rspecifier; + num_err++; + continue; + } + + const Segmentation &segmentation = segmentation_reader.Value(*it); + if (segmentation.Dim() != 1) { + KALDI_ERR << "Segments segmentation for utt " << *it << " is not " + << "kaldi segment converted to segmentation format " + << "in " << segmentation_rspecifier; + } + const Segment &segment = *(segmentation.Begin()); + + out_segmentation.PushBack(segment); + + num_done++; + } + + Sort(&out_segmentation); + segmentation_writer.Write(reco_id, out_segmentation); + num_segmentations++; + } + + KALDI_LOG << "Combined " << num_done << " utterance-level segments " + << "into " << num_segmentations + << " recording-level segmentations; failed with " + << num_err << " utterances."; + + return ((num_done > 0 && num_err < num_done) ? 0 : 1); + } catch(const std::exception &e) { + std::cerr << e.what(); + return -1; + } +} + diff --git a/src/segmenterbin/segmentation-combine-segments.cc b/src/segmenterbin/segmentation-combine-segments.cc new file mode 100644 index 00000000000..7034a8a1734 --- /dev/null +++ b/src/segmenterbin/segmentation-combine-segments.cc @@ -0,0 +1,128 @@ +// segmenterbin/segmentation-combine-segments.cc + +// Copyright 2015-16 Vimal Manohar (Johns Hopkins University) + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "base/kaldi-common.h" +#include "util/common-utils.h" +#include "segmenter/segmentation-utils.h" + +int main(int argc, char *argv[]) { + try { + using namespace kaldi; + using namespace segmenter; + + const char *usage = + "Combine utterance-level segmentations in an archive to " + "recording-level segmentations using the kaldi segments to map " + "utterances to their positions in the recordings.\n" + "A reco2utt file is used to specify which utterances belong to each " + "recording.\n" + "\n" + "Usage: segmentation-combine-segments [options] " + " " + " " + " \n" + " e.g.: segmentation-combine-segments ark:utt.seg " + "'ark:segmentation-init-from-segments --shift-to-zero=false " + "data/dev/segments ark:- |' ark,t:data/dev/reco2utt ark:file.seg\n" + "See also: segmentation-combine-segments-to-recording, " + "segmentation-merge, segmentatin-merge-recordings, " + "segmentation-post-process --merge-adjacent-segments\n"; + + ParseOptions po(usage); + + po.Read(argc, argv); + + if (po.NumArgs() != 4) { + po.PrintUsage(); + exit(1); + } + + std::string utt_segmentation_rspecifier = po.GetArg(1), + segments_segmentation_rspecifier = po.GetArg(2), + reco2utt_rspecifier = po.GetArg(3), + segmentation_wspecifier = po.GetArg(4); + + SequentialTokenVectorReader reco2utt_reader(reco2utt_rspecifier); + RandomAccessSegmentationReader segments_segmentation_reader( + segments_segmentation_rspecifier); + RandomAccessSegmentationReader utt_segmentation_reader( + utt_segmentation_rspecifier); + SegmentationWriter segmentation_writer(segmentation_wspecifier); + + int32 num_done = 0, num_segmentations = 0, num_err = 0; + int64 num_segments = 0; + + for (; !reco2utt_reader.Done(); reco2utt_reader.Next()) { + const std::vector &utts = reco2utt_reader.Value(); + const std::string &reco_id = reco2utt_reader.Key(); + + Segmentation out_segmentation; + + for (std::vector::const_iterator it = utts.begin(); + it != utts.end(); ++it) { + if (!segments_segmentation_reader.HasKey(*it)) { + KALDI_WARN << "Could not find utterance " << *it << " in " + << "segments segmentation " + << segments_segmentation_rspecifier; + num_err++; + continue; + } + + const Segmentation &segments_segmentation = + segments_segmentation_reader.Value(*it); + if (segments_segmentation.Dim() != 1) { + KALDI_ERR << "Segments segmentation for utt " << *it << " is not " + << "kaldi segment converted to segmentation format " + << "in " << segments_segmentation_rspecifier; + } + const Segment &segment = *(segments_segmentation.Begin()); + + if (!utt_segmentation_reader.HasKey(*it)) { + KALDI_WARN << "Could not find utterance " << *it << " in " + << "segmentation " << utt_segmentation_rspecifier; + num_err++; + continue; + } + const Segmentation &utt_segmentation + = utt_segmentation_reader.Value(*it); + + num_segments += InsertFromSegmentation(utt_segmentation, + segment.start_frame, false, + &out_segmentation, NULL); + num_done++; + } + + Sort(&out_segmentation); + segmentation_writer.Write(reco_id, out_segmentation); + num_segmentations++; + } + + KALDI_LOG << "Combined " << num_done << " utterance-level segmentations " + << "into " << num_segmentations + << " recording-level segmentations; failed with " + << num_err << " utterances; " + << "wrote a total of " << num_segments << " segments."; + + return ((num_done > 0 && num_err < num_done) ? 0 : 1); + } catch(const std::exception &e) { + std::cerr << e.what(); + return -1; + } +} + diff --git a/src/segmenterbin/segmentation-copy.cc b/src/segmenterbin/segmentation-copy.cc new file mode 100644 index 00000000000..26d0f47682d --- /dev/null +++ b/src/segmenterbin/segmentation-copy.cc @@ -0,0 +1,232 @@ +// segmenterbin/segmentation-copy.cc + +// Copyright 2015-16 Vimal Manohar (Johns Hopkins University) + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "base/kaldi-common.h" +#include "util/common-utils.h" +#include "segmenter/segmentation.h" +#include "segmenter/segmentation-utils.h" + +int main(int argc, char *argv[]) { + try { + using namespace kaldi; + using namespace segmenter; + + const char *usage = + "Copy segmentation or archives of segmentation.\n" + "If label-map is supplied, then apply the mapping to the labels \n" + "when copying.\n" + "If utt2label-rspecifier is supplied, then ignore the \n" + "original labels, and map all the segments of an utterance using \n" + "the supplied utt2label map.\n" + "\n" + "Usage: segmentation-copy [options] " + "\n" + " e.g.: segmentation-copy ark:1.seg ark,t:-\n" + " or \n" + " segmentation-copy [options] " + "\n" + " e.g.: segmentation-copy --binary=false foo -\n"; + + bool binary = true; + std::string label_map_rxfilename, utt2label_rspecifier; + std::string include_rxfilename, exclude_rxfilename; + int32 keep_label = -1; + BaseFloat frame_subsampling_factor = 1; + + ParseOptions po(usage); + + po.Register("binary", &binary, + "Write in binary mode " + "(only relevant if output is a wxfilename)"); + po.Register("label-map", &label_map_rxfilename, + "File with mapping from old to new labels"); + po.Register("frame-subsampling-factor", &frame_subsampling_factor, + "Change frame rate by this factor"); + po.Register("utt2label-rspecifier", &utt2label_rspecifier, + "Mapping for each utterance to an integer label"); + po.Register("keep-label", &keep_label, + "If supplied, only segments of this label are written out"); + po.Register("include", &include_rxfilename, + "Text file, the first field of each" + " line being interpreted as an " + "utterance-id whose features will be included"); + po.Register("exclude", &exclude_rxfilename, + "Text file, the first field of each " + "line being interpreted as an utterance-id" + " whose features will be excluded"); + + po.Read(argc, argv); + + if (po.NumArgs() != 2) { + po.PrintUsage(); + exit(1); + } + + // all these "fn"'s are either rspecifiers or filenames. + + std::string segmentation_in_fn = po.GetArg(1), + segmentation_out_fn = po.GetArg(2); + + // Read mapping from old to new labels + unordered_map label_map; + if (!label_map_rxfilename.empty()) { + Input ki(label_map_rxfilename); + std::string line; + while (std::getline(ki.Stream(), line)) { + std::vector splits; + SplitStringToVector(line, " ", true, &splits); + + if (splits.size() != 2) + KALDI_ERR << "Invalid format of line " << line + << " in " << label_map_rxfilename; + + label_map[std::atoi(splits[0].c_str())] = std::atoi(splits[1].c_str()); + } + } + + unordered_set include_set; + if (include_rxfilename != "") { + if (exclude_rxfilename != "") { + KALDI_ERR << "should not have both --exclude and --include option!"; + } + Input ki(include_rxfilename); + std::string line; + while (std::getline(ki.Stream(), line)) { + std::vector split_line; + SplitStringToVector(line, " \t\r", true, &split_line); + KALDI_ASSERT(!split_line.empty() && + "Empty line encountered in input from --include option"); + include_set.insert(split_line[0]); + } + } + + unordered_set exclude_set; + if (exclude_rxfilename != "") { + if (include_rxfilename != "") { + KALDI_ERR << "should not have both --exclude and --include option!"; + } + Input ki(exclude_rxfilename); + std::string line; + while (std::getline(ki.Stream(), line)) { + std::vector split_line; + SplitStringToVector(line, " \t\r", true, &split_line); + KALDI_ASSERT(!split_line.empty() && + "Empty line encountered in input from --exclude option"); + exclude_set.insert(split_line[0]); + } + } + + bool in_is_rspecifier = + (ClassifyRspecifier(segmentation_in_fn, NULL, NULL) + != kNoRspecifier), + out_is_wspecifier = + (ClassifyWspecifier(segmentation_out_fn, NULL, NULL, NULL) + != kNoWspecifier); + + if (in_is_rspecifier != out_is_wspecifier) + KALDI_ERR << "Cannot mix regular files and archives"; + + int64 num_done = 0, num_err = 0; + + if (!in_is_rspecifier) { + Segmentation segmentation; + { + bool binary_in; + Input ki(segmentation_in_fn, &binary_in); + segmentation.Read(ki.Stream(), binary_in); + } + + if (!label_map_rxfilename.empty()) + RelabelSegmentsUsingMap(label_map, &segmentation); + + if (keep_label != -1) + KeepSegments(keep_label, &segmentation); + + if (frame_subsampling_factor != 1.0) { + ScaleFrameShift(frame_subsampling_factor, &segmentation); + } + + if (!utt2label_rspecifier.empty()) + KALDI_ERR << "It makes no sense to specify utt2label-rspecifier " + << "when not reading segmentation archives."; + + Output ko(segmentation_out_fn, binary); + segmentation.Write(ko.Stream(), binary); + + KALDI_LOG << "Copied segmentation to " << segmentation_out_fn; + return 0; + } else { + RandomAccessInt32Reader utt2label_reader(utt2label_rspecifier); + + SegmentationWriter writer(segmentation_out_fn); + SequentialSegmentationReader reader(segmentation_in_fn); + + for (; !reader.Done(); reader.Next()) { + const std::string &key = reader.Key(); + + if (include_rxfilename != "" && include_set.count(key) == 0) { + continue; + } + + if (exclude_rxfilename != "" && include_set.count(key) > 0) { + continue; + } + + if (label_map_rxfilename.empty() && + frame_subsampling_factor == 1.0 && + utt2label_rspecifier.empty() && + keep_label == -1) { + writer.Write(key, reader.Value()); + } else { + Segmentation segmentation = reader.Value(); + if (!label_map_rxfilename.empty()) + RelabelSegmentsUsingMap(label_map, &segmentation); + if (!utt2label_rspecifier.empty()) { + if (!utt2label_reader.HasKey(key)) { + KALDI_WARN << "Utterance " << key + << " not found in utt2label map " + << utt2label_rspecifier; + num_err++; + continue; + } + + RelabelAllSegments(utt2label_reader.Value(key), &segmentation); + } + if (keep_label != -1) + KeepSegments(keep_label, &segmentation); + + if (frame_subsampling_factor != 1.0) + ScaleFrameShift(frame_subsampling_factor, &segmentation); + + writer.Write(key, segmentation); + } + + num_done++; + } + + KALDI_LOG << "Copied " << num_done << " segmentation; failed with " + << num_err << " segmentations"; + return (num_done != 0 ? 0 : 1); + } + } catch(const std::exception &e) { + std::cerr << e.what(); + return -1; + } +} + diff --git a/src/segmenterbin/segmentation-create-subsegments.cc b/src/segmenterbin/segmentation-create-subsegments.cc new file mode 100644 index 00000000000..9d7f4c08b6d --- /dev/null +++ b/src/segmenterbin/segmentation-create-subsegments.cc @@ -0,0 +1,175 @@ +// segmenterbin/segmentation-create-subsegments.cc + +// Copyright 2015-16 Vimal Manohar (Johns Hopkins University) + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "base/kaldi-common.h" +#include "util/common-utils.h" +#include "segmenter/segmentation-utils.h" + +int main(int argc, char *argv[]) { + try { + using namespace kaldi; + using namespace segmenter; + + const char *usage = + "Create sub-segmentation of a segmentation by intersecting with " + "segments from a 'filter' segmentation. \n" + "The labels for the new subsegments are decided " + "depending on whether the label of 'filter' segment " + "matches the specified 'filter_label' or not:\n" + " if filter segment's label == filter_label: \n" + " if subsegment_label is specified:\n" + " label = subsegment_label\n" + " else: \n" + " label = filter_label \n" + " else: \n" + " if unmatched_label is specified:\n" + " label = unmatched_label\n" + " else\n:" + " label = primary_label\n" + "See the function SubSegmentUsingNonOverlappingSegments() " + "for more details.\n" + "\n" + "Usage: segmentation-create-subsegments [options] " + " " + " \n" + " or : segmentation-create-subsegments [options] " + " " + " \n" + " e.g.: segmentation-create-subsegments --binary=false " + "--filter-label=1 --subsegment-label=1000 foo bar -\n" + " segmentation-create-subsegments --filter-label=1 " + "--subsegment-label=1000 ark:1.foo ark:1.bar ark:-\n"; + + bool binary = true, ignore_missing = false; + int32 filter_label = -1, subsegment_label = -1, unmatched_label = -1; + ParseOptions po(usage); + + po.Register("binary", &binary, + "Write in binary mode " + "(only relevant if output is a wxfilename)"); + po.Register("filter-label", &filter_label, + "The label on which filtering is done."); + po.Register("subsegment-label", &subsegment_label, + "If non-negative, change the class-id of the matched regions " + "in the intersection of the two segmentations to this label."); + po.Register("unmatched-label", &unmatched_label, + "If non-negative, change the class-id of the unmatched " + "regions in the intersection of the two segmentations " + "to this label."); + po.Register("ignore-missing", &ignore_missing, "Ignore missing " + "segmentations in filter. If this is set true, then the " + "segmentations with missing key in filter are written " + "without any modification."); + + po.Read(argc, argv); + + if (po.NumArgs() != 3) { + po.PrintUsage(); + exit(1); + } + + std::string segmentation_in_fn = po.GetArg(1), + secondary_segmentation_in_fn = po.GetArg(2), + segmentation_out_fn = po.GetArg(3); + + // all these "fn"'s are either rspecifiers or filenames. + + bool in_is_rspecifier = + (ClassifyRspecifier(segmentation_in_fn, NULL, NULL) + != kNoRspecifier), + filter_is_rspecifier = + (ClassifyRspecifier(secondary_segmentation_in_fn, NULL, NULL) + != kNoRspecifier), + out_is_wspecifier = + (ClassifyWspecifier(segmentation_out_fn, NULL, NULL, NULL) + != kNoWspecifier); + + if (in_is_rspecifier != out_is_wspecifier || + in_is_rspecifier != filter_is_rspecifier) + KALDI_ERR << "Cannot mix regular files and archives"; + + int64 num_done = 0, num_err = 0; + + if (!in_is_rspecifier) { + Segmentation segmentation; + { + bool binary_in; + Input ki(segmentation_in_fn, &binary_in); + segmentation.Read(ki.Stream(), binary_in); + } + Segmentation secondary_segmentation; + { + bool binary_in; + Input ki(secondary_segmentation_in_fn, &binary_in); + secondary_segmentation.Read(ki.Stream(), binary_in); + } + + Segmentation new_segmentation; + SubSegmentUsingNonOverlappingSegments( + segmentation, secondary_segmentation, filter_label, subsegment_label, + unmatched_label, &new_segmentation); + Output ko(segmentation_out_fn, binary); + new_segmentation.Write(ko.Stream(), binary); + + KALDI_LOG << "Created subsegments of " << segmentation_in_fn + << " based on " << secondary_segmentation_in_fn + << " and wrote to " << segmentation_out_fn; + return 0; + } else { + SegmentationWriter writer(segmentation_out_fn); + SequentialSegmentationReader reader(segmentation_in_fn); + RandomAccessSegmentationReader filter_reader( + secondary_segmentation_in_fn); + + for (; !reader.Done(); reader.Next(), num_done++) { + const Segmentation &segmentation = reader.Value(); + const std::string &key = reader.Key(); + + if (!filter_reader.HasKey(key)) { + KALDI_WARN << "Could not find filter segmentation for utterance " + << key; + if (!ignore_missing) + num_err++; + else + writer.Write(key, segmentation); + continue; + } + const Segmentation &secondary_segmentation = filter_reader.Value(key); + + Segmentation new_segmentation; + SubSegmentUsingNonOverlappingSegments(segmentation, + secondary_segmentation, + filter_label, subsegment_label, + unmatched_label, + &new_segmentation); + + writer.Write(key, new_segmentation); + } + + KALDI_LOG << "Created subsegments for " << num_done << " segmentations; " + << "failed with " << num_err << " segmentations"; + + return ((num_done != 0 && num_err < num_done) ? 0 : 1); + } + } catch(const std::exception &e) { + std::cerr << e.what(); + return -1; + } +} + diff --git a/src/segmenterbin/segmentation-get-stats.cc b/src/segmenterbin/segmentation-get-stats.cc new file mode 100644 index 00000000000..b25d6913f06 --- /dev/null +++ b/src/segmenterbin/segmentation-get-stats.cc @@ -0,0 +1,125 @@ +// segmenterbin/segmentation-get-per-frame-stats.cc + +// Copyright 2016 Vimal Manohar (Johns Hopkins University) + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "base/kaldi-common.h" +#include "util/common-utils.h" +#include "segmenter/segmentation-utils.h" + +int main(int argc, char *argv[]) { + try { + using namespace kaldi; + using namespace segmenter; + + const char *usage = + "Get per-frame stats from segmentation. \n" + "Currently supported stats are \n" + " num-overlaps: Number of overlapping segments common to this frame\n" + " num-classes: Number of distinct classes common to this frame\n" + "\n" + "Usage: segmentation-get-stats [options] " + " \n" + " e.g.: segmentation-get-stats ark:1.seg ark:/dev/null " + "ark:num_classes.ark\n"; + + ParseOptions po(usage); + + std::string lengths_rspecifier; + int32 length_tolerance = 2; + + po.Register("lengths-rspecifier", &lengths_rspecifier, + "Archive of frame lengths of the utterances. " + "Fills up any extra length with zero stats."); + po.Register("length-tolerance", &length_tolerance, + "Tolerate shortage of this many frames in the specified " + "lengths file"); + + po.Read(argc, argv); + + if (po.NumArgs() != 3) { + po.PrintUsage(); + exit(1); + } + + std::string segmentation_rspecifier = po.GetArg(1), + num_overlaps_wspecifier = po.GetArg(2), + num_classes_wspecifier = po.GetArg(3); + + int64 num_done = 0, num_err = 0; + + SequentialSegmentationReader reader(segmentation_rspecifier); + Int32VectorWriter num_overlaps_writer(num_overlaps_wspecifier); + Int32VectorWriter num_classes_writer(num_classes_wspecifier); + + RandomAccessInt32Reader lengths_reader(lengths_rspecifier); + + for (; !reader.Done(); reader.Next(), num_done++) { + const Segmentation &segmentation = reader.Value(); + const std::string &key = reader.Key(); + + int32 length = -1; + if (!lengths_rspecifier.empty()) { + if (!lengths_reader.HasKey(key)) { + KALDI_WARN << "Could not find length for key " << key; + num_err++; + continue; + } + length = lengths_reader.Value(key); + } + + std::vector > class_counts_per_frame; + if (!GetClassCountsPerFrame(segmentation, length, + length_tolerance, + &class_counts_per_frame)) { + KALDI_WARN << "Failed getting stats for key " << key; + num_err++; + continue; + } + + if (length == -1) + length = class_counts_per_frame.size(); + + std::vector num_classes_per_frame(length, 0); + std::vector num_overlaps_per_frame(length, 0); + + for (int32 i = 0; i < class_counts_per_frame.size(); i++) { + std::map &class_counts = class_counts_per_frame[i]; + + for (std::map::const_iterator it = class_counts.begin(); + it != class_counts.end(); ++it) { + if (it->second > 0) + num_classes_per_frame[i]++; + num_overlaps_per_frame[i] += it->second; + } + } + + num_classes_writer.Write(key, num_classes_per_frame); + num_overlaps_writer.Write(key, num_overlaps_per_frame); + + num_done++; + } + + KALDI_LOG << "Got stats for " << num_done << " segmentations; failed with " + << num_err << " segmentations"; + return (num_done != 0 ? 0 : 1); + } catch(const std::exception &e) { + std::cerr << e.what(); + return -1; + } +} + diff --git a/src/segmenterbin/segmentation-init-from-ali.cc b/src/segmenterbin/segmentation-init-from-ali.cc new file mode 100644 index 00000000000..a98a54368c9 --- /dev/null +++ b/src/segmenterbin/segmentation-init-from-ali.cc @@ -0,0 +1,91 @@ +// segmenterbin/segmentation-init-from-ali.cc + +// Copyright 2015-16 Vimal Manohar (Johns Hopkins University) + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "base/kaldi-common.h" +#include "util/common-utils.h" +#include "segmenter/segmentation-utils.h" + +int main(int argc, char *argv[]) { + try { + using namespace kaldi; + using namespace segmenter; + + const char *usage = + "Initialize utterance-level segmentations from alignments file. \n" + "The user can pass this to segmentation-combine-segments to " + "create recording-level segmentations." + "\n" + "Usage: segmentation-init-from-ali [options] " + " \n" + " e.g.: segmentation-init-from-ali ark:1.ali ark:-\n" + "See also: segmentation-init-from-segments, " + "segmentation-combine-segments\n"; + + ParseOptions po(usage); + po.Read(argc, argv); + + if (po.NumArgs() != 2) { + po.PrintUsage(); + exit(1); + } + + std::string ali_rspecifier = po.GetArg(1), + segmentation_wspecifier = po.GetArg(2); + + SegmentationWriter segmentation_writer(segmentation_wspecifier); + + int32 num_done = 0, num_segmentations = 0; + int64 num_segments = 0; + int64 num_err = 0; + + std::vector frame_counts_per_class; + + SequentialInt32VectorReader alignment_reader(ali_rspecifier); + + for (; !alignment_reader.Done(); alignment_reader.Next()) { + const std::string &key = alignment_reader.Key(); + const std::vector &alignment = alignment_reader.Value(); + + Segmentation segmentation; + + num_segments += InsertFromAlignment(alignment, 0, alignment.size(), + 0, &segmentation, + &frame_counts_per_class); + + Sort(&segmentation); + segmentation_writer.Write(key, segmentation); + + num_done++; + num_segmentations++; + } + + KALDI_LOG << "Processed " << num_done << " utterances; failed with " + << num_err << " utterances; " + << "wrote " << num_segmentations << " segmentations " + << "with a total of " << num_segments << " segments."; + KALDI_LOG << "Number of frames for the different classes are : "; + WriteIntegerVector(KALDI_LOG, false, frame_counts_per_class); + + return ((num_done > 0 && num_err < num_done) ? 0 : 1); + } catch(const std::exception &e) { + std::cerr << e.what(); + return -1; + } +} + diff --git a/src/segmenterbin/segmentation-init-from-lengths.cc b/src/segmenterbin/segmentation-init-from-lengths.cc new file mode 100644 index 00000000000..28c998c220b --- /dev/null +++ b/src/segmenterbin/segmentation-init-from-lengths.cc @@ -0,0 +1,82 @@ +// segmenterbin/segmentation-init-from-lengths.cc + +// Copyright 2015-16 Vimal Manohar (Johns Hopkins University) + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "base/kaldi-common.h" +#include "util/common-utils.h" +#include "segmenter/segmentation.h" + +int main(int argc, char *argv[]) { + try { + using namespace kaldi; + using namespace segmenter; + + const char *usage = + "Initialize segmentations from frame lengths file\n" + "\n" + "Usage: segmentation-init-from-lengths [options] " + " \n" + " e.g.: segmentation-init-from-lengths " + "\"ark:feat-to-len scp:feats.scp ark:- |\" ark:-\n" + "\n" + "See also: segmentation-init-from-ali, " + "segmentation-init-from-segments\n"; + + int32 label = 1; + + ParseOptions po(usage); + + po.Register("label", &label, "Label to assign to the created segments"); + po.Read(argc, argv); + + if (po.NumArgs() != 2) { + po.PrintUsage(); + exit(1); + } + + std::string lengths_rspecifier = po.GetArg(1), + segmentation_wspecifier = po.GetArg(2); + + SequentialInt32Reader lengths_reader(lengths_rspecifier); + SegmentationWriter segmentation_writer(segmentation_wspecifier); + + int32 num_done = 0; + + for (; !lengths_reader.Done(); lengths_reader.Next()) { + const std::string &key = lengths_reader.Key(); + const int32 &length = lengths_reader.Value(); + + Segmentation segmentation; + + if (length > 0) { + segmentation.EmplaceBack(0, length - 1, label); + } + + segmentation_writer.Write(key, segmentation); + num_done++; + } + + KALDI_LOG << "Created " << num_done << " segmentations."; + + return (num_done > 0 ? 0 : 1); + } catch(const std::exception &e) { + std::cerr << e.what(); + return -1; + } +} + diff --git a/src/segmenterbin/segmentation-init-from-segments.cc b/src/segmenterbin/segmentation-init-from-segments.cc new file mode 100644 index 00000000000..c39996b5ef4 --- /dev/null +++ b/src/segmenterbin/segmentation-init-from-segments.cc @@ -0,0 +1,179 @@ +// segmenterbin/segmentation-init-from-segments.cc + +// Copyright 2015-16 Vimal Manohar (Johns Hopkins University) + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "base/kaldi-common.h" +#include "util/common-utils.h" +#include "segmenter/segmentation.h" + +// If segments file contains +// Alpha-001 Alpha 0.00 0.16 +// Alpha-002 Alpha 1.50 4.10 +// Beta-001 Beta 0.50 2.66 +// Beta-002 Beta 3.50 5.20 +// the output segmentation will contain +// Alpha-001 [ 0 16 1 ] +// Alpha-002 [ 0 360 1 ] +// Beta-001 [ 0 216 1 ] +// Beta-002 [ 0 170 1 ] +// If --shift-to-zero=false is provided, then the output will contain +// Alpha-001 [ 0 16 1 ] +// Alpha-002 [ 150 410 1 ] +// Beta-001 [ 50 266 1 ] +// Beta-002 [ 350 520 1 ] +// +// If the following utt2label-rspecifier was provided: +// Alpha-001 2 +// Alpha-002 2 +// Beta-001 4 +// Beta-002 4 +// then the output segmentation will contain +// Alpha-001 [ 0 16 2 ] +// Alpha-002 [ 0 360 2 ] +// Beta-001 [ 0 216 4 ] +// Beta-002 [ 0 170 4 ] + +int main(int argc, char *argv[]) { + try { + using namespace kaldi; + using namespace segmenter; + + const char *usage = + "Convert segments from segments file into utterance-level " + "segmentation format. \n" + "The user can convert the segmenation to recording-level using " + "the binary segmentation-combine-segments-to-recording.\n" + "\n" + "Usage: segmentation-init-from-segments [options] " + " \n" + " e.g.: segmentation-init-from-segments segments ark:-\n"; + + int32 segment_label = 1; + BaseFloat frame_shift = 0.01, frame_overlap = 0.015; + std::string utt2label_rspecifier; + bool shift_to_zero = true; + + ParseOptions po(usage); + + po.Register("segment-label", &segment_label, + "Label for all the segments in the segmentations"); + po.Register("utt2label-rspecifier", &utt2label_rspecifier, + "Mapping for each utterance to an integer label. " + "If supplied, these labels will be used as the segment " + "labels"); + po.Register("shift-to-zero", &shift_to_zero, + "Shift all segments to 0th frame"); + po.Register("frame-shift", &frame_shift, "Frame shift in seconds"); + po.Register("frame-overlap", &frame_overlap, "Frame overlap in seconds"); + + po.Read(argc, argv); + + if (po.NumArgs() != 2) { + po.PrintUsage(); + exit(1); + } + + std::string segments_rxfilename = po.GetArg(1), + segmentation_wspecifier = po.GetArg(2); + + SegmentationWriter writer(segmentation_wspecifier); + RandomAccessInt32Reader utt2label_reader(utt2label_rspecifier); + + Input ki(segments_rxfilename); + + int64 num_lines = 0, num_done = 0; + + std::string line; + + while (std::getline(ki.Stream(), line)) { + num_lines++; + + std::vector split_line; + // Split the line by space or tab and check the number of fields in each + // line. There must be 4 fields--segment name , reacording wav file name, + // start time, end time; 5th field (channel info) is optional. + SplitStringToVector(line, " \t\r", true, &split_line); + if (split_line.size() != 4 && split_line.size() != 5) { + KALDI_WARN << "Invalid line in segments file: " << line; + continue; + } + std::string utt = split_line[0], + reco = split_line[1], + start_str = split_line[2], + end_str = split_line[3]; + + // Convert the start time and endtime to real from string. Segment is + // ignored if start or end time cannot be converted to real. + double start, end; + if (!ConvertStringToReal(start_str, &start)) { + KALDI_WARN << "Invalid line in segments file [bad start]: " << line; + continue; + } + if (!ConvertStringToReal(end_str, &end)) { + KALDI_WARN << "Invalid line in segments file [bad end]: " << line; + continue; + } + + // start time must not be negative; start time must not be greater than + // end time, except if end time is -1 + if (start < 0 || (end != -1.0 && end <= 0) || + ((start >= end) && (end > 0))) { + KALDI_WARN << "Invalid line in segments file " + << "[empty or invalid segment]: " << line; + continue; + } + + if (split_line.size() >= 5) + KALDI_ERR << "Not supporting channel in segments file"; + + Segmentation segmentation; + + if (!utt2label_rspecifier.empty()) { + if (!utt2label_reader.HasKey(utt)) { + KALDI_WARN << "Could not find utterance " << utt << " in " + << utt2label_rspecifier; + continue; + } + + segment_label = utt2label_reader.Value(utt); + } + + int32 length = round((end - frame_overlap)/ frame_shift) + - round(start / frame_shift); + + if (shift_to_zero) + segmentation.EmplaceBack(0, length, segment_label); + else + segmentation.EmplaceBack(round(start / frame_shift), + round((end-frame_overlap) / frame_shift) - 1, + segment_label); + + writer.Write(utt, segmentation); + num_done++; + } + + KALDI_LOG << "Successfully processed " << num_done << " lines out of " + << num_lines << " in the segments file"; + + return (num_done > num_lines / 2 ? 0 : 1); + } catch(const std::exception &e) { + std::cerr << e.what(); + return -1; + } +} + diff --git a/src/segmenterbin/segmentation-intersect-ali.cc b/src/segmenterbin/segmentation-intersect-ali.cc new file mode 100644 index 00000000000..a551eee02ce --- /dev/null +++ b/src/segmenterbin/segmentation-intersect-ali.cc @@ -0,0 +1,99 @@ +// segmenterbin/segmentation-intersect-ali.cc + +// Copyright 2015-16 Vimal Manohar (Johns Hopkins University) + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "base/kaldi-common.h" +#include "util/common-utils.h" +#include "segmenter/segmentation-utils.h" + +int main(int argc, char *argv[]) { + try { + using namespace kaldi; + using namespace segmenter; + + const char *usage = + "Intersect (like sets) segmentation with an alignment and retain \n" + "only segments where the alignment is the specified label. \n" + "\n" + "Usage: segmentation-intersect-alignment [options] " + " " + "\n" + " e.g.: segmentation-intersect-alignment --binary=false ark:foo.seg " + "ark:filter.ali ark,t:-\n" + "See also: segmentation-combine-segments, " + "segmentation-intersect-segments, segmentation-create-subsegments\n"; + + ParseOptions po(usage); + + int32 ali_label = 0, min_alignment_chunk_length = 0; + + po.Register("ali-label", &ali_label, + "Intersect only at this label of alignments"); + po.Register("min-alignment-chunk-length", &min_alignment_chunk_length, + "The minimmum number of consecutive frames of ali_label in " + "alignment at which the segments can be intersected."); + + po.Read(argc, argv); + + if (po.NumArgs() != 3) { + po.PrintUsage(); + exit(1); + } + + std::string segmentation_rspecifier = po.GetArg(1), + ali_rspecifier = po.GetArg(2), + segmentation_wspecifier = po.GetArg(3); + + int32 num_done = 0, num_err = 0; + + SegmentationWriter writer(segmentation_wspecifier); + SequentialSegmentationReader segmentation_reader(segmentation_rspecifier); + RandomAccessInt32VectorReader alignment_reader(ali_rspecifier); + + for (; !segmentation_reader.Done(); segmentation_reader.Next()) { + const Segmentation &segmentation = segmentation_reader.Value(); + const std::string &key = segmentation_reader.Key(); + + if (!alignment_reader.HasKey(key)) { + KALDI_WARN << "Could not find segmentation for key " << key + << " in " << ali_rspecifier; + num_err++; + continue; + } + const std::vector &ali = alignment_reader.Value(key); + + Segmentation out_segmentation; + IntersectSegmentationAndAlignment(segmentation, ali, ali_label, + min_alignment_chunk_length, + &out_segmentation); + out_segmentation.Sort(); + + writer.Write(key, out_segmentation); + num_done++; + } + + KALDI_LOG << "Intersected " << num_done + << " segmentations with alignments; failed with " + << num_err << " segmentations"; + return (num_done != 0 ? 0 : 1); + } catch(const std::exception &e) { + std::cerr << e.what(); + return -1; + } +} + diff --git a/src/segmenterbin/segmentation-intersect-segments.cc b/src/segmenterbin/segmentation-intersect-segments.cc new file mode 100644 index 00000000000..1c9861ba453 --- /dev/null +++ b/src/segmenterbin/segmentation-intersect-segments.cc @@ -0,0 +1,145 @@ +// segmenterbin/segmentation-intersect-segments.cc + +// Copyright 2015-16 Vimal Manohar (Johns Hopkins University) + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "base/kaldi-common.h" +#include "util/common-utils.h" +#include "segmenter/segmentation-utils.h" + +namespace kaldi { +namespace segmenter { + +void IntersectSegmentationsNonOverlapping( + const Segmentation &in_segmentation, + const Segmentation &secondary_segmentation, + int32 mismatch_label, + Segmentation *out_segmentation) { + KALDI_ASSERT(out_segmentation); + KALDI_ASSERT(secondary_segmentation.Dim() > 0); + + std::vector alignment; + ConvertToAlignment(secondary_segmentation, -1, -1, 0, &alignment); + + for (SegmentList::const_iterator it = in_segmentation.Begin(); + it != in_segmentation.End(); ++it) { + if (it->end_frame >= alignment.size()) { + alignment.resize(it->end_frame + 1, -1); + } + Segmentation filter_segmentation; + InsertFromAlignment(alignment, it->start_frame, it->end_frame + 1, + 0, &filter_segmentation, NULL); + + for (SegmentList::const_iterator f_it = filter_segmentation.Begin(); + f_it != filter_segmentation.End(); ++f_it) { + int32 label = it->Label(); + if (f_it->Label() != it->Label()) { + if (mismatch_label == -1) continue; + label = mismatch_label; + } + + out_segmentation->EmplaceBack(f_it->start_frame, f_it->end_frame, + label); + } + } +} + +} +} + +int main(int argc, char *argv[]) { + try { + using namespace kaldi; + using namespace segmenter; + + const char *usage = + "Intersect segments from two archives by retaining only regions .\n" + "where the primary and secondary segments match on label\n" + "\n" + "Usage: segmentation-intersect-segments [options] " + " " + "\n" + " e.g.: segmentation-intersect-segments ark:foo.seg ark:bar.seg " + "ark,t:-\n" + "See also: segmentation-create-subsegments, " + "segmentation-intersect-ali\n"; + + int32 mismatch_label = -1; + bool assume_non_overlapping_secondary = true; + + ParseOptions po(usage); + + po.Register("mismatch-label", &mismatch_label, + "Intersect only where secondary segment has this label"); + po.Register("assume-non-overlapping-secondary", & + assume_non_overlapping_secondary, + "Assume secondary segments are non-overlapping"); + + po.Read(argc, argv); + + if (po.NumArgs() != 3) { + po.PrintUsage(); + exit(1); + } + + std::string primary_rspecifier = po.GetArg(1), + secondary_rspecifier = po.GetArg(2), + segmentation_writer = po.GetArg(3); + + if (!assume_non_overlapping_secondary) { + KALDI_ERR << "Secondary segment must be non-overlapping for now"; + } + + int64 num_done = 0, num_err = 0; + + SegmentationWriter writer(segmentation_writer); + SequentialSegmentationReader primary_reader(primary_rspecifier); + RandomAccessSegmentationReader secondary_reader(secondary_rspecifier); + + for (; !primary_reader.Done(); primary_reader.Next()) { + const Segmentation &segmentation = primary_reader.Value(); + const std::string &key = primary_reader.Key(); + + if (!secondary_reader.HasKey(key)) { + KALDI_WARN << "Could not find segmentation for key " << key + << " in " << secondary_rspecifier; + num_err++; + continue; + } + const Segmentation &secondary_segmentation = secondary_reader.Value(key); + + Segmentation out_segmentation; + IntersectSegmentationsNonOverlapping(segmentation, + secondary_segmentation, + mismatch_label, + &out_segmentation); + + Sort(&out_segmentation); + + writer.Write(key, out_segmentation); + num_done++; + } + + KALDI_LOG << "Intersected " << num_done << " segmentations; failed with " + << num_err << " segmentations"; + return (num_done != 0 ? 0 : 1); + } catch(const std::exception &e) { + std::cerr << e.what(); + return -1; + } +} + diff --git a/src/segmenterbin/segmentation-merge-recordings.cc b/src/segmenterbin/segmentation-merge-recordings.cc new file mode 100644 index 00000000000..85b5108be29 --- /dev/null +++ b/src/segmenterbin/segmentation-merge-recordings.cc @@ -0,0 +1,101 @@ +// segmenterbin/segmentation-merge-recordings.cc + +// Copyright 2016 Vimal Manohar (Johns Hopkins University) + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "base/kaldi-common.h" +#include "util/common-utils.h" +#include "segmenter/segmentation-utils.h" + +int main(int argc, char *argv[]) { + try { + using namespace kaldi; + using namespace segmenter; + + const char *usage = + "Merge segmentations of different recordings into one segmentation " + "using a mapping from new to old recording name\n" + "\n" + "Usage: segmentation-merge-recordings [options] " + " \n" + " e.g.: segmentation-merge-recordings ark:sdm2ihm_reco.map " + "ark:ihm_seg.ark ark:sdm_seg.ark\n"; + + ParseOptions po(usage); + + po.Read(argc, argv); + + if (po.NumArgs() != 3) { + po.PrintUsage(); + exit(1); + } + + std::string new2old_list_rspecifier = po.GetArg(1); + std::string segmentation_rspecifier = po.GetArg(2), + segmentation_wspecifier = po.GetArg(3); + + SequentialTokenVectorReader new2old_reader(new2old_list_rspecifier); + RandomAccessSegmentationReader segmentation_reader( + segmentation_rspecifier); + SegmentationWriter segmentation_writer(segmentation_wspecifier); + + int32 num_new_segmentations = 0, num_old_segmentations = 0; + int64 num_segments = 0, num_err = 0; + + for (; !new2old_reader.Done(); new2old_reader.Next()) { + const std::vector &old_key_list = new2old_reader.Value(); + const std::string &new_key = new2old_reader.Key(); + + KALDI_ASSERT(old_key_list.size() > 0); + + Segmentation segmentation; + + for (std::vector::const_iterator it = old_key_list.begin(); + it != old_key_list.end(); ++it) { + num_old_segmentations++; + + if (!segmentation_reader.HasKey(*it)) { + KALDI_WARN << "Could not find key " << *it << " in " + << "old segmentation " << segmentation_rspecifier; + num_err++; + continue; + } + + const Segmentation &this_segmentation = segmentation_reader.Value(*it); + + num_segments += InsertFromSegmentation(this_segmentation, 0, NULL, + &segmentation); + } + Sort(&segmentation); + + segmentation_writer.Write(new_key, segmentation); + + num_new_segmentations++; + } + + KALDI_LOG << "Merged " << num_old_segmentations << " old segmentations " + << "into " << num_new_segmentations << " new segmentations; " + << "created overall " << num_segments << " segments; " + << "failed to merge " << num_err << " old segmentations"; + + return (num_new_segmentations > 0 && num_err < num_old_segmentations / 2); + } catch(const std::exception &e) { + std::cerr << e.what(); + return -1; + } +} + diff --git a/src/segmenterbin/segmentation-merge.cc b/src/segmenterbin/segmentation-merge.cc new file mode 100644 index 00000000000..21e9a410e15 --- /dev/null +++ b/src/segmenterbin/segmentation-merge.cc @@ -0,0 +1,146 @@ +// segmenterbin/segmentation-merge.cc + +// Copyright 2015 Vimal Manohar (Johns Hopkins University) + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "base/kaldi-common.h" +#include "util/common-utils.h" +#include "segmenter/segmentation-utils.h" + +int main(int argc, char *argv[]) { + try { + using namespace kaldi; + using namespace segmenter; + + const char *usage = + "Merge corresponding segments from multiple archives or files.\n" + "i.e. for each utterance in the first segmentation, the segments " + "from all the supplied segmentations are merged and put in a single " + "segmentation." + "\n" + "Usage: segmentation-merge [options] " + " ... " + "\n" + " e.g.: segmentation-merge ark:foo.seg ark:bar.seg ark,t:-\n" + " or \n" + " segmentation-merge " + " ... " + "\n" + " e.g.: segmentation-merge --binary=false foo bar -\n" + "See also: segmentation-copy, segmentation-merge-recordings, " + "segmentation-post-process --merge-labels\n"; + + bool binary = true; + bool sort = true; + + ParseOptions po(usage); + + po.Register("binary", &binary, + "Write in binary mode " + "(only relevant if output is a wxfilename)"); + po.Register("sort", &sort, "Sort the segements after merging"); + + po.Read(argc, argv); + + if (po.NumArgs() <= 2) { + po.PrintUsage(); + exit(1); + } + + std::string segmentation_in_fn = po.GetArg(1), + segmentation_out_fn = po.GetArg(po.NumArgs()); + + // all these "fn"'s are either rspecifiers or filenames. + bool in_is_rspecifier = + (ClassifyRspecifier(segmentation_in_fn, NULL, NULL) + != kNoRspecifier), + out_is_wspecifier = + (ClassifyWspecifier(segmentation_out_fn, NULL, NULL, NULL) + != kNoWspecifier); + + if (in_is_rspecifier != out_is_wspecifier) + KALDI_ERR << "Cannot mix regular files and archives"; + + int64 num_done = 0, num_err = 0; + + if (!in_is_rspecifier) { + Segmentation segmentation; + { + bool binary_in; + Input ki(segmentation_in_fn, &binary_in); + segmentation.Read(ki.Stream(), binary_in); + } + + for (int32 i = 2; i < po.NumArgs(); i++) { + bool binary_in; + Input ki(po.GetArg(i), &binary_in); + Segmentation other_segmentation; + other_segmentation.Read(ki.Stream(), binary_in); + ExtendSegmentation(other_segmentation, false, + &segmentation); + } + + Sort(&segmentation); + + Output ko(segmentation_out_fn, binary); + segmentation.Write(ko.Stream(), binary); + + KALDI_LOG << "Merged segmentations to " << segmentation_out_fn; + return 0; + } else { + SegmentationWriter writer(segmentation_out_fn); + SequentialSegmentationReader reader(segmentation_in_fn); + std::vector other_readers( + po.NumArgs()-2, + static_cast(NULL)); + + for (size_t i = 0; i < po.NumArgs()-2; i++) { + other_readers[i] = new RandomAccessSegmentationReader(po.GetArg(i+2)); + } + + for (; !reader.Done(); reader.Next()) { + Segmentation segmentation(reader.Value()); + std::string key = reader.Key(); + + for (size_t i = 0; i < po.NumArgs()-2; i++) { + if (!other_readers[i]->HasKey(key)) { + KALDI_WARN << "Could not find segmentation for key " << key + << " in " << po.GetArg(i+2); + num_err++; + } + const Segmentation &other_segmentation = + other_readers[i]->Value(key); + ExtendSegmentation(other_segmentation, false, + &segmentation); + } + + Sort(&segmentation); + + writer.Write(key, segmentation); + num_done++; + } + + KALDI_LOG << "Merged " << num_done << " segmentation; failed with " + << num_err << " segmentations"; + return (num_done != 0 ? 0 : 1); + } + } catch(const std::exception &e) { + std::cerr << e.what(); + return -1; + } +} + diff --git a/src/segmenterbin/segmentation-post-process.cc b/src/segmenterbin/segmentation-post-process.cc new file mode 100644 index 00000000000..921ee5dc5d8 --- /dev/null +++ b/src/segmenterbin/segmentation-post-process.cc @@ -0,0 +1,142 @@ +// segmenterbin/segmentation-post-process.cc + +// Copyright 2015-16 Vimal Manohar (Johns Hopkins University) + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "base/kaldi-common.h" +#include "util/common-utils.h" +#include "segmenter/segmentation-post-processor.h" +#include "segmenter/segmentation-utils.h" + +int main(int argc, char *argv[]) { + try { + using namespace kaldi; + using namespace segmenter; + + const char *usage = + "Post processing of segmentation that does the following operations " + "in order: \n" + "1) Merge labels: Merge labels specified in --merge-labels into a " + "single label specified by --merge-dst-label. \n" + "2) Padding segments: Pad segments of label specified by --pad-label " + "by a few frames as specified by --pad-length. \n" + "3) Shrink segments: Shrink segments of label specified by " + "--shrink-label by a few frames as specified by --shrink-length. \n" + "4) Blend segments with neighbors: Blend short segments of class-id " + "specified by --blend-short-segments-class that are " + "shorter than --max-blend-length frames with their " + "respective neighbors if both the neighbors are within " + "a distance of --max-intersegment-length frames.\n" + "5) Remove segments: Remove segments of class-ids contained " + "in --remove-labels.\n" + "6) Merge adjacent segments: Merge adjacent segments of the same " + "label if they are within a distance of --max-intersegment-length " + "frames.\n" + "7) Split segments: Split segments that are longer than " + "--max-segment-length frames into overlapping segments " + "with an overlap of --overlap-length frames. \n" + "Usage: segmentation-post-process [options] " + "\n" + " or : segmentation-post-process [options] " + "\n" + " e.g.: segmentation-post-process --binary=false foo -\n" + " segmentation-post-process ark:foo.seg ark,t:-\n" + "See also: segmentation-merge, segmentation-copy, " + "segmentation-remove-segments\n"; + + bool binary = true; + + ParseOptions po(usage); + + SegmentationPostProcessingOptions opts; + + po.Register("binary", &binary, + "Write in binary mode " + "(only relevant if output is a wxfilename)"); + + opts.Register(&po); + + po.Read(argc, argv); + if (po.NumArgs() != 2) { + po.PrintUsage(); + exit(1); + } + + SegmentationPostProcessor post_processor(opts); + + std::string segmentation_in_fn = po.GetArg(1), + segmentation_out_fn = po.GetArg(2); + + bool in_is_rspecifier = + (ClassifyRspecifier(segmentation_in_fn, NULL, NULL) + != kNoRspecifier), + out_is_wspecifier = + (ClassifyWspecifier(segmentation_out_fn, NULL, NULL, NULL) + != kNoWspecifier); + + if (in_is_rspecifier != out_is_wspecifier) + KALDI_ERR << "Cannot mix regular files and archives"; + + int64 num_done = 0, num_err = 0; + + if (!in_is_rspecifier) { + Segmentation segmentation; + { + bool binary_in; + Input ki(segmentation_in_fn, &binary_in); + segmentation.Read(ki.Stream(), binary_in); + } + if (post_processor.PostProcess(&segmentation)) { + Output ko(segmentation_out_fn, binary); + Sort(&segmentation); + segmentation.Write(ko.Stream(), binary); + KALDI_LOG << "Post-processed segmentation " << segmentation_in_fn + << " and wrote " << segmentation_out_fn; + return 0; + } + KALDI_LOG << "Failed post-processing segmentation " + << segmentation_in_fn; + return 1; + } + + SegmentationWriter writer(segmentation_out_fn); + SequentialSegmentationReader reader(segmentation_in_fn); + for (; !reader.Done(); reader.Next()) { + Segmentation segmentation(reader.Value()); + const std::string &key = reader.Key(); + + if (!post_processor.PostProcess(&segmentation)) { + num_err++; + continue; + } + + Sort(&segmentation); + + writer.Write(key, segmentation); + num_done++; + } + + KALDI_LOG << "Successfully post-processed " << num_done + << " segmentations; " + << "failed with " << num_err << " segmentations"; + return (num_done != 0 ? 0 : 1); + } catch(const std::exception &e) { + std::cerr << e.what(); + return -1; + } +} + diff --git a/src/segmenterbin/segmentation-remove-segments.cc b/src/segmenterbin/segmentation-remove-segments.cc new file mode 100644 index 00000000000..ce3ef2de6fd --- /dev/null +++ b/src/segmenterbin/segmentation-remove-segments.cc @@ -0,0 +1,155 @@ +// segmenterbin/segmentation-remove-segments.cc + +// Copyright 2015-16 Vimal Manohar (Johns Hopkins University) + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "base/kaldi-common.h" +#include "util/common-utils.h" +#include "segmenter/segmentation-utils.h" + +int main(int argc, char *argv[]) { + try { + using namespace kaldi; + using namespace segmenter; + + const char *usage = + "Remove segments of particular class_id (e.g silence or noise) " + "or a set of class_ids.\n" + "The labels to removed can be made utterance-specific by passing " + "--remove-labels-rspecifier option.\n" + "\n" + "Usage: segmentation-remove-segments [options] " + " \n" + " or : segmentation-remove-segments [options] " + " \n" + "\n" + " e.g.: segmentation-remove-segments --remove-label=0 ark:foo.ark " + "ark:foo.speech.ark\n" + "See also: segmentation-post-process --remove-labels, " + "segmentation-post-process --max-blend-length, segmentation-copy\n"; + + bool binary = true; + + int32 remove_label = -1; + std::string remove_labels_rspecifier = ""; + + ParseOptions po(usage); + + po.Register("binary", &binary, + "Write in binary mode " + "(only relevant if output is a wxfilename)"); + po.Register("remove-label", &remove_label, "Remove segments of this label"); + po.Register("remove-labels-rspecifier", &remove_labels_rspecifier, + "Specify colon separated list of labels for each key"); + + po.Read(argc, argv); + + if (po.NumArgs() != 2) { + po.PrintUsage(); + exit(1); + } + + std::string segmentation_in_fn = po.GetArg(1), + segmentation_out_fn = po.GetArg(2); + + // all these "fn"'s are either rspecifiers or filenames. + + bool in_is_rspecifier = + (ClassifyRspecifier(segmentation_in_fn, NULL, NULL) + != kNoRspecifier), + out_is_wspecifier = + (ClassifyWspecifier(segmentation_out_fn, NULL, NULL, NULL) + != kNoWspecifier); + + if (in_is_rspecifier != out_is_wspecifier) + KALDI_ERR << "Cannot mix regular files and archives"; + + int64 num_done = 0, num_missing = 0; + + if (!in_is_rspecifier) { + Segmentation segmentation; + { + bool binary_in; + Input ki(segmentation_in_fn, &binary_in); + segmentation.Read(ki.Stream(), binary_in); + } + if (!remove_labels_rspecifier.empty()) { + KALDI_ERR << "It does not make sense to specify " + << "--remove-labels-rspecifier " + << "for single segmentation"; + } + + RemoveSegments(remove_label, &segmentation); + + { + Output ko(segmentation_out_fn, binary); + segmentation.Write(ko.Stream(), binary); + } + + KALDI_LOG << "Removed segments and wrote segmentation to " + << segmentation_out_fn; + + return 0; + } else { + SegmentationWriter writer(segmentation_out_fn); + SequentialSegmentationReader reader(segmentation_in_fn); + + RandomAccessTokenReader remove_labels_reader(remove_labels_rspecifier); + + for (; !reader.Done(); reader.Next(), num_done++) { + Segmentation segmentation(reader.Value()); + std::string key = reader.Key(); + + if (!remove_labels_rspecifier.empty()) { + if (!remove_labels_reader.HasKey(key)) { + KALDI_WARN << "No remove-labels found for recording " << key; + num_missing++; + writer.Write(key, segmentation); + continue; + } + + std::vector remove_labels; + const std::string& remove_labels_str = + remove_labels_reader.Value(key); + + if (!SplitStringToIntegers(remove_labels_str, ":,", false, + &remove_labels)) { + KALDI_ERR << "Bad colon-separated list " + << remove_labels_str << " for key " << key + << " in " << remove_labels_rspecifier; + } + + remove_label = remove_labels[0]; + + RemoveSegments(remove_labels, &segmentation); + } else { + RemoveSegments(remove_label, &segmentation); + } + writer.Write(key, segmentation); + } + + KALDI_LOG << "Removed segments " << "from " << num_done + << " segmentations; " + << "remove-labels list missing for " << num_missing; + return (num_done != 0 ? 0 : 1); + } + } catch(const std::exception &e) { + std::cerr << e.what(); + return -1; + } +} + diff --git a/src/segmenterbin/segmentation-split-segments.cc b/src/segmenterbin/segmentation-split-segments.cc new file mode 100644 index 00000000000..a45211b28ca --- /dev/null +++ b/src/segmenterbin/segmentation-split-segments.cc @@ -0,0 +1,194 @@ +// segmenterbin/segmentation-split-segments.cc + +// Copyright 2016 Vimal Manohar (Johns Hopkins University) + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "base/kaldi-common.h" +#include "util/common-utils.h" +#include "segmenter/segmentation-utils.h" + +int main(int argc, char *argv[]) { + try { + using namespace kaldi; + using namespace segmenter; + + const char *usage = + "Split long segments optionally using alignment.\n" + "The splitting works in two possible ways:\n" + " 1) If alignment is not provided: The segments are split if they\n" + " are longer than --max-segment-length frames into overlapping\n" + " segments with an overlap of --overlap-length frames.\n" + " 2) If alignment is provided: The segments are split if they\n" + " are longer than --max-segment-length frames at the region \n" + " where there is a contiguous segment of --ali-label in the \n" + " alignment that is at least --min-alignment-chunk-length frames \n" + " long.\n" + "Usage: segmentation-split-segments [options] " + " \n" + " or : segmentation-split-segments [options] " + " \n" + " e.g.: segmentation-split-segments --binary=false foo -\n" + " segmentation-split-segments ark:foo.seg ark,t:-\n" + "See also: segmentation-post-process\n"; + + bool binary = true; + int32 max_segment_length = -1; + int32 min_remainder = -1; + int32 overlap_length = 0; + int32 split_label = -1; + int32 ali_label = 0; + int32 min_alignment_chunk_length = 2; + + std::string alignments_in_fn; + + ParseOptions po(usage); + + po.Register("binary", &binary, + "Write in binary mode " + "(only relevant if output is a wxfilename)"); + po.Register("max-segment-length", &max_segment_length, + "If segment is longer than this length, split it into " + "pieces with less than these many frames. " + "Refer to the SplitSegments() code for details. " + "Used in conjunction with the option --overlap-length."); + po.Register("min-remainder", &min_remainder, + "The minimum remainder left after splitting that will " + "prevent a splitting from begin done. " + "Set to max-segment-length / 2, if not specified. " + "Applicable only when alignments is not specified."); + po.Register("overlap-length", &overlap_length, + "When splitting segments longer than max-segment-length, " + "have the pieces overlap by these many frames. " + "Refer to the SplitSegments() code for details. " + "Used in conjunction with the option --max-segment-length."); + po.Register("split-label", &split_label, + "If supplied, split only segments of these labels. " + "Otherwise, split all segments."); + po.Register("alignments", &alignments_in_fn, + "A single alignment file or archive of alignment used " + "for splitting, " + "depending on whether the input segmentation is single file " + "or archive"); + po.Register("ali-label", &ali_label, + "Split at this label of alignments"); + po.Register("min-alignment-chunk-length", &min_alignment_chunk_length, + "The minimum number of frames of alignment with ali_label " + "at which to split the segments"); + + po.Read(argc, argv); + if (po.NumArgs() != 2) { + po.PrintUsage(); + exit(1); + } + + std::string segmentation_in_fn = po.GetArg(1), + segmentation_out_fn = po.GetArg(2); + + bool in_is_rspecifier = + (ClassifyRspecifier(segmentation_in_fn, NULL, NULL) + != kNoRspecifier), + out_is_wspecifier = + (ClassifyWspecifier(segmentation_out_fn, NULL, NULL, NULL) + != kNoWspecifier); + + if (in_is_rspecifier != out_is_wspecifier) + KALDI_ERR << "Cannot mix regular files and archives"; + + if (min_remainder == -1) { + min_remainder = max_segment_length / 2; + } + + int64 num_done = 0, num_err = 0; + + if (!in_is_rspecifier) { + std::vector ali; + + Segmentation segmentation; + { + bool binary_in; + Input ki(segmentation_in_fn, &binary_in); + segmentation.Read(ki.Stream(), binary_in); + } + + if (!alignments_in_fn.empty()) { + { + bool binary_in; + Input ki(alignments_in_fn, &binary_in); + ReadIntegerVector(ki.Stream(), binary_in, &ali); + } + SplitSegmentsUsingAlignment(max_segment_length, + split_label, ali, ali_label, + min_alignment_chunk_length, + &segmentation); + } else { + SplitSegments(max_segment_length, min_remainder, + overlap_length, split_label, &segmentation); + } + + Sort(&segmentation); + + { + Output ko(segmentation_out_fn, binary); + segmentation.Write(ko.Stream(), binary); + } + + KALDI_LOG << "Split segmentation " << segmentation_in_fn + << " and wrote " << segmentation_out_fn; + return 0; + } + + SegmentationWriter writer(segmentation_out_fn); + SequentialSegmentationReader reader(segmentation_in_fn); + RandomAccessInt32VectorReader ali_reader(alignments_in_fn); + + for (; !reader.Done(); reader.Next()) { + Segmentation segmentation(reader.Value()); + const std::string &key = reader.Key(); + + if (!alignments_in_fn.empty()) { + if (!ali_reader.HasKey(key)) { + KALDI_WARN << "Could not find key " << key + << " in alignments " << alignments_in_fn; + num_err++; + continue; + } + SplitSegmentsUsingAlignment(max_segment_length, split_label, + ali_reader.Value(key), ali_label, + min_alignment_chunk_length, + &segmentation); + } else { + SplitSegments(max_segment_length, min_remainder, + overlap_length, split_label, + &segmentation); + } + + Sort(&segmentation); + + writer.Write(key, segmentation); + num_done++; + } + + KALDI_LOG << "Successfully split " << num_done + << " segmentations; " + << "failed with " << num_err << " segmentations"; + return (num_done != 0 ? 0 : 1); + } catch(const std::exception &e) { + std::cerr << e.what(); + return -1; + } +} + diff --git a/src/segmenterbin/segmentation-to-ali.cc b/src/segmenterbin/segmentation-to-ali.cc new file mode 100644 index 00000000000..9a618247a42 --- /dev/null +++ b/src/segmenterbin/segmentation-to-ali.cc @@ -0,0 +1,99 @@ +// segmenterbin/segmentation-to-ali.cc + +// Copyright 2015-16 Vimal Manohar (Johns Hopkins University) + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "base/kaldi-common.h" +#include "util/common-utils.h" +#include "segmenter/segmentation-utils.h" + +int main(int argc, char *argv[]) { + try { + using namespace kaldi; + using namespace segmenter; + + const char *usage = + "Convert segmentation to alignment\n" + "\n" + "Usage: segmentation-to-ali [options] " + "\n" + " e.g.: segmentation-to-ali ark:1.seg ark:1.ali\n"; + + std::string lengths_rspecifier; + int32 default_label = 0, length_tolerance = 2; + + ParseOptions po(usage); + + po.Register("lengths-rspecifier", &lengths_rspecifier, + "Archive of frame lengths " + "of the utterances. Fills up any extra length with " + "the specified default-label"); + po.Register("default-label", &default_label, "Fill any extra length " + "with this label"); + po.Register("length-tolerance", &length_tolerance, "Tolerate shortage of " + "this many frames in the specified lengths file"); + + po.Read(argc, argv); + + if (po.NumArgs() != 2) { + po.PrintUsage(); + exit(1); + } + + std::string segmentation_rspecifier = po.GetArg(1); + std::string alignment_wspecifier = po.GetArg(2); + + RandomAccessInt32Reader lengths_reader(lengths_rspecifier); + + SequentialSegmentationReader segmentation_reader(segmentation_rspecifier); + Int32VectorWriter alignment_writer(alignment_wspecifier); + + int32 num_err = 0, num_done = 0; + for (; !segmentation_reader.Done(); segmentation_reader.Next()) { + const Segmentation &segmentation = segmentation_reader.Value(); + const std::string &key = segmentation_reader.Key(); + + int32 length = -1; + if (lengths_rspecifier != "") { + if (!lengths_reader.HasKey(key)) { + KALDI_WARN << "Could not find length for utterance " << key; + num_err++; + continue; + } + length = lengths_reader.Value(key); + } + + std::vector ali; + if (!ConvertToAlignment(segmentation, default_label, length, + length_tolerance, &ali)) { + KALDI_WARN << "Conversion failed for utterance " << key; + num_err++; + continue; + } + alignment_writer.Write(key, ali); + num_done++; + } + + KALDI_LOG << "Converted " << num_done << " segmentations into alignments; " + << "failed with " << num_err << " segmentations"; + return (num_done != 0 ? 0 : 1); + } catch(const std::exception &e) { + std::cerr << e.what(); + return -1; + } +} + diff --git a/src/segmenterbin/segmentation-to-rttm.cc b/src/segmenterbin/segmentation-to-rttm.cc new file mode 100644 index 00000000000..6ffd1a8b1e8 --- /dev/null +++ b/src/segmenterbin/segmentation-to-rttm.cc @@ -0,0 +1,255 @@ +// segmenterbin/segmentation-to-rttm.cc + +// Copyright 2015-16 Vimal Manohar + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "base/kaldi-common.h" +#include "util/common-utils.h" +#include "segmenter/segmentation.h" + +namespace kaldi { +namespace segmenter { + +/** + * This function is used to write the segmentation in RTTM format. Each class is + * treated as a "SPEAKER". If map_to_speech_and_sil is true, then the class_id 0 + * is treated as SILENCE and every other class_id as SPEECH. The argument + * start_time is used to set what the time corresponding to the 0 frame in the + * segment. Each segment is converted into the following line, + * SPEAKER 1 + * ,where + * is the file_id supplied as an argument + * is the start time of the segment in seconds + * is the length of the segment in seconds + * is the class_id stored in the segment. If map_to_speech_and_sil is + * set true then is either SPEECH or SILENCE. + * The function retunns the largest class_id that it encounters. +**/ + +int32 WriteRttm(const Segmentation &segmentation, + std::ostream &os, const std::string &file_id, + const std::string &channel, + BaseFloat frame_shift, BaseFloat start_time, + bool map_to_speech_and_sil) { + SegmentList::const_iterator it = segmentation.Begin(); + int32 largest_class = 0; + for (; it != segmentation.End(); ++it) { + os << "SPEAKER " << file_id << " " << channel << " " + << it->start_frame * frame_shift + start_time << " " + << (it->Length()) * frame_shift << " "; + if (map_to_speech_and_sil) { + switch (it->Label()) { + case 1: + os << "SPEECH "; + break; + default: + os << "SILENCE "; + break; + } + largest_class = 1; + } else { + if (it->Label() >= 0) { + os << it->Label() << " "; + if (it->Label() > largest_class) + largest_class = it->Label(); + } + } + os << "" << std::endl; + } + return largest_class; +} + +} +} + +int main(int argc, char *argv[]) { + try { + using namespace kaldi; + using namespace segmenter; + + const char *usage = + "Convert segmentation into RTTM\n" + "\n" + "Usage: segmentation-to-rttm [options] \n" + " e.g.: segmentation-to-rttm ark:1.seg -\n"; + + bool map_to_speech_and_sil = true; + + BaseFloat frame_shift = 0.01; + std::string segments_rxfilename; + std::string reco2file_and_channel_rxfilename; + ParseOptions po(usage); + + po.Register("frame-shift", &frame_shift, "Frame shift in seconds"); + po.Register("segments", &segments_rxfilename, "Segments file"); + po.Register("reco2file-and-channel", &reco2file_and_channel_rxfilename, "reco2file_and_channel file"); + po.Register("map-to-speech-and-sil", &map_to_speech_and_sil, "Map all classes to SPEECH and SILENCE"); + + po.Read(argc, argv); + + if (po.NumArgs() != 2) { + po.PrintUsage(); + exit(1); + } + + unordered_map utt2file; + unordered_map utt2start_time; + + if (!segments_rxfilename.empty()) { + Input ki(segments_rxfilename); // no binary argment: never binary. + int32 i = 0; + std::string line; + /* read each line from segments file */ + while (std::getline(ki.Stream(), line)) { + std::vector split_line; + // Split the line by space or tab and check the number of fields in each + // line. There must be 4 fields--segment name , reacording wav file name, + // start time, end time; 5th field (channel info) is optional. + SplitStringToVector(line, " \t\r", true, &split_line); + if (split_line.size() != 4 && split_line.size() != 5) { + KALDI_WARN << "Invalid line in segments file: " << line; + continue; + } + std::string segment = split_line[0], + utterance = split_line[1], + start_str = split_line[2], + end_str = split_line[3]; + + // Convert the start time and endtime to real from string. Segment is + // ignored if start or end time cannot be converted to real. + double start, end; + if (!ConvertStringToReal(start_str, &start)) { + KALDI_WARN << "Invalid line in segments file [bad start]: " << line; + continue; + } + if (!ConvertStringToReal(end_str, &end)) { + KALDI_WARN << "Invalid line in segments file [bad end]: " << line; + continue; + } + // start time must not be negative; start time must not be greater than + // end time, except if end time is -1 + if (start < 0 || end <= 0 || start >= end) { + KALDI_WARN << "Invalid line in segments file [empty or invalid segment]: " + << line; + continue; + } + int32 channel = -1; // means channel info is unspecified. + // if each line has 5 elements then 5th element must be channel identifier + if(split_line.size() == 5) { + if (!ConvertStringToInteger(split_line[4], &channel) || channel < 0) { + KALDI_WARN << "Invalid line in segments file [bad channel]: " << line; + continue; + } + } + + utt2file.insert(std::make_pair(segment, utterance)); + utt2start_time.insert(std::make_pair(segment, start)); + i++; + } + KALDI_LOG << "Read " << i << " lines from " << segments_rxfilename; + } + + unordered_map , StringHasher> reco2file_and_channel; + + if (!reco2file_and_channel_rxfilename.empty()) { + Input ki(reco2file_and_channel_rxfilename); // no binary argment: never binary. + + int32 i = 0; + std::string line; + /* read each line from reco2file_and_channel file */ + while (std::getline(ki.Stream(), line)) { + std::vector split_line; + SplitStringToVector(line, " \t\r", true, &split_line); + if (split_line.size() != 3) { + KALDI_WARN << "Invalid line in reco2file_and_channel file: " << line; + continue; + } + + const std::string &reco_id = split_line[0]; + const std::string &file_id = split_line[1]; + const std::string &channel = split_line[2]; + + reco2file_and_channel.insert(std::make_pair(reco_id, std::make_pair(file_id, channel))); + i++; + } + + KALDI_LOG << "Read " << i << " lines from " << reco2file_and_channel_rxfilename; + } + + unordered_set seen_files; + + std::string segmentation_rspecifier = po.GetArg(1), + rttm_out_wxfilename = po.GetArg(2); + + int64 num_done = 0, num_err = 0; + + Output ko(rttm_out_wxfilename, false); + SequentialSegmentationReader reader(segmentation_rspecifier); + for (; !reader.Done(); reader.Next(), num_done++) { + Segmentation segmentation(reader.Value()); + const std::string &key = reader.Key(); + + std::string reco_id = key; + BaseFloat start_time = 0.0; + if (!segments_rxfilename.empty()) { + if (utt2file.count(key) == 0 || utt2start_time.count(key) == 0) + KALDI_ERR << "Could not find key " << key << " in segments " + << segments_rxfilename; + KALDI_ASSERT(utt2file.count(key) > 0 && utt2start_time.count(key) > 0); + reco_id = utt2file[key]; + start_time = utt2start_time[key]; + } + + std::string file_id, channel; + if (!reco2file_and_channel_rxfilename.empty()) { + if (reco2file_and_channel.count(reco_id) == 0) + KALDI_ERR << "Could not find recording " << reco_id + << " in " << reco2file_and_channel_rxfilename; + file_id = reco2file_and_channel[reco_id].first; + channel = reco2file_and_channel[reco_id].second; + } else { + file_id = reco_id; + channel = "1"; + } + + int32 largest_class = WriteRttm(segmentation, ko.Stream(), file_id, channel, frame_shift, start_time, map_to_speech_and_sil); + + if (map_to_speech_and_sil) { + if (seen_files.count(reco_id) == 0) { + ko.Stream() << "SPKR-INFO " << file_id << " " << channel << " unknown SILENCE \n"; + ko.Stream() << "SPKR-INFO " << file_id << " " << channel << " unknown SPEECH \n"; + seen_files.insert(reco_id); + } + } else { + for (int32 i = 0; i < largest_class; i++) { + ko.Stream() << "SPKR-INFO " << file_id << " " << channel << " unknown " << i << " \n"; + } + } + } + + KALDI_LOG << "Copied " << num_done << " segmentation; failed with " + << num_err << " segmentations"; + return (num_done != 0 ? 0 : 1); + } catch(const std::exception &e) { + std::cerr << e.what(); + return -1; + } +} + + + + diff --git a/src/segmenterbin/segmentation-to-segments.cc b/src/segmenterbin/segmentation-to-segments.cc new file mode 100644 index 00000000000..c57aa827ead --- /dev/null +++ b/src/segmenterbin/segmentation-to-segments.cc @@ -0,0 +1,133 @@ +// segmenterbin/segmentation-to-segments.cc + +// Copyright 2015-16 Vimal Manohar (Johns Hopkins University) + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include "base/kaldi-common.h" +#include "util/common-utils.h" +#include "segmenter/segmentation.h" + +int main(int argc, char *argv[]) { + try { + using namespace kaldi; + using namespace segmenter; + + const char *usage = + "Convert segmentation to a segments file and utt2spk file." + "Assumes that the input segmentations are indexed by reco-id and " + "treats speakers from different recording as distinct speakers." + "\n" + "Usage: segmentation-to-segments [options] " + " \n" + " e.g.: segmentation-to-segments ark:foo.seg ark,t:utt2spk segments\n"; + + BaseFloat frame_shift = 0.01, frame_overlap = 0.015; + bool single_speaker = false, per_utt_speaker = false; + ParseOptions po(usage); + + po.Register("frame-shift", &frame_shift, "Frame shift in seconds"); + po.Register("frame-overlap", &frame_overlap, "Frame overlap in seconds"); + po.Register("single-speaker", &single_speaker, "If this is set true, " + "then all the utterances in a recording are mapped to the " + "same speaker"); + po.Register("per-utt-speaker", &per_utt_speaker, + "If this is set true, then each utterance is mapped to distint " + "speaker with spkr_id = utt_id"); + + po.Read(argc, argv); + + if (po.NumArgs() != 3) { + po.PrintUsage(); + exit(1); + } + + if (frame_shift < 0.001 || frame_shift > 1) { + KALDI_ERR << "Invalid frame-shift " << frame_shift << "; must be in " + << "the range [0.001,1]"; + } + + if (frame_overlap < 0 || frame_overlap > 1) { + KALDI_ERR << "Invalid frame-overlap " << frame_overlap << "; must be in " + << "the range [0,1]"; + } + + std::string segmentation_rspecifier = po.GetArg(1), + utt2spk_wspecifier = po.GetArg(2), + segments_wxfilename = po.GetArg(3); + + SequentialSegmentationReader reader(segmentation_rspecifier); + TokenWriter utt2spk_writer(utt2spk_wspecifier); + + Output ko(segments_wxfilename, false); + + int32 num_done = 0; + int64 num_segments = 0; + + for (; !reader.Done(); reader.Next(), num_done++) { + const Segmentation &segmentation = reader.Value(); + const std::string &key = reader.Key(); + + for (SegmentList::const_iterator it = segmentation.Begin(); + it != segmentation.End(); ++it) { + BaseFloat start_time = it->start_frame * frame_shift; + BaseFloat end_time = (it->end_frame + 1) * frame_shift + frame_overlap; + + std::ostringstream oss; + + if (!single_speaker) { + oss << key << "-" << it->Label(); + } else { + oss << key; + } + + std::string spk = oss.str(); + + oss << "-"; + oss << std::setw(6) << std::setfill('0') << it->start_frame; + oss << std::setw(1) << "-"; + oss << std::setw(6) << std::setfill('0') + << it->end_frame + 1 + + static_cast(frame_overlap / frame_shift); + + std::string utt = oss.str(); + + if (per_utt_speaker) + utt2spk_writer.Write(utt, utt); + else + utt2spk_writer.Write(utt, spk); + + ko.Stream() << utt << " " << key << " "; + ko.Stream() << std::fixed << std::setprecision(3) << start_time << " "; + ko.Stream() << std::setprecision(3) << end_time << "\n"; + + num_segments++; + } + } + + KALDI_LOG << "Converted " << num_done << " segmentations to segments; " + << "wrote " << num_segments << " segments"; + + return (num_done != 0 ? 0 : 1); + } catch(const std::exception &e) { + std::cerr << e.what(); + return -1; + } +} + diff --git a/tools/config/common_path.sh b/tools/config/common_path.sh index 3e2ea50d685..36b5350dd8e 100644 --- a/tools/config/common_path.sh +++ b/tools/config/common_path.sh @@ -20,4 +20,5 @@ ${KALDI_ROOT}/src/online2bin:\ ${KALDI_ROOT}/src/onlinebin:\ ${KALDI_ROOT}/src/sgmm2bin:\ ${KALDI_ROOT}/src/sgmmbin:\ +${KALDI_ROOT}/src/segmenterbin:\ $PATH From 8c11f77f8a2af11c69e142992b6fa80b3dbc845d Mon Sep 17 00:00:00 2001 From: Vimal Manohar Date: Tue, 29 Nov 2016 23:08:33 -0500 Subject: [PATCH 040/213] asr_diarization: Adding do_corruption_data_dir.sh for corruption with MUSAN noise --- .../segmentation/do_corruption_data_dir.sh | 140 ++++++++++++++++++ 1 file changed, 140 insertions(+) create mode 100755 egs/aspire/s5/local/segmentation/do_corruption_data_dir.sh diff --git a/egs/aspire/s5/local/segmentation/do_corruption_data_dir.sh b/egs/aspire/s5/local/segmentation/do_corruption_data_dir.sh new file mode 100755 index 00000000000..36bf4c93306 --- /dev/null +++ b/egs/aspire/s5/local/segmentation/do_corruption_data_dir.sh @@ -0,0 +1,140 @@ +#! /bin/bash + +# Copyright 2016 Vimal Manohar +# Apache 2.0 + +set -e +set -u +set -o pipefail + +. path.sh + +stage=0 +corruption_stage=-10 +corrupt_only=false + +# Data options +data_dir=data/train_si284 # Expecting whole data directory. +speed_perturb=true +num_data_reps=5 # Number of corrupted versions +snrs="20:10:15:5:0:-5" +foreground_snrs="20:10:15:5:0:-5" +background_snrs="20:10:15:5:0:-5" +base_rirs=simulated + +# Parallel options +reco_nj=40 +cmd=queue.pl + +# Options for feature extraction +mfcc_config=conf/mfcc_hires_bp_vh.conf +feat_suffix=hires_bp_vh + +reco_vad_dir= # Output of prepare_unsad_data.sh. + # If provided, the speech labels and deriv weights will be + # copied into the output data directory. + +. utils/parse_options.sh + +if [ $# -ne 0 ]; then + echo "Usage: $0" + exit 1 +fi + +data_id=`basename ${data_dir}` + +rvb_opts=() +if [ "$base_rirs" == "simulated" ]; then + # This is the config for the system using simulated RIRs and point-source noises + rvb_opts+=(--rir-set-parameters "0.5, RIRS_NOISES/simulated_rirs/smallroom/rir_list") + rvb_opts+=(--rir-set-parameters "0.5, RIRS_NOISES/simulated_rirs/mediumroom/rir_list") + rvb_opts+=(--noise-set-parameters RIRS_NOISES/pointsource_noises/noise_list) +else + # This is the config for the JHU ASpIRE submission system + rvb_opts+=(--rir-set-parameters "1.0, RIRS_NOISES/real_rirs_isotropic_noises/rir_list") + rvb_opts+=(--noise-set-parameters RIRS_NOISES/real_rirs_isotropic_noises/noise_list) +fi + +corrupted_data_id=${data_id}_corrupted + +if [ $stage -le 1 ]; then + python steps/data/reverberate_data_dir.py \ + "${rvb_opts[@]}" \ + --prefix="rev" \ + --foreground-snrs=$foreground_snrs \ + --background-snrs=$background_snrs \ + --speech-rvb-probability=1 \ + --pointsource-noise-addition-probability=1 \ + --isotropic-noise-addition-probability=1 \ + --num-replications=$num_data_reps \ + --max-noises-per-minute=1 \ + data/${data_id} data/${corrupted_data_id} +fi + +corrupted_data_dir=data/${corrupted_data_id} + +if $speed_perturb; then + if [ $stage -le 2 ]; then + ## Assuming whole data directories + for x in $clean_data_dir $corrupted_data_dir $noise_data_dir; do + cp $x/reco2dur $x/utt2dur + utils/data/perturb_data_dir_speed_3way.sh $x ${x}_sp + done + fi + + corrupted_data_dir=${corrupted_data_dir}_sp + corrupted_data_id=${corrupted_data_id}_sp + + if [ $stage -le 3 ]; then + utils/data/perturb_data_dir_volume.sh --scale-low 0.03125 --scale-high 2 \ + ${corrupted_data_dir} + fi +fi + +if $corrupt_only; then + echo "$0: Got corrupted data directory in ${corrupted_data_dir}" + exit 0 +fi + +mfccdir=`basename $mfcc_config` +mfccdir=${mfccdir%%.conf} + +if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $mfccdir/storage ]; then + utils/create_split_dir.pl \ + /export/b0{3,4,5,6}/$USER/kaldi-data/egs/aspire-$(date +'%m_%d_%H_%M')/s5/$mfccdir/storage $mfccdir/storage +fi + +if [ $stage -le 4 ]; then + if [ ! -z $feat_suffix ]; then + utils/copy_data_dir.sh $corrupted_data_dir ${corrupted_data_dir}_$feat_suffix + corrupted_data_dir=${corrupted_data_dir}_$feat_suffix + fi + steps/make_mfcc.sh --mfcc-config $mfcc_config \ + --cmd "$cmd" --nj $reco_nj \ + $corrupted_data_dir exp/make_${mfccdir}/${corrupted_data_id} $mfccdir + steps/compute_cmvn_stats.sh --fake \ + $corrupted_data_dir exp/make_${mfccdir}/${corrupted_data_id} $mfccdir +else + if [ ! -z $feat_suffix ]; then + corrupted_data_dir=${corrupted_data_dir}_$feat_suffix + fi +fi + +if [ $stage -le 8 ]; then + if [ ! -z "$reco_vad_dir" ]; then + if [ ! -f $reco_vad_dir/speech_feat.scp ]; then + echo "$0: Could not find file $reco_vad_dir/speech_feat.scp" + exit 1 + fi + + cat $reco_vad_dir/speech_feat.scp | \ + steps/segmentation/get_reverb_scp.pl -f 1 $num_data_reps | \ + sort -k1,1 > ${corrupted_data_dir}/speech_feat.scp + + cat $reco_vad_dir/deriv_weights.scp | \ + steps/segmentation/get_reverb_scp.pl -f 1 $num_data_reps | \ + sort -k1,1 > ${corrupted_data_dir}/deriv_weights.scp + fi +fi + +exit 0 From 5fccac18df0d48db0d5611f8edbf7708f35ac67f Mon Sep 17 00:00:00 2001 From: Vimal Manohar Date: Wed, 30 Nov 2016 00:38:56 -0500 Subject: [PATCH 041/213] asr_diarization: Add do_corruption_data_dir_music.sh for corruption with MUSAN music --- .../do_corruption_data_dir_music.sh | 203 ++++++++++++++++++ .../s5/local/segmentation/make_musan_music.py | 69 ++++++ .../segmentation/train_stats_sad_music.sh | 172 +++++++++++++++ 3 files changed, 444 insertions(+) create mode 100755 egs/aspire/s5/local/segmentation/do_corruption_data_dir_music.sh create mode 100755 egs/aspire/s5/local/segmentation/make_musan_music.py create mode 100644 egs/aspire/s5/local/segmentation/train_stats_sad_music.sh diff --git a/egs/aspire/s5/local/segmentation/do_corruption_data_dir_music.sh b/egs/aspire/s5/local/segmentation/do_corruption_data_dir_music.sh new file mode 100755 index 00000000000..214cba347da --- /dev/null +++ b/egs/aspire/s5/local/segmentation/do_corruption_data_dir_music.sh @@ -0,0 +1,203 @@ +#!/bin/bash +set -e +set -u +set -o pipefail + +. path.sh +. cmd.sh + +num_data_reps=5 +data_dir=data/train_si284 + +nj=40 +reco_nj=40 + +stage=0 +corruption_stage=-10 + +pad_silence=false + +mfcc_config=conf/mfcc_hires_bp_vh.conf +feat_suffix=hires_bp_vh +mfcc_irm_config=conf/mfcc_hires_bp.conf + +dry_run=false +corrupt_only=false +speed_perturb=true + +reco_vad_dir= + +max_jobs_run=20 + +foreground_snrs="5:2:1:0:-2:-5:-10:-20" +background_snrs="5:2:1:0:-2:-5:-10:-20" + +. utils/parse_options.sh + +if [ $# -ne 0 ]; then + echo "Usage: $0" + exit 1 +fi + +data_id=`basename ${data_dir}` + +rvb_opts=() +# This is the config for the system using simulated RIRs and point-source noises +rvb_opts+=(--rir-set-parameters "0.5, RIRS_NOISES/simulated_rirs/smallroom/rir_list") +rvb_opts+=(--rir-set-parameters "0.5, RIRS_NOISES/simulated_rirs/mediumroom/rir_list") +rvb_opts+=(--noise-set-parameters RIRS_NOISES/music/music_list) + +music_utt2num_frames=RIRS_NOISES/music/split_utt2num_frames + +corrupted_data_id=${data_id}_music_corrupted +orig_corrupted_data_id=$corrupted_data_id + +if [ $stage -le 1 ]; then + python steps/data/reverberate_data_dir.py \ + "${rvb_opts[@]}" \ + --prefix="music" \ + --foreground-snrs=$foreground_snrs \ + --background-snrs=$background_snrs \ + --speech-rvb-probability=1 \ + --pointsource-noise-addition-probability=1 \ + --isotropic-noise-addition-probability=1 \ + --num-replications=$num_data_reps \ + --max-noises-per-minute=5 \ + data/${data_id} data/${corrupted_data_id} +fi + +if $dry_run; then + exit 0 +fi + +corrupted_data_dir=data/${corrupted_data_id} +orig_corrupted_data_dir=$corrupted_data_dir + +if $speed_perturb; then + if [ $stage -le 2 ]; then + ## Assuming whole data directories + for x in $corrupted_data_dir; do + cp $x/reco2dur $x/utt2dur + utils/data/perturb_data_dir_speed_3way.sh $x ${x}_sp + done + fi + + corrupted_data_dir=${corrupted_data_dir}_sp + corrupted_data_id=${corrupted_data_id}_sp + + if [ $stage -le 3 ]; then + utils/data/perturb_data_dir_volume.sh --scale-low 0.03125 --scale-high 2 \ + ${corrupted_data_dir} + fi +fi + +if $corrupt_only; then + echo "$0: Got corrupted data directory in ${corrupted_data_dir}" + exit 0 +fi + +mfccdir=`basename $mfcc_config` +mfccdir=${mfccdir%%.conf} + +if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $mfccdir/storage ]; then + utils/create_split_dir.pl \ + /export/b0{3,4,5,6}/$USER/kaldi-data/egs/aspire-$(date +'%m_%d_%H_%M')/s5/$mfccdir/storage $mfccdir/storage +fi + +if [ $stage -le 4 ]; then + if [ ! -z $feat_suffix ]; then + utils/copy_data_dir.sh $corrupted_data_dir ${corrupted_data_dir}_$feat_suffix + corrupted_data_dir=${corrupted_data_dir}_$feat_suffix + fi + steps/make_mfcc.sh --mfcc-config $mfcc_config \ + --cmd "$train_cmd" --nj $reco_nj \ + $corrupted_data_dir exp/make_${mfccdir}/${corrupted_data_id} $mfccdir + steps/compute_cmvn_stats.sh --fake \ + $corrupted_data_dir exp/make_${mfccdir}/${corrupted_data_id} $mfccdir +else + if [ ! -z $feat_suffix ]; then + corrupted_data_dir=${corrupted_data_dir}_$feat_suffix + fi +fi + +if [ $stage -le 8 ]; then + if [ ! -z "$reco_vad_dir" ]; then + if [ ! -f $reco_vad_dir/speech_feat.scp ]; then + echo "$0: Could not find file $reco_vad_dir/speech_feat.scp" + exit 1 + fi + + cat $reco_vad_dir/speech_feat.scp | \ + steps/segmentation/get_reverb_scp.pl -f 1 $num_data_reps "music" | \ + sort -k1,1 > ${corrupted_data_dir}/speech_feat.scp + + cat $reco_vad_dir/deriv_weights.scp | \ + steps/segmentation/get_reverb_scp.pl -f 1 $num_data_reps "music" | \ + sort -k1,1 > ${corrupted_data_dir}/deriv_weights.scp + fi +fi + +# music_dir is without speed perturbation +music_dir=exp/make_music_labels/${orig_corrupted_data_id} +music_data_dir=$music_dir/music_data + +mkdir -p $music_data_dir + +if [ $stage -le 10 ]; then + utils/data/get_utt2num_frames.sh $corrupted_data_dir + utils/split_data.sh --per-reco ${orig_corrupted_data_dir} $reco_nj + + cp $orig_corrupted_data_dir/wav.scp $music_data_dir + + # Combine the VAD from the base recording and the VAD from the overlapping segments + # to create per-frame labels of the number of overlapping speech segments + # Unreliable segments are regions where no VAD labels were available for the + # overlapping segments. These can be later removed by setting deriv weights to 0. + $train_cmd JOB=1:$reco_nj $music_dir/log/get_music_seg.JOB.log \ + segmentation-init-from-additive-signals-info --lengths-rspecifier=ark,t:$corrupted_data_dir/utt2num_frames \ + --additive-signals-segmentation-rspecifier="ark:segmentation-init-from-lengths ark:$music_utt2num_frames ark:- |" \ + "ark:utils/filter_scp.pl ${orig_corrupted_data_dir}/split${reco_nj}reco/JOB/utt2spk $corrupted_data_dir/utt2num_frames | segmentation-init-from-lengths --label=1 ark:- ark:- | segmentation-post-process --remove-labels=1 ark:- ark:- |" \ + ark,t:$orig_corrupted_data_dir/additive_signals_info.txt \ + ark:- \| \ + segmentation-post-process --merge-adjacent-segments ark:- \ + ark:- \| \ + segmentation-to-segments ark:- ark:$music_data_dir/utt2spk.JOB \ + $music_data_dir/segments.JOB + + for n in `seq $reco_nj`; do cat $music_data_dir/utt2spk.$n; done > $music_data_dir/utt2spk + for n in `seq $reco_nj`; do cat $music_data_dir/segments.$n; done > $music_data_dir/segments + + utils/fix_data_dir.sh $music_data_dir + + if $speed_perturb; then + utils/data/perturb_data_dir_speed_3way.sh $music_data_dir ${music_data_dir}_sp + fi +fi + +if $speed_perturb; then + music_data_dir=${music_data_dir}_sp +fi + +label_dir=music_labels + +mkdir -p $label_dir +label_dir=`perl -e '($dir,$pwd)= @ARGV; if($dir!~m:^/:) { $dir = "$pwd/$dir"; } print $dir; ' $label_dir ${PWD}` + +if [ $stage -le 11 ]; then + utils/split_data.sh --per-reco ${music_data_dir} $reco_nj + + $train_cmd JOB=1:$reco_nj $music_dir/log/get_music_labels.JOB.log \ + utils/data/get_reco2utt.sh ${music_data_dir}/split${reco_nj}reco/JOB '&&' \ + segmentation-init-from-segments --shift-to-zero=false \ + ${music_data_dir}/split${reco_nj}reco/JOB/segments ark:- \| \ + segmentation-combine-segments-to-recordings ark:- ark,t:${music_data_dir}/split${reco_nj}reco/JOB/reco2utt \ + ark:- \| \ + segmentation-to-ali --lengths-rspecifier=ark,t:${corrupted_data_dir}/utt2num_frames ark:- \ + ark,scp:$label_dir/music_labels_${corrupted_data_id}.JOB.ark,$label_dir/music_labels_${corrupted_data_id}.JOB.scp +fi + +for n in `seq $reco_nj`; do + cat $label_dir/music_labels_${corrupted_data_id}.$n.scp +done > ${corrupted_data_dir}/music_labels.scp + +exit 0 diff --git a/egs/aspire/s5/local/segmentation/make_musan_music.py b/egs/aspire/s5/local/segmentation/make_musan_music.py new file mode 100755 index 00000000000..5d13078de63 --- /dev/null +++ b/egs/aspire/s5/local/segmentation/make_musan_music.py @@ -0,0 +1,69 @@ +#! /usr/bin/env python + +from __future__ import print_function +import argparse +import os + + +def _get_args(): + parser = argparse.ArgumentParser() + + parser.add_argument("--use-vocals", type=str, default="false", + choices=["true", "false"], + help="If true, also add music with vocals in the " + "output music-set-parameters") + parser.add_argument("root_dir", type=str, + help="Root directory of MUSAN corpus") + parser.add_argument("music_list", type=argparse.FileType('w'), + help="Convert music list into noise-set-paramters " + "for steps/data/reverberate_data_dir.py") + + args = parser.parse_args() + + args.use_vocals = True if args.use_vocals == "true" else False + return args + + +def read_vocals(annotations): + vocals = {} + for line in open(annotations): + parts = line.strip().split() + if parts[2] == "Y": + vocals[parts[0]] = True + return vocals + + +def write_music(utt, file_path, music_list): + print ('{utt} {file_path}'.format( + utt=utt, file_path=file_path), file=music_list) + + +def prepare_music_set(root_dir, use_vocals, music_list): + vocals = {} + music_dir = os.path.join(root_dir, "music") + for root, dirs, files in os.walk(music_dir): + if os.path.exists(os.path.join(root, "ANNOTATIONS")): + vocals = read_vocals(os.path.join(root, "ANNOTATIONS")) + + for f in files: + file_path = os.path.join(root, f) + if f.endswith(".wav"): + utt = str(f).replace(".wav", "") + if not use_vocals and utt in vocals: + continue + write_music(utt, file_path, music_list) + music_list.close() + + +def main(): + args = _get_args() + + try: + prepare_music_set(args.root_dir, args.use_vocals, + args.music_list) + finally: + args.music_list.close() + + +if __name__ == '__main__': + main() diff --git a/egs/aspire/s5/local/segmentation/train_stats_sad_music.sh b/egs/aspire/s5/local/segmentation/train_stats_sad_music.sh new file mode 100644 index 00000000000..8242b83c747 --- /dev/null +++ b/egs/aspire/s5/local/segmentation/train_stats_sad_music.sh @@ -0,0 +1,172 @@ +#!/bin/bash + +# This is a script to train a time-delay neural network for speech activity detection (SAD) and +# music-id using statistic pooling component for long-context information. + +set -o pipefail +set -e +set -u + +. cmd.sh + +# At this script level we don't support not running on GPU, as it would be painfully slow. +# If you want to run without GPU you'd have to call train_tdnn.sh with --gpu false, +# --num-threads 16 and --minibatch-size 128. + +stage=0 +train_stage=-10 +get_egs_stage=-10 +egs_opts= # Directly passed to get_egs_multiple_targets.py + +# TDNN options +splice_indexes="-3,-2,-1,0,1,2,3 -6,0,mean+count(-99:3:9:99) -9,0,3 0" +relu_dim=256 +chunk_width=20 # We use chunk training for training TDNN +extra_left_context=100 # Maximum left context in egs apart from TDNN's left context +extra_right_context=20 # Maximum right context in egs apart from TDNN's right context + +# We randomly select an extra {left,right} context for each job between +# min_extra_*_context and extra_*_context so that the network can get used +# to different contexts used to compute statistics. +min_extra_left_context=20 +min_extra_right_context=0 + +# training options +num_epochs=2 +initial_effective_lrate=0.0003 +final_effective_lrate=0.00003 +num_jobs_initial=3 +num_jobs_final=8 +remove_egs=false +max_param_change=0.2 # Small max-param change for small network +extra_egs_copy_cmd= # Used if you want to do some weird stuff to egs + # such as removing one of the targets + +num_utts_subset_valid=50 # "utts" is actually recording. So this is prettly small. +num_utts_subset_train=50 + +# target options +train_data_dir=data/train_azteec_whole_sp_corrupted_hires + +speech_feat_scp= +music_labels_scp= + +deriv_weights_scp= + +egs_dir= +nj=40 +feat_type=raw +config_dir= + +dir= +affix=a + +. cmd.sh +. ./path.sh +. ./utils/parse_options.sh + +num_hidden_layers=`echo $splice_indexes | perl -ane 'print scalar @F'` || exit 1 +if [ -z "$dir" ]; then + dir=exp/nnet3_stats_sad_music/nnet_tdnn +fi + +dir=$dir${affix:+_$affix}_n${num_hidden_layers} + +if ! cuda-compiled; then + cat < Date: Wed, 30 Nov 2016 16:29:14 -0500 Subject: [PATCH 042/213] asr_diarization: Recipe for music-id on broadcast news --- .../v1/local/run_nnet3_music_id.sh | 217 ++++++++++++++++++ 1 file changed, 217 insertions(+) create mode 100644 egs/bn_music_speech/v1/local/run_nnet3_music_id.sh diff --git a/egs/bn_music_speech/v1/local/run_nnet3_music_id.sh b/egs/bn_music_speech/v1/local/run_nnet3_music_id.sh new file mode 100644 index 00000000000..d96acdabaaa --- /dev/null +++ b/egs/bn_music_speech/v1/local/run_nnet3_music_id.sh @@ -0,0 +1,217 @@ +#!/bin/bash + +set -e +set -o pipefail +set -u + +. path.sh +. cmd.sh + +feat_affix=bp_vh +affix= +reco_nj=32 + +stage=-1 + +# SAD network config +iter=final +extra_left_context=100 # Set to some large value +extra_right_context=20 + + +# Configs +frame_subsampling_factor=1 + +min_silence_duration=3 # minimum number of frames for silence +min_speech_duration=3 # minimum number of frames for speech +min_music_duration=3 # minimum number of frames for music +music_transition_probability=0.1 +sil_transition_probability=0.1 +speech_transition_probability=0.1 +sil_prior=0.3 +speech_prior=0.4 +music_prior=0.3 + +# Decoding options +acwt=1 +beam=10 +max_active=7000 + +mfcc_config=conf/mfcc_hires_bp.conf + +echo $* + +. utils/parse_options.sh + +if [ $# -ne 3 ]; then + echo "Usage: $0 " + echo " e.g.: $0 data/bn exp/nnet3_sad_snr/tdnn_j_n4 exp/dnn_music_id" + exit 1 +fi + +# Set to true if the test data has > 8kHz sampling frequency. +do_downsampling=true + +data_dir=$1 +sad_nnet_dir=$2 +dir=$3 + +data_id=`basename $data_dir` + +export PATH="$KALDI_ROOT/tools/sph2pipe_v2.5/:$PATH" +[ ! -z `which sph2pipe` ] + +for f in $sad_nnet_dir/$iter.raw $sad_nnet_dir/post_output-speech.vec $sad_nnet_dir/post_output-music.vec; do + if [ ! -f $f ]; then + echo "$0: Could not find $f. See the local/segmentation/run_train_sad.sh" + exit 1 + fi +done + +mkdir -p $dir + +new_data_dir=$dir/${data_id} +if [ $stage -le 0 ]; then + utils/data/convert_data_dir_to_whole.sh $data_dir ${new_data_dir}_whole + + freq=`cat $mfcc_config | perl -pe 's/\s*#.*//g' | grep "sample-frequency=" | awk -F'=' '{if (NF == 0) print 16000; else print $2}'` + sox=`which sox` + + cat $data_dir/wav.scp | python -c "import sys +for line in sys.stdin.readlines(): + splits = line.strip().split() + if splits[-1] == '|': + out_line = line.strip() + ' $sox -t wav - -r $freq -c 1 -b 16 -t wav - downsample |' + else: + out_line = 'cat {0} {1} | $sox -t wav - -r $freq -c 1 -b 16 -t wav - downsample |'.format(splits[0], ' '.join(splits[1:])) + print (out_line)" > ${new_data_dir}_whole/wav.scp + + utils/copy_data_dir.sh ${new_data_dir}_whole ${new_data_dir}_whole_bp_hires +fi + +test_data_dir=${new_data_dir}_whole_bp_hires + +if [ $stage -le 1 ]; then + steps/make_mfcc.sh --mfcc-config $mfcc_config --nj $reco_nj --cmd "$train_cmd" \ + ${new_data_dir}_whole_bp_hires exp/make_hires/${data_id}_whole_bp mfcc_hires + steps/compute_cmvn_stats.sh ${new_data_dir}_whole_bp_hires exp/make_hires/${data_id}_whole_bp mfcc_hires +fi + +if [ $stage -le 2 ]; then + output_name=output-speech + post_vec=$sad_nnet_dir/post_${output_name}.vec + steps/nnet3/compute_output.sh --nj $reco_nj --cmd "$train_cmd" \ + --post-vec "$post_vec" \ + --iter $iter \ + --extra-left-context $extra_left_context \ + --extra-right-context $extra_right_context \ + --frames-per-chunk 150 \ + --output-name $output_name \ + --frame-subsampling-factor $frame_subsampling_factor \ + --get-raw-nnet-from-am false ${test_data_dir} $sad_nnet_dir $dir/sad_${data_id}_whole_bp +fi + +if [ $stage -le 3 ]; then + output_name=output-music + post_vec=$sad_nnet_dir/post_${output_name}.vec + steps/nnet3/compute_output.sh --nj $reco_nj --cmd "$train_cmd" \ + --post-vec "$post_vec" \ + --iter $iter \ + --extra-left-context $extra_left_context \ + --extra-right-context $extra_right_context \ + --frames-per-chunk 150 \ + --output-name $output_name \ + --frame-subsampling-factor $frame_subsampling_factor \ + --get-raw-nnet-from-am false ${test_data_dir} $sad_nnet_dir $dir/music_${data_id}_whole_bp +fi + +if [ $stage -le 4 ]; then + $train_cmd JOB=1:$reco_nj $dir/get_average_likes.JOB.log \ + paste-feats \ + "ark:gunzip -c $dir/sad_${data_id}_whole_bp/log_likes.JOB.gz | extract-feature-segments ark:- 'utils/filter_scp.pl -f 2 ${test_data_dir}/split$reco_nj/JOB/utt2spk $data_dir/segments |' ark:- |" \ + "ark:gunzip -c $dir/music_${data_id}_whole_bp/log_likes.JOB.gz | select-feats 1 ark:- ark:- | extract-feature-segments ark:- 'utils/filter_scp.pl -f 2 ${test_data_dir}/split$reco_nj/JOB/utt2spk $data_dir/segments |' ark:- |" \ + ark:- \| \ + matrix-sum-rows --do-average ark:- ark,t:$dir/average_likes.JOB.ark + + for n in `seq $reco_nj`; do + cat $dir/average_likes.$n.ark + done | awk '{print $1" "( exp($3) + exp($5) + 0.01) / (exp($4) + 0.01)}' | \ + local/print_scores.py /dev/stdin | compute-eer - +fi + +lang=$dir/lang + +if [ $stage -le 5 ]; then + mkdir -p $lang + + # Create a lang directory with phones.txt and topo with + # silence, music and speech phones. + steps/segmentation/internal/prepare_sad_lang.py \ + --phone-transition-parameters="--phone-list=1 --min-duration=$min_silence_duration --end-transition-probability=$sil_transition_probability" \ + --phone-transition-parameters="--phone-list=2 --min-duration=$min_speech_duration --end-transition-probability=$speech_transition_probability" \ + --phone-transition-parameters="--phone-list=3 --min-duration=$min_music_duration --end-transition-probability=$music_transition_probability" \ + $lang + + cp $lang/phones.txt $lang/words.txt +fi + +feat_dim=2 # dummy. We don't need this. +if [ $stage -le 6 ]; then + $train_cmd $dir/log/create_transition_model.log gmm-init-mono \ + $lang/topo $feat_dim - $dir/tree \| \ + copy-transition-model --binary=false - $dir/trans.mdl || exit 1 +fi + +# Make unigram G.fst +if [ $stage -le 7 ]; then + cat > $lang/word2prior < $lang/G.fst +fi + +graph_dir=$dir/graph_test + +if [ $stage -le 8 ]; then + $train_cmd $dir/log/make_vad_graph.log \ + steps/segmentation/internal/make_sad_graph.sh --iter trans \ + $lang $dir $dir/graph_test || exit 1 +fi + +seg_dir=$dir/segmentation_${data_id}_whole_bp +mkdir -p $seg_dir + +if [ $stage -le 9 ]; then + decoder_opts+=(--acoustic-scale=$acwt --beam=$beam --max-active=$max_active) + $train_cmd JOB=1:$reco_nj $dir/decode.JOB.log \ + paste-feats \ + "ark:gunzip -c $dir/sad_${data_id}_whole_bp/log_likes.JOB.gz | extract-feature-segments ark:- 'utils/filter_scp.pl -f 2 ${test_data_dir}/split$reco_nj/JOB/utt2spk $data_dir/segments |' ark:- |" \ + "ark:gunzip -c $dir/music_${data_id}_whole_bp/log_likes.JOB.gz | select-feats 1 ark:- ark:- | extract-feature-segments ark:- 'utils/filter_scp.pl -f 2 ${test_data_dir}/split$reco_nj/JOB/utt2spk $data_dir/segments |' ark:- |" \ + ark:- \| decode-faster-mapped ${decoder_opts[@]} \ + $dir/trans.mdl $graph_dir/HCLG.fst ark:- \ + ark:/dev/null ark:- \| \ + ali-to-phones --per-frame $dir/trans.mdl ark:- \ + "ark:|gzip -c > $seg_dir/ali.JOB.gz" +fi + +include_silence=true +if [ $stage -le 10 ]; then + $train_cmd JOB=1:$reco_nj $dir/log/get_class_id.JOB.log \ + ali-to-post "ark:gunzip -c $seg_dir/ali.JOB.gz |" ark:- \| \ + post-to-feats --post-dim=4 ark:- ark:- \| \ + matrix-sum-rows --do-average ark:- ark,t:- \| \ + sid/vector_to_music_labels.pl ${include_silence:+--include-silence-in-music} '>' $dir/ratio.JOB + + for n in `seq $reco_nj`; do + cat $dir/ratio.$n + done > $dir/ratio + + cat $dir/ratio | local/print_scores.py /dev/stdin | compute-eer - +fi + +# LOG (compute-eer:main():compute-eer.cc:136) Equal error rate is 0.860585%, at threshold 1.99361 From 82bfc5a60512a064787e494a5ab8fe1a173caf68 Mon Sep 17 00:00:00 2001 From: Vimal Manohar Date: Fri, 25 Nov 2016 01:06:26 -0500 Subject: [PATCH 043/213] asr_diarization: Utilities invert_vector.pl and vector_get_max.pl --- .../s5/steps/segmentation/invert_vector.pl | 20 ++++++++++++++ .../s5/steps/segmentation/vector_get_max.pl | 26 +++++++++++++++++++ 2 files changed, 46 insertions(+) create mode 100755 egs/wsj/s5/steps/segmentation/invert_vector.pl create mode 100644 egs/wsj/s5/steps/segmentation/vector_get_max.pl diff --git a/egs/wsj/s5/steps/segmentation/invert_vector.pl b/egs/wsj/s5/steps/segmentation/invert_vector.pl new file mode 100755 index 00000000000..c16243a0b93 --- /dev/null +++ b/egs/wsj/s5/steps/segmentation/invert_vector.pl @@ -0,0 +1,20 @@ +#! /usr/bin/perl +use strict; +use warnings; + +while () { + chomp; + my @F = split; + my $utt = shift @F; + shift @F; + + print "$utt [ "; + for (my $i = 0; $i < $#F; $i++) { + if ($F[$i] == 0) { + print "1 "; + } else { + print 1.0/$F[$i] . " "; + } + } + print "]\n"; +} diff --git a/egs/wsj/s5/steps/segmentation/vector_get_max.pl b/egs/wsj/s5/steps/segmentation/vector_get_max.pl new file mode 100644 index 00000000000..abb8ea977a2 --- /dev/null +++ b/egs/wsj/s5/steps/segmentation/vector_get_max.pl @@ -0,0 +1,26 @@ +#! /usr/bin/perl + +use warnings; +use strict; + +while (<>) { + chomp; + if (m/^\S+\s+\[.+\]\s*$/) { + my @F = split; + my $utt = shift @F; + shift; + + my $max_id = 0; + my $max = $F[0]; + for (my $i = 1; $i < $#F; $i++) { + if ($F[$i] > $max) { + $max_id = $i; + $max = $F[$i]; + } + } + + print "$utt $max_id\n"; + } else { + die "Invalid line $_\n"; + } +} From 709ac923ddc07ff2ca568fafa7e62a6a50f5ea8c Mon Sep 17 00:00:00 2001 From: Vimal Manohar Date: Wed, 30 Nov 2016 17:03:12 -0500 Subject: [PATCH 044/213] asr_diarization: Recipe for segmentation on AMI SDM dev set --- .../s5b/local/prepare_parallel_train_data.sh | 24 ++-- .../segmentation/run_segmentation_ami.sh | 128 ++++++++++++++++++ 2 files changed, 140 insertions(+), 12 deletions(-) create mode 100755 egs/aspire/s5/local/segmentation/run_segmentation_ami.sh diff --git a/egs/ami/s5b/local/prepare_parallel_train_data.sh b/egs/ami/s5b/local/prepare_parallel_train_data.sh index b049c906c3b..b551bacfb92 100755 --- a/egs/ami/s5b/local/prepare_parallel_train_data.sh +++ b/egs/ami/s5b/local/prepare_parallel_train_data.sh @@ -5,6 +5,10 @@ # but the wav data is copied from data/ihm. This is a little tricky because the # utterance ids are different between the different mics +train_set=train + +. utils/parse_options.sh + if [ $# != 1 ]; then echo "Usage: $0 [sdm1|mdm8]" @@ -18,12 +22,10 @@ if [ $mic == "ihm" ]; then exit 1; fi -train_set=train - . cmd.sh . ./path.sh -for f in data/ihm/train/utt2spk data/$mic/train/utt2spk; do +for f in data/ihm/${train_set}/utt2spk data/$mic/${train_set}/utt2spk; do if [ ! -f $f ]; then echo "$0: expected file $f to exist" exit 1 @@ -32,12 +34,12 @@ done set -e -o pipefail -mkdir -p data/$mic/train_ihmdata +mkdir -p data/$mic/${train_set}_ihmdata # the utterance-ids and speaker ids will be from the SDM or MDM data -cp data/$mic/train/{spk2utt,text,utt2spk} data/$mic/train_ihmdata/ +cp data/$mic/${train_set}/{spk2utt,text,utt2spk} data/$mic/${train_set}_ihmdata/ # the recording-ids will be from the IHM data. -cp data/ihm/train/{wav.scp,reco2file_and_channel} data/$mic/train_ihmdata/ +cp data/ihm/${train_set}/{wav.scp,reco2file_and_channel} data/$mic/${train_set}_ihmdata/ # map sdm/mdm segments to the ihm segments @@ -47,19 +49,17 @@ mic_base_upcase=$(echo $mic | sed 's/[0-9]//g' | tr 'a-z' 'A-Z') # It has lines like: # AMI_EN2001a_H02_FEO065_0021133_0021442 AMI_EN2001a_SDM_FEO065_0021133_0021442 -tmpdir=data/$mic/train_ihmdata/ +tmpdir=data/$mic/${train_set}_ihmdata/ -awk '{print $1, $1}' $tmpdir/ihmutt2utt # Map the 1st field of the segments file from the ihm data (the 1st field being # the utterance-id) to the corresponding SDM or MDM utterance-id. The other # fields remain the same (e.g. we want the recording-ids from the IHM data). -utils/apply_map.pl -f 1 $tmpdir/ihmutt2utt data/$mic/train_ihmdata/segments - -utils/fix_data_dir.sh data/$mic/train_ihmdata +utils/apply_map.pl -f 1 $tmpdir/ihmutt2utt data/$mic/${train_set}_ihmdata/segments -rm $tmpdir/ihmutt2utt +utils/fix_data_dir.sh data/$mic/${train_set}_ihmdata exit 0; diff --git a/egs/aspire/s5/local/segmentation/run_segmentation_ami.sh b/egs/aspire/s5/local/segmentation/run_segmentation_ami.sh new file mode 100755 index 00000000000..46ebf013b82 --- /dev/null +++ b/egs/aspire/s5/local/segmentation/run_segmentation_ami.sh @@ -0,0 +1,128 @@ +#! /bin/bash + +# Copyright 2016 Vimal Manohar +# Apache 2.0. + +. cmd.sh +. path.sh + +set -e +set -o pipefail +set -u + +stage=-1 +nnet_dir=exp/nnet3_sad_snr/nnet_tdnn_k_n4 + +. utils/parse_options.sh + +export PATH=$KALDI_ROOT/tools/sctk/bin:$PATH + +src_dir=/export/a09/vmanoha1/workspace_asr_diarization/egs/ami/s5b # AMI src_dir +dir=exp/sad_ami_sdm1_dev/ref + +mkdir -p $dir + +# Expecting user to have done run.sh to run the AMI recipe in $src_dir for +# both sdm and ihm microphone conditions + +if [ $stage -le 1 ]; then + ( + cd $src_dir + local/prepare_parallel_train_data.sh --train-set dev sdm1 + + awk '{print $1" "$2}' $src_dir/data/ihm/dev/segments > \ + $src_dir/data/ihm/dev/utt2reco + awk '{print $1" "$2}' $src_dir/data/sdm1/dev/segments > \ + $src_dir/data/sdm1/dev/utt2reco + + cat $src_dir/data/sdm1/dev_ihmdata/ihmutt2utt | \ + utils/apply_map.pl -f 1 $src_dir/data/ihm/dev/utt2reco | \ + utils/apply_map.pl -f 2 $src_dir/data/sdm1/dev/utt2reco | \ + sort -u > $src_dir/data/sdm1/dev_ihmdata/ihm2sdm_reco + ) +fi + +if [ $stage -le 2 ]; then + ( + cd $src_dir + utils/data/get_reco2utt.sh $src_dir/data/sdm1/dev + ) + + phone_map=$dir/phone_map + steps/segmentation/get_sad_map.py \ + $src_dir/data/lang | utils/sym2int.pl -f 1 $src_dir/data/lang/phones.txt > \ + $phone_map +fi + +if [ $stage -le 3 ]; then + # Expecting user to have run local/run_cleanup_segmentation.sh in $src_dir + ( + cd $src_dir + steps/align_fmllr.sh --nj 18 --cmd "$train_cmd" \ + data/sdm1/dev_ihmdata data/lang \ + exp/ihm/tri3_cleaned \ + exp/sdm1/tri3_cleaned_dev_ihmdata + ) +fi + +if [ $stage -le 4 ]; then + steps/segmentation/internal/convert_ali_to_vad.sh --cmd "$train_cmd" \ + $src_dir/exp/sdm1/tri3_cleaned_dev_ihmdata $phone_map $dir +fi + +echo "A 1" > $dir/channel_map +cat $src_dir/data/sdm1/dev/reco2file_and_channel | \ + utils/apply_map.pl -f 3 $dir/channel_map > $dir/reco2file_and_channel + +if [ $stage -le 5 ]; then + $train_cmd $dir/log/get_ref_rttm.log \ + segmentation-combine-segments scp:$dir/sad_seg.scp \ + "ark:segmentation-init-from-segments --shift-to-zero=false $src_dir/data/sdm1/dev_ihmdata/segments ark:- |" \ + ark,t:$src_dir/data/sdm1/dev_ihmdata/reco2utt ark:- \| \ + segmentation-merge-recordings \ + "ark,t:utils/utt2spk_to_spk2utt.pl $src_dir/data/sdm1/dev_ihmdata/ihm2sdm_reco |" \ + ark:- ark:- \| \ + segmentation-to-rttm --reco2file-and-channel=$dir/reco2file_and_channel \ + ark:- $dir/ref.rttm +fi + +if [ $stage -le 6 ]; then + $train_cmd $dir/log/get_uem.log \ + segmentation-init-from-segments --shift-to-zero=false $src_dir/data/sdm1/dev/segments ark:- \| \ + segmentation-combine-segments-to-recordings ark:- ark,t:$src_dir/data/sdm1/dev/reco2utt ark:- \| \ + segmentation-post-process --remove-labels=0 --merge-adjacent-segments \ + --max-intersegment-length=10000 ark:- ark:- \| \ + segmentation-to-rttm --reco2file-and-channel=$dir/reco2file_and_channel \ + ark:- - \| grep SPEECH \| grep SPEAKER \| \ + rttmSmooth.pl -s 0 \| awk '{ print $2" "$3" "$4" "$5+$4 }' '>' $dir/uem +fi + +hyp_dir=$nnet_dir/segmentation_ami_sdm1_dev_whole_bp + +if [ $stage -le 7 ]; then + steps/segmentation/do_segmentation_data_dir.sh --reco-nj 18 \ + --mfcc-config conf/mfcc_hires_bp.conf --feat-affix bp --do-downsampling true \ + --extra-left-context 100 --extra-right-context 20 \ + --output-name output-speech --frame-subsampling-factor 6 \ + $src_dir/data/sdm1/dev data/ami_sdm1_dev $nnet_dir +fi + + +if [ $stage -le 8 ]; then + utils/data/get_reco2utt.sh $src_dir/data/sdm1/dev_ihmdata + + steps/segmentation/convert_utt2spk_and_segments_to_rttm.py \ + $hyp_dir/ami_sdm1_dev_seg/utt2spk \ + $hyp_dir/ami_sdm1_dev_seg/segments \ + $dir/reco2file_and_channel \ + /dev/stdout | spkr2sad.pl > $hyp_dir/sys.rttm +fi + +if [ $stage -le 9 ]; then + md-eval.pl -s <(cat $hyp_dir/sys.rttm | grep speech | rttmSmooth.pl -s 0) \ + -r <(cat $dir/ref.rttm | grep SPEECH | rttmSmooth.pl -s 0 ) \ + -u $dir/uem -c 0.25 +fi + +#md-eval.pl -s <( segmentation-init-from-segments --shift-to-zero=false exp/nnet3_sad_snr/nnet_tdnn_j_n4/segmentation_ami_sdm1_dev_whole_bp/ami_sdm1_dev_seg/segments ark:- | segmentation-combine-segments-to-recordings ark:- ark,t:exp/nnet3_sad_snr/nnet_tdnn_j_n4/segmentation_ami_sdm1_dev_whole_bp/ami_sdm1_dev_seg/reco2utt ark:- | segmentation-to-ali --length-tolerance=1000 --lengths-rspecifier=ark,t:data/ami_sdm1_dev_whole_bp_hires/utt2num_frames ark:- ark:- | +#segmentation-init-from-ali ark:- ark:- | segmentation-to-rttm ark:- - | grep SPEECH | rttmSmooth.pl -s 0) From 64d1456d831ea5cd33e61fc384a85bff16855d6e Mon Sep 17 00:00:00 2001 From: Vimal Manohar Date: Wed, 30 Nov 2016 17:01:59 -0500 Subject: [PATCH 045/213] asr_diarization: Fisher recipe from data preparation, training nnet and testing on AMI --- .../local/segmentation/prepare_fisher_data.sh | 88 +++++++++++++++++++ .../s5/local/segmentation/run_fisher.sh | 23 +++++ 2 files changed, 111 insertions(+) create mode 100644 egs/aspire/s5/local/segmentation/prepare_fisher_data.sh create mode 100644 egs/aspire/s5/local/segmentation/run_fisher.sh diff --git a/egs/aspire/s5/local/segmentation/prepare_fisher_data.sh b/egs/aspire/s5/local/segmentation/prepare_fisher_data.sh new file mode 100644 index 00000000000..1344e185a02 --- /dev/null +++ b/egs/aspire/s5/local/segmentation/prepare_fisher_data.sh @@ -0,0 +1,88 @@ +#! /bin/bash + +# This script prepares Fisher data for training a speech activity detection +# and music detection system + +# Copyright 2016 Vimal Manohar +# Apache 2.0. + +. path.sh +. cmd.sh + +if [ $# -ne 0 ]; then + echo "Usage: $0" + echo "This script is to serve as an example recipe." + echo "Edit the script to change variables if needed." + exit 1 +fi + +dir=exp/unsad/make_unsad_fisher_train_100k # Work dir +subset=150 + +# All the paths below can be modified to any absolute path. + +# The original data directory which will be converted to a whole (recording-level) directory. +train_data_dir=data/fisher_train_100k + +model_dir=exp/tri3a # Model directory used for decoding +sat_model_dir=exp/tri4a # Model directory used for getting alignments +lang=data/lang # Language directory +lang_test=data/lang_test # Language directory used to build graph + +# Hard code the mapping from phones to SAD labels +# 0 for silence, 1 for speech, 2 for noise, 3 for unk +cat < $dir/fisher_sad.map +sil 0 +sil_B 0 +sil_E 0 +sil_I 0 +sil_S 0 +laughter 2 +laughter_B 2 +laughter_E 2 +laughter_I 2 +laughter_S 2 +noise 2 +noise_B 2 +noise_E 2 +noise_I 2 +noise_S 2 +oov 3 +oov_B 3 +oov_E 3 +oov_I 3 +oov_S 3 +EOF + +# Expecting the user to have done run.sh to have $model_dir, +# $sat_model_dir, $lang, $lang_test, $train_data_dir +local/segmentation/prepare_unsad_data.sh \ + --sad-map $dir/fisher_sad.map \ + --config-dir conf \ + --reco-nj 40 --nj 100 --cmd "$train_cmd" \ + --sat-model $sat_model_dir \ + --lang-test $lang_test \ + $train_data_dir $lang $model_dir $dir + +data_dir=${train_data_dir}_whole + +if [ ! -z $subset ]; then + # Work on a subset + utils/subset_data_dir.sh ${data_dir} $subset \ + ${data_dir}_$subset + data_dir=${data_dir}_$subset +fi + +reco_vad_dir=$dir/`basename $model_dir`_reco_vad_`basename $train_data_dir`_sp + +# Add noise from MUSAN corpus to data directory and create a new data directory +local/segmentation/do_corruption_data_dir.sh + --data-dir $data_dir \ + --reco-vad-dir $reco_vad_dir + --feat-suffix hires_bp --mfcc-config conf/mfcc_hires_bp.conf + +# Add music from MUSAN corpus to data directory and create a new data directory +local/segmentation/do_corruption_data_dir_music.sh + --data-dir $data_dir \ + --reco-vad-dir $reco_vad_dir + --feat-suffix hires_bp --mfcc-config conf/mfcc_hires_bp.conf diff --git a/egs/aspire/s5/local/segmentation/run_fisher.sh b/egs/aspire/s5/local/segmentation/run_fisher.sh new file mode 100644 index 00000000000..e39ef5f3a91 --- /dev/null +++ b/egs/aspire/s5/local/segmentation/run_fisher.sh @@ -0,0 +1,23 @@ +#! /bin/bash + +# Copyright 2016 Vimal Manohar +# Apache 2.0. + +local/segmentation/prepare_fisher_data.sh + +utils/combine_data.sh --extra-files "speech_feat.scp deriv_weights.scp deriv_weights_manual_seg.scp music_labels.scp" \ + data/fisher_train_100k_whole_all_corrupted_sp_hires_bp \ + data/fisher_train_100k_whole_corrupted_sp_hires_bp \ + data/fisher_train_100k_whole_music_corrupted_sp_hires_bp + +local/segmentation/train_stats_sad_music.sh \ + --train-data-dir data/fisher_train_100k_whole_all_corrupted_sp_hires_bp \ + --speech-feat-scp data/fisher_train_100k_whole_corrupted_sp_hires_bp/speech_feat.scp \ + --deriv-weights-scp data/fisher_train_100k_whole_corrupted_sp_hires_bp/deriv_weights.scp \ + --music-labels-scp data/fisher_train-100k_whole_music_corrupted_sp_hires_bp/music_labels.scp \ + --max-param-change 0.2 \ + --num-epochs 2 --affix k \ + --splice-indexes "-3,-2,-1,0,1,2,3 -6,0,mean+count(-99:3:9:99) -9,0,3 0" + +local/segmentation/run_segmentation_ami.sh \ + --nnet-dir exp/nnet3_sad_snr/nnet_tdnn_k_n4 From 99ce6c816ca70670f9b555e9d0a79667c24f8b91 Mon Sep 17 00:00:00 2001 From: Vimal Manohar Date: Thu, 24 Nov 2016 01:18:19 -0500 Subject: [PATCH 046/213] asr_diarization: created compute-snr-targets --- src/featbin/Makefile | 2 +- src/featbin/compute-snr-targets.cc | 273 +++++++++++++++++++++++++++++ src/matrix/kaldi-matrix.cc | 81 +++++++++ src/matrix/kaldi-matrix.h | 4 + 4 files changed, 359 insertions(+), 1 deletion(-) create mode 100644 src/featbin/compute-snr-targets.cc diff --git a/src/featbin/Makefile b/src/featbin/Makefile index e1a9a1ebe0d..aaa4abca24c 100644 --- a/src/featbin/Makefile +++ b/src/featbin/Makefile @@ -16,7 +16,7 @@ BINFILES = compute-mfcc-feats compute-plp-feats compute-fbank-feats \ compute-and-process-kaldi-pitch-feats modify-cmvn-stats wav-copy \ wav-reverberate append-vector-to-feats detect-sinusoids shift-feats \ concat-feats append-post-to-feats post-to-feats vector-to-feat \ - extract-column + extract-column compute-snr-targets OBJFILES = diff --git a/src/featbin/compute-snr-targets.cc b/src/featbin/compute-snr-targets.cc new file mode 100644 index 00000000000..cdb7ef66c2a --- /dev/null +++ b/src/featbin/compute-snr-targets.cc @@ -0,0 +1,273 @@ +// featbin/compute-snr-targets.cc + +// Copyright 2015-2016 Vimal Manohar + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "base/kaldi-common.h" +#include "util/common-utils.h" +#include "matrix/kaldi-matrix.h" + +int main(int argc, char *argv[]) { + try { + using namespace kaldi; + + const char *usage = + "Compute snr targets using clean and noisy speech features.\n" + "The targets can be of 3 types -- \n" + "Irm (Ideal Ratio Mask) = Clean fbank / (Clean fbank + Noise fbank)\n" + "FbankMask = Clean fbank / Noisy fbank\n" + "Snr (Signal To Noise Ratio) = Clean fbank / Noise fbank\n" + "Both input and output features are assumed to be in log domain.\n" + "ali-rspecifier and silence-phones are used to identify whether " + "a particular frame is \"clean\" or not. Silence frames in " + "\"clean\" fbank are treated as \"noise\" and hence the SNR for those " + "frames are -inf in log scale.\n" + "Usage: compute-snr-targets [options] \n" + " or compute-snr-targets [options] --binary-targets \n" + "e.g.: compute-snr-targets scp:clean.scp scp:noisy.scp ark:targets.ark\n"; + + std::string target_type = "Irm"; + std::string ali_rspecifier; + std::string silence_phones_str; + std::string floor_str = "-inf", ceiling_str = "inf"; + int32 length_tolerance = 0; + bool binary_targets = false; + int32 target_dim = -1; + + ParseOptions po(usage); + po.Register("target_type", &target_type, "Target type can be FbankMask or IRM"); + po.Register("ali-rspecifier", &ali_rspecifier, "If provided, all the " + "energy in the silence region of clean file is considered noise"); + po.Register("silence-phones", &silence_phones_str, "Comma-separated list of " + "silence phones"); + po.Register("floor", &floor_str, "If specified, the target is floored at " + "this value. You may want to do this if you are using targets " + "in original log form as is usual in the case of Snr, but may " + "not if you are applying Exp() as is usual in the case of Irm"); + po.Register("ceiling", &ceiling_str, "If specified, the target is ceiled " + "at this value. You may want to do this if you expect " + "infinities or very large values, particularly for Snr targets."); + po.Register("length-tolerance", &length_tolerance, "Tolerate differences " + "in utterance lengths of these many frames"); + po.Register("binary-targets", &binary_targets, "If specified, then the " + "targets are created considering each frame to be either " + "completely signal or completely noise as decided by the " + "ali-rspecifier option. When ali-rspecifier is not specified, " + "then the entire utterance is considered to be just signal." + "If this option is specified, then only a single argument " + "-- the clean features -- is must be specified."); + po.Register("target-dim", &target_dim, "Overrides the target dimension. " + "Applicable only with --binary-targets is specified"); + + po.Read(argc, argv); + + if (po.NumArgs() != 3 && po.NumArgs() != 2) { + po.PrintUsage(); + exit(1); + } + + std::vector silence_phones; + if (!silence_phones_str.empty()) { + if (!SplitStringToIntegers(silence_phones_str, ":", false, &silence_phones)) { + KALDI_ERR << "Invalid silence-phones string " << silence_phones_str; + } + std::sort(silence_phones.begin(), silence_phones.end()); + } + + double floor = kLogZeroDouble, ceiling = -kLogZeroDouble; + + if (floor_str != "-inf") + if (!ConvertStringToReal(floor_str, &floor)) { + KALDI_ERR << "Invalid --floor value " << floor_str; + } + + if (ceiling_str != "inf") + if (!ConvertStringToReal(ceiling_str, &ceiling)) { + KALDI_ERR << "Invalid --ceiling value " << ceiling_str; + } + + int32 num_done = 0, num_err = 0, num_success = 0; + int64 num_sil_frames = 0; + int64 num_speech_frames = 0; + + if (!binary_targets) { + // This is the 'normal' case, where we have both clean and + // noise/corrupted input features. + // The word 'noisy' in the variable names is used to mean 'corrupted'. + std::string clean_rspecifier = po.GetArg(1), + noisy_rspecifier = po.GetArg(2), + targets_wspecifier = po.GetArg(3); + + SequentialBaseFloatMatrixReader noisy_reader(noisy_rspecifier); + RandomAccessBaseFloatMatrixReader clean_reader(clean_rspecifier); + BaseFloatMatrixWriter kaldi_writer(targets_wspecifier); + + RandomAccessInt32VectorReader alignment_reader(ali_rspecifier); + + for (; !noisy_reader.Done(); noisy_reader.Next(), num_done++) { + const std::string &key = noisy_reader.Key(); + Matrix total_energy(noisy_reader.Value()); + // Although this is called 'energy', it is actually log filterbank + // features of noise or corrupted files + // Actually noise feats in the case of Irm and Snr + + // TODO: Support multiple corrupted version for a particular clean file + std::string uniq_key = key; + if (!clean_reader.HasKey(uniq_key)) { + KALDI_WARN << "Could not find uniq key " << uniq_key << " " + << "in clean feats " << clean_rspecifier; + num_err++; + continue; + } + + Matrix clean_energy(clean_reader.Value(uniq_key)); + + if (target_type == "Irm") { + total_energy.LogAddExpMat(1.0, clean_energy, kNoTrans); + } + + if (!ali_rspecifier.empty()) { + if (!alignment_reader.HasKey(uniq_key)) { + KALDI_WARN << "Could not find uniq key " << uniq_key + << "in alignment " << ali_rspecifier; + num_err++; + continue; + } + const std::vector &ali = alignment_reader.Value(key); + + if (std::abs(static_cast (ali.size()) - clean_energy.NumRows()) > length_tolerance) { + KALDI_WARN << "Mismatch in number of frames in alignment " + << "and feats; " << static_cast(ali.size()) + << " vs " << clean_energy.NumRows(); + num_err++; + continue; + } + + int32 length = std::min(static_cast(ali.size()), clean_energy.NumRows()); + if (ali.size() < length) + // TODO: Support this case + KALDI_ERR << "This code currently does not support the case " + << "where alignment smaller than features because " + << "it is not expected to happen"; + + KALDI_ASSERT(clean_energy.NumRows() == length); + KALDI_ASSERT(total_energy.NumRows() == length); + + if (clean_energy.NumRows() < length) clean_energy.Resize(length, clean_energy.NumCols(), kCopyData); + if (total_energy.NumRows() < length) total_energy.Resize(length, total_energy.NumCols(), kCopyData); + + for (int32 i = 0; i < clean_energy.NumRows(); i++) { + if (std::binary_search(silence_phones.begin(), silence_phones.end(), ali[i])) { + clean_energy.Row(i).Set(kLogZeroDouble); + num_sil_frames++; + } else num_speech_frames++; + } + } + + clean_energy.AddMat(-1.0, total_energy); + if (ceiling_str != "inf") { + clean_energy.ApplyCeiling(ceiling); + } + + if (floor_str != "-inf") { + clean_energy.ApplyFloor(floor); + } + + kaldi_writer.Write(key, Matrix(clean_energy)); + num_success++; + } + } else { + // Copying tables of features. + std::string feats_rspecifier = po.GetArg(1), + targets_wspecifier = po.GetArg(2); + + SequentialBaseFloatMatrixReader feats_reader(feats_rspecifier); + BaseFloatMatrixWriter kaldi_writer(targets_wspecifier); + + RandomAccessInt32VectorReader alignment_reader(ali_rspecifier); + + int64 num_sil_frames = 0; + int64 num_speech_frames = 0; + + for (; !feats_reader.Done(); feats_reader.Next(), num_done++) { + const std::string &key = feats_reader.Key(); + const Matrix &feats = feats_reader.Value(); + + Matrix targets; + + if (target_dim < 0) + targets.Resize(feats.NumRows(), feats.NumCols()); + else + targets.Resize(feats.NumRows(), target_dim); + + if (target_type == "Snr") + targets.Set(-kLogZeroDouble); + + if (!ali_rspecifier.empty()) { + if (!alignment_reader.HasKey(key)) { + KALDI_WARN << "Could not find uniq key " << key + << " in alignment " << ali_rspecifier; + num_err++; + continue; + } + + const std::vector &ali = alignment_reader.Value(key); + + if (std::abs(static_cast (ali.size()) - feats.NumRows()) > length_tolerance) { + KALDI_WARN << "Mismatch in number of frames in alignment " + << "and feats; " << static_cast(ali.size()) + << " vs " << feats.NumRows(); + num_err++; + continue; + } + + int32 length = std::min(static_cast(ali.size()), feats.NumRows()); + KALDI_ASSERT(ali.size() >= length); + + for (int32 i = 0; i < feats.NumRows(); i++) { + if (std::binary_search(silence_phones.begin(), silence_phones.end(), ali[i])) { + targets.Row(i).Set(kLogZeroDouble); + num_sil_frames++; + } else { + num_speech_frames++; + } + } + + if (ceiling_str != "inf") { + targets.ApplyCeiling(ceiling); + } + + if (floor_str != "-inf") { + targets.ApplyFloor(floor); + } + + kaldi_writer.Write(key, targets); + } + } + } + + KALDI_LOG << "Computed SNR targets for " << num_success + << " out of " << num_done << " utterances; failed for " + << num_err; + KALDI_LOG << "Got [ " << num_speech_frames << "," + << num_sil_frames << "] frames of silence and speech"; + return (num_success > 0 ? 0 : 1); + } catch(const std::exception &e) { + std::cerr << e.what(); + return -1; + } +} diff --git a/src/matrix/kaldi-matrix.cc b/src/matrix/kaldi-matrix.cc index 4c3948ba2f5..0b5191e1e7a 100644 --- a/src/matrix/kaldi-matrix.cc +++ b/src/matrix/kaldi-matrix.cc @@ -396,6 +396,87 @@ void MatrixBase::AddMat(const Real alpha, const MatrixBase& A, } } +template +void MatrixBase::LogAddExpMat(const Real alpha, const MatrixBase& A, + MatrixTransposeType transA) { + if (alpha == 0) return; + + if (&A == this) { + if (transA == kNoTrans) { + Add(alpha + 1.0); + } else { + KALDI_ASSERT(num_rows_ == num_cols_ && "AddMat: adding to self (transposed): not symmetric."); + Real *data = data_; + if (alpha == 1.0) { // common case-- handle separately. + for (MatrixIndexT row = 0; row < num_rows_; row++) { + for (MatrixIndexT col = 0; col < row; col++) { + Real *lower = data + (row * stride_) + col, + *upper = data + (col * stride_) + row; + Real sum = LogAdd(*lower, *upper); + *lower = *upper = sum; + } + *(data + (row * stride_) + row) += Log(2.0); // diagonal. + } + } else { + for (MatrixIndexT row = 0; row < num_rows_; row++) { + for (MatrixIndexT col = 0; col < row; col++) { + Real *lower = data + (row * stride_) + col, + *upper = data + (col * stride_) + row; + Real lower_tmp = *lower; + if (alpha > 0) { + *lower = LogAdd(*lower, Log(alpha) + *upper); + *upper = LogAdd(*upper, Log(alpha) + lower_tmp); + } else { + KALDI_ASSERT(alpha < 0); + *lower = LogSub(*lower, Log(-alpha) + *upper); + *upper = LogSub(*upper, Log(-alpha) + lower_tmp); + } + } + if (alpha > -1.0) + *(data + (row * stride_) + row) += Log(1.0 + alpha); // diagonal. + else + KALDI_ERR << "Cannot subtract log-matrices if the difference is " + << "negative"; + } + } + } + } else { + int aStride = (int) A.stride_; + Real *adata = A.data_, *data = data_; + if (transA == kNoTrans) { + KALDI_ASSERT(A.num_rows_ == num_rows_ && A.num_cols_ == num_cols_); + if (num_rows_ == 0) return; + for (MatrixIndexT row = 0; row < num_rows_; row++) { + for (MatrixIndexT col = 0; col < num_cols_; col++) { + Real *value = data + (row * stride_) + col, + *aValue = adata + (row * aStride) + col; + if (alpha > 0) + *value = LogAdd(*value, Log(alpha) + *aValue); + else { + KALDI_ASSERT(alpha < 0); + *value = LogSub(*value, Log(-alpha) + *aValue); + } + } + } + } else { + KALDI_ASSERT(A.num_cols_ == num_rows_ && A.num_rows_ == num_cols_); + if (num_rows_ == 0) return; + for (MatrixIndexT row = 0; row < num_rows_; row++) { + for (MatrixIndexT col = 0; col < num_cols_; col++) { + Real *value = data + (row * stride_) + col, + *aValue = adata + (col * aStride) + row; + if (alpha > 0) + *value = LogAdd(*value, Log(alpha) + *aValue); + else { + KALDI_ASSERT(alpha < 0); + *value = LogSub(*value, Log(-alpha) + *aValue); + } + } + } + } + } +} + template template void MatrixBase::AddSp(const Real alpha, const SpMatrix &S) { diff --git a/src/matrix/kaldi-matrix.h b/src/matrix/kaldi-matrix.h index dccd52a9af4..b5a6bc7521d 100644 --- a/src/matrix/kaldi-matrix.h +++ b/src/matrix/kaldi-matrix.h @@ -548,6 +548,10 @@ class MatrixBase { /// *this += alpha * M [or M^T] void AddMat(const Real alpha, const MatrixBase &M, MatrixTransposeType transA = kNoTrans); + + /// *this += alpha * M [or M^T] when the matrices are stored as log + void LogAddExpMat(const Real alpha, const MatrixBase &M, + MatrixTransposeType transA = kNoTrans); /// *this = beta * *this + alpha * M M^T, for symmetric matrices. It only /// updates the lower triangle of *this. It will leave the matrix asymmetric; From ef36cf5f41d3800ff3768ceb34816c3c53159151 Mon Sep 17 00:00:00 2001 From: Vimal Manohar Date: Wed, 23 Nov 2016 21:56:24 -0500 Subject: [PATCH 047/213] asr_diarization: make_snr_targets.sh --- .../s5/steps/segmentation/make_snr_targets.sh | 104 ++++++++++++++++++ 1 file changed, 104 insertions(+) create mode 100755 egs/wsj/s5/steps/segmentation/make_snr_targets.sh diff --git a/egs/wsj/s5/steps/segmentation/make_snr_targets.sh b/egs/wsj/s5/steps/segmentation/make_snr_targets.sh new file mode 100755 index 00000000000..71f603a690e --- /dev/null +++ b/egs/wsj/s5/steps/segmentation/make_snr_targets.sh @@ -0,0 +1,104 @@ +#!/bin/bash + +# Copyright 2015-16 Vimal Manohar +# Apache 2.0 +set -e +set -o pipefail + +nj=4 +cmd=run.pl +stage=0 + +data_id= + +compress=true +target_type=Irm +apply_exp=false + +ali_rspecifier= +silence_phones_str=0 + +ignore_noise_dir=false + +ceiling=inf +floor=-inf + +length_tolerance=2 +transform_matrix= + +echo "$0 $@" # Print the command line for logging + +if [ -f path.sh ]; then . ./path.sh; fi +. parse_options.sh || exit 1; + +if [ $# != 5 ]; then + echo "Usage: $0 [options] --target-type (Irm|Snr) "; + echo " or : $0 [options] --target-type FbankMask "; + echo "e.g.: $0 data/train_clean_fbank data/train_noise_fbank data/train_corrupted_hires exp/make_snr_targets/train snr_targets" + echo "options: " + echo " --nj # number of parallel jobs" + echo " --cmd (utils/run.pl|utils/queue.pl ) # how to run jobs." + exit 1; +fi + +clean_data=$1 +noise_or_noisy_data=$2 +data=$3 +tmpdir=$4 +targets_dir=$5 + +mkdir -p $targets_dir + +[ -z "$data_id" ] && data_id=`basename $data` + +utils/split_data.sh $clean_data $nj + +for n in `seq $nj`; do + utils/subset_data_dir.sh --utt-list $clean_data/split$nj/$n/utt2spk $noise_or_noisy_data $noise_or_noisy_data/subset${nj}/$n +done + +$ignore_noise_dir && utils/split_data.sh $data $nj + +targets_dir=`perl -e '($data,$pwd)= @ARGV; if($data!~m:^/:) { $data = "$pwd/$data"; } print $data; ' $targets_dir ${PWD}` + +for n in `seq $nj`; do + utils/create_data_link.pl $targets_dir/${data_id}.$n.ark +done + +apply_exp_opts= +if $apply_exp; then + apply_exp_opts=" copy-matrix --apply-exp=true ark:- ark:- |" +fi + +copy_feats_opts="copy-feats" +if [ ! -z "$transform_matrix" ]; then + copy_feats_opts="transform-feats $transform_matrix" +fi + +if [ $stage -le 1 ]; then + if ! $ignore_noise_dir; then + $cmd JOB=1:$nj $tmpdir/make_`basename $targets_dir`_${data_id}.JOB.log \ + compute-snr-targets --length-tolerance=$length_tolerance --target-type=$target_type \ + ${ali_rspecifier:+--ali-rspecifier="$ali_rspecifier" --silence-phones=$silence_phones_str} \ + --floor=$floor --ceiling=$ceiling \ + "ark:$copy_feats_opts scp:$clean_data/split$nj/JOB/feats.scp ark:- |" \ + "ark,s,cs:$copy_feats_opts scp:$noise_or_noisy_data/subset$nj/JOB/feats.scp ark:- |" \ + ark:- \|$apply_exp_opts \ + copy-feats --compress=$compress ark:- \ + ark,scp:$targets_dir/${data_id}.JOB.ark,$targets_dir/${data_id}.JOB.scp || exit 1 + else + feat_dim=$(feat-to-dim scp:$data/feats.scp -) || exit 1 + $cmd JOB=1:$nj $tmpdir/make_`basename $targets_dir`_${data_id}.JOB.log \ + compute-snr-targets --length-tolerance=$length_tolerance --target-type=$target_type \ + ${ali_rspecifier:+--ali-rspecifier="$ali_rspecifier" --silence-phones=$silence_phones_str} \ + --floor=$floor --ceiling=$ceiling --binary-targets --target-dim=$feat_dim \ + scp:$data/split$nj/JOB/feats.scp \ + ark:- \|$apply_exp_opts \ + copy-feats --compress=$compress ark:- \ + ark,scp:$targets_dir/${data_id}.JOB.ark,$targets_dir/${data_id}.JOB.scp || exit 1 + fi +fi + +for n in `seq $nj`; do + cat $targets_dir/${data_id}.$n.scp +done > $data/`basename $targets_dir`.scp From c3da17a90f0a8c22e0abee6e49f672d986d6bf80 Mon Sep 17 00:00:00 2001 From: Vimal Manohar Date: Thu, 24 Nov 2016 03:09:24 -0500 Subject: [PATCH 048/213] asr_diarization: Added script to get DCT matrix --- egs/wsj/s5/utils/data/get_dct_matrix.py | 108 ++++++++++++++++++++++++ 1 file changed, 108 insertions(+) create mode 100755 egs/wsj/s5/utils/data/get_dct_matrix.py diff --git a/egs/wsj/s5/utils/data/get_dct_matrix.py b/egs/wsj/s5/utils/data/get_dct_matrix.py new file mode 100755 index 00000000000..88b28b5dd5c --- /dev/null +++ b/egs/wsj/s5/utils/data/get_dct_matrix.py @@ -0,0 +1,108 @@ +#!/usr/bin/env python + +# we're using python 3.x style print but want it to work in python 2.x, +from __future__ import print_function +import os, argparse, sys, math, warnings + +import numpy as np + +def ComputeLifterCoeffs(Q, dim): + coeffs = np.zeros((dim)) + for i in range(0, dim): + coeffs[i] = 1.0 + 0.5 * Q * math.sin(math.pi * i / Q); + + return coeffs + +def ComputeIDctMatrix(K, N, cepstral_lifter=0): + matrix = np.zeros((K, N)) + # normalizer for X_0 + normalizer = math.sqrt(1.0 / N); + for j in range(0, N): + matrix[0, j] = normalizer; + # normalizer for other elements + normalizer = math.sqrt(2.0 / N); + for k in range(1, K): + for n in range(0, N): + matrix[k, n] = normalizer * math.cos(math.pi/N * (n + 0.5) * k); + + if cepstral_lifter != 0: + lifter_coeffs = ComputeLifterCoeffs(cepstral_lifter, K) + for k in range(0, K): + matrix[k, :] = matrix[k, :] / lifter_coeffs[k]; + + return matrix.T + +def ComputeDctMatrix(K, N, cepstral_lifter=0): + matrix = np.zeros((K, N)) + # normalizer for X_0 + normalizer = math.sqrt(1.0 / N); + for j in range(0, N): + matrix[0, j] = normalizer; + # normalizer for other elements + normalizer = math.sqrt(2.0 / N); + for k in range(1, K): + for n in range(0, N): + matrix[k, n] = normalizer * math.cos(math.pi/N * (n + 0.5) * k); + + if cepstral_lifter != 0: + lifter_coeffs = ComputeLifterCoeffs(cepstral_lifter, K) + for k in range(0, K): + matrix[k, :] = matrix[k, :] * lifter_coeffs[k]; + + return matrix + +def GetArgs(): + parser = argparse.ArgumentParser(description="Write DCT/IDCT matrix") + parser.add_argument("--cepstral-lifter", type=float, + help="Here we need the scaling factor on cepstra in the production of MFCC" + "to cancel out the effect of lifter, e.g. 22.0", default=22.0) + parser.add_argument("--num-ceps", type=int, + default=13, + help="Number of cepstral dimensions") + parser.add_argument("--num-filters", type=int, + default=23, + help="Number of mel filters") + parser.add_argument("--get-idct-matrix", type=str, default="false", + choices=["true","false"], + help="Get IDCT matrix instead of DCT matrix") + parser.add_argument("--add-zero-column", type=str, default="true", + choices=["true","false"], + help="Add a column to convert the matrix from a linear transform to affine transform") + parser.add_argument("out_file", type=str, + help="Output file") + + args = parser.parse_args() + + return args + +def CheckArgs(args): + if args.num_ceps > args.num_filters: + raise Exception("num-ceps must not be larger than num-filters") + + args.out_file_handle = open(args.out_file, 'w') + + return args + +def Main(): + args = GetArgs() + args = CheckArgs(args) + + if args.get_idct_matrix == "false": + matrix = ComputeDctMatrix(args.num_ceps, args.num_filters, + args.cepstral_lifter) + if args.add_zero_column == "true": + matrix = np.append(matrix, np.zeros((args.num_ceps,1)), 1) + else: + matrix = ComputeIDctMatrix(args.num_ceps, args.num_filters, + args.cepstral_lifter) + + if args.add_zero_column == "true": + matrix = np.append(matrix, np.zeros((args.num_filters,1)), 1) + + print('[ ', file=args.out_file_handle) + np.savetxt(args.out_file_handle, matrix, fmt='%.6e') + print(' ]', file=args.out_file_handle) + +if __name__ == "__main__": + Main() + From 83cbdd6a695c2523396a25f7294692936a34da9b Mon Sep 17 00:00:00 2001 From: Vimal Manohar Date: Tue, 29 Nov 2016 22:58:53 -0500 Subject: [PATCH 049/213] asr_diarization_clean: Adding run_train_sad.sh --- .../s5/local/segmentation/run_train_sad.sh | 150 ++++++++++++++++++ 1 file changed, 150 insertions(+) create mode 100755 egs/aspire/s5/local/segmentation/run_train_sad.sh diff --git a/egs/aspire/s5/local/segmentation/run_train_sad.sh b/egs/aspire/s5/local/segmentation/run_train_sad.sh new file mode 100755 index 00000000000..9b1f104939a --- /dev/null +++ b/egs/aspire/s5/local/segmentation/run_train_sad.sh @@ -0,0 +1,150 @@ +#!/bin/bash + +# this is the standard "tdnn" system, built in nnet3; it's what we use to +# call multi-splice. + +set -o pipefail +set -e +set -u + +. cmd.sh + +# At this script level we don't support not running on GPU, as it would be painfully slow. +# If you want to run without GPU you'd have to call train_tdnn.sh with --gpu false, +# --num-threads 16 and --minibatch-size 128. + +stage=0 +train_stage=-10 +get_egs_stage=-10 +egs_opts= + +splice_indexes="-3,-2,-1,0,1,2,3 -6,0 -9,0,3 0" +relu_dim=256 + +# training options +num_epochs=2 +initial_effective_lrate=0.0003 +final_effective_lrate=0.00003 +num_jobs_initial=3 +num_jobs_final=8 +remove_egs=false +max_param_change=1 +extra_egs_copy_cmd= + +num_utts_subset_valid=40 +num_utts_subset_train=40 +add_idct=true + +# target options +train_data_dir=data/train_azteec_whole_sp_corrupted_hires + +snr_scp= +speech_feat_scp= + +deriv_weights_scp= +deriv_weights_for_irm_scp= + +egs_dir= +nj=40 +feat_type=raw +config_dir= +compute_objf_opts= + +dir= +affix=a + +. cmd.sh +. ./path.sh +. ./utils/parse_options.sh + +num_hidden_layers=`echo $splice_indexes | perl -ane 'print scalar @F'` || exit 1 +if [ -z "$dir" ]; then + dir=exp/nnet3_sad_snr/nnet_tdnn +fi + +dir=$dir${affix:+_$affix}_n${num_hidden_layers} + +if ! cuda-compiled; then + cat < Date: Tue, 6 Sep 2016 14:18:40 -0400 Subject: [PATCH 050/213] asr_diarization: Modified online ivector extraction to accept frame weights option --- .../get_ivector_weights_from_ctm_conf.pl | 77 +++++++++++++++++++ .../online/nnet2/extract_ivectors_online.sh | 35 +++++++-- src/online2bin/ivector-extract-online2.cc | 35 ++++++++- 3 files changed, 140 insertions(+), 7 deletions(-) create mode 100755 egs/ami/s5b/local/get_ivector_weights_from_ctm_conf.pl diff --git a/egs/ami/s5b/local/get_ivector_weights_from_ctm_conf.pl b/egs/ami/s5b/local/get_ivector_weights_from_ctm_conf.pl new file mode 100755 index 00000000000..96db9af3638 --- /dev/null +++ b/egs/ami/s5b/local/get_ivector_weights_from_ctm_conf.pl @@ -0,0 +1,77 @@ +#! /usr/bin/perl +use strict; +use warnings; +use Getopt::Long; + +my $pad_frames = 0; +my $silence_weight = 0.00001; +my $scale_weights_by_ctm_conf = "false"; +my $frame_shift = 0.01; + +GetOptions('pad-frames:i' => \$pad_frames, + 'silence-weight:f' => \$frame_shift, + 'scale-weights-by-ctm-conf:s' => \$scale_weights_by_ctm_conf, + 'frame-shift:f' => \$frame_shift); + +if (scalar @ARGV != 1) { + die "Usage: get_ivector_weights_from_ctm_conf.pl < > "; +} + +my $utt2dur = shift @ARGV; + +$pad_frames >= 0 || die "Bad pad-frames value $pad_frames; must be >= 0"; +($scale_weights_by_ctm_conf eq 'false') || ($scale_weights_by_ctm_conf eq 'true') || die "Bad scale-weights-by-ctm-conf $scale_weights_by_ctm_conf; must be true/false"; + +open(L, "<$utt2dur") || die "unable to open utt2dur file $utt2dur"; + +my @all_utts = (); +my %utt2weights; + +while () { + chomp; + my @A = split; + @A == 2 || die "Incorrent format of utt2dur file $_"; + my ($utt, $len) = @A; + + push @all_utts, $utt; + $len = int($len / $frame_shift); + + # Initialize weights for each utterance + my $weights = []; + for (my $n = 0; $n < $len; $n++) { + push @$weights, $silence_weight; + } + $utt2weights{$utt} = $weights; +} +close(L); + +while () { + chomp; + my @A = split; + @A == 6 || die "bad ctm line $_"; + + my $utt = $A[0]; + my $beg = $A[2]; + my $len = $A[3]; + my $beg_int = int($beg / $frame_shift) - $pad_frames; + my $len_int = int($len / $frame_shift) + 2*$pad_frames; + my $conf = $A[5]; + + my $array_ref = $utt2weights{$utt}; + defined $array_ref || die "No length info for utterance $utt"; + + for (my $t = $beg_int; $t < $beg_int + $len_int; $t++) { + if ($t >= 0 && $t < @$array_ref) { + if ($scale_weights_by_ctm_conf eq "false") { + ${$array_ref}[$t] = 1; + } else { + ${$array_ref}[$t] = $conf; + } + } + } +} + +foreach my $utt (keys %utt2weights) { + my $array_ref = $utt2weights{$utt}; + print ($utt, " [ ", join(" ", @$array_ref), " ]\n"); +} diff --git a/egs/wsj/s5/steps/online/nnet2/extract_ivectors_online.sh b/egs/wsj/s5/steps/online/nnet2/extract_ivectors_online.sh index b52de1f516b..f1edd874fa6 100755 --- a/egs/wsj/s5/steps/online/nnet2/extract_ivectors_online.sh +++ b/egs/wsj/s5/steps/online/nnet2/extract_ivectors_online.sh @@ -42,6 +42,9 @@ max_count=0 # The use of this option (e.g. --max-count 100) can make # posterior-scaling, so assuming the posterior-scale is 0.1, # --max-count 100 starts having effect after 1000 frames, or # 10 seconds of data. +weights= +use_most_recent_ivector=true +max_remembered_frames=1000 # End configuration section. @@ -89,6 +92,8 @@ splice_opts=$(cat $srcdir/splice_opts) # involved in online decoding. We need to create a config file for iVector # extraction. +absdir=$(readlink -f $dir) + ieconf=$dir/conf/ivector_extractor.conf echo -n >$ieconf cp $srcdir/online_cmvn.conf $dir/conf/ || exit 1; @@ -103,12 +108,19 @@ echo "--ivector-extractor=$srcdir/final.ie" >>$ieconf echo "--num-gselect=$num_gselect" >>$ieconf echo "--min-post=$min_post" >>$ieconf echo "--posterior-scale=$posterior_scale" >>$ieconf -echo "--max-remembered-frames=1000" >>$ieconf # the default +echo "--max-remembered-frames=$max_remembered_frames" >>$ieconf # the default echo "--max-count=$max_count" >>$ieconf +echo "--use-most-recent-ivector=$use_most_recent_ivector" >>$use_most_recent_ivector +if [ ! -z "$weights" ]; then + if [ -f $weights ] && gunzip -c $weights > /dev/null; then + cp -f $weights $absdir/weights.gz || exit 1 + else + echo "Could not open file $weights" + exit 1 + fi +fi -absdir=$(readlink -f $dir) - for n in $(seq $nj); do # This will do nothing unless the directory $dir/storage exists; # it can be used to distribute the data among multiple machines. @@ -117,10 +129,21 @@ done if [ $stage -le 0 ]; then echo "$0: extracting iVectors" - $cmd JOB=1:$nj $dir/log/extract_ivectors.JOB.log \ - ivector-extract-online2 --config=$ieconf ark:$sdata/JOB/spk2utt scp:$sdata/JOB/feats.scp ark:- \| \ - copy-feats --compress=$compress ark:- \ + if [ ! -z "$weights" ]; then + $cmd JOB=1:$nj $dir/log/extract_ivectors.JOB.log \ + ivector-extract-online2 --config=$ieconf \ + --frame-weights-rspecifier="ark:gunzip -c $absdir/weights.gz |" \ + --length-tolerance=1 \ + ark:$sdata/JOB/spk2utt scp:$sdata/JOB/feats.scp ark:- \| \ + copy-feats --compress=$compress ark:- \ + ark,scp:$absdir/ivector_online.JOB.ark,$absdir/ivector_online.JOB.scp || exit 1; + else + $cmd JOB=1:$nj $dir/log/extract_ivectors.JOB.log \ + ivector-extract-online2 --config=$ieconf \ + ark:$sdata/JOB/spk2utt scp:$sdata/JOB/feats.scp ark:- \| \ + copy-feats --compress=$compress ark:- \ ark,scp:$absdir/ivector_online.JOB.ark,$absdir/ivector_online.JOB.scp || exit 1; + fi fi if [ $stage -le 1 ]; then diff --git a/src/online2bin/ivector-extract-online2.cc b/src/online2bin/ivector-extract-online2.cc index 3251d93b5dd..f597f66763b 100644 --- a/src/online2bin/ivector-extract-online2.cc +++ b/src/online2bin/ivector-extract-online2.cc @@ -55,6 +55,8 @@ int main(int argc, char *argv[]) { g_num_threads = 8; bool repeat = false; + int32 length_tolerance = 0; + std::string frame_weights_rspecifier; po.Register("num-threads", &g_num_threads, "Number of threads to use for computing derived variables " @@ -62,6 +64,12 @@ int main(int argc, char *argv[]) { po.Register("repeat", &repeat, "If true, output the same number of iVectors as input frames " "(including repeated data)."); + po.Register("frame-weights-rspecifier", &frame_weights_rspecifier, + "Archive of frame weights to scale stats"); + po.Register("length-tolerance", &length_tolerance, + "Tolerance on the difference in number of frames " + "for feats and weights"); + po.Read(argc, argv); if (po.NumArgs() != 3) { @@ -82,9 +90,9 @@ int main(int argc, char *argv[]) { SequentialTokenVectorReader spk2utt_reader(spk2utt_rspecifier); RandomAccessBaseFloatMatrixReader feature_reader(feature_rspecifier); + RandomAccessBaseFloatVectorReader frame_weights_reader(frame_weights_rspecifier); BaseFloatMatrixWriter ivector_writer(ivectors_wspecifier); - for (; !spk2utt_reader.Done(); spk2utt_reader.Next()) { std::string spk = spk2utt_reader.Key(); const std::vector &uttlist = spk2utt_reader.Value(); @@ -105,6 +113,31 @@ int main(int argc, char *argv[]) { &matrix_feature); ivector_feature.SetAdaptationState(adaptation_state); + + if (!frame_weights_rspecifier.empty()) { + if (!frame_weights_reader.HasKey(utt)) { + KALDI_WARN << "Did not find weights for utterance " << utt; + num_err++; + continue; + } + const Vector &weights = frame_weights_reader.Value(utt); + + if (std::abs(weights.Dim() - feats.NumRows()) > length_tolerance) { + num_err++; + continue; + } + + std::vector > frame_weights; + for (int32 i = 0; i < feats.NumRows(); i++) { + if (i < weights.Dim()) + frame_weights.push_back(std::make_pair(i, weights(i))); + else + frame_weights.push_back(std::make_pair(i, 0.0)); + } + + + ivector_feature.UpdateFrameWeights(frame_weights); + } int32 T = feats.NumRows(), n = (repeat ? 1 : ivector_config.ivector_period), From 88970c81cd4f581660d828a51a335ed5b7161674 Mon Sep 17 00:00:00 2001 From: Vimal Manohar Date: Fri, 2 Sep 2016 14:52:24 -0400 Subject: [PATCH 051/213] asr_diarization: Added script to resolve CTM overlaps --- egs/wsj/s5/steps/resolve_ctm_overlaps.py | 149 +++++++++++++++++++++++ 1 file changed, 149 insertions(+) create mode 100755 egs/wsj/s5/steps/resolve_ctm_overlaps.py diff --git a/egs/wsj/s5/steps/resolve_ctm_overlaps.py b/egs/wsj/s5/steps/resolve_ctm_overlaps.py new file mode 100755 index 00000000000..aaee767e7e4 --- /dev/null +++ b/egs/wsj/s5/steps/resolve_ctm_overlaps.py @@ -0,0 +1,149 @@ +#!/usr/bin/env python +# Copyright 2014 Johns Hopkins University (Authors: Daniel Povey, Vijayaditya Peddinti). +# 2016 Vimal Manohar +# Apache 2.0. + +# Script to combine ctms with overlapping segments + +import sys, math, numpy as np, argparse +break_threshold = 0.01 + +def ReadSegments(segments_file): + segments = {} + for line in open(segments_file).readlines(): + parts = line.strip().split() + segments[parts[0]] = (parts[1], float(parts[2]), float(parts[3])) + return segments + +#def get_breaks(ctm, prev_end): +# breaks = [] +# for i in xrange(0, len(ctm)): +# if ctm[i][2] - prev_end > break_threshold: +# breaks.append([i, ctm[i][2]]) +# prev_end = ctm[i][2] + ctm[i][3] +# return np.array(breaks) + +# Resolve overlaps within segments of the same recording +def ResolveOverlaps(ctms, segments): + total_ctm = [] + if len(ctms) == 0: + raise Exception('Something wrong with the input ctms') + + next_utt = ctms[0][0][0] + for ctm_index in range(len(ctms) - 1): + # Assumption here is that the segments are written in consecutive order? + cur_ctm = ctms[ctm_index] + next_ctm = ctms[ctm_index + 1] + + cur_utt = next_utt + next_utt = next_ctm[0][0] + if (next_utt not in segments): + raise Exception('Could not find utterance %s in segments' % next_utt) + + if len(cur_ctm) > 0: + assert(cur_utt == cur_ctm[0][0]) + + assert(next_utt > cur_utt) + if (cur_utt not in segments): + raise Exception('Could not find utterance %s in segments' % cur_utt) + + # length of this segment + window_length = segments[cur_utt][2] - segments[cur_utt][1] + + # overlap of this segment with the next segment + # Note: It is possible for this to be negative when there is actually + # no overlap between consecutive segments. + overlap = segments[cur_utt][2] - segments[next_utt][1] + + # find the breaks after overlap starts + index = len(cur_ctm) + + for i in xrange(len(cur_ctm)): + if (cur_ctm[i][2] + cur_ctm[i][3]/2.0 > (window_length - overlap/2.0)): + # if midpoint of a hypothesis word is beyond the midpoint of the + # overlap region + index = i + break + + # Ignore the hypotheses beyond this midpoint. They will be considered as + # part of the next segment. + total_ctm += cur_ctm[:index] + + # Ignore the hypotheses of the next utterance that overlaps with the + # current utterance + index = -1 + for i in xrange(len(next_ctm)): + if (next_ctm[i][2] + next_ctm[i][3]/2.0 > (overlap/2.0)): + index = i + break + + if index >= 0: + ctms[ctm_index + 1] = next_ctm[index:] + else: + ctms[ctm_index + 1] = [] + + # merge the last ctm entirely + total_ctm += ctms[-1] + + return total_ctm + +def ReadCtm(ctm_file_lines, segments): + ctms = {} + for key in [ x[0] for x in segments.values() ]: + ctms[key] = [] + + ctm = [] + prev_utt = ctm_file_lines[0].split()[0] + for line in ctm_file_lines: + parts = line.split() + if (prev_utt == parts[0]): + ctm.append([parts[0], parts[1], float(parts[2]), + float(parts[3])] + parts[4:]) + else: + # New utterance. Append the previous utterance's CTM + # into the list for the utterance's recording + ctms[segments[ctm[0][0]][0]].append(ctm) + + assert(parts[0] > prev_utt) + + prev_utt = parts[0] + ctm = [] + ctm.append([parts[0], parts[1], float(parts[2]), + float(parts[3])] + parts[4:]) + + # append the last ctm + ctms[segments[ctm[0][0]][0]].append(ctm) + return ctms + +def WriteCtm(ctm_lines, out_file): + for line in ctm_lines: + out_file.write("{0} {1} {2} {3} {4}\n".format(line[0], line[1], line[2], line[3], " ".join(line[4:]))) + +if __name__ == "__main__": + usage = """ Python script to resolve overlaps in ctms """ + parser = argparse.ArgumentParser(usage) + parser.add_argument('segments', type=str, help = 'use segments to resolve overlaps') + parser.add_argument('ctm_in', type=str, help='input_ctm_file') + parser.add_argument('ctm_out', type=str, help='output_ctm_file') + params = parser.parse_args() + + if params.ctm_in == "-": + params.ctm_in = sys.stdin + else: + params.ctm_in = open(params.ctm_in) + if params.ctm_out == "-": + params.ctm_out = sys.stdout + else: + params.ctm_out = open(params.ctm_out, 'w') + + segments = ReadSegments(params.segments) + + # Read CTMs into a dictionary indexed by the recording + ctms = ReadCtm(params.ctm_in.readlines(), segments) + + for key in sorted(ctms.keys()): + # Process CTMs in the sorted order of recordings + ctm_reco = ctms[key] + ctm_reco = ResolveOverlaps(ctm_reco, segments) + WriteCtm(ctm_reco, params.ctm_out) + params.ctm_out.close() From eb727f101347bf2486b071e8060ff9679d392e4a Mon Sep 17 00:00:00 2001 From: Vimal Manohar Date: Fri, 2 Sep 2016 14:47:24 -0400 Subject: [PATCH 052/213] asr_diarization: AMI script without ivectors --- egs/ami/s5b/local/chain/run_tdnn_noivec.sh | 245 +++++++++++++++++++++ 1 file changed, 245 insertions(+) create mode 100755 egs/ami/s5b/local/chain/run_tdnn_noivec.sh diff --git a/egs/ami/s5b/local/chain/run_tdnn_noivec.sh b/egs/ami/s5b/local/chain/run_tdnn_noivec.sh new file mode 100755 index 00000000000..d1329dc2bd1 --- /dev/null +++ b/egs/ami/s5b/local/chain/run_tdnn_noivec.sh @@ -0,0 +1,245 @@ +#!/bin/bash + +# This is a chain-training script with TDNN neural networks. +# Please see RESULTS_* for examples of command lines invoking this script. + + +# local/nnet3/run_tdnn.sh --stage 8 --use-ihm-ali true --mic sdm1 # rerunning with biphone +# local/nnet3/run_tdnn.sh --stage 8 --use-ihm-ali false --mic sdm1 + +# local/chain/run_tdnn.sh --use-ihm-ali true --mic sdm1 --train-set train --gmm tri3 --nnet3-affix "" --stage 12 & + +# local/chain/run_tdnn.sh --use-ihm-ali true --mic mdm8 --stage 12 & +# local/chain/run_tdnn.sh --use-ihm-ali true --mic mdm8 --train-set train --gmm tri3 --nnet3-affix "" --stage 12 & + +# local/chain/run_tdnn.sh --mic sdm1 --use-ihm-ali true --train-set train_cleaned --gmm tri3_cleaned& + + +set -e -o pipefail + +# First the options that are passed through to run_ivector_common.sh +# (some of which are also used in this script directly). +stage=0 +mic=ihm +nj=30 +min_seg_len=1.55 +use_ihm_ali=false +train_set=train_cleaned +gmm=tri3_cleaned # the gmm for the target data +ihm_gmm=tri3 # the gmm for the IHM system (if --use-ihm-ali true). +num_threads_ubm=32 +nnet3_affix=_cleaned # cleanup affix for nnet3 and chain dirs, e.g. _cleaned + +# The rest are configs specific to this script. Most of the parameters +# are just hardcoded at this level, in the commands below. +train_stage=-10 +tree_affix= # affix for tree directory, e.g. "a" or "b", in case we change the configuration. +tdnn_affix= #affix for TDNN directory, e.g. "a" or "b", in case we change the configuration. +common_egs_dir= # you can set this to use previously dumped egs. + +# End configuration section. +echo "$0 $@" # Print the command line for logging + +. cmd.sh +. ./path.sh +. ./utils/parse_options.sh + + +if ! cuda-compiled; then + cat <data/lang_chain/topo + fi +fi + +if [ $stage -le 13 ]; then + # Get the alignments as lattices (gives the chain training more freedom). + # use the same num-jobs as the alignments + steps/align_fmllr_lats.sh --nj 100 --cmd "$train_cmd" ${lores_train_data_dir} \ + data/lang $gmm_dir $lat_dir + rm $lat_dir/fsts.*.gz # save space +fi + +if [ $stage -le 14 ]; then + # Build a tree using our new topology. We know we have alignments for the + # speed-perturbed data (local/nnet3/run_ivector_common.sh made them), so use + # those. + if [ -f $tree_dir/final.mdl ]; then + echo "$0: $tree_dir/final.mdl already exists, refusing to overwrite it." + exit 1; + fi + steps/nnet3/chain/build_tree.sh --frame-subsampling-factor 3 \ + --context-opts "--context-width=2 --central-position=1" \ + --leftmost-questions-truncate -1 \ + --cmd "$train_cmd" 4200 ${lores_train_data_dir} data/lang_chain $ali_dir $tree_dir +fi + +if [ $stage -le 15 ]; then + mkdir -p $dir + + echo "$0: creating neural net configs"; + + steps/nnet3/tdnn/make_configs.py \ + --self-repair-scale-nonlinearity 0.00001 \ + --feat-dir data/$mic/${train_set}_sp_hires_comb \ + --tree-dir $tree_dir \ + --relu-dim 450 \ + --splice-indexes "-1,0,1 -1,0,1,2 -3,0,3 -3,0,3 -3,0,3 -6,-3,0 0" \ + --use-presoftmax-prior-scale false \ + --xent-regularize 0.1 \ + --xent-separate-forward-affine true \ + --include-log-softmax false \ + --final-layer-normalize-target 1.0 \ + $dir/configs || exit 1; +fi + +if [ $stage -le 16 ]; then + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $dir/egs/storage ]; then + utils/create_split_dir.pl \ + /export/b0{5,6,7,8}/$USER/kaldi-data/egs/ami-$(date +'%m_%d_%H_%M')/s5b/$dir/egs/storage $dir/egs/storage + fi + + touch $dir/egs/.nodelete # keep egs around when that run dies. + + steps/nnet3/chain/train.py --stage $train_stage \ + --cmd "$decode_cmd" \ + --feat.cmvn-opts "--norm-means=true --norm-vars=false" \ + --chain.xent-regularize 0.1 \ + --chain.leaky-hmm-coefficient 0.1 \ + --chain.l2-regularize 0.00005 \ + --chain.apply-deriv-weights false \ + --chain.lm-opts="--num-extra-lm-states=2000" \ + --egs.dir "$common_egs_dir" \ + --egs.opts "--frames-overlap-per-eg 0" \ + --egs.chunk-width 150 \ + --trainer.num-chunk-per-minibatch 128 \ + --trainer.frames-per-iter 1500000 \ + --trainer.num-epochs 4 \ + --trainer.optimization.num-jobs-initial 2 \ + --trainer.optimization.num-jobs-final 12 \ + --trainer.optimization.initial-effective-lrate 0.001 \ + --trainer.optimization.final-effective-lrate 0.0001 \ + --trainer.max-param-change 2.0 \ + --cleanup.remove-egs true \ + --feat-dir $train_data_dir \ + --tree-dir $tree_dir \ + --lat-dir $lat_dir \ + --dir $dir +fi + + +graph_dir=$dir/graph_${LM} +if [ $stage -le 17 ]; then + # Note: it might appear that this data/lang_chain directory is mismatched, and it is as + # far as the 'topo' is concerned, but this script doesn't read the 'topo' from + # the lang directory. + utils/mkgraph.sh --left-biphone --self-loop-scale 1.0 data/lang_${LM} $dir $graph_dir +fi + +if [ $stage -le 18 ]; then + rm $dir/.error 2>/dev/null || true + for decode_set in dev eval; do + ( + nj_dev=`cat data/$mic/${decode_set}_hires/spk2utt | wc -l` + if [ $nj_dev -gt 30 ]; then + nj_dev=30 + fi + steps/nnet3/decode.sh --acwt 1.0 --post-decode-acwt 10.0 \ + --nj $nj_dev --cmd "$decode_cmd" \ + --scoring-opts "--min-lmwt 5 " \ + $graph_dir data/$mic/${decode_set}_hires $dir/decode_${decode_set} || exit 1; + ) || touch $dir/.error & + done + wait + if [ -f $dir/.error ]; then + echo "$0: something went wrong in decoding" + exit 1 + fi +fi +exit 0 + From 8d529d98e76ce066889b4c3d634028e5459d7e2d Mon Sep 17 00:00:00 2001 From: Vimal Manohar Date: Sat, 10 Dec 2016 17:49:41 -0500 Subject: [PATCH 053/213] asr_diarization: Adding run_tdnn_1a.sh --- egs/ami/s5b/local/chain/tuning/run_tdnn_1a.sh | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/egs/ami/s5b/local/chain/tuning/run_tdnn_1a.sh b/egs/ami/s5b/local/chain/tuning/run_tdnn_1a.sh index 8df62af8bad..7a38dc80b26 100755 --- a/egs/ami/s5b/local/chain/tuning/run_tdnn_1a.sh +++ b/egs/ami/s5b/local/chain/tuning/run_tdnn_1a.sh @@ -226,8 +226,12 @@ if [ $stage -le 18 ]; then rm $dir/.error 2>/dev/null || true for decode_set in dev eval; do ( + nj_dev=`cat data/$mic/${decode_set}_hires/spk2utt | wc -l` + if [ $nj_dev -gt $nj ]; then + nj_dev=$nj + fi steps/nnet3/decode.sh --acwt 1.0 --post-decode-acwt 10.0 \ - --nj $nj --cmd "$decode_cmd" \ + --nj $nj_dev --cmd "$decode_cmd" \ --online-ivector-dir exp/$mic/nnet3${nnet3_affix}/ivectors_${decode_set}_hires \ --scoring-opts "--min-lmwt 5 " \ $graph_dir data/$mic/${decode_set}_hires $dir/decode_${decode_set} || exit 1; @@ -239,4 +243,4 @@ if [ $stage -le 18 ]; then exit 1 fi fi -exit 0 \ No newline at end of file +exit 0 From 37f57138b582efc7e44a1908d23660ee7a0f5545 Mon Sep 17 00:00:00 2001 From: Vimal Manohar Date: Sat, 10 Sep 2016 19:57:36 -0400 Subject: [PATCH 054/213] asr_diarization: Modified AMI scoring script to add overlap resolution and non-mbr decoding --- egs/ami/s5b/local/score.sh | 3 ++ egs/ami/s5b/local/score_asclite.sh | 75 +++++++++++++++++++++--------- 2 files changed, 57 insertions(+), 21 deletions(-) diff --git a/egs/ami/s5b/local/score.sh b/egs/ami/s5b/local/score.sh index 6a077c39644..c186c4b303d 100755 --- a/egs/ami/s5b/local/score.sh +++ b/egs/ami/s5b/local/score.sh @@ -15,6 +15,9 @@ min_lmwt=7 # unused, max_lmwt=15 # unused, iter=final asclite=true +overlap_spk=4 +resolve_overlaps=false # unused +decode_mbr=true #end configuration section. [ -f ./path.sh ] && . ./path.sh diff --git a/egs/ami/s5b/local/score_asclite.sh b/egs/ami/s5b/local/score_asclite.sh index 7327f6246af..0682722214a 100755 --- a/egs/ami/s5b/local/score_asclite.sh +++ b/egs/ami/s5b/local/score_asclite.sh @@ -12,6 +12,8 @@ max_lmwt=15 asclite=true iter=final overlap_spk=4 +resolve_overlaps=false +stm_suffix= # end configuration section. [ -f ./path.sh ] && . ./path.sh . parse_options.sh || exit 1; @@ -36,7 +38,7 @@ hubscr=$KALDI_ROOT/tools/sctk/bin/hubscr.pl [ ! -f $hubscr ] && echo "Cannot find scoring program at $hubscr" && exit 1; hubdir=`dirname $hubscr` -for f in $data/stm $data/glm $lang/words.txt $lang/phones/word_boundary.int \ +for f in $data/stm${stm_suffix} $data/glm $lang/words.txt $lang/phones/word_boundary.int \ $model $data/segments $data/reco2file_and_channel $dir/lat.1.gz; do [ ! -f $f ] && echo "$0: expecting file $f to exist" && exit 1; done @@ -55,36 +57,66 @@ nj=$(cat $dir/num_jobs) mkdir -p $dir/ascoring/log +copy_ctm_script="cat -" +if $resolve_overlaps; then + copy_ctm_script="steps/resolve_ctm_overlaps.py $data/segments - -" +fi + if [ $stage -le 0 ]; then for LMWT in $(seq $min_lmwt $max_lmwt); do rm -f $dir/.error ( - $cmd JOB=1:$nj $dir/ascoring/log/get_ctm.${LMWT}.JOB.log \ - mkdir -p $dir/ascore_${LMWT}/ '&&' \ - lattice-scale --inv-acoustic-scale=${LMWT} "ark:gunzip -c $dir/lat.JOB.gz|" ark:- \| \ - lattice-limit-depth ark:- ark:- \| \ - lattice-push --push-strings=false ark:- ark:- \| \ - lattice-align-words-lexicon --max-expand=10.0 \ - $lang/phones/align_lexicon.int $model ark:- ark:- \| \ - lattice-to-ctm-conf $frame_shift_opt --decode-mbr=$decode_mbr ark:- - \| \ - utils/int2sym.pl -f 5 $lang/words.txt \| \ - utils/convert_ctm.pl $data/segments $data/reco2file_and_channel \ - '>' $dir/ascore_${LMWT}/${name}.JOB.ctm || touch $dir/.error; + if $decode_mbr; then + $cmd JOB=1:$nj $dir/ascoring/log/get_ctm.${LMWT}.JOB.log \ + mkdir -p $dir/ascore_${LMWT}/ '&&' \ + lattice-scale --inv-acoustic-scale=${LMWT} "ark:gunzip -c $dir/lat.JOB.gz|" ark:- \| \ + lattice-limit-depth ark:- ark:- \| \ + lattice-push --push-strings=false ark:- ark:- \| \ + lattice-align-words-lexicon --max-expand=10.0 \ + $lang/phones/align_lexicon.int $model ark:- ark:- \| \ + lattice-to-ctm-conf $frame_shift_opt --decode-mbr=$decode_mbr ark:- - \| \ + utils/int2sym.pl -f 5 $lang/words.txt \ + '>' $dir/ascore_${LMWT}/${name}.JOB.ctm || touch $dir/.error; + else + $cmd JOB=1:$nj $dir/ascoring/log/get_ctm.${LMWT}.JOB.log \ + mkdir -p $dir/ascore_${LMWT}/ '&&' \ + lattice-scale --inv-acoustic-scale=${LMWT} "ark:gunzip -c $dir/lat.JOB.gz|" ark:- \| \ + lattice-limit-depth ark:- ark:- \| \ + lattice-1best ark:- ark:- \| \ + lattice-push --push-strings=false ark:- ark:- \| \ + lattice-align-words-lexicon --max-expand=10.0 \ + $lang/phones/align_lexicon.int $model ark:- ark:- \| \ + nbest-to-ctm $frame_shift_opt ark:- - \| \ + utils/int2sym.pl -f 5 $lang/words.txt \ + '>' $dir/ascore_${LMWT}/${name}.JOB.ctm || touch $dir/.error; + fi + # Merge and clean, - for ((n=1; n<=nj; n++)); do cat $dir/ascore_${LMWT}/${name}.${n}.ctm; done > $dir/ascore_${LMWT}/${name}.ctm - rm -f $dir/ascore_${LMWT}/${name}.*.ctm + for ((n=1; n<=nj; n++)); do + cat $dir/ascore_${LMWT}/${name}.${n}.ctm; + rm -f $dir/ascore_${LMWT}/${name}.${n}.ctm + done > $dir/ascore_${LMWT}/${name}.utt.ctm )& done wait; [ -f $dir/.error ] && echo "$0: error during ctm generation. check $dir/ascoring/log/get_ctm.*.log" && exit 1; fi +if [ $stage -le 1 ]; then + for LMWT in $(seq $min_lmwt $max_lmwt); do + cat $dir/ascore_${LMWT}/${name}.utt.ctm | \ + $copy_ctm_script | utils/convert_ctm.pl $data/segments $data/reco2file_and_channel \ + > $dir/ascore_${LMWT}/${name}.ctm || exit 1 + done +fi + if [ $stage -le 1 ]; then # Remove some stuff we don't want to score, from the ctm. # - we remove hesitations here, otherwise the CTM would have a bug! # (confidences in place of the removed hesitations), - for x in $dir/ascore_*/${name}.ctm; do - cp $x $x.tmpf; + for LMWT in $(seq $min_lmwt $max_lmwt); do + x=$dir/ascore_${LMWT}/${name}.ctm + mv $x $x.tmpf; cat $x.tmpf | grep -i -v -E '\[noise|laughter|vocalized-noise\]' | \ grep -i -v -E ' (ACH|AH|EEE|EH|ER|EW|HA|HEE|HM|HMM|HUH|MM|OOF|UH|UM) ' | \ grep -i -v -E '' > $x; @@ -94,8 +126,9 @@ fi if [ $stage -le 2 ]; then if [ "$asclite" == "true" ]; then - oname=$name + oname=${name} [ ! -z $overlap_spk ] && oname=${name}_o$overlap_spk + oname=${oname}${stm_suffix} echo "asclite is starting" # Run scoring, meaning of hubscr.pl options: # -G .. produce alignment graphs, @@ -109,10 +142,10 @@ if [ $stage -le 2 ]; then # -V .. skip validation of input transcripts, # -h rt-stt .. removes non-lexical items from CTM, $cmd LMWT=$min_lmwt:$max_lmwt $dir/ascoring/log/score.LMWT.log \ - cp $data/stm $dir/ascore_LMWT/ '&&' \ + cp $data/stm${stm_suffix} $dir/ascore_LMWT/ '&&' \ cp $dir/ascore_LMWT/${name}.ctm $dir/ascore_LMWT/${oname}.ctm '&&' \ $hubscr -G -v -m 1:2 -o$overlap_spk -a -C -B 8192 -p $hubdir -V -l english \ - -h rt-stt -g $data/glm -r $dir/ascore_LMWT/stm $dir/ascore_LMWT/${oname}.ctm || exit 1 + -h rt-stt -g $data/glm -r $dir/ascore_LMWT/stm${stm_suffix} $dir/ascore_LMWT/${oname}.ctm || exit 1 # Compress some scoring outputs : alignment info and graphs, echo -n "compressing asclite outputs " for LMWT in $(seq $min_lmwt $max_lmwt); do @@ -126,8 +159,8 @@ if [ $stage -le 2 ]; then echo done else $cmd LMWT=$min_lmwt:$max_lmwt $dir/ascoring/log/score.LMWT.log \ - cp $data/stm $dir/ascore_LMWT/ '&&' \ - $hubscr -p $hubdir -v -V -l english -h hub5 -g $data/glm -r $dir/ascore_LMWT/stm $dir/ascore_LMWT/${name}.ctm || exit 1 + cp $data/stm${stm_suffix} $dir/ascore_LMWT/ '&&' \ + $hubscr -p $hubdir -v -V -l english -h hub5 -g $data/glm -r $dir/ascore_LMWT/stm${suffix} $dir/ascore_LMWT/${name}${stm_suffix}.ctm || exit 1 fi fi From c5aeb9159d7a023683cca1cb9c3704ad578893d3 Mon Sep 17 00:00:00 2001 From: Vimal Manohar Date: Sat, 10 Sep 2016 19:58:18 -0400 Subject: [PATCH 055/213] asr_diarization: Get times per word from mbr sausage links --- src/lat/sausages.cc | 57 ++++++++++++++++++++++++++++++++++++--------- src/lat/sausages.h | 10 ++++++-- 2 files changed, 54 insertions(+), 13 deletions(-) diff --git a/src/lat/sausages.cc b/src/lat/sausages.cc index 53678efe844..0af1c0f6620 100644 --- a/src/lat/sausages.cc +++ b/src/lat/sausages.cc @@ -51,10 +51,19 @@ void MinimumBayesRisk::MbrDecode() { R_[q] = rhat; } if (R_[q] != 0) { - one_best_times_.push_back(times_[q]); BaseFloat confidence = 0.0; + bool first_time = true; for (int32 j = 0; j < gamma_[q].size(); j++) - if (gamma_[q][j].first == R_[q]) confidence = gamma_[q][j].second; + if (gamma_[q][j].first == R_[q]) { + KALDI_ASSERT(first_time); first_time = false; + confidence = gamma_[q][j].second; + KALDI_ASSERT(confidence > 0); + KALDI_ASSERT(begin_times_[q].count(R_[q]) > 0); + KALDI_ASSERT(end_times_[q].count(R_[q]) > 0); + one_best_times_.push_back(make_pair( + begin_times_[q][R_[q]] / confidence, + end_times_[q][R_[q]] / confidence)); + } one_best_confidences_.push_back(confidence); } } @@ -145,11 +154,13 @@ void MinimumBayesRisk::AccStats() { vector > gamma(Q+1); // temp. form of gamma. // index 1...Q [word] -> occ. + vector > tau_b(Q+1), tau_e(Q+1); + // The tau arrays below are the sums over words of the tau_b // and tau_e timing quantities mentioned in Appendix C of // the paper... we are using these to get averaged times for // the sausage bins, not specifically for the 1-best output. - Vector tau_b(Q+1), tau_e(Q+1); + //Vector tau_b(Q+1), tau_e(Q+1); double Ltmp = EditDistance(N, Q, alpha, alpha_dash, alpha_dash_arc); if (L_ != 0 && Ltmp > L_) { // L_ != 0 is to rule out 1st iter. @@ -189,8 +200,11 @@ void MinimumBayesRisk::AccStats() { // next: gamma(q, w(a)) += beta_dash_arc(q) AddToMap(w_a, beta_dash_arc(q), &(gamma[q])); // next: accumulating times, see decl for tau_b,tau_e - tau_b(q) += state_times_[s_a] * beta_dash_arc(q); - tau_e(q) += state_times_[n] * beta_dash_arc(q); + AddToMap(w_a, state_times_[s_a] * beta_dash_arc(q), &(tau_b[q]), false); + AddToMap(w_a, state_times_[n] * beta_dash_arc(q), &(tau_e[q]), false); + KALDI_ASSERT(tau_b[q].size() == tau_e[q].size()); + //tau_b(q) += state_times_[s_a] * beta_dash_arc(q); + //tau_e(q) += state_times_[n] * beta_dash_arc(q); break; case 2: beta_dash(s_a, q) += beta_dash_arc(q); @@ -203,8 +217,11 @@ void MinimumBayesRisk::AccStats() { // WARNING: there was an error in Appendix C. If we followed // the instructions there the next line would say state_times_[sa], but // it would be wrong. I will try to publish an erratum. - tau_b(q) += state_times_[n] * beta_dash_arc(q); - tau_e(q) += state_times_[n] * beta_dash_arc(q); + AddToMap(0, state_times_[n] * beta_dash_arc(q), &(tau_b[q]), false); + AddToMap(0, state_times_[n] * beta_dash_arc(q), &(tau_e[q]), false); + KALDI_ASSERT(tau_b[q].size() == tau_e[q].size()); + //tau_b(q) += state_times_[n] * beta_dash_arc(q); + //tau_e(q) += state_times_[n] * beta_dash_arc(q); break; default: KALDI_ERR << "Invalid b_arc value"; // error in code. @@ -221,8 +238,11 @@ void MinimumBayesRisk::AccStats() { AddToMap(0, beta_dash_arc(q), &(gamma[q])); // the statements below are actually redundant because // state_times_[1] is zero. - tau_b(q) += state_times_[1] * beta_dash_arc(q); - tau_e(q) += state_times_[1] * beta_dash_arc(q); + //tau_b(q) += state_times_[1] * beta_dash_arc(q); + //tau_e(q) += state_times_[1] * beta_dash_arc(q); + AddToMap(0, state_times_[1] * beta_dash_arc(q), &(tau_b[q]), false); + AddToMap(0, state_times_[1] * beta_dash_arc(q), &(tau_e[q]), false); + KALDI_ASSERT(tau_b[q].size() == tau_e[q].size()); } for (int32 q = 1; q <= Q; q++) { // a check (line 35) double sum = 0.0; @@ -249,9 +269,24 @@ void MinimumBayesRisk::AccStats() { // indexing. times_.clear(); times_.resize(Q); + begin_times_.clear(); + begin_times_.resize(Q); + end_times_.clear(); + end_times_.resize(Q); for (int32 q = 1; q <= Q; q++) { - times_[q-1].first = tau_b(q); - times_[q-1].second = tau_e(q); + KALDI_ASSERT(tau_b[q].size() == tau_e[q].size()); + for (map::iterator iter = tau_b[q].begin(); + iter != tau_b[q].end(); ++iter) { + times_[q-1].first += iter->second; + begin_times_[q-1].insert(make_pair(iter->first, iter->second)); + } + + for (map::iterator iter = tau_e[q].begin(); + iter != tau_e[q].end(); ++iter) { + times_[q-1].second += iter->second; + end_times_[q-1].insert(make_pair(iter->first, iter->second)); + } + if (times_[q-1].first > times_[q-1].second) // this is quite bad. KALDI_WARN << "Times out of order"; if (q > 1 && times_[q-2].second > times_[q-1].first) { diff --git a/src/lat/sausages.h b/src/lat/sausages.h index 8ada15e64b5..4f709bf1703 100644 --- a/src/lat/sausages.h +++ b/src/lat/sausages.h @@ -133,8 +133,8 @@ class MinimumBayesRisk { // used in the algorithm. /// Function used to increment map. - static inline void AddToMap(int32 i, double d, std::map *gamma) { - if (d == 0) return; + static inline void AddToMap(int32 i, double d, std::map *gamma, bool return_if_zero = true) { + if (return_if_zero && d == 0) return; std::pair pr(i, d); std::pair::iterator, bool> ret = gamma->insert(pr); if (!ret.second) // not inserted, so add to contents. @@ -178,6 +178,12 @@ class MinimumBayesRisk { // paper. We sort in reverse order on the second member (posterior), so more // likely word is first. + std::vector > begin_times_; + std::vector > end_times_; + // The average start and end times for each word in a confusion-network bin. + // These are the tau_b and tau_e quantities in Appendix C of the paper. + // Indexed from zero, like gamma_ and R_. + std::vector > times_; // The average start and end times for each confusion-network bin. This // is like an average over words, of the tau_b and tau_e quantities in From ae14b71367da1d51274a7cf664f1e321dac72f97 Mon Sep 17 00:00:00 2001 From: Vimal Manohar Date: Sat, 10 Sep 2016 19:58:36 -0400 Subject: [PATCH 056/213] asr_diarization: Adding a general wrapper to chain decoding script --- egs/ami/s5b/local/chain/run_decode.sh | 131 ++++++++++++++++++++++++++ 1 file changed, 131 insertions(+) create mode 100755 egs/ami/s5b/local/chain/run_decode.sh diff --git a/egs/ami/s5b/local/chain/run_decode.sh b/egs/ami/s5b/local/chain/run_decode.sh new file mode 100755 index 00000000000..545bdc7b157 --- /dev/null +++ b/egs/ami/s5b/local/chain/run_decode.sh @@ -0,0 +1,131 @@ +#!/bin/bash + +# Copyright 2016 Vimal Manohar +# Apache 2.0. + +set -e +set -o pipefail +set -u + +stage=-1 +decode_stage=1 + +mic=ihm +use_ihm_ali=false +exp_name=tdnn + +nj=20 + +cleanup_affix= +graph_dir= + +decode_set=dev +decode_suffix= + +extractor= +use_ivectors=true +use_offline_ivectors=false +frames_per_chunk=50 + +scoring_opts= + +. path.sh +. cmd.sh + +. parse_options.sh + +new_mic=$mic +if [ $use_ihm_ali == "true" ]; then + new_mic=${mic}_cleanali +fi + +dir=exp/$new_mic/chain${cleanup_affix:+_$cleanup_affix}/${exp_name} + +if [ $stage -le -1 ]; then + mfccdir=mfcc_${mic} + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $mfccdir/storage ]; then + utils/create_split_dir.pl /export/b0{1,2,3,4}/$USER/kaldi-data/egs/ami-$mic-$(date +'%m_%d_%H_%M')/s5/$mfccdir/storage $mfccdir/storage + fi + + steps/make_mfcc.sh --nj $nj --mfcc-config conf/mfcc.conf \ + --cmd "$train_cmd" data/$mic/${decode_set} exp/make_${mic}/$decode_set $mfccdir || exit 1; + + steps/compute_cmvn_stats.sh data/$mic/${decode_set} exp/make_${mic}/$mic/$decode_set $mfccdir || exit 1; + + utils/fix_data_dir.sh data/$mic/${decode_set} +fi + +if [ $stage -le 0 ]; then + mfccdir=mfcc_${mic}_hires + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $mfccdir/storage ]; then + utils/create_split_dir.pl /export/b0{1,2,3,4}/$USER/kaldi-data/egs/ami-$mic-$(date +'%m_%d_%H_%M')/s5/$mfccdir/storage $mfccdir/storage + fi + + utils/copy_data_dir.sh data/$mic/$decode_set data/$mic/${decode_set}_hires + + steps/make_mfcc.sh --nj $nj --mfcc-config conf/mfcc_hires.conf \ + --cmd "$train_cmd" data/$mic/${decode_set}_hires exp/make_${mic}_hires/$decode_set $mfccdir || exit 1; + + steps/compute_cmvn_stats.sh data/$mic/${decode_set}_hires exp/make_${mic}_hires/$mic/$decode_set $mfccdir || exit 1; + + utils/fix_data_dir.sh data/$mic/${decode_set}_hires +fi + +if $use_ivectors && [ $stage -le 1 ]; then + if [ -z "$extractor" ]; then + echo "--extractor must be supplied when using ivectors" + exit 1 + fi + + if $use_offline_ivectors; then + steps/online/nnet2/extract_ivectors.sh \ + --cmd "$train_cmd" --nj 8 \ + data/$mic/${decode_set}_hires data/lang $extractor \ + exp/$mic/nnet3${cleanup_affix:+_$cleanup_affix}/ivectors_offline_${decode_set} || exit 1 + else + steps/online/nnet2/extract_ivectors_online.sh \ + --cmd "$train_cmd" --nj 8 \ + data/$mic/${decode_set}_hires $extractor \ + exp/$mic/nnet3${cleanup_affix:+_$cleanup_affix}/ivectors_${decode_set} || exit 1 + fi +fi + +final_lm=`cat data/local/lm/final_lm` +LM=$final_lm.pr1-7 + +if [ -z "$graph_dir" ]; then + graph_dir=$dir/graph_${LM} + if [ $stage -le 2 ]; then + # Note: it might appear that this $lang directory is mismatched, and it is as + # far as the 'topo' is concerned, but this script doesn't read the 'topo' from + # the lang directory. + utils/mkgraph.sh --self-loop-scale 1.0 data/lang_${LM} $dir $graph_dir + fi +fi + +nj=`cat data/$mic/${decode_set}/utt2spk|cut -d' ' -f2|sort -u|wc -l` + +if [ $nj -gt 50 ]; then + nj=50 +fi + +if [ "$frames_per_chunk" -ne 50 ]; then + decode_suffix=${decode_suffix}_cs${frames_per_chunk} +fi + +if [ $stage -le 3 ]; then + ivector_opts= + if $use_ivectors; then + if $use_offline_ivectors; then + ivector_opts="--online-ivector-dir exp/$mic/nnet3${cleanup_affix:+_$cleanup_affix}/ivectors_offline_${decode_set}" + decode_suffix=${decode_suffix}_offline + else + ivector_opts="--online-ivector-dir exp/$mic/nnet3${cleanup_affix:+_$cleanup_affix}/ivectors_${decode_set}" + fi + fi + steps/nnet3/decode.sh --acwt 1.0 --post-decode-acwt 10.0 \ + --stage $decode_stage --frames-per-chunk $frames_per_chunk \ + --nj $nj --cmd "$decode_cmd" $ivector_opts \ + --scoring-opts "--min-lmwt 5 --decode-mbr false $scoring_opts" \ + $graph_dir data/$mic/${decode_set}_hires $dir/decode${decode_suffix}_${decode_set} || exit 1; +fi From 194030d78d05fdec857642126882388b4470406b Mon Sep 17 00:00:00 2001 From: Vimal Manohar Date: Tue, 22 Nov 2016 11:18:29 -0500 Subject: [PATCH 057/213] asr_diarization: Adding RT --- egs/ami/s5b/local/make_rt_2004_dev.pl | 1 + egs/ami/s5b/local/make_rt_2004_eval.pl | 1 + egs/ami/s5b/local/make_rt_2005_eval.pl | 1 + egs/ami/s5b/local/run_prepare_rt.sh | 1 + egs/ami/s5b/path.sh | 4 +- egs/rt/s5/cmd.sh | 1 + egs/rt/s5/conf/fbank.conf | 3 + egs/rt/s5/conf/librispeech_mfcc.conf | 1 + egs/rt/s5/conf/mfcc_hires.conf | 10 ++ egs/rt/s5/conf/mfcc_vad.conf | 5 + egs/rt/s5/conf/pitch.conf | 1 + egs/rt/s5/conf/vad_decode_icsi.conf | 40 +++++++ egs/rt/s5/conf/vad_decode_pitch.conf | 55 ++++++++++ egs/rt/s5/conf/vad_icsi_babel.conf | 39 +++++++ egs/rt/s5/conf/vad_icsi_babel_3models.conf | 54 ++++++++++ egs/rt/s5/conf/vad_icsi_rt.conf | 40 +++++++ egs/rt/s5/conf/vad_snr_rt.conf | 35 ++++++ egs/rt/s5/conf/zc_vad.conf | 5 + egs/rt/s5/diarization | 1 + egs/rt/s5/local/make_rt_2004_dev.pl | 64 +++++++++++ egs/rt/s5/local/make_rt_2004_eval.pl | 64 +++++++++++ egs/rt/s5/local/make_rt_2005_eval.pl | 64 +++++++++++ egs/rt/s5/local/run_prepare_rt.sh | 87 +++++++++++++++ egs/rt/s5/local/score.sh | 53 +++++++++ egs/rt/s5/local/score_asclite.sh | 120 +++++++++++++++++++++ egs/rt/s5/local/snr | 1 + egs/rt/s5/path.sh | 5 + egs/rt/s5/sid | 1 + egs/rt/s5/steps | 1 + egs/rt/s5/utils | 1 + src/ivectorbin/Makefile | 4 +- 31 files changed, 761 insertions(+), 2 deletions(-) create mode 120000 egs/ami/s5b/local/make_rt_2004_dev.pl create mode 120000 egs/ami/s5b/local/make_rt_2004_eval.pl create mode 120000 egs/ami/s5b/local/make_rt_2005_eval.pl create mode 120000 egs/ami/s5b/local/run_prepare_rt.sh create mode 120000 egs/rt/s5/cmd.sh create mode 100644 egs/rt/s5/conf/fbank.conf create mode 100644 egs/rt/s5/conf/librispeech_mfcc.conf create mode 100644 egs/rt/s5/conf/mfcc_hires.conf create mode 100644 egs/rt/s5/conf/mfcc_vad.conf create mode 100644 egs/rt/s5/conf/pitch.conf create mode 100644 egs/rt/s5/conf/vad_decode_icsi.conf create mode 100644 egs/rt/s5/conf/vad_decode_pitch.conf create mode 100644 egs/rt/s5/conf/vad_icsi_babel.conf create mode 100644 egs/rt/s5/conf/vad_icsi_babel_3models.conf create mode 100644 egs/rt/s5/conf/vad_icsi_rt.conf create mode 100644 egs/rt/s5/conf/vad_snr_rt.conf create mode 100644 egs/rt/s5/conf/zc_vad.conf create mode 120000 egs/rt/s5/diarization create mode 100755 egs/rt/s5/local/make_rt_2004_dev.pl create mode 100755 egs/rt/s5/local/make_rt_2004_eval.pl create mode 100755 egs/rt/s5/local/make_rt_2005_eval.pl create mode 100755 egs/rt/s5/local/run_prepare_rt.sh create mode 100755 egs/rt/s5/local/score.sh create mode 100755 egs/rt/s5/local/score_asclite.sh create mode 120000 egs/rt/s5/local/snr create mode 100755 egs/rt/s5/path.sh create mode 120000 egs/rt/s5/sid create mode 120000 egs/rt/s5/steps create mode 120000 egs/rt/s5/utils diff --git a/egs/ami/s5b/local/make_rt_2004_dev.pl b/egs/ami/s5b/local/make_rt_2004_dev.pl new file mode 120000 index 00000000000..a0d27619369 --- /dev/null +++ b/egs/ami/s5b/local/make_rt_2004_dev.pl @@ -0,0 +1 @@ +../../../rt/s5/local/make_rt_2004_dev.pl \ No newline at end of file diff --git a/egs/ami/s5b/local/make_rt_2004_eval.pl b/egs/ami/s5b/local/make_rt_2004_eval.pl new file mode 120000 index 00000000000..8b951f9c940 --- /dev/null +++ b/egs/ami/s5b/local/make_rt_2004_eval.pl @@ -0,0 +1 @@ +../../../rt/s5/local/make_rt_2004_eval.pl \ No newline at end of file diff --git a/egs/ami/s5b/local/make_rt_2005_eval.pl b/egs/ami/s5b/local/make_rt_2005_eval.pl new file mode 120000 index 00000000000..6185b83a5a3 --- /dev/null +++ b/egs/ami/s5b/local/make_rt_2005_eval.pl @@ -0,0 +1 @@ +../../../rt/s5/local/make_rt_2005_eval.pl \ No newline at end of file diff --git a/egs/ami/s5b/local/run_prepare_rt.sh b/egs/ami/s5b/local/run_prepare_rt.sh new file mode 120000 index 00000000000..e10f1d53a19 --- /dev/null +++ b/egs/ami/s5b/local/run_prepare_rt.sh @@ -0,0 +1 @@ +../../../rt/s5/local/run_prepare_rt.sh \ No newline at end of file diff --git a/egs/ami/s5b/path.sh b/egs/ami/s5b/path.sh index ad2c93b309b..d8f46e6b8a0 100644 --- a/egs/ami/s5b/path.sh +++ b/egs/ami/s5b/path.sh @@ -10,4 +10,6 @@ SRILM=$KALDI_ROOT/tools/srilm/bin/i686-m64 BEAMFORMIT=$KALDI_ROOT/tools/BeamformIt export PATH=$PATH:$LMBIN:$BEAMFORMIT:$SRILM - +export PATH=$PATH:$KALDI_ROOT/tools/sph2pipe_v2.5 +export PATH=$PATH:/home/vmanoha1/kaldi-waveform/src/segmenterbin +export PATH=$PATH:$KALDI_ROOT/tools/sctk/bin diff --git a/egs/rt/s5/cmd.sh b/egs/rt/s5/cmd.sh new file mode 120000 index 00000000000..19f7e836644 --- /dev/null +++ b/egs/rt/s5/cmd.sh @@ -0,0 +1 @@ +../../wsj/s5/cmd.sh \ No newline at end of file diff --git a/egs/rt/s5/conf/fbank.conf b/egs/rt/s5/conf/fbank.conf new file mode 100644 index 00000000000..07e1639e6ee --- /dev/null +++ b/egs/rt/s5/conf/fbank.conf @@ -0,0 +1,3 @@ +# No non-default options for now. +--num-mel-bins=40 # similar to Google's setup. + diff --git a/egs/rt/s5/conf/librispeech_mfcc.conf b/egs/rt/s5/conf/librispeech_mfcc.conf new file mode 100644 index 00000000000..45d284ad05c --- /dev/null +++ b/egs/rt/s5/conf/librispeech_mfcc.conf @@ -0,0 +1 @@ +--use-energy=false diff --git a/egs/rt/s5/conf/mfcc_hires.conf b/egs/rt/s5/conf/mfcc_hires.conf new file mode 100644 index 00000000000..434834a6725 --- /dev/null +++ b/egs/rt/s5/conf/mfcc_hires.conf @@ -0,0 +1,10 @@ +# config for high-resolution MFCC features, intended for neural network training +# Note: we keep all cepstra, so it has the same info as filterbank features, +# but MFCC is more easily compressible (because less correlated) which is why +# we prefer this method. +--use-energy=false # use average of log energy, not energy. +--num-mel-bins=40 # similar to Google's setup. +--num-ceps=40 # there is no dimensionality reduction. +--low-freq=20 # low cutoff frequency for mel bins... this is high-bandwidth data, so + # there might be some information at the low end. +--high-freq=-400 # high cutoff frequently, relative to Nyquist of 8000 (=7600) diff --git a/egs/rt/s5/conf/mfcc_vad.conf b/egs/rt/s5/conf/mfcc_vad.conf new file mode 100644 index 00000000000..22765c6280e --- /dev/null +++ b/egs/rt/s5/conf/mfcc_vad.conf @@ -0,0 +1,5 @@ +--sample-frequency=16000 +--frame-length=25 # the default is 25. +--low-freq=20 # the default. +--high-freq=-600 # the default is zero meaning use the Nyquist (4k in this case). +--num-ceps=13 # higher than the default which is 12. diff --git a/egs/rt/s5/conf/pitch.conf b/egs/rt/s5/conf/pitch.conf new file mode 100644 index 00000000000..e959a19d5b8 --- /dev/null +++ b/egs/rt/s5/conf/pitch.conf @@ -0,0 +1 @@ +--sample-frequency=16000 diff --git a/egs/rt/s5/conf/vad_decode_icsi.conf b/egs/rt/s5/conf/vad_decode_icsi.conf new file mode 100644 index 00000000000..15ba288e3af --- /dev/null +++ b/egs/rt/s5/conf/vad_decode_icsi.conf @@ -0,0 +1,40 @@ +## Features paramters +window_size=100 # 1s +frames_per_gaussian=2000 + +## Phase 1 parameters +num_frames_init_silence=2000 +num_frames_init_sound=10000 +num_frames_init_sound_next=2000 +sil_num_gauss_init=2 +sound_num_gauss_init=2 +sil_max_gauss=2 +sound_max_gauss=6 +sil_gauss_incr=0 +sound_gauss_incr=2 +num_iters=5 +min_sil_variance=0.1 +min_sound_variance=0.01 +min_speech_variance=0.001 + +## Phase 2 parameters +num_frames_init_speech=10000 +speech_num_gauss_init=6 +sil_max_gauss_phase2=7 +sound_max_gauss_phase2=18 +speech_max_gauss_phase2=16 +sil_gauss_incr_phase2=1 +sound_gauss_incr_phase2=2 +speech_gauss_incr_phase2=2 +num_iters_phase2=5 + +## Phase 3 parameters +sil_num_gauss_init_phase3=2 +speech_num_gauss_init_phase3=2 +sil_max_gauss_phase3=5 +speech_max_gauss_phase3=12 +sil_gauss_incr_phase3=1 +speech_gauss_incr_phase3=2 +num_iters_phase3=7 + + diff --git a/egs/rt/s5/conf/vad_decode_pitch.conf b/egs/rt/s5/conf/vad_decode_pitch.conf new file mode 100644 index 00000000000..d7ba1d40093 --- /dev/null +++ b/egs/rt/s5/conf/vad_decode_pitch.conf @@ -0,0 +1,55 @@ +## Features paramters +window_size=10 # 1s +smooth_weights=false +smoothing_window=2 +smooth_mask=true + +## Phase 1 parameters +num_frames_init_silence=200 +num_frames_init_sound=200 +num_frames_init_sound_next=200 +sil_num_gauss_init=2 +sound_num_gauss_init=2 +sil_max_gauss=2 +sound_max_gauss=2 +sil_gauss_incr=1 +sound_gauss_incr=1 +sil_frames_incr=200 +sound_frames_incr=200 +sound_frames_next_incr=200 +num_iters=5 +min_sil_variance=0.1 +min_sound_variance=0.01 +min_speech_variance=0.001 + +## Phase 2 parameters +num_frames_init_speech=5000 +speech_num_gauss_init=6 +sil_max_gauss_phase2=7 +sound_max_gauss_phase2=7 +speech_max_gauss_phase2=16 +sil_gauss_incr_phase2=1 +sound_gauss_incr_phase2=1 +speech_gauss_incr_phase2=2 +num_iters_phase2=20 +window_size_phase2_init=10 +window_size_phase2_next=10 +window_size_incr_iter=5 + +num_frames_init_speech_phase2=100000 +num_frames_init_silence_phase2=200000 +num_frames_init_sound_phase2=200000 +speech_frames_incr_phase2=200000 +sil_frames_incr_phase2=200000 +sound_frames_incr_phase2=200000 + +## Phase 3 parameters +sil_num_gauss_init_phase3=2 +speech_num_gauss_init_phase3=2 +sil_max_gauss_phase3=5 +speech_max_gauss_phase3=12 +sil_gauss_incr_phase3=1 +speech_gauss_incr_phase3=2 +num_iters_phase3=7 + + diff --git a/egs/rt/s5/conf/vad_icsi_babel.conf b/egs/rt/s5/conf/vad_icsi_babel.conf new file mode 100644 index 00000000000..70f651403f5 --- /dev/null +++ b/egs/rt/s5/conf/vad_icsi_babel.conf @@ -0,0 +1,39 @@ +## Features paramters +window_size=10 # 100 ms +frames_per_gaussian=200 + +## Phase 1 parameters +num_frames_init_silence=2000 # 20s - Lowest energy frames selected to initialize Silence GMM +num_frames_init_sound=10000 # 100s - Highest energy frames selected to initialize Sound GMM +num_frames_init_sound_next=2000 # 20s - Highest zero crossing frames selected to initialize Sound GMM +sil_num_gauss_init=2 +sound_num_gauss_init=2 +sil_max_gauss=2 +sound_max_gauss=6 +sil_gauss_incr=0 +sound_gauss_incr=2 +num_iters=5 +min_sil_variance=0.1 +min_sound_variance=0.01 +min_speech_variance=0.001 + +## Phase 2 parameters +speech_num_gauss_init=6 +sil_max_gauss_phase2=7 +sound_max_gauss_phase2=18 +speech_max_gauss_phase2=16 +sil_gauss_incr_phase2=1 +sound_gauss_incr_phase2=2 +speech_gauss_incr_phase2=2 +num_iters_phase2=5 + +## Phase 3 parameters +sil_num_gauss_init_phase3=2 +speech_num_gauss_init_phase3=2 +sil_max_gauss_phase3=5 +speech_max_gauss_phase3=12 +sil_gauss_incr_phase3=1 +speech_gauss_incr_phase3=2 +num_iters_phase3=7 + + diff --git a/egs/rt/s5/conf/vad_icsi_babel_3models.conf b/egs/rt/s5/conf/vad_icsi_babel_3models.conf new file mode 100644 index 00000000000..1196f0d2aff --- /dev/null +++ b/egs/rt/s5/conf/vad_icsi_babel_3models.conf @@ -0,0 +1,54 @@ +## Features paramters +window_size=10 # 100 ms +frames_per_gaussian=200 + +## Phase 1 parameters +num_frames_init_silence=2000 # 20s - Lowest energy frames selected to initialize Silence GMM +num_frames_init_sound=10000 # 100s - Highest energy frames selected to initialize Sound GMM +num_frames_init_sound_next=2000 # 20s - Highest zero crossing frames selected to initialize Sound GMM +sil_num_gauss_init=2 +sound_num_gauss_init=2 +sil_max_gauss=2 +sound_max_gauss=6 +sil_gauss_incr=0 +sound_gauss_incr=2 +num_iters=5 +min_sil_variance=0.1 +min_sound_variance=0.01 +min_speech_variance=0.001 + +## Phase 2 parameters +speech_num_gauss_init=6 +sil_max_gauss_phase2=7 +sound_max_gauss_phase2=18 +speech_max_gauss_phase2=16 +sil_gauss_incr_phase2=1 +sound_gauss_incr_phase2=2 +speech_gauss_incr_phase2=2 +num_iters_phase2=5 + +## Phase 3 parameters +num_frames_silence_phase3_init=2000 +num_frames_speech_phase3_init=2000 +sil_num_gauss_init_phase3=2 +speech_num_gauss_init_phase3=2 +sil_max_gauss_phase3=5 +sil_max_gauss_phase4=8 +speech_max_gauss_phase4=16 +sil_gauss_incr_phase3=1 +sil_gauss_incr_phase4=1 +speech_gauss_incr_phase4=2 +num_iters_phase3=5 +num_iters_phase4=5 + +## Phase 5 parameters +sil_num_gauss_init_phase5=2 +speech_num_gauss_init_phase5=2 +sil_max_gauss_phase5=5 +speech_max_gauss_phase5=12 +sil_gauss_incr_phase5=1 +speech_gauss_incr_phase5=2 +num_iters_phase5=7 + + + diff --git a/egs/rt/s5/conf/vad_icsi_rt.conf b/egs/rt/s5/conf/vad_icsi_rt.conf new file mode 100644 index 00000000000..d19038014db --- /dev/null +++ b/egs/rt/s5/conf/vad_icsi_rt.conf @@ -0,0 +1,40 @@ +## Features paramters +window_size=10 # 100 ms +frames_per_gaussian=200 + +## Phase 1 parameters +num_frames_init_silence=2000 +num_frames_init_sound=10000 +num_frames_init_sound_next=2000 +sil_num_gauss_init=2 +sound_num_gauss_init=2 +sil_max_gauss=2 +sound_max_gauss=6 +sil_gauss_incr=0 +sound_gauss_incr=2 +num_iters=5 +min_sil_variance=0.1 +min_sound_variance=0.01 +min_speech_variance=0.001 + +## Phase 2 parameters +#num_frames_init_speech=10000 +speech_num_gauss_init=6 +sil_max_gauss_phase2=7 +sound_max_gauss_phase2=18 +speech_max_gauss_phase2=16 +sil_gauss_incr_phase2=1 +sound_gauss_incr_phase2=2 +speech_gauss_incr_phase2=2 +num_iters_phase2=5 + +## Phase 3 parameters +sil_num_gauss_init_phase3=2 +speech_num_gauss_init_phase3=2 +sil_max_gauss_phase3=5 +speech_max_gauss_phase3=12 +sil_gauss_incr_phase3=1 +speech_gauss_incr_phase3=2 +num_iters_phase3=7 + + diff --git a/egs/rt/s5/conf/vad_snr_rt.conf b/egs/rt/s5/conf/vad_snr_rt.conf new file mode 100644 index 00000000000..a1029eb8fe6 --- /dev/null +++ b/egs/rt/s5/conf/vad_snr_rt.conf @@ -0,0 +1,35 @@ +## Features paramters +window_size=5 # 5 frame. Window over which initial selection of frames + +frames_per_silence_gaussian=200 # 2s per Gaussian +frames_per_sound_gaussian=200 # 2s per Gaussian +frames_per_speech_gaussian=2000 # 20s per Gaussian + +## Phase 1 parameters +num_frames_init_silence=1000 # 10s - Lowest energy frames selected to initialize Silence GMM +num_frames_init_silence_next=200 # 2s - Lowest energy frames selected to initialize Silence GMM +num_frames_init_sound=1000 # 10s - Highest energy frames selected to initialize Sound GMM +num_frames_init_sound_next=200 # 2s - Highest zero crossing frames selected to initialize Sound GMM +num_frames_init_speech=10000 # 100s - Highest energy frames selected to initialize Sound GMM +sil_num_gauss_init=2 +sound_num_gauss_init=2 +speech_num_gauss_init=6 +sil_max_gauss=7 +sound_max_gauss=12 +speech_max_gauss=16 +sil_gauss_incr=1 +sound_gauss_incr=2 +speech_gauss_incr=2 +num_iters=10 + +## Phase 3 parameters +num_frames_init_silence_phase3=1000 # 10s - Lowest energy frames selected to initialize Silence GMM +num_frames_init_silence_next_phase3=200 # 2s - Lowest energy frames selected to initialize Silence GMM +num_frames_init_speech_phase3=10000 # 100s - Highest energy frames selected to initialize Sound GMM +sil_num_gauss_init=2 +speech_num_gauss_init=6 +sil_max_gauss=7 +speech_max_gauss=16 +sil_gauss_incr=1 +speech_gauss_incr=2 +num_iters_phase3=10 diff --git a/egs/rt/s5/conf/zc_vad.conf b/egs/rt/s5/conf/zc_vad.conf new file mode 100644 index 00000000000..b5d94450709 --- /dev/null +++ b/egs/rt/s5/conf/zc_vad.conf @@ -0,0 +1,5 @@ +--sample-frequency=16000 +--frame-length=25 # the default is 25. +--dither=0.0 +--zero-crossing-threshold=1e-5 + diff --git a/egs/rt/s5/diarization b/egs/rt/s5/diarization new file mode 120000 index 00000000000..ba78a9126af --- /dev/null +++ b/egs/rt/s5/diarization @@ -0,0 +1 @@ +../../sre08/v1/diarization \ No newline at end of file diff --git a/egs/rt/s5/local/make_rt_2004_dev.pl b/egs/rt/s5/local/make_rt_2004_dev.pl new file mode 100755 index 00000000000..8a08dd268a7 --- /dev/null +++ b/egs/rt/s5/local/make_rt_2004_dev.pl @@ -0,0 +1,64 @@ +#!/usr/bin/perl -w +# Copyright 2015 Vimal Manohar +# Apache 2.0. + +use strict; +use File::Basename; + +if (@ARGV != 2) { + print STDERR "Usage: $0 \n" . + " e.g.: $0 /export/corpora5/LDC/LDC2007S11 data\n"; + exit(1); +} + +my ($db_base, $out_dir) = @ARGV; +$out_dir = "$out_dir/rt04_dev"; + +if (system("mkdir -p $out_dir")) { + die "Error making directory $out_dir"; +} + +open(SPKR, ">", "$out_dir/utt2spk") + or die "Could not open the output file $out_dir/utt2spk"; +open(WAV, ">", "$out_dir/wav.scp") + or die "Could not open the output file $out_dir/wav.scp"; +open(RECO2FILE_AND_CHANNEL, ">", "$out_dir/reco2file_and_channel") + or die "Could not open the output file $out_dir/reco2file_and_channel"; + +open(LIST, 'find ' . $db_base . '/data/audio/dev04s -name "*.sph" |'); + + +my $sox =`which sox` || die "Could not find sox in PATH"; +chomp($sox); + +while (my $line = ) { + chomp($line); + my ($file_id, $path, $suffix) = fileparse($line, qr/\.[^.]*/); + if ($suffix =~ /.sph/) { + #print WAV $file_id . " $sox $line -c 1 -b 16 -t wav - |\n"; + print WAV $file_id . " sph2pipe -f wav $line |\n"; + } elsif ($suffix =~ /.wav/) { + print WAV $file_id . " $line |\n"; + } else { + die "$0: Unknown suffix $suffix in $line\n" + } + + print SPKR "$file_id $file_id\n"; + print RECO2FILE_AND_CHANNEL "$file_id $file_id 1\n"; +} + +close(LIST) || die; +close(WAV) || die; +close(SPKR) || die; + +if (system( + "utils/utt2spk_to_spk2utt.pl $out_dir/utt2spk >$out_dir/spk2utt") != 0) { + die "Error creating spk2utt file in directory $out_dir"; +} + +system("utils/fix_data_dir.sh $out_dir"); + +if (system( + "utils/validate_data_dir.sh --no-text --no-feats $out_dir") != 0) { + die "Error validating directory $out_dir"; +} diff --git a/egs/rt/s5/local/make_rt_2004_eval.pl b/egs/rt/s5/local/make_rt_2004_eval.pl new file mode 100755 index 00000000000..4c1286ea1cc --- /dev/null +++ b/egs/rt/s5/local/make_rt_2004_eval.pl @@ -0,0 +1,64 @@ +#!/usr/bin/perl -w +# Copyright 2015 Vimal Manohar +# Apache 2.0. + +use strict; +use File::Basename; + +if (@ARGV != 2) { + print STDERR "Usage: $0 \n" . + " e.g.: $0 /export/corpora5/LDC/LDC2007S12/package/rt04_eval data\n"; + exit(1); +} + +my ($db_base, $out_dir) = @ARGV; +$out_dir = "$out_dir/rt04_eval"; + +if (system("mkdir -p $out_dir")) { + die "Error making directory $out_dir"; +} + +open(SPKR, ">", "$out_dir/utt2spk") + or die "Could not open the output file $out_dir/utt2spk"; +open(WAV, ">", "$out_dir/wav.scp") + or die "Could not open the output file $out_dir/wav.scp"; +open(RECO2FILE_AND_CHANNEL, ">", "$out_dir/reco2file_and_channel") + or die "Could not open the output file $out_dir/reco2file_and_channel"; + +open(LIST, 'find ' . $db_base . '/data/audio/eval04s -name "*.sph" |'); + +my $sox =`which sox` || die "Could not find sox in PATH"; +chomp($sox); + +while (my $line = ) { + chomp($line); + my ($file_id, $path, $suffix) = fileparse($line, qr/\.[^.]*/); + if ($suffix =~ /.sph/) { + #print WAV $file_id . " $sox $line -c 1 -b 16 -t wav - |\n"; + print WAV $file_id . " sph2pipe -f wav $line |\n"; + } elsif ($suffix =~ /.wav/) { + print WAV $file_id . " $line |\n"; + } else { + die "$0: Unknown suffix $suffix in $line\n" + } + + print SPKR "$file_id $file_id\n"; + print RECO2FILE_AND_CHANNEL "$file_id $file_id 1\n"; +} + +close(LIST) || die; +close(WAV) || die; +close(SPKR) || die; + +if (system( + "utils/utt2spk_to_spk2utt.pl $out_dir/utt2spk >$out_dir/spk2utt") != 0) { + die "Error creating spk2utt file in directory $out_dir"; +} + +system("utils/fix_data_dir.sh $out_dir"); + +if (system( + "utils/validate_data_dir.sh --no-text --no-feats $out_dir") != 0) { + die "Error validating directory $out_dir"; +} + diff --git a/egs/rt/s5/local/make_rt_2005_eval.pl b/egs/rt/s5/local/make_rt_2005_eval.pl new file mode 100755 index 00000000000..d48dcaae926 --- /dev/null +++ b/egs/rt/s5/local/make_rt_2005_eval.pl @@ -0,0 +1,64 @@ +#!/usr/bin/perl -w +# Copyright 2015 Vimal Manohar +# Apache 2.0. + +use strict; +use File::Basename; + +if (@ARGV != 2) { + print STDERR "Usage: $0 \n" . + " e.g.: $0 /export/corpora5/LDC/LDC2011S06 data\n"; + exit(1); +} + +my ($db_base, $out_dir) = @ARGV; +$out_dir = "$out_dir/rt05_eval"; + +if (system("mkdir -p $out_dir")) { + die "Error making directory $out_dir"; +} + +open(SPKR, ">", "$out_dir/utt2spk") + or die "Could not open the output file $out_dir/utt2spk"; +open(WAV, ">", "$out_dir/wav.scp") + or die "Could not open the output file $out_dir/wav.scp"; +open(RECO2FILE_AND_CHANNEL, ">", "$out_dir/reco2file_and_channel") + or die "Could not open the output file $out_dir/reco2file_and_channel"; + +open(LIST, 'find ' . $db_base . '/data/audio/eval05s -name "*.sph" |'); + +my $sox =`which sox` || die "Could not find sox in PATH"; +chomp($sox); + +while (my $line = ) { + chomp($line); + my ($file_id, $path, $suffix) = fileparse($line, qr/\.[^.]*/); + if ($suffix =~ /.sph/) { + print WAV $file_id . " $sox $line -c 1 -b 16 -t wav - |\n"; + } elsif ($suffix =~ /.wav/) { + print WAV $file_id . " $line |\n"; + } else { + die "$0: Unknown suffix $suffix in $line\n" + } + + print SPKR "$file_id $file_id\n"; + print RECO2FILE_AND_CHANNEL "$file_id $file_id 1\n"; +} + +close(LIST) || die; +close(WAV) || die; +close(SPKR) || die; + +if (system( + "utils/utt2spk_to_spk2utt.pl $out_dir/utt2spk >$out_dir/spk2utt") != 0) { + die "Error creating spk2utt file in directory $out_dir"; +} + +system("utils/fix_data_dir.sh $out_dir"); + +if (system( + "utils/validate_data_dir.sh --no-text --no-feats $out_dir") != 0) { + die "Error validating directory $out_dir"; +} + + diff --git a/egs/rt/s5/local/run_prepare_rt.sh b/egs/rt/s5/local/run_prepare_rt.sh new file mode 100755 index 00000000000..c431f760dab --- /dev/null +++ b/egs/rt/s5/local/run_prepare_rt.sh @@ -0,0 +1,87 @@ +#! /bin/bash + +# Copyright 2016 Vimal Manohar +# Apache 2.0 + +set -e +set -o pipefail +set -u + +. path.sh +. cmd.sh + +mic=sdm +task=sad + +. parse_options.sh + +RT04_DEV_ROOT=/export/corpora5/LDC/LDC2007S11 +RT04_EVAL_ROOT=/export/corpora5/LDC/LDC2007S12/package/rt04_eval +RT05_EVAL_ROOT=/export/corpora5/LDC/LDC2011S06 + +if [ ! -f data/rt04_dev/.done ]; then + local/make_rt_2004_dev.pl $RT04_DEV_ROOT data + touch data/rt04_dev/.done +fi + +if [ ! -f data/rt04_eval/.done ]; then + local/make_rt_2004_eval.pl $RT04_EVAL_ROOT data + touch data/rt04_eval/.done +fi + +if [ ! -f data/rt05_eval/.done ]; then + local/make_rt_2005_eval.pl $RT05_EVAL_ROOT data + touch data/rt05_eval/.done +fi + +mkdir -p data/local + +dir=data/local/rt05_eval/$mic/$task +mkdir -p $dir + +if [ $task == "stt" ]; then + cp $RT05_EVAL_ROOT/data/reference/concatenated/rt05s.confmtg.050614.${task}.${mic}.stm $dir/stm +else + cp $RT05_EVAL_ROOT/data/reference/concatenated/rt05s.confmtg.050614.${task}.${mic}.rttm $dir/rttm +fi + +cp $RT05_EVAL_ROOT/data/indicies/expt_05s_${task}ul_eval05s_eng_confmtg_${mic}_1.uem $dir/uem +cat $dir/uem | awk '!/;;/{if (NF > 0) print $1}' | perl -pe 's/(.*)\.sph/$1/g' | sort -u > $dir/list +utils/subset_data_dir.sh --utt-list $dir/list data/rt05_eval data/rt05_eval_${mic}_${task} +[ -f $dir/stm ] && cp $dir/stm data/rt05_eval_${mic}_${task} +[ -f $dir/uem ] && cp $dir/uem data/rt05_eval_${mic}_${task} +[ -f $dir/rttm ] && cp $dir/rttm data/rt05_eval_${mic}_${task} + +dir=data/local/rt04_dev/$mic/$task +mkdir -p $dir + +if [ $task == "stt" ]; then + cp $RT04_DEV_ROOT/data/reference/dev04s/concatenated/dev04s.040809.${mic}.stm $dir/stm +elif [ $task == "spkr" ]; then + cp $RT04_DEV_ROOT/data/reference/dev04s/concatenated/dev04s.040809.${mic}.rttm $dir/rttm +else + cat $RT04_DEV_ROOT/data/reference/dev04s/concatenated/dev04s.040809.${mic}.rttm | spkr2sad.pl | rttmSmooth.pl -s 0 > $dir/rttm +fi +cp $RT04_DEV_ROOT/data/indices/dev04s/dev04s.${mic}.uem $dir/uem +cat $dir/uem | awk '!/;;/{if (NF > 0) print $1}' | perl -pe 's/(.*)\.sph/$1/g' | sort -u > $dir/list +utils/subset_data_dir.sh --utt-list $dir/list data/rt04_dev data/rt04_dev_${mic}_${task} +[ -f $dir/stm ] && cp $dir/stm data/rt04_dev_${mic}_${task} +[ -f $dir/uem ] && cp $dir/uem data/rt04_dev_${mic}_${task} +[ -f $dir/rttm ] && cp $dir/rttm data/rt04_dev_${mic}_${task} + +dir=data/local/rt04_eval/$mic/$task +mkdir -p $dir + +if [ $task == "stt" ]; then + cp $RT04_EVAL_ROOT/data/reference/eval04s/concatenated/eval04s.040511.${mic}.stm $dir/stm +elif [ $task == "spkr" ]; then + cp $RT04_EVAL_ROOT/data/reference/eval04s/concatenated/eval04s.040511.${mic}.rttm $dir/rttm +else + cat $RT04_EVAL_ROOT/data/reference/eval04s/concatenated/eval04s.040511.${mic}.rttm | spkr2sad.pl | rttmSmooth.pl -s 0 > $dir/rttm +fi +cp $RT04_EVAL_ROOT/data/indices/eval04s/eval04s.${mic}.uem $dir/uem +cat $dir/uem | awk '!/;;/{if (NF > 0) print $1}' | perl -pe 's/(.*)\.sph/$1/g' | sort -u > $dir/list +utils/subset_data_dir.sh --utt-list $dir/list data/rt04_eval data/rt04_eval_${mic}_${task} +[ -f $dir/stm ] && cp $dir/stm data/rt04_eval_${mic}_${task} +[ -f $dir/uem ] && cp $dir/uem data/rt04_eval_${mic}_${task} +[ -f $dir/rttm ] && cp $dir/rttm data/rt04_eval_${mic}_${task} diff --git a/egs/rt/s5/local/score.sh b/egs/rt/s5/local/score.sh new file mode 100755 index 00000000000..1c3e2cbe8c4 --- /dev/null +++ b/egs/rt/s5/local/score.sh @@ -0,0 +1,53 @@ +#!/bin/bash + +# Copyright Johns Hopkins University (Author: Daniel Povey) 2012 +# Copyright University of Edinburgh (Author: Pawel Swietojanski) 2014 +# Apache 2.0 + +orig_args= +for x in "$@"; do orig_args="$orig_args '$x'"; done + +# begin configuration section. we include all the options that score_sclite.sh or +# score_basic.sh might need, or parse_options.sh will die. +cmd=run.pl +stage=0 +min_lmwt=9 # unused, +max_lmwt=15 # unused, +asclite=true +#end configuration section. + +[ -f ./path.sh ] && . ./path.sh +. parse_options.sh || exit 1; + +if [ $# -ne 3 ]; then + echo "Usage: local/score.sh [options] " && exit; + echo " Options:" + echo " --cmd (run.pl|queue.pl...) # specify how to run the sub-processes." + echo " --stage (0|1|2) # start scoring script from part-way through." + echo " --min_lmwt # minumum LM-weight for lattice rescoring " + echo " --max_lmwt # maximum LM-weight for lattice rescoring " + echo " --asclite (true/false) # score with ascltie instead of sclite (overlapped speech)" + exit 1; +fi + +data=$1 + +mic=$(echo $data | awk -F '/' '{print $2}') +case $mic in + ihm*) + echo "Using sclite for IHM (close talk)," + eval local/score_asclite.sh --asclite false $orig_args + ;; + sdm*) + echo "Using asclite for overlapped speech SDM (single distant mic)," + eval local/score_asclite.sh --asclite $asclite $orig_args + ;; + mdm*) + echo "Using asclite for overlapped speech MDM (multiple distant mics)," + eval local/score_asclite.sh --asclite $asclite $orig_args + ;; + *) + echo "local/score.sh: no ihm/sdm/mdm directories found. AMI recipe assumes data/{ihm,sdm,mdm}/..." + exit 1; + ;; +esac diff --git a/egs/rt/s5/local/score_asclite.sh b/egs/rt/s5/local/score_asclite.sh new file mode 100755 index 00000000000..86b801b975d --- /dev/null +++ b/egs/rt/s5/local/score_asclite.sh @@ -0,0 +1,120 @@ +#!/bin/bash +# Copyright Johns Hopkins University (Author: Daniel Povey) 2012. Apache 2.0. +# 2014, University of Edinburgh, (Author: Pawel Swietojanski) + +# begin configuration section. +cmd=run.pl +stage=0 +min_lmwt=9 +max_lmwt=15 +reverse=false +asclite=true +overlap_spk=4 +#end configuration section. + +[ -f ./path.sh ] && . ./path.sh +. parse_options.sh || exit 1; + +if [ $# -ne 3 ]; then + echo "Usage: local/score_asclite.sh [--cmd (run.pl|queue.pl...)] " + echo " Options:" + echo " --cmd (run.pl|queue.pl...) # specify how to run the sub-processes." + echo " --stage (0|1|2) # start scoring script from part-way through." + echo " --min_lmwt # minumum LM-weight for lattice rescoring " + echo " --max_lmwt # maximum LM-weight for lattice rescoring " + echo " --reverse (true/false) # score with time reversed features " + exit 1; +fi + +data=$1 +lang=$2 # Note: may be graph directory not lang directory, but has the necessary stuff copied. +dir=$3 + +model=$dir/../final.mdl # assume model one level up from decoding dir. + +hubscr=$KALDI_ROOT/tools/sctk/bin/hubscr.pl +[ ! -f $hubscr ] && echo "Cannot find scoring program at $hubscr" && exit 1; +hubdir=`dirname $hubscr` + +for f in $data/stm $data/glm $lang/words.txt $lang/phones/word_boundary.int \ + $model $data/segments $data/reco2file_and_channel $dir/lat.1.gz; do + [ ! -f $f ] && echo "$0: expecting file $f to exist" && exit 1; +done + +name=`basename $data`; # e.g. eval2000 + +mkdir -p $dir/ascoring/log + +if [ $stage -le 0 ]; then + if $reverse; then + $cmd LMWT=$min_lmwt:$max_lmwt $dir/ascoring/log/get_ctm.LMWT.log \ + mkdir -p $dir/ascore_LMWT/ '&&' \ + lattice-1best --lm-scale=LMWT "ark:gunzip -c $dir/lat.*.gz|" ark:- \| \ + lattice-reverse ark:- ark:- \| \ + lattice-align-words --reorder=false $lang/phones/word_boundary.int $model ark:- ark:- \| \ + nbest-to-ctm ark:- - \| \ + utils/int2sym.pl -f 5 $lang/words.txt \| \ + utils/convert_ctm.pl $data/segments $data/reco2file_and_channel \ + '>' $dir/ascore_LMWT/$name.ctm || exit 1; + else + $cmd LMWT=$min_lmwt:$max_lmwt $dir/ascoring/log/get_ctm.LMWT.log \ + mkdir -p $dir/ascore_LMWT/ '&&' \ + lattice-1best --lm-scale=LMWT "ark:gunzip -c $dir/lat.*.gz|" ark:- \| \ + lattice-align-words $lang/phones/word_boundary.int $model ark:- ark:- \| \ + nbest-to-ctm ark:- - \| \ + utils/int2sym.pl -f 5 $lang/words.txt \| \ + utils/convert_ctm.pl $data/segments $data/reco2file_and_channel \ + '>' $dir/ascore_LMWT/$name.ctm || exit 1; + fi +fi + +if [ $stage -le 1 ]; then +# Remove some stuff we don't want to score, from the ctm. + for x in $dir/ascore_*/$name.ctm; do + cp $x $dir/tmpf; + cat $dir/tmpf | grep -i -v -E '\[noise|laughter|vocalized-noise\]' | \ + grep -i -v -E '' > $x; +# grep -i -v -E '|%HESITATION' > $x; + done +fi + +if [ $stage -le 2 ]; then + if [ "$asclite" == "true" ]; then + oname=$name + [ ! -z $overlap_spk ] && oname=${name}_o$overlap_spk + echo "asclite is starting" + # Run scoring, meaning of hubscr.pl options: + # -G .. produce alignment graphs, + # -v .. verbose, + # -m .. max-memory in GBs, + # -o .. max N of overlapping speakers, + # -a .. use asclite, + # -C .. compression for asclite, + # -B .. blocksize for asclite (kBs?), + # -p .. path for other components, + # -V .. skip validation of input transcripts, + # -h rt-stt .. removes non-lexical items from CTM, + $cmd LMWT=$min_lmwt:$max_lmwt $dir/ascoring/log/score.LMWT.log \ + cp $data/stm $dir/ascore_LMWT/ '&&' \ + cp $dir/ascore_LMWT/${name}.ctm $dir/ascore_LMWT/${oname}.ctm '&&' \ + $hubscr -G -v -m 1:2 -o$overlap_spk -a -C -B 8192 -p $hubdir -V -l english \ + -h rt-stt -g $data/glm -r $dir/ascore_LMWT/stm $dir/ascore_LMWT/${oname}.ctm || exit 1 + # Compress some scoring outputs : alignment info and graphs, + echo -n "compressing asclite outputs " + for LMWT in $(seq $min_lmwt $max_lmwt); do + ascore=$dir/ascore_${LMWT} + gzip -f $ascore/${oname}.ctm.filt.aligninfo.csv + cp $ascore/${oname}.ctm.filt.alignments/index.html $ascore/${oname}.ctm.filt.overlap.html + tar -C $ascore -czf $ascore/${oname}.ctm.filt.alignments.tar.gz ${oname}.ctm.filt.alignments + rm -r $ascore/${oname}.ctm.filt.alignments + echo -n "LMWT:$LMWT " + done + echo done + else + $cmd LMWT=$min_lmwt:$max_lmwt $dir/ascoring/log/score.LMWT.log \ + cp $data/stm $dir/ascore_LMWT/ '&&' \ + $hubscr -p $hubdir -V -l english -h hub5 -g $data/glm -r $dir/ascore_LMWT/stm $dir/ascore_LMWT/${name}.ctm || exit 1 + fi +fi + +exit 0 diff --git a/egs/rt/s5/local/snr b/egs/rt/s5/local/snr new file mode 120000 index 00000000000..6d422e11960 --- /dev/null +++ b/egs/rt/s5/local/snr @@ -0,0 +1 @@ +../../../wsj_noisy/s5/local/snr \ No newline at end of file diff --git a/egs/rt/s5/path.sh b/egs/rt/s5/path.sh new file mode 100755 index 00000000000..8461d980758 --- /dev/null +++ b/egs/rt/s5/path.sh @@ -0,0 +1,5 @@ +export KALDI_ROOT=`pwd`/../../.. +[ -f $KALDI_ROOT/tools/env.sh ] && . $KALDI_ROOT/tools/env.sh +export PATH=$PWD/utils/:$KALDI_ROOT/src/bin:$KALDI_ROOT/tools/openfst/bin:$KALDI_ROOT/src/fstbin/:$KALDI_ROOT/src/gmmbin/:$KALDI_ROOT/src/featbin/:$KALDI_ROOT/src/lm/:$KALDI_ROOT/src/sgmmbin/:$KALDI_ROOT/src/sgmm2bin/:$KALDI_ROOT/src/fgmmbin/:$KALDI_ROOT/src/latbin/:$KALDI_ROOT/src/nnetbin:$KALDI_ROOT/src/nnet2bin/:$KALDI_ROOT/src/kwsbin:$KALDI_ROOT/src/online2bin/:$KALDI_ROOT/src/ivectorbin/:$KALDI_ROOT/src/lmbin/:$KALDI_ROOT/src/nnet3bin/:$KALDI_ROOT/src/segmenterbin/:$PWD:$PATH:$KALDI_ROOT/tools/sctk/bin +export PATH=$KALDI_ROOT/tools/sph2pipe_v2.5/:$PATH +export LC_ALL=C diff --git a/egs/rt/s5/sid b/egs/rt/s5/sid new file mode 120000 index 00000000000..5cb0274b7d6 --- /dev/null +++ b/egs/rt/s5/sid @@ -0,0 +1 @@ +../../sre08/v1/sid/ \ No newline at end of file diff --git a/egs/rt/s5/steps b/egs/rt/s5/steps new file mode 120000 index 00000000000..1b186770dd1 --- /dev/null +++ b/egs/rt/s5/steps @@ -0,0 +1 @@ +../../wsj/s5/steps/ \ No newline at end of file diff --git a/egs/rt/s5/utils b/egs/rt/s5/utils new file mode 120000 index 00000000000..a3279dc8679 --- /dev/null +++ b/egs/rt/s5/utils @@ -0,0 +1 @@ +../../wsj/s5/utils/ \ No newline at end of file diff --git a/src/ivectorbin/Makefile b/src/ivectorbin/Makefile index 71a855762fe..5df22a2bb8a 100644 --- a/src/ivectorbin/Makefile +++ b/src/ivectorbin/Makefile @@ -15,7 +15,9 @@ BINFILES = ivector-extractor-init ivector-extractor-acc-stats \ ivector-subtract-global-mean ivector-plda-scoring \ logistic-regression-train logistic-regression-eval \ logistic-regression-copy create-split-from-vad \ - ivector-extract-online ivector-adapt-plda + ivector-extract-online ivector-adapt-plda \ + ivector-extract-dense ivector-cluster \ + ivector-cluster-plda OBJFILES = From 32977de2bc08b401de488d0d6ae9eed26b68ed69 Mon Sep 17 00:00:00 2001 From: Vimal Manohar Date: Sat, 10 Dec 2016 17:11:09 -0500 Subject: [PATCH 058/213] asr_diarization: Adding ami_normalize_transcripts.pl --- .../s5b/local/ami_normalize_transcripts.pl | 129 ++++++++++++++++++ 1 file changed, 129 insertions(+) create mode 100644 egs/ami/s5b/local/ami_normalize_transcripts.pl diff --git a/egs/ami/s5b/local/ami_normalize_transcripts.pl b/egs/ami/s5b/local/ami_normalize_transcripts.pl new file mode 100644 index 00000000000..772e8b50fec --- /dev/null +++ b/egs/ami/s5b/local/ami_normalize_transcripts.pl @@ -0,0 +1,129 @@ +#!/usr/bin/env perl + +# Copyright 2014 University of Edinburgh (Author: Pawel Swietojanski) +# 2016 Vimal Manohar + +# The script - based on punctuation times - splits segments longer than #words (input parameter) +# and produces bit more more normalised form of transcripts, as follows +# MeetID Channel Spkr stime etime transcripts + +#use List::MoreUtils 'indexes'; +use strict; +use warnings; + +sub normalise_transcripts; + +sub merge_hashes { + my ($h1, $h2) = @_; + my %hash1 = %$h1; my %hash2 = %$h2; + foreach my $key2 ( keys %hash2 ) { + if( exists $hash1{$key2} ) { + warn "Key [$key2] is in both hashes!"; + next; + } else { + $hash1{$key2} = $hash2{$key2}; + } + } + return %hash1; +} + +sub print_hash { + my ($h) = @_; + my %hash = %$h; + foreach my $k (sort keys %hash) { + print "$k : $hash{$k}\n"; + } +} + +sub get_name { + #no warnings; + my $sname = sprintf("%07d_%07d", $_[0]*100, $_[1]*100) || die 'Input undefined!'; + #use warnings; + return $sname; +} + +sub split_on_comma { + + my ($text, $comma_times, $btime, $etime, $max_words_per_seg)= @_; + my %comma_hash = %$comma_times; + + print "Btime, Etime : $btime, $etime\n"; + + my $stime = ($etime+$btime)/2; #split time + my $skey = ""; + my $otime = $btime; + foreach my $k (sort {$comma_hash{$a} cmp $comma_hash{$b} } keys %comma_hash) { + print "Key : $k : $comma_hash{$k}\n"; + my $ktime = $comma_hash{$k}; + if ($ktime==$btime) { next; } + if ($ktime==$etime) { last; } + if (abs($stime-$ktime)/20) { + $st=$comma_hash{$skey}; + $et = $etime; + } + my (@utts) = split (' ', $utts1[$i]); + if ($#utts < $max_words_per_seg) { + my $nm = get_name($st, $et); + print "SplittedOnComma[$i]: $nm : $utts1[$i]\n"; + $transcripts{$nm} = $utts1[$i]; + } else { + print 'Continue splitting!'; + my %transcripts2 = split_on_comma($utts1[$i], \%comma_hash, $st, $et, $max_words_per_seg); + %transcripts = merge_hashes(\%transcripts, \%transcripts2); + } + } + return %transcripts; +} + +sub normalise_transcripts { + my $text = $_; + + #DO SOME ROUGH AND OBVIOUS PRELIMINARY NORMALISATION, AS FOLLOWS + #remove the remaining punctation labels e.g. some text ,0 some text ,1 + $text =~ s/[\.\,\?\!\:][0-9]+//g; + #there are some extra spurious puncations without spaces, e.g. UM,I, replace with space + $text =~ s/[A-Z']+,[A-Z']+/ /g; + #split words combination, ie. ANTI-TRUST to ANTI TRUST (None of them appears in cmudict anyway) + #$text =~ s/(.*)([A-Z])\s+(\-)(.*)/$1$2$3$4/g; + $text =~ s/\-/ /g; + #substitute X_M_L with X. M. L. etc. + $text =~ s/\_/. /g; + #normalise and trim spaces + $text =~ s/^\s*//g; + $text =~ s/\s*$//g; + $text =~ s/\s+/ /g; + #some transcripts are empty with -, nullify (and ignore) them + $text =~ s/^\-$//g; + $text =~ s/\s+\-$//; + # apply few exception for dashed phrases, Mm-Hmm, Uh-Huh, etc. those are frequent in AMI + # and will be added to dictionary + $text =~ s/MM HMM/MM\-HMM/g; + $text =~ s/UH HUH/UH\-HUH/g; + + return $text; +} + +while(<>) { + chomp; + print normalise_transcripts($_) . "\n"; +} + From 4219de118a3e73cf3db0aa7fa346dcdabac3a86e Mon Sep 17 00:00:00 2001 From: Vimal Manohar Date: Wed, 21 Sep 2016 19:53:25 -0400 Subject: [PATCH 059/213] asr_diarization: Two-stage decoding baseline AMI --- .../s5b/local/chain/run_decode_two_stage.sh | 135 ++++++++++++++++++ 1 file changed, 135 insertions(+) create mode 100755 egs/ami/s5b/local/chain/run_decode_two_stage.sh diff --git a/egs/ami/s5b/local/chain/run_decode_two_stage.sh b/egs/ami/s5b/local/chain/run_decode_two_stage.sh new file mode 100755 index 00000000000..0d354bfa574 --- /dev/null +++ b/egs/ami/s5b/local/chain/run_decode_two_stage.sh @@ -0,0 +1,135 @@ +#!/bin/bash + +set -e -u +set -o pipefail + +stage=-1 +decode_stage=1 + +mic=ihm +use_ihm_ali=false +exp_name=tdnn + +cleanup_affix= + +decode_set=dev +extractor= +use_ivectors=true +scoring_opts= +lmwt=8 +pad_frames=10 + +. path.sh +. cmd.sh + +. parse_options.sh + +new_mic=$mic +if [ $use_ihm_ali == "true" ]; then + new_mic=${mic}_cleanali +fi + +dir=exp/$new_mic/chain${cleanup_affix:+_$cleanup_affix}/${exp_name} + +nj=20 + +if [ $stage -le -1 ]; then + mfccdir=mfcc_${mic} + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $mfccdir/storage ]; then + utils/create_split_dir.pl /export/b0{1,2,3,4}/$USER/kaldi-data/egs/ami-$mic-$(date +'%m_%d_%H_%M')/s5/$mfccdir/storage $mfccdir/storage + fi + + steps/make_mfcc.sh --nj $nj --mfcc-config conf/mfcc.conf \ + --cmd "$train_cmd" data/$mic/${decode_set} exp/make_${mic}/$decode_set $mfccdir || exit 1; + + steps/compute_cmvn_stats.sh data/$mic/${decode_set} exp/make_${mic}/$mic/$decode_set $mfccdir || exit 1; + + utils/fix_data_dir.sh data/$mic/${decode_set} +fi + +utils/data/get_utt2dur.sh data/$mic/${decode_set} + +if [ $stage -le 0 ]; then + mfccdir=mfcc_${mic}_hires + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $mfccdir/storage ]; then + utils/create_split_dir.pl /export/b0{1,2,3,4}/$USER/kaldi-data/egs/ami-$mic-$(date +'%m_%d_%H_%M')/s5/$mfccdir/storage $mfccdir/storage + fi + + utils/copy_data_dir.sh data/$mic/$decode_set data/$mic/${decode_set}_hires + + steps/make_mfcc.sh --nj $nj --mfcc-config conf/mfcc_hires.conf \ + --cmd "$train_cmd" data/$mic/${decode_set}_hires exp/make_${mic}_hires/$decode_set $mfccdir || exit 1; + + steps/compute_cmvn_stats.sh data/$mic/${decode_set}_hires exp/make_${mic}_hires/$mic/$decode_set $mfccdir || exit 1; + + utils/fix_data_dir.sh data/$mic/${decode_set}_hires +fi + +if $use_ivectors && [ $stage -le 1 ]; then + if [ -z "$extractor" ]; then + "--extractor must be supplied when using ivectors" + fi + + steps/online/nnet2/extract_ivectors_online.sh \ + --cmd "$train_cmd" --nj 8 \ + data/$mic/${decode_set}_hires $extractor \ + exp/$mic/nnet3${cleanup_affix:+_$cleanup_affix}/ivectors_${decode_set} || exit 1 +fi + +final_lm=`cat data/local/lm/final_lm` +LM=$final_lm.pr1-7 +graph_dir=$dir/graph_${LM} +if [ $stage -le 2 ]; then + # Note: it might appear that this $lang directory is mismatched, and it is as + # far as the 'topo' is concerned, but this script doesn't read the 'topo' from + # the lang directory. + utils/mkgraph.sh --self-loop-scale 1.0 data/lang_${LM} $dir $graph_dir +fi + +nj=`cat data/$mic/${decode_set}/utt2spk|cut -d' ' -f2|sort -u|wc -l` + +if [ $nj -gt 50 ]; then + nj=50 +fi + +if [ $stage -le 3 ]; then + ivector_opts= + if $use_ivectors; then + ivector_opts="--online-ivector-dir exp/$mic/nnet3${cleanup_affix:+_$cleanup_affix}/ivectors_${decode_set}" + fi + steps/nnet3/decode.sh --acwt 1.0 --post-decode-acwt 10.0 \ + --stage $decode_stage \ + --nj $nj --cmd "$decode_cmd" $ivector_opts \ + --scoring-opts "--min-lmwt 5 $scoring_opts" \ + $graph_dir data/$mic/${decode_set}_hires $dir/decode_${decode_set} || exit 1; +fi + +ivector_weights=$dir/decode_${decode_set}/ascore_$lmwt/ivector_weights.gz + +if [ $stage -le 4 ]; then + cat $dir/decode_${decode_set}/ascore_$lmwt/${decode_set}_hires.utt.ctm | \ + grep -i -v -E '\[noise|laughter|vocalized-noise\]' | \ + local/get_ivector_weights_from_ctm_conf.pl \ + --pad-frames $pad_frames data/$mic/${decode_set}/utt2dur | \ + gzip -c > $ivector_weights +fi + +if [ $stage -le 5 ]; then + steps/online/nnet2/extract_ivectors_online.sh \ + --cmd "$train_cmd" --nj $nj --weights $ivector_weights \ + data/$mic/${decode_set}_hires $extractor \ + exp/$mic/nnet3${cleanup_affix:+_$cleanup_affix}/ivectors_${decode_set}_stage2 || exit 1 +fi + +if [ $stage -le 6 ]; then + ivector_opts= + if $use_ivectors; then + ivector_opts="--online-ivector-dir exp/$mic/nnet3${cleanup_affix:+_$cleanup_affix}/ivectors_${decode_set}_stage2" + fi + steps/nnet3/decode.sh --acwt 1.0 --post-decode-acwt 10.0 \ + --stage $decode_stage \ + --nj $nj --cmd "$decode_cmd" $ivector_opts \ + --scoring-opts "--min-lmwt 5 $scoring_opts" \ + $graph_dir data/$mic/${decode_set}_hires $dir/decode_${decode_set}_stage2 || exit 1; +fi + From 318b52ef67dc678233a9c08edd8db04f8550cfe6 Mon Sep 17 00:00:00 2001 From: Vimal Manohar Date: Sat, 10 Dec 2016 17:36:44 -0500 Subject: [PATCH 060/213] asr_diarization: Adding modify_stm.pl to remove beginning and end from scoring --- egs/ami/s5b/local/modify_stm.py | 97 +++++++++++++++++++++++++++++++++ 1 file changed, 97 insertions(+) create mode 100755 egs/ami/s5b/local/modify_stm.py diff --git a/egs/ami/s5b/local/modify_stm.py b/egs/ami/s5b/local/modify_stm.py new file mode 100755 index 00000000000..52ab6fed1ef --- /dev/null +++ b/egs/ami/s5b/local/modify_stm.py @@ -0,0 +1,97 @@ +#! /usr/bin/env python + +import sys +import collections +import itertools +import argparse + +from collections import defaultdict + +def IgnoreWordList(stm_lines, wordlist): + for i in range(0, len(stm_lines)): + line = stm_lines[i] + splits = line.strip().split() + + line_changed = False + for j in range(5, len(splits)): + if str.lower(splits[j]) in wordlist: + splits[j] = "{{ {0} / @ }}".format(splits[j]) + line_changed = True + + + if line_changed: + stm_lines[i] = " ".join(splits) + +def IgnoreIsolatedWords(stm_lines): + for i in range(0, len(stm_lines)): + line = stm_lines[i] + splits = line.strip().split() + + assert( splits[5][0] != '<' ) + + if len(splits) == 6 and splits[5] != "IGNORE_TIME_SEGMENT_IN_SCORING": + splits.insert(5, "") + else: + splits.insert(5, "") + stm_lines[i] = " ".join(splits) + +def IgnoreBeginnings(stm_lines): + beg_times = defaultdict(itertools.repeat(float("inf")).next) + + lines_to_add = [] + for line in stm_lines: + splits = line.strip().split() + + beg_times[(splits[0],splits[1])] = min(beg_times[(splits[0],splits[1])], float(splits[3])) + + for t,v in beg_times.iteritems(): + lines_to_add.append("{0} {1} {0} 0.0 {2} IGNORE_TIME_SEGMENT_IN_SCORING".format(t[0], t[1], v)) + + stm_lines.extend(lines_to_add) + +def WriteStmLines(stm_lines): + for line in stm_lines: + print(line) + +def GetArgs(): + parser = argparse.ArgumentParser("This script modifies STM to remove certain words and segments from scoring. Use sort +0 -1 +1 -2 +3nb -4 while writing out.", + formatter_class = argparse.ArgumentDefaultsHelpFormatter) + + parser.add_argument("--ignore-beginnings", + type = str, choices = ["true", "false"], + help = "Ignore beginnings of the recordings since " + "they are not transcribed") + parser.add_argument("--ignore-isolated-words", + type = str, choices = ["true", "false"], + help = "Remove isolated words from scoring " + "because they may be hard to recognize without " + "speaker diarization") + parser.add_argument("--ignore-word-list", + type = str, + help = "List of words to be ignored") + + args = parser.parse_args() + + return args + +def Main(): + args = GetArgs() + + stm_lines = [ x.strip() for x in sys.stdin.readlines() ] + + print (';; LABEL "NO_ISO", "No isolated words", "Ignoring isolated words"') + print (';; LABEL "ISO", "Isolated words", "isolated words"') + + #if args.ignore_word_list is not None: + # wordlist = {} + # for x in open(args.ignore_word_list).readlines(): + # wordlist[str.lower(x.strip())] = 1 + # IgnoreWordList(stm_lines, wordlist) + + IgnoreIsolatedWords(stm_lines) + IgnoreBeginnings(stm_lines) + + WriteStmLines(stm_lines) + +if __name__ == "__main__": + Main() From f141eb0f86f2e25ec64f0b643735cc3bb649b324 Mon Sep 17 00:00:00 2001 From: Vimal Manohar Date: Tue, 22 Nov 2016 12:42:44 -0500 Subject: [PATCH 061/213] asr_diarization: Removing sctk and other additions to AMI path.sh --- egs/ami/s5b/path.sh | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/egs/ami/s5b/path.sh b/egs/ami/s5b/path.sh index d8f46e6b8a0..b4711d23926 100644 --- a/egs/ami/s5b/path.sh +++ b/egs/ami/s5b/path.sh @@ -9,7 +9,4 @@ LMBIN=$KALDI_ROOT/tools/irstlm/bin SRILM=$KALDI_ROOT/tools/srilm/bin/i686-m64 BEAMFORMIT=$KALDI_ROOT/tools/BeamformIt -export PATH=$PATH:$LMBIN:$BEAMFORMIT:$SRILM -export PATH=$PATH:$KALDI_ROOT/tools/sph2pipe_v2.5 -export PATH=$PATH:/home/vmanoha1/kaldi-waveform/src/segmenterbin -export PATH=$PATH:$KALDI_ROOT/tools/sctk/bin +export PATH=$LMBIN:$BEAMFORMIT:$SRILM:$PATH From d95b27eaac4d011127dcb5a32f635c05a00a60f3 Mon Sep 17 00:00:00 2001 From: Vimal Manohar Date: Fri, 18 Nov 2016 00:57:40 -0500 Subject: [PATCH 062/213] asr_diarization: Updating AMI nnet3 recipes --- egs/ami/s5b/local/nnet3/run_blstm.sh | 3 +- egs/ami/s5b/local/nnet3/run_lstm.sh | 7 ++-- egs/ami/s5b/local/nnet3/run_tdnn.sh | 53 +++++++++++++++++++++------- 3 files changed, 47 insertions(+), 16 deletions(-) diff --git a/egs/ami/s5b/local/nnet3/run_blstm.sh b/egs/ami/s5b/local/nnet3/run_blstm.sh index 776151fb5aa..e0e7bcfcdcf 100755 --- a/egs/ami/s5b/local/nnet3/run_blstm.sh +++ b/egs/ami/s5b/local/nnet3/run_blstm.sh @@ -7,6 +7,7 @@ remove_egs=true use_ihm_ali=false train_set=train_cleaned ihm_gmm=tri3 +gmm=tri3a_cleaned nnet3_affix=_cleaned # BLSTM params @@ -32,6 +33,7 @@ local/nnet3/run_lstm.sh --affix $affix \ --srand $srand \ --train-stage $train_stage \ --train-set $train_set \ + --gmm $gmm \ --ihm-gmm $ihm_gmm \ --nnet3-affix $nnet3_affix \ --lstm-delay " [-1,1] [-2,2] [-3,3] " \ @@ -49,4 +51,3 @@ local/nnet3/run_lstm.sh --affix $affix \ --num-epochs $num_epochs \ --use-ihm-ali $use_ihm_ali \ --remove-egs $remove_egs - diff --git a/egs/ami/s5b/local/nnet3/run_lstm.sh b/egs/ami/s5b/local/nnet3/run_lstm.sh index c5583e2d0ef..25254629933 100755 --- a/egs/ami/s5b/local/nnet3/run_lstm.sh +++ b/egs/ami/s5b/local/nnet3/run_lstm.sh @@ -225,9 +225,12 @@ if [ $stage -le 14 ]; then [ ! -z $decode_iter ] && model_opts=" --iter $decode_iter "; for decode_set in dev eval; do ( - num_jobs=`cat data/$mic/${decode_set}_hires/utt2spk|cut -d' ' -f2|sort -u|wc -l` + nj_dev=`cat data/$mic/${decode_set}_hires/spk2utt | wc -l` + if [ $nj_dev -gt $nj ]; then + nj_dev=$nj + fi decode_dir=${dir}/decode_${decode_set} - steps/nnet3/decode.sh --nj 250 --cmd "$decode_cmd" \ + steps/nnet3/decode.sh --nj $nj_dev --cmd "$decode_cmd" \ $model_opts \ --extra-left-context $extra_left_context \ --extra-right-context $extra_right_context \ diff --git a/egs/ami/s5b/local/nnet3/run_tdnn.sh b/egs/ami/s5b/local/nnet3/run_tdnn.sh index bbc6ed5c042..7b463f4ce57 100755 --- a/egs/ami/s5b/local/nnet3/run_tdnn.sh +++ b/egs/ami/s5b/local/nnet3/run_tdnn.sh @@ -45,10 +45,12 @@ tdnn_affix= #affix for TDNN directory e.g. "a" or "b", in case we change the co # Options which are not passed through to run_ivector_common.sh train_stage=-10 splice_indexes="-2,-1,0,1,2 -1,2 -3,3 -7,2 -3,3 0 0" -remove_egs=true +remove_egs=false relu_dim=850 num_epochs=3 +common_egs_dir= + . cmd.sh . ./path.sh . ./utils/parse_options.sh @@ -122,30 +124,55 @@ fi [ ! -f $ali_dir/ali.1.gz ] && echo "$0: expected $ali_dir/ali.1.gz to exist" && exit 1 if [ $stage -le 12 ]; then + steps/nnet3/tdnn/make_configs.py \ + --self-repair-scale-nonlinearity 0.00001 \ + --feat-dir $train_data_dir \ + --ivector-dir $train_ivector_dir \ + --ali-dir $ali_dir \ + --relu-dim $relu_dim \ + --splice-indexes "$splice_indexes" \ + --use-presoftmax-prior-scale true \ + --include-log-softmax true \ + --final-layer-normalize-target 1.0 \ + $dir/configs || exit 1; +fi + +if [ $stage -le 13 ]; then if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $dir/egs/storage ]; then utils/create_split_dir.pl \ /export/b0{3,4,5,6}/$USER/kaldi-data/egs/ami-$(date +'%m_%d_%H_%M')/s5b/$dir/egs/storage $dir/egs/storage fi - steps/nnet3/tdnn/train.sh --stage $train_stage \ - --num-epochs $num_epochs --num-jobs-initial 2 --num-jobs-final 12 \ - --splice-indexes "$splice_indexes" \ - --feat-type raw \ - --online-ivector-dir ${train_ivector_dir} \ - --cmvn-opts "--norm-means=false --norm-vars=false" \ - --initial-effective-lrate 0.0015 --final-effective-lrate 0.00015 \ + steps/nnet3/train_dnn.py --stage $train_stage \ --cmd "$decode_cmd" \ - --relu-dim "$relu_dim" \ - --remove-egs "$remove_egs" \ - $train_data_dir data/lang $ali_dir $dir + --feat.online-ivector-dir $train_ivector_dir \ + --feat.cmvn-opts "--norm-means=false --norm-vars=false" \ + --egs.dir "$common_egs_dir" \ + --trainer.samples-per-iter 400000 \ + --trainer.num-epochs $num_epochs \ + --trainer.optimization.num-jobs-initial 2 \ + --trainer.optimization.num-jobs-final 12 \ + --trainer.optimization.initial-effective-lrate 0.0015 \ + --trainer.optimization.final-effective-lrate 0.00015 \ + --trainer.max-param-change 2.0 \ + --cleanup.remove-egs "$remove_egs" \ + --cleanup true \ + --feat-dir $train_data_dir \ + --lang data/lang \ + --ali-dir $ali_dir \ + --dir $dir fi -if [ $stage -le 12 ]; then +if [ $stage -le 14 ]; then rm $dir/.error || true 2>/dev/null for decode_set in dev eval; do ( + nj_dev=`cat data/$mic/${decode_set}_hires/spk2utt | wc -l` + if [ $nj_dev -gt $nj ]; then + nj_dev=$nj + fi decode_dir=${dir}/decode_${decode_set} - steps/nnet3/decode.sh --nj $nj --cmd "$decode_cmd" \ + steps/nnet3/decode.sh --nj $nj_dev --cmd "$decode_cmd" \ --online-ivector-dir exp/$mic/nnet3${nnet3_affix}/ivectors_${decode_set}_hires \ $graph_dir data/$mic/${decode_set}_hires $decode_dir ) & From 0ad80e29586f6dc5acb04c34210774426a5ab4a1 Mon Sep 17 00:00:00 2001 From: Vimal Manohar Date: Sat, 10 Dec 2016 17:40:50 -0500 Subject: [PATCH 063/213] asr_diarization: Add initial training scripts for overlapped speech detection on AMI --- egs/ami/s5b/local/run_train_raw_lstm.sh | 143 ++++++++++++++++++++++++ egs/ami/s5b/local/run_train_raw_tdnn.sh | 120 ++++++++++++++++++++ 2 files changed, 263 insertions(+) create mode 100755 egs/ami/s5b/local/run_train_raw_lstm.sh create mode 100644 egs/ami/s5b/local/run_train_raw_tdnn.sh diff --git a/egs/ami/s5b/local/run_train_raw_lstm.sh b/egs/ami/s5b/local/run_train_raw_lstm.sh new file mode 100755 index 00000000000..5c0431fe796 --- /dev/null +++ b/egs/ami/s5b/local/run_train_raw_lstm.sh @@ -0,0 +1,143 @@ +#!/bin/bash + +set -o pipefail +set -e +set -u + +. cmd.sh + +# At this script level we don't support not running on GPU, as it would be painfully slow. +# If you want to run without GPU you'd have to call train_tdnn.sh with --gpu false, +# --num-threads 16 and --minibatch-size 128. + +stage=0 +train_stage=-10 +get_egs_stage=-10 +egs_opts= + +# LSTM options +splice_indexes="-2,-1,0,1,2 0" +label_delay=0 +num_lstm_layers=2 +cell_dim=64 +hidden_dim=64 +recurrent_projection_dim=32 +non_recurrent_projection_dim=32 +chunk_width=40 +chunk_left_context=40 +lstm_delay="-1 -2" + +# training options +num_epochs=3 +initial_effective_lrate=0.0003 +final_effective_lrate=0.00003 +num_jobs_initial=3 +num_jobs_final=8 +momentum=0.5 +num_chunk_per_minibatch=256 +samples_per_iter=20000 +remove_egs=false +max_param_change=1 + +num_utts_subset_valid=6 +num_utts_subset_train=6 + +use_dense_targets=false +extra_egs_copy_cmd="nnet3-copy-egs-overlap-detection ark:- ark:- |" + +# target options +train_data_dir=data/sdm1/train_whole_sp_hires_bp +targets_scp=exp/sdm1/overlap_speech_train_cleaned_sp/overlap_feats.scp +deriv_weights_scp=exp/sdm1/overlap_speech_train_cleaned_sp/deriv_weights.scp +egs_dir= +nj=40 +feat_type=raw +config_dir= +compute_objf_opts= + +mic=sdm1 +dir= +affix=a + +. cmd.sh +. ./path.sh +. ./utils/parse_options.sh + +num_hidden_layers=`echo $splice_indexes | perl -ane 'print scalar @F'` || exit 1 +if [ -z "$dir" ]; then + dir=exp/$mic/nnet3_raw/nnet_lstm +fi + +dir=$dir${affix:+_$affix}_n${num_hidden_layers} +if [ $label_delay -gt 0 ]; then dir=${dir}_ld$label_delay; fi + + +if ! cuda-compiled; then + cat < Date: Fri, 9 Dec 2016 20:15:19 -0500 Subject: [PATCH 064/213] asr_diarization: Babel changes --- .../lang/101-cantonese-limitedLP.official.conf | 2 +- .../lang/105-turkish-limitedLP.official.conf | 2 +- egs/babel/s5c/path.sh | 4 +++- egs/babel/s5c/run-1-main.sh | 1 - egs/babel/s5c/run-4-anydecode.sh | 17 +++++++---------- 5 files changed, 12 insertions(+), 14 deletions(-) diff --git a/egs/babel/s5c/conf/lang/101-cantonese-limitedLP.official.conf b/egs/babel/s5c/conf/lang/101-cantonese-limitedLP.official.conf index e5d60c12367..9efcdc6a164 100644 --- a/egs/babel/s5c/conf/lang/101-cantonese-limitedLP.official.conf +++ b/egs/babel/s5c/conf/lang/101-cantonese-limitedLP.official.conf @@ -92,7 +92,7 @@ oovSymbol="" lexiconFlags="--romanized --oov " # Scoring protocols (dummy GLM file to appease the scoring script) -glmFile=/export/babel/data/splits/Cantonese_Babel101/cantonese.glm +glmFile=dummy.glm lexicon_file=/export/babel/data/101-cantonese/release-babel101b-v0.4c_sub-train1/conversational/reference_materials/lexicon.sub-train1.txt cer=1 diff --git a/egs/babel/s5c/conf/lang/105-turkish-limitedLP.official.conf b/egs/babel/s5c/conf/lang/105-turkish-limitedLP.official.conf index ae4cb55f4d5..014b519f3b7 100644 --- a/egs/babel/s5c/conf/lang/105-turkish-limitedLP.official.conf +++ b/egs/babel/s5c/conf/lang/105-turkish-limitedLP.official.conf @@ -3,7 +3,7 @@ #speech corpora files location train_data_dir=/export/babel/data/105-turkish/release-current-b/conversational/training -train_data_list=/export/babel/data/splits/Turkish_Babel105/train.LimitedLP.official.list +train_data_list=/export/babel/data/splits/Turkish_Babel105/train.LimitedLP.list train_nj=16 #RADICAL DEV data files diff --git a/egs/babel/s5c/path.sh b/egs/babel/s5c/path.sh index c8fdbad6ff7..97954c1f560 100755 --- a/egs/babel/s5c/path.sh +++ b/egs/babel/s5c/path.sh @@ -1,5 +1,7 @@ export KALDI_ROOT=`pwd`/../../.. . /export/babel/data/software/env.sh -export PATH=$PWD/utils/:$KALDI_ROOT/tools/sph2pipe_v2.5/:$KALDI_ROOT/src/bin:$KALDI_ROOT/tools/openfst/bin:$KALDI_ROOT/src/fstbin/:$KALDI_ROOT/src/gmmbin/:$KALDI_ROOT/src/featbin/:$KALDI_ROOT/src/lm/:$KALDI_ROOT/src/sgmmbin/:$KALDI_ROOT/src/sgmm2bin/:$KALDI_ROOT/src/fgmmbin/:$KALDI_ROOT/src/latbin/:$KALDI_ROOT/src/nnetbin:$KALDI_ROOT/src/nnet2bin/:$KALDI_ROOT/src/kwsbin:$PWD:$PATH +. $KALDI_ROOT/tools/config/common_path.sh +export PATH=$PWD/utils/:$KALDI_ROOT/tools/sph2pipe_v2.5/:$KALDI_ROOT/tools/openfst/bin:$PWD:$PATH +. $KALDI_ROOT/tools/env.sh export LC_ALL=C diff --git a/egs/babel/s5c/run-1-main.sh b/egs/babel/s5c/run-1-main.sh index e01910ffac0..be37cc8dca0 100755 --- a/egs/babel/s5c/run-1-main.sh +++ b/egs/babel/s5c/run-1-main.sh @@ -249,7 +249,6 @@ if [ ! -f exp/tri5/.done ]; then touch exp/tri5/.done fi - ################################################################################ # Ready to start SGMM training ################################################################################ diff --git a/egs/babel/s5c/run-4-anydecode.sh b/egs/babel/s5c/run-4-anydecode.sh index 68b87ea1e27..2071eb94d2d 100755 --- a/egs/babel/s5c/run-4-anydecode.sh +++ b/egs/babel/s5c/run-4-anydecode.sh @@ -10,13 +10,13 @@ dir=dev10h.pem kind= data_only=false fast_path=true -skip_kws=false +skip_kws=true skip_stt=false skip_scoring=false max_states=150000 extra_kws=true vocab_kws=false -tri5_only=false +tri5_only=true wip=0.5 echo "run-4-test.sh $@" @@ -196,7 +196,6 @@ if [ ! -f $dataset_dir/.done ] ; then else echo "Unknown type of the dataset: \"$dataset_segments\"!"; echo "Valid dataset types are: seg, uem, pem"; - exit 1 fi elif [ "$dataset_kind" == "unsupervised" ] ; then if [ "$dataset_segments" == "seg" ] ; then @@ -215,12 +214,10 @@ if [ ! -f $dataset_dir/.done ] ; then else echo "Unknown type of the dataset: \"$dataset_segments\"!"; echo "Valid dataset types are: seg, uem, pem"; - exit 1 fi else echo "Unknown kind of the dataset: \"$dataset_kind\"!"; echo "Valid dataset kinds are: supervised, unsupervised, shadow"; - exit 1 fi if [ ! -f ${dataset_dir}/.plp.done ]; then @@ -284,11 +281,11 @@ if ! $fast_path ; then "${lmwt_plp_extra_opts[@]}" \ ${dataset_dir} data/lang ${decode} - local/run_kws_stt_task.sh --cer $cer --max-states $max_states \ - --skip-scoring $skip_scoring --extra-kws $extra_kws --wip $wip \ - --cmd "$decode_cmd" --skip-kws $skip_kws --skip-stt $skip_stt \ - "${lmwt_plp_extra_opts[@]}" \ - ${dataset_dir} data/lang ${decode}.si + #local/run_kws_stt_task.sh --cer $cer --max-states $max_states \ + # --skip-scoring $skip_scoring --extra-kws $extra_kws --wip $wip \ + # --cmd "$decode_cmd" --skip-kws $skip_kws --skip-stt $skip_stt \ + # "${lmwt_plp_extra_opts[@]}" \ + # ${dataset_dir} data/lang ${decode}.si fi if $tri5_only; then From 3408934e8cd2823e3caca268f98c2b74b58661e2 Mon Sep 17 00:00:00 2001 From: Vimal Manohar Date: Fri, 9 Dec 2016 23:48:49 -0500 Subject: [PATCH 065/213] diarization: Adding num-fft-bins --- src/feat/feature-fbank.cc | 8 ++++---- src/feat/feature-mfcc.cc | 8 ++++---- src/feat/feature-spectrogram.cc | 2 +- src/feat/feature-spectrogram.h | 11 ++++++++++- src/feat/feature-window.cc | 10 +++++----- src/feat/feature-window.h | 11 +++++++++++ src/feat/mel-computations.cc | 16 +++++----------- src/feat/pitch-functions.cc | 12 +++++++++++- src/feat/pitch-functions.h | 9 ++++++++- 9 files changed, 59 insertions(+), 28 deletions(-) diff --git a/src/feat/feature-fbank.cc b/src/feat/feature-fbank.cc index c54069696b5..3c53ef1ec08 100644 --- a/src/feat/feature-fbank.cc +++ b/src/feat/feature-fbank.cc @@ -28,9 +28,9 @@ FbankComputer::FbankComputer(const FbankOptions &opts): if (opts.energy_floor > 0.0) log_energy_floor_ = Log(opts.energy_floor); - int32 padded_window_size = opts.frame_opts.PaddedWindowSize(); - if ((padded_window_size & (padded_window_size-1)) == 0) // Is a power of two... - srfft_ = new SplitRadixRealFft(padded_window_size); + int32 num_fft_bins = opts.frame_opts.NumFftBins(); + if ((num_fft_bins & (num_fft_bins-1)) == 0) // Is a power of two... + srfft_ = new SplitRadixRealFft(num_fft_bins); // We'll definitely need the filterbanks info for VTLN warping factor 1.0. // [note: this call caches it.] @@ -76,7 +76,7 @@ void FbankComputer::Compute(BaseFloat signal_log_energy, const MelBanks &mel_banks = *(GetMelBanks(vtln_warp)); - KALDI_ASSERT(signal_frame->Dim() == opts_.frame_opts.PaddedWindowSize() && + KALDI_ASSERT(signal_frame->Dim() == opts_.frame_opts.NumFftBins() && feature->Dim() == this->Dim()); diff --git a/src/feat/feature-mfcc.cc b/src/feat/feature-mfcc.cc index c1962a5c1d1..47912cc8693 100644 --- a/src/feat/feature-mfcc.cc +++ b/src/feat/feature-mfcc.cc @@ -29,7 +29,7 @@ void MfccComputer::Compute(BaseFloat signal_log_energy, BaseFloat vtln_warp, VectorBase *signal_frame, VectorBase *feature) { - KALDI_ASSERT(signal_frame->Dim() == opts_.frame_opts.PaddedWindowSize() && + KALDI_ASSERT(signal_frame->Dim() == opts_.frame_opts.NumFftBins() && feature->Dim() == this->Dim()); const MelBanks &mel_banks = *(GetMelBanks(vtln_warp)); @@ -98,9 +98,9 @@ MfccComputer::MfccComputer(const MfccOptions &opts): if (opts.energy_floor > 0.0) log_energy_floor_ = Log(opts.energy_floor); - int32 padded_window_size = opts.frame_opts.PaddedWindowSize(); - if ((padded_window_size & (padded_window_size-1)) == 0) // Is a power of two... - srfft_ = new SplitRadixRealFft(padded_window_size); + int32 num_fft_bins = opts.frame_opts.NumFftBins(); + if ((num_fft_bins & (num_fft_bins-1)) == 0) // Is a power of two... + srfft_ = new SplitRadixRealFft(num_fft_bins); // We'll definitely need the filterbanks info for VTLN warping factor 1.0. // [note: this call caches it.] diff --git a/src/feat/feature-spectrogram.cc b/src/feat/feature-spectrogram.cc index 953f38fc54f..f5f1c420462 100644 --- a/src/feat/feature-spectrogram.cc +++ b/src/feat/feature-spectrogram.cc @@ -48,7 +48,7 @@ void SpectrogramComputer::Compute(BaseFloat signal_log_energy, BaseFloat vtln_warp, VectorBase *signal_frame, VectorBase *feature) { - KALDI_ASSERT(signal_frame->Dim() == opts_.frame_opts.PaddedWindowSize() && + KALDI_ASSERT(signal_frame->Dim() == opts_.frame_opts.NumFftBins() && feature->Dim() == this->Dim()); diff --git a/src/feat/feature-spectrogram.h b/src/feat/feature-spectrogram.h index ec318556f24..6ca0697ef78 100644 --- a/src/feat/feature-spectrogram.h +++ b/src/feat/feature-spectrogram.h @@ -39,10 +39,13 @@ struct SpectrogramOptions { FrameExtractionOptions frame_opts; BaseFloat energy_floor; bool raw_energy; // If true, compute energy before preemphasis and windowing + bool use_energy; // append an extra dimension with energy to the filter banks + BaseFloat low_freq; // e.g. 20; lower frequency cutoff + BaseFloat high_freq; // an upper frequency cutoff; 0 -> no cutoff, negative SpectrogramOptions() : energy_floor(0.0), // not in log scale: a small value e.g. 1.0e-10 - raw_energy(true) {} + raw_energy(true), use_energy(true), low_freq(0), high_freq(0) {} void Register(OptionsItf *opts) { frame_opts.Register(opts); @@ -50,6 +53,12 @@ struct SpectrogramOptions { "Floor on energy (absolute, not relative) in Spectrogram computation"); opts->Register("raw-energy", &raw_energy, "If true, compute energy before preemphasis and windowing"); + opts->Register("use-energy", &use_energy, + "Add an extra dimension with energy to the spectrogram output."); + opts->Register("low-freq", &low_freq, + "Low cutoff frequency for mel bins"); + opts->Register("high-freq", &high_freq, + "High cutoff frequency for mel bins (if < 0, offset from Nyquist)"); } }; diff --git a/src/feat/feature-window.cc b/src/feat/feature-window.cc index 65c0a2a29c3..7b86e71dbb7 100644 --- a/src/feat/feature-window.cc +++ b/src/feat/feature-window.cc @@ -163,7 +163,7 @@ void ExtractWindow(int64 sample_offset, BaseFloat *log_energy_pre_window) { KALDI_ASSERT(sample_offset >= 0 && wave.Dim() != 0); int32 frame_length = opts.WindowSize(), - frame_length_padded = opts.PaddedWindowSize(); + num_fft_bins = opts.NumFftBins(); int64 num_samples = sample_offset + wave.Dim(), start_sample = FirstSampleOfFrame(f, opts), end_sample = start_sample + frame_length; @@ -175,8 +175,8 @@ void ExtractWindow(int64 sample_offset, KALDI_ASSERT(sample_offset == 0 || start_sample >= sample_offset); } - if (window->Dim() != frame_length_padded) - window->Resize(frame_length_padded, kUndefined); + if (window->Dim() != num_fft_bins) + window->Resize(num_fft_bins, kUndefined); // wave_start and wave_end are start and end indexes into 'wave', for the // piece of wave that we're trying to extract. @@ -206,8 +206,8 @@ void ExtractWindow(int64 sample_offset, } } - if (frame_length_padded > frame_length) - window->Range(frame_length, frame_length_padded - frame_length).SetZero(); + if (num_fft_bins > frame_length) + window->Range(frame_length, num_fft_bins - frame_length).SetZero(); SubVector frame(*window, 0, frame_length); diff --git a/src/feat/feature-window.h b/src/feat/feature-window.h index 287f1bf01f6..4165f43f1f0 100644 --- a/src/feat/feature-window.h +++ b/src/feat/feature-window.h @@ -42,6 +42,7 @@ struct FrameExtractionOptions { std::string window_type; // e.g. Hamming window bool round_to_power_of_two; BaseFloat blackman_coeff; + int32 num_fft_bins; bool snip_edges; // May be "hamming", "rectangular", "povey", "hanning", "blackman" // "povey" is a window I made to be similar to Hamming but to go to zero at the @@ -57,6 +58,7 @@ struct FrameExtractionOptions { window_type("povey"), round_to_power_of_two(true), blackman_coeff(0.42), + num_fft_bins(128), snip_edges(true){ } void Register(OptionsItf *opts) { @@ -77,6 +79,8 @@ struct FrameExtractionOptions { "Constant coefficient for generalized Blackman window."); opts->Register("round-to-power-of-two", &round_to_power_of_two, "If true, round window size to power of two."); + opts->Register("num-fft-bins", &num_fft_bins, + "Number of FFT bins to compute spectrogram"); opts->Register("snip-edges", &snip_edges, "If true, end effects will be handled by outputting only frames that " "completely fit in the file, and the number of frames depends on the " @@ -93,6 +97,13 @@ struct FrameExtractionOptions { return (round_to_power_of_two ? RoundUpToNearestPowerOfTwo(WindowSize()) : WindowSize()); } + int32 NumFftBins() const { + int32 padded_window_size = PaddedWindowSize(); + if (num_fft_bins > padded_window_size) + return (round_to_power_of_two ? RoundUpToNearestPowerOfTwo(num_fft_bins) : + num_fft_bins); + return padded_window_size; + } }; diff --git a/src/feat/mel-computations.cc b/src/feat/mel-computations.cc index 714d963f01b..db3f3334ca2 100644 --- a/src/feat/mel-computations.cc +++ b/src/feat/mel-computations.cc @@ -37,13 +37,7 @@ MelBanks::MelBanks(const MelBanksOptions &opts, int32 num_bins = opts.num_bins; if (num_bins < 3) KALDI_ERR << "Must have at least 3 mel bins"; BaseFloat sample_freq = frame_opts.samp_freq; - int32 window_length = static_cast(frame_opts.samp_freq*0.001*frame_opts.frame_length_ms); - int32 window_length_padded = - (frame_opts.round_to_power_of_two ? - RoundUpToNearestPowerOfTwo(window_length) : - window_length); - KALDI_ASSERT(window_length_padded % 2 == 0); - int32 num_fft_bins = window_length_padded/2; + int32 num_fft_bins = frame_opts.NumFftBins(); BaseFloat nyquist = 0.5 * sample_freq; BaseFloat low_freq = opts.low_freq, high_freq; @@ -59,8 +53,8 @@ MelBanks::MelBanks(const MelBanksOptions &opts, << " and high-freq " << high_freq << " vs. nyquist " << nyquist; - BaseFloat fft_bin_width = sample_freq / window_length_padded; - // fft-bin width [think of it as Nyquist-freq / half-window-length] + BaseFloat fft_bin_width = sample_freq / num_fft_bins; + // fft-bin width [think of it as Nyquist-freq / num_fft_bins] BaseFloat mel_low_freq = MelScale(low_freq); BaseFloat mel_high_freq = MelScale(high_freq); @@ -104,9 +98,9 @@ MelBanks::MelBanks(const MelBanksOptions &opts, center_freqs_(bin) = InverseMelScale(center_mel); // this_bin will be a vector of coefficients that is only // nonzero where this mel bin is active. - Vector this_bin(num_fft_bins); + Vector this_bin(num_fft_bins / 2); int32 first_index = -1, last_index = -1; - for (int32 i = 0; i < num_fft_bins; i++) { + for (int32 i = 0; i < num_fft_bins / 2; i++) { BaseFloat freq = (fft_bin_width * i); // Center frequency of this fft // bin. BaseFloat mel = MelScale(freq); diff --git a/src/feat/pitch-functions.cc b/src/feat/pitch-functions.cc index 430e9bdb53a..07e1d181243 100644 --- a/src/feat/pitch-functions.cc +++ b/src/feat/pitch-functions.cc @@ -1402,7 +1402,8 @@ OnlineProcessPitch::OnlineProcessPitch( dim_ ((opts.add_pov_feature ? 1 : 0) + (opts.add_normalized_log_pitch ? 1 : 0) + (opts.add_delta_pitch ? 1 : 0) - + (opts.add_raw_log_pitch ? 1 : 0)) { + + (opts.add_raw_log_pitch ? 1 : 0) + + (opts.add_raw_pov ? 1 : 0)) { KALDI_ASSERT(dim_ > 0 && " At least one of the pitch features should be chosen. " "Check your post-process-pitch options."); @@ -1425,6 +1426,8 @@ void OnlineProcessPitch::GetFrame(int32 frame, (*feat)(index++) = GetDeltaPitchFeature(frame_delayed); if (opts_.add_raw_log_pitch) (*feat)(index++) = GetRawLogPitchFeature(frame_delayed); + if (opts_.add_raw_pov) + (*feat)(index++) = GetRawPov(frame_delayed); KALDI_ASSERT(index == dim_); } @@ -1482,6 +1485,13 @@ BaseFloat OnlineProcessPitch::GetNormalizedLogPitchFeature(int32 frame) { return normalized_log_pitch * opts_.pitch_scale; } +BaseFloat OnlineProcessPitch::GetRawPov(int32 frame) const { + Vector tmp(kRawFeatureDim); + src_->GetFrame(frame, &tmp); // (NCCF, pitch) from pitch extractor + BaseFloat nccf = tmp(0); + return NccfToPov(nccf); +} + // inline void OnlineProcessPitch::GetNormalizationWindow(int32 t, diff --git a/src/feat/pitch-functions.h b/src/feat/pitch-functions.h index 70e85380be6..b94ac661c10 100644 --- a/src/feat/pitch-functions.h +++ b/src/feat/pitch-functions.h @@ -231,6 +231,7 @@ struct ProcessPitchOptions { bool add_normalized_log_pitch; bool add_delta_pitch; bool add_raw_log_pitch; + bool add_raw_pov; ProcessPitchOptions() : pitch_scale(2.0), @@ -245,7 +246,7 @@ struct ProcessPitchOptions { add_pov_feature(true), add_normalized_log_pitch(true), add_delta_pitch(true), - add_raw_log_pitch(false) { } + add_raw_log_pitch(false), add_raw_pov(false) { } void Register(ParseOptions *opts) { @@ -286,6 +287,8 @@ struct ProcessPitchOptions { "features"); opts->Register("add-raw-log-pitch", &add_raw_log_pitch, "If true, log(pitch) is added to output features"); + opts->Register("add-raw-pov", &add_raw_pov, + "If true, add NCCF converted to POV"); } }; @@ -396,6 +399,10 @@ class OnlineProcessPitch: public OnlineFeatureInterface { /// Called from GetFrame(). inline BaseFloat GetNormalizedLogPitchFeature(int32 frame); + /// Computes and retures the raw POV for this frames. + /// Called from GetFrames(). + inline BaseFloat GetRawPov(int32 frame) const; + /// Computes the normalization window sizes. inline void GetNormalizationWindow(int32 frame, int32 src_frames_ready, From 0755e2ccad9f9a1e0886eaf334d007d72a7b5ae6 Mon Sep 17 00:00:00 2001 From: Vimal Manohar Date: Fri, 9 Dec 2016 23:50:21 -0500 Subject: [PATCH 066/213] asr_diarization: Removing stats component from make_jesus_configs.py --- egs/wsj/s5/steps/nnet3/make_jesus_configs.py | 67 -------------------- 1 file changed, 67 deletions(-) diff --git a/egs/wsj/s5/steps/nnet3/make_jesus_configs.py b/egs/wsj/s5/steps/nnet3/make_jesus_configs.py index 7f3aba2328c..0742fb4f1df 100755 --- a/egs/wsj/s5/steps/nnet3/make_jesus_configs.py +++ b/egs/wsj/s5/steps/nnet3/make_jesus_configs.py @@ -139,73 +139,6 @@ printable_name, old_val, new_val, args.num_jesus_blocks)) setattr(args, name, new_val); -# this is a bit like a struct, initialized from a string, which describes how to -# set up the statistics-pooling and statistics-extraction components. -# An example string is 'mean(-99:3:9::99)', which means, compute the mean of -# data within a window of -99 to +99, with distinct means computed every 9 frames -# (we round to get the appropriate one), and with the input extracted on multiples -# of 3 frames (so this will force the input to this layer to be evaluated -# every 3 frames). Another example string is 'mean+stddev(-99:3:9:99)', -# which will also cause the standard deviation to be computed. -class StatisticsConfig: - # e.g. c = StatisticsConfig('mean+stddev(-99:3:9:99)', 400, 'jesus1-forward-output-affine') - def __init__(self, config_string, input_dim, input_name): - self.input_dim = input_dim - self.input_name = input_name - - m = re.search("(mean|mean\+stddev)\((-?\d+):(-?\d+):(-?\d+):(-?\d+)\)", - config_string) - if m == None: - sys.exit("Invalid splice-index or statistics-config string: " + config_string) - self.output_stddev = (m.group(1) != 'mean') - self.left_context = -int(m.group(2)) - self.input_period = int(m.group(3)) - self.stats_period = int(m.group(4)) - self.right_context = int(m.group(5)) - if not (self.left_context > 0 and self.right_context > 0 and - self.input_period > 0 and self.stats_period > 0 and - self.left_context % self.stats_period == 0 and - self.right_context % self.stats_period == 0 and - self.stats_period % self.input_period == 0): - sys.exit("Invalid configuration of statistics-extraction: " + config_string) - - # OutputDim() returns the output dimension of the node that this produces. - def OutputDim(self): - return self.input_dim * (2 if self.output_stddev else 1) - - # OutputDims() returns an array of output dimensions, consisting of - # [ input-dim ] if just "mean" was specified, otherwise - # [ input-dim input-dim ] - def OutputDims(self): - return [ self.input_dim, self.input_dim ] if self.output_stddev else [ self.input_dim ] - - # Descriptor() returns the textual form of the descriptor by which the - # output of this node is to be accessed. - def Descriptor(self): - return 'Round({0}-pooling-{1}-{2}, {3})'.format(self.input_name, self.left_context, self.right_context, - self.stats_period) - - # This function writes the configuration lines need to compute the specified - # statistics, to the file f. - def WriteConfigs(self, f): - print('component name={0}-extraction-{1}-{2} type=StatisticsExtractionComponent input-dim={3} ' - 'input-period={4} output-period={5} include-variance={6} '.format( - self.input_name, self.left_context, self.right_context, - self.input_dim, self.input_period, self.stats_period, - ('true' if self.output_stddev else 'false')), file=f) - print('component-node name={0}-extraction-{1}-{2} component={0}-extraction-{1}-{2} input={0} '.format( - self.input_name, self.left_context, self.right_context), file=f) - stats_dim = 1 + self.input_dim * (2 if self.output_stddev else 1) - print('component name={0}-pooling-{1}-{2} type=StatisticsPoolingComponent input-dim={3} ' - 'input-period={4} left-context={1} right-context={2} num-log-count-features=0 ' - 'output-stddevs={5} '.format(self.input_name, self.left_context, self.right_context, - stats_dim, self.stats_period, - ('true' if self.output_stddev else 'false')), - file=f) - print('component-node name={0}-pooling-{1}-{2} component={0}-pooling-{1}-{2} input={0}-extraction-{1}-{2} '.format( - self.input_name, self.left_context, self.right_context), file=f) - - ## Work out splice_array From 3cda27b74cffab0d665416863fcb68b44f2df388 Mon Sep 17 00:00:00 2001 From: Vimal Manohar Date: Fri, 9 Dec 2016 23:51:51 -0500 Subject: [PATCH 067/213] asr_diarzation: Raw nnet3 changes --- .../nnet3/train/chain_objf/acoustic_model.py | 8 +- egs/wsj/s5/steps/libs/nnet3/train/common.py | 142 +++++++++++++++++- .../nnet3/train/frame_level_objf/common.py | 10 +- egs/wsj/s5/steps/nnet3/chain/train.py | 24 ++- egs/wsj/s5/steps/nnet3/train_dnn.py | 8 + egs/wsj/s5/steps/nnet3/train_raw_dnn.py | 8 + egs/wsj/s5/steps/nnet3/train_raw_rnn.py | 22 ++- egs/wsj/s5/steps/nnet3/train_rnn.py | 60 ++++++-- 8 files changed, 250 insertions(+), 32 deletions(-) diff --git a/egs/wsj/s5/steps/libs/nnet3/train/chain_objf/acoustic_model.py b/egs/wsj/s5/steps/libs/nnet3/train/chain_objf/acoustic_model.py index 0c871f07c2e..7e712ad912e 100644 --- a/egs/wsj/s5/steps/libs/nnet3/train/chain_objf/acoustic_model.py +++ b/egs/wsj/s5/steps/libs/nnet3/train/chain_objf/acoustic_model.py @@ -223,7 +223,9 @@ def train_one_iteration(dir, iter, srand, egs_dir, leaky_hmm_coefficient, momentum, max_param_change, shuffle_buffer_size, frame_subsampling_factor, truncate_deriv_weights, - run_opts, background_process_handler=None): + run_opts, + dropout_proportions=None, + background_process_handler=None): """ Called from steps/nnet3/chain/train.py for one iteration for neural network training with LF-MMI objective @@ -302,6 +304,10 @@ def train_one_iteration(dir, iter, srand, egs_dir, cur_num_chunk_per_minibatch = num_chunk_per_minibatch / 2 cur_max_param_change = float(max_param_change) / math.sqrt(2) + if dropout_proportions is not None: + raw_model_string = common_train_lib.apply_dropout( + dropout_proportions, raw_model_string) + train_new_models(dir=dir, iter=iter, srand=srand, num_jobs=num_jobs, num_archives_processed=num_archives_processed, num_archives=num_archives, diff --git a/egs/wsj/s5/steps/libs/nnet3/train/common.py b/egs/wsj/s5/steps/libs/nnet3/train/common.py index dc24b37fdee..f2485e36784 100644 --- a/egs/wsj/s5/steps/libs/nnet3/train/common.py +++ b/egs/wsj/s5/steps/libs/nnet3/train/common.py @@ -22,7 +22,7 @@ logger.addHandler(logging.NullHandler()) -class RunOpts: +class RunOpts(object): """A structure to store run options. Run options like queue.pl and run.pl, along with their memory @@ -318,6 +318,122 @@ def get_learning_rate(iter, num_jobs, num_iters, num_archives_processed, return num_jobs * effective_learning_rate +def parse_dropout_option(num_archives_to_process, dropout_option): + components = dropout_option.strip().split(' ') + dropout_schedule = [] + for component in components: + parts = component.split('=') + + if len(parts) == 2: + component_name = parts[0] + this_dropout_str = parts[1] + elif len(parts) == 1: + component_name = '*' + this_dropout_str = parts[0] + else: + raise Exception("The dropout schedule must be specified in the " + "format 'pattern1=func1 patter2=func2' where " + "the pattern can be omitted for a global function " + "for all components.\n" + "Got {0} in {1}".format(component, dropout_option)) + + this_dropout_values = _parse_dropout_string( + num_archives_to_process, this_dropout_str) + dropout_schedule.append((component_name, this_dropout_values)) + return dropout_schedule + + +def _parse_dropout_string(num_archives_to_process, dropout_str): + dropout_values = [] + parts = dropout_str.strip().split(',') + try: + if len(parts) < 2: + raise Exception("dropout proportion string must specify " + "at least the start and end dropouts") + + dropout_values.append((0, float(parts[0]))) + for i in range(1, len(parts)): + value_x_pair = parts[i].split('@') + if len(value_x_pair) == 1: + dropout_proportion = float(parts[i]) + dropout_values.append((0.5 * num_archives_to_process, + dropout_proportion)) + else: + assert len(value_x_pair) == 2 + dropout_proportion, data_fraction = value_x_pair + dropout_values.append( + (float(data_fraction) * num_archives_to_process, + float(dropout_proportion))) + + dropout_values.append((num_archives_to_process, float(parts[-1]))) + except Exception as e: + logger.error("Unable to parse dropout proportion string {0}. " + "See help for option " + "--dropout-schedule.".format(dropout_str)) + raise e + + # reverse sort so that its easy to retrieve the dropout proportion + # for a particular data fraction + dropout_values.sort(key=lambda x: x[0], reverse=True) + for num_archives, proportion in dropout_values: + assert num_archives <= num_archives_to_process and num_archives >= 0 + assert proportion <= 1 and proportion >= 0 + + return dropout_values + + +def get_dropout_proportions(dropout_schedule, + num_archives_processed): + + dropout_proportions = [] + for component_name, component_dropout_schedule in dropout_schedule: + dropout_proportions.append( + (component_name, + _get_component_dropout(component_dropout_schedule, + num_archives_processed))) + return dropout_proportions + + +def _get_component_dropout(dropout_schedule, num_archives_processed): + if num_archives_processed == 0: + assert dropout_schedule[-1][0] == 0 + return dropout_schedule[-1][1] + try: + (dropout_schedule_index, initial_num_archives, + initial_dropout) = next((i, tup[0], tup[1]) + for i, tup in enumerate(dropout_schedule) + if tup[0] < num_archives_processed) + except StopIteration as e: + logger.error("Could not find num_archives in dropout schedule " + "corresponding to num_archives_processed {0}.\n" + "Maybe something wrong with the parsed " + "dropout schedule {1}.".format( + num_archives_processed, dropout_schedule)) + raise e + + final_num_archives, final_dropout = dropout_schedule[ + dropout_schedule_index - 1] + assert (num_archives_processed > initial_num_archives + and num_archives_processed < final_num_archives) + + return ((num_archives_processed - initial_num_archives) + * (final_dropout - initial_dropout) + / (final_num_archives - initial_num_archives)) + + +def apply_dropout(dropout_proportions, raw_model_string): + edit_config_lines = [] + + for component_name, dropout_proportion in dropout_proportions: + edit_config_lines.append( + "set-dropout-proportion name={0} proportion={1}".format( + component_name, dropout_proportion)) + + return ("""{raw_model_string} nnet3-copy --edits='{edits}' \ + - - |""".format(raw_model_string=raw_model_string, + edits=";".join(edit_config_lines))) + + def do_shrinkage(iter, model_file, shrink_saturation_threshold, get_raw_nnet_from_am=True): @@ -530,6 +646,30 @@ def __init__(self): Note: we implemented it in such a way that it doesn't increase the effective learning rate.""") + self.parser.add_argument("--trainer.dropout-schedule", type=str, + dest='dropout_schedule', default='', + help="""Use this to specify the dropout + schedule. You specify a piecewise linear + function on the domain [0,1], where 0 is the + start and 1 is the end of training; the + function-argument (x) rises linearly with the + amount of data you have seen, not iteration + number (this improves invariance to + num-jobs-{initial-final}). E.g. '0,0.2,0' + means 0 at the start; 0.2 after seeing half + the data; and 0 at the end. You may specify + the x-value of selected points, e.g. + '0,0.2@0.25,0' means that the 0.2 + dropout-proportion is reached a quarter of the + way through the data. The start/end x-values + are at x=0/x=1, and other unspecified x-values + are interpolated between known x-values. You + may specify different rules for different + component-name patterns using 'pattern1=func1 + pattern2=func2', e.g. 'relu*=0,0.1,0 + lstm*=0,0.2,0'. More general should precede + less general patterns, as they are applied + sequentially.""") # General options self.parser.add_argument("--stage", type=int, default=-4, diff --git a/egs/wsj/s5/steps/libs/nnet3/train/frame_level_objf/common.py b/egs/wsj/s5/steps/libs/nnet3/train/frame_level_objf/common.py index 55508daf02c..37dd36aa392 100644 --- a/egs/wsj/s5/steps/libs/nnet3/train/frame_level_objf/common.py +++ b/egs/wsj/s5/steps/libs/nnet3/train/frame_level_objf/common.py @@ -13,6 +13,7 @@ import logging import math import os +import random import time import libs.common as common_lib @@ -153,7 +154,7 @@ def train_one_iteration(dir, iter, srand, egs_dir, cv_minibatch_size=256, frames_per_eg=-1, min_deriv_time=None, max_deriv_time=None, min_left_context=None, min_right_context=None, - shrinkage_value=1.0, + shrinkage_value=1.0, dropout_proportions=None, get_raw_nnet_from_am=True, background_process_handler=None, extra_egs_copy_cmd=""): @@ -280,11 +281,16 @@ def train_one_iteration(dir, iter, srand, egs_dir, except OSError: pass + if dropout_proportions is not None: + raw_model_string = common_train_lib.apply_dropout( + dropout_proportions, raw_model_string) + train_new_models(dir=dir, iter=iter, srand=srand, num_jobs=num_jobs, num_archives_processed=num_archives_processed, num_archives=num_archives, raw_model_string=raw_model_string, egs_dir=egs_dir, - left_context=left_context, right_context=right_context, + left_context=left_context, + right_context=right_context, momentum=momentum, max_param_change=cur_max_param_change, shuffle_buffer_size=shuffle_buffer_size, minibatch_size=cur_minibatch_size, diff --git a/egs/wsj/s5/steps/nnet3/chain/train.py b/egs/wsj/s5/steps/nnet3/chain/train.py index 0254589be85..ca5c5f098ad 100755 --- a/egs/wsj/s5/steps/nnet3/chain/train.py +++ b/egs/wsj/s5/steps/nnet3/chain/train.py @@ -134,14 +134,17 @@ def get_args(): shrink-threshold at the non-linearities. E.g. 0.99. Only applicable when the neural net contains sigmoid or tanh units.""") - parser.add_argument("--trainer.optimization.shrink-saturation-threshold", type=float, + parser.add_argument("--trainer.optimization.shrink-saturation-threshold", + type=float, dest='shrink_saturation_threshold', default=0.40, - help="""Threshold that controls when we apply the 'shrinkage' - (i.e. scaling by shrink-value). If the saturation of the - sigmoid and tanh nonlinearities in the neural net (as - measured by steps/nnet3/get_saturation.pl) exceeds this - threshold we scale the parameter matrices with the + help="""Threshold that controls when we apply the + 'shrinkage' (i.e. scaling by shrink-value). If the + saturation of the sigmoid and tanh nonlinearities in + the neural net (as measured by + steps/nnet3/get_saturation.pl) exceeds this threshold + we scale the parameter matrices with the shrink-value.""") + # RNN-specific training options parser.add_argument("--trainer.deriv-truncate-margin", type=int, dest='deriv_truncate_margin', default=None, @@ -307,7 +310,6 @@ def train(args, run_opts, background_process_handler): nnet3-init --srand=-2 {dir}/configs/init.config \ {dir}/init.raw""".format(command=run_opts.command, dir=args.dir)) - egs_left_context = left_context + args.frame_subsampling_factor/2 egs_right_context = right_context + args.frame_subsampling_factor/2 @@ -392,6 +394,10 @@ def learning_rate(iter, current_num_jobs, num_archives_processed): args.initial_effective_lrate, args.final_effective_lrate) + if args.dropout_schedule is not None: + dropout_schedule = common_train_lib.parse_dropout_option( + num_archives_to_process, args.dropout_schedule) + min_deriv_time = None max_deriv_time = None if args.deriv_truncate_margin is not None: @@ -436,6 +442,10 @@ def learning_rate(iter, current_num_jobs, num_archives_processed): num_archives=num_archives, learning_rate=learning_rate(iter, current_num_jobs, num_archives_processed), + dropout_proportions=( + None if args.dropout_schedule is None + else common_train_lib.get_dropout_proportions( + dropout_schedule, num_archives_processed)), shrinkage_value=shrinkage_value, num_chunk_per_minibatch=args.num_chunk_per_minibatch, num_hidden_layers=num_hidden_layers, diff --git a/egs/wsj/s5/steps/nnet3/train_dnn.py b/egs/wsj/s5/steps/nnet3/train_dnn.py index 83170ea1e8e..2813f719606 100755 --- a/egs/wsj/s5/steps/nnet3/train_dnn.py +++ b/egs/wsj/s5/steps/nnet3/train_dnn.py @@ -286,6 +286,10 @@ def learning_rate(iter, current_num_jobs, num_archives_processed): args.initial_effective_lrate, args.final_effective_lrate) + if args.dropout_schedule is not None: + dropout_schedule = common_train_lib.parse_dropout_option( + num_archives_to_process, args.dropout_schedule) + logger.info("Training will run for {0} epochs = " "{1} iterations".format(args.num_epochs, num_iters)) @@ -312,6 +316,10 @@ def learning_rate(iter, current_num_jobs, num_archives_processed): num_archives=num_archives, learning_rate=learning_rate(iter, current_num_jobs, num_archives_processed), + dropout_proportions=( + None if args.dropout_schedule is None + else common_train_lib.get_dropout_proportions( + dropout_schedule, num_archives_processed)), minibatch_size=args.minibatch_size, frames_per_eg=args.frames_per_eg, num_hidden_layers=num_hidden_layers, diff --git a/egs/wsj/s5/steps/nnet3/train_raw_dnn.py b/egs/wsj/s5/steps/nnet3/train_raw_dnn.py index d7651889d83..efeaa13662e 100755 --- a/egs/wsj/s5/steps/nnet3/train_raw_dnn.py +++ b/egs/wsj/s5/steps/nnet3/train_raw_dnn.py @@ -288,6 +288,10 @@ def learning_rate(iter, current_num_jobs, num_archives_processed): args.initial_effective_lrate, args.final_effective_lrate) + if args.dropout_schedule is not None: + dropout_schedule = common_train_lib.parse_dropout_option( + num_archives_to_process, args.dropout_schedule) + logger.info("Training will run for {0} epochs = " "{1} iterations".format(args.num_epochs, num_iters)) @@ -314,6 +318,10 @@ def learning_rate(iter, current_num_jobs, num_archives_processed): num_archives=num_archives, learning_rate=learning_rate(iter, current_num_jobs, num_archives_processed), + dropout_proportions=( + None if args.dropout_schedule is None + else common_train_lib.get_dropout_proportions( + dropout_schedule, num_archives_processed)), minibatch_size=args.minibatch_size, frames_per_eg=args.frames_per_eg, num_hidden_layers=num_hidden_layers, diff --git a/egs/wsj/s5/steps/nnet3/train_raw_rnn.py b/egs/wsj/s5/steps/nnet3/train_raw_rnn.py index e4af318fb57..4a2424e54f5 100755 --- a/egs/wsj/s5/steps/nnet3/train_raw_rnn.py +++ b/egs/wsj/s5/steps/nnet3/train_raw_rnn.py @@ -116,13 +116,15 @@ def get_args(): shrink-threshold at the non-linearities. E.g. 0.99. Only applicable when the neural net contains sigmoid or tanh units.""") - parser.add_argument("--trainer.optimization.shrink-saturation-threshold", type=float, + parser.add_argument("--trainer.optimization.shrink-saturation-threshold", + type=float, dest='shrink_saturation_threshold', default=0.40, - help="""Threshold that controls when we apply the 'shrinkage' - (i.e. scaling by shrink-value). If the saturation of the - sigmoid and tanh nonlinearities in the neural net (as - measured by steps/nnet3/get_saturation.pl) exceeds this - threshold we scale the parameter matrices with the + help="""Threshold that controls when we apply the + 'shrinkage' (i.e. scaling by shrink-value). If the + saturation of the sigmoid and tanh nonlinearities in + the neural net (as measured by + steps/nnet3/get_saturation.pl) exceeds this threshold + we scale the parameter matrices with the shrink-value.""") parser.add_argument("--trainer.optimization.cv-minibatch-size", type=int, dest='cv_minibatch_size', default=256, @@ -391,6 +393,10 @@ def learning_rate(iter, current_num_jobs, num_archives_processed): args.initial_effective_lrate, args.final_effective_lrate) + if args.dropout_schedule is not None: + dropout_schedule = common_train_lib.parse_dropout_option( + num_archives_to_process, args.dropout_schedule) + min_deriv_time = None max_deriv_time = None if args.deriv_truncate_margin is not None: @@ -437,6 +443,10 @@ def learning_rate(iter, current_num_jobs, num_archives_processed): num_archives=num_archives, learning_rate=learning_rate(iter, current_num_jobs, num_archives_processed), + dropout_proportions=( + None if args.dropout_schedule is None + else common_train_lib.get_dropout_proportions( + dropout_schedule, num_archives_processed)), shrinkage_value=shrinkage_value, minibatch_size=args.num_chunk_per_minibatch, num_hidden_layers=num_hidden_layers, diff --git a/egs/wsj/s5/steps/nnet3/train_rnn.py b/egs/wsj/s5/steps/nnet3/train_rnn.py index 482c9a8ee03..a0318f28829 100755 --- a/egs/wsj/s5/steps/nnet3/train_rnn.py +++ b/egs/wsj/s5/steps/nnet3/train_rnn.py @@ -65,13 +65,24 @@ def get_args(): used to train an LSTM. Caution: if you double this you should halve --trainer.samples-per-iter.""") - parser.add_argument("--egs.chunk-left-context", type=int, - dest='chunk_left_context', default=40, - help="""Number of left steps used in the estimation of - LSTM state before prediction of the first label""") - - parser.add_argument("--trainer.samples-per-iter", type=int, - dest='samples_per_iter', default=20000, + parser.add_argument("--egs.chunk-left-context", type=int, dest='chunk_left_context', + default = 40, + help="""Number of left steps used in the estimation of LSTM + state before prediction of the first label""") + parser.add_argument("--egs.chunk-right-context", type=int, dest='chunk_right_context', + default = 0, + help="""Number of right steps used in the estimation of BLSTM + state before prediction of the first label""") + parser.add_argument("--trainer.min-extra-left-context", type=int, dest='min_extra_left_context', + default = None, + help="""Number of left steps used in the estimation of LSTM + state before prediction of the first label""") + parser.add_argument("--trainer.min-extra-right-context", type=int, dest='min_extra_right_context', + default = None, + help="""Number of right steps used in the estimation of BLSTM + state before prediction of the first label""") + parser.add_argument("--trainer.samples-per-iter", type=int, dest='samples_per_iter', + default=20000, help="""This is really the number of egs in each archive. Each eg has 'chunk_width' frames in it-- for chunk_width=20, this value (20k) is equivalent @@ -100,13 +111,15 @@ def get_args(): shrink-threshold at the non-linearities. E.g. 0.99. Only applicable when the neural net contains sigmoid or tanh units.""") - parser.add_argument("--trainer.optimization.shrink-saturation-threshold", type=float, + parser.add_argument("--trainer.optimization.shrink-saturation-threshold", + type=float, dest='shrink_saturation_threshold', default=0.40, - help="""Threshold that controls when we apply the 'shrinkage' - (i.e. scaling by shrink-value). If the saturation of the - sigmoid and tanh nonlinearities in the neural net (as - measured by steps/nnet3/get_saturation.pl) exceeds this - threshold we scale the parameter matrices with the + help="""Threshold that controls when we apply the + 'shrinkage' (i.e. scaling by shrink-value). If the + saturation of the sigmoid and tanh nonlinearities in + the neural net (as measured by + steps/nnet3/get_saturation.pl) exceeds this threshold + we scale the parameter matrices with the shrink-value.""") parser.add_argument("--trainer.optimization.cv-minibatch-size", type=int, dest='cv_minibatch_size', default=256, @@ -177,12 +190,19 @@ def process_args(args): "--trainer.deriv-truncate-margin.".format( args.deriv_truncate_margin)) + if args.min_extra_left_context is None: + args.min_extra_left_context = args.chunk_left_context + + if args.min_extra_right_context is None: + args.min_extra_right_context = args.chunk_right_context + if (not os.path.exists(args.dir) or not os.path.exists(args.dir+"/configs")): raise Exception("This scripts expects {0} to exist and have a configs " "directory which is the output of " "make_configs.py script") + if args.transform_dir is None: args.transform_dir = args.ali_dir @@ -363,6 +383,10 @@ def learning_rate(iter, current_num_jobs, num_archives_processed): args.initial_effective_lrate, args.final_effective_lrate) + if args.dropout_schedule is not None: + dropout_schedule = common_train_lib.parse_dropout_option( + num_archives_to_process, args.dropout_schedule) + min_deriv_time = None max_deriv_time = None if args.deriv_truncate_margin is not None: @@ -408,12 +432,18 @@ def learning_rate(iter, current_num_jobs, num_archives_processed): num_archives=num_archives, learning_rate=learning_rate(iter, current_num_jobs, num_archives_processed), + dropout_proportions=( + None if args.dropout_schedule is None + else common_train_lib.get_dropout_proportions( + dropout_schedule, num_archives_processed)), shrinkage_value=shrinkage_value, minibatch_size=args.num_chunk_per_minibatch, num_hidden_layers=num_hidden_layers, add_layers_period=args.add_layers_period, - left_context=left_context, - right_context=right_context, + min_left_context = model_left_context + args.min_extra_left_context, + min_right_context = model_right_context + args.min_extra_right_context, + max_left_context = left_context, + max_right_context = right_context, min_deriv_time=min_deriv_time, max_deriv_time=max_deriv_time, momentum=args.momentum, From d2e7742663c34a401fad616a289a59b4f101a0b0 Mon Sep 17 00:00:00 2001 From: Vimal Manohar Date: Fri, 9 Dec 2016 23:52:19 -0500 Subject: [PATCH 068/213] asr_diarzation: Minor consmetic change --- egs/wsj/s5/steps/nnet3/lstm/make_configs.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/egs/wsj/s5/steps/nnet3/lstm/make_configs.py b/egs/wsj/s5/steps/nnet3/lstm/make_configs.py index 9fb9fad1d0c..ff0fb8225ac 100755 --- a/egs/wsj/s5/steps/nnet3/lstm/make_configs.py +++ b/egs/wsj/s5/steps/nnet3/lstm/make_configs.py @@ -319,7 +319,8 @@ def MakeConfigs(config_dir, feat_dim, ivector_dim, num_targets, add_lda, for i in range(num_lstm_layers, num_hidden_layers): prev_layer_output = nodes.AddAffRelNormLayer(config_lines, "L{0}".format(i+1), prev_layer_output, hidden_dim, - ng_affine_options, self_repair_scale = self_repair_scale_nonlinearity, max_change_per_component = max_change_per_component) + ng_affine_options, self_repair_scale = self_repair_scale_nonlinearity, + max_change_per_component = max_change_per_component) # make the intermediate config file for layerwise discriminative # training nodes.AddFinalLayer(config_lines, prev_layer_output, num_targets, ng_affine_options, max_change_per_component = max_change_per_component_final, label_delay = label_delay, include_log_softmax = include_log_softmax, add_final_sigmoid = add_final_sigmoid, objective_type = objective_type) From ebb3c5afd26005f67bcd6ca0b9bc6826ce12fa25 Mon Sep 17 00:00:00 2001 From: Vimal Manohar Date: Fri, 9 Dec 2016 23:52:56 -0500 Subject: [PATCH 069/213] asr_diarization: fixing bug in reverberate_data_dir.py --- egs/wsj/s5/steps/data/reverberate_data_dir.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/egs/wsj/s5/steps/data/reverberate_data_dir.py b/egs/wsj/s5/steps/data/reverberate_data_dir.py index 69bc5e08b3b..0080bdba5f0 100755 --- a/egs/wsj/s5/steps/data/reverberate_data_dir.py +++ b/egs/wsj/s5/steps/data/reverberate_data_dir.py @@ -113,6 +113,12 @@ def CheckArgs(args): if not os.path.exists(args.output_additive_noise_dir): os.makedirs(args.output_additive_noise_dir) + ## Check arguments. + + if args.num_replicas > 1 and args.prefix is None: + args.prefix = "rvb" + warnings.warn("--prefix is set to 'rvb' as --num-replications is larger than 1.") + if not args.num_replicas > 0: raise Exception("--num-replications cannot be non-positive") From f99d7cdcd26940a2974249afdeadf526a67a214d Mon Sep 17 00:00:00 2001 From: Vimal Manohar Date: Fri, 9 Dec 2016 23:53:38 -0500 Subject: [PATCH 070/213] asr_diarization: Adding learning rate facor --- src/nnet3/nnet-component-itf.h | 6 +-- src/nnet3/nnet-utils.cc | 92 ++++++++++++++++++++++++++++++---- src/nnet3/nnet-utils.h | 14 ++++++ 3 files changed, 98 insertions(+), 14 deletions(-) diff --git a/src/nnet3/nnet-component-itf.h b/src/nnet3/nnet-component-itf.h index 3013c485ea4..600450a578a 100644 --- a/src/nnet3/nnet-component-itf.h +++ b/src/nnet3/nnet-component-itf.h @@ -413,15 +413,15 @@ class UpdatableComponent: public Component { /// a different value than x will returned. BaseFloat LearningRate() const { return learning_rate_; } + /// Gets the learning rate factor + BaseFloat LearningRateFactor() const { return learning_rate_factor_; } + /// Gets per-component max-change value. Note: the components themselves do /// not enforce the per-component max-change; it's enforced in class /// NnetTrainer by querying the max-changes for each component. /// See NnetTrainer::UpdateParamsWithMaxChange() in nnet3/nnet-training.cc. BaseFloat MaxChange() const { return max_change_; } - /// Gets the learning rate factor - BaseFloat LearningRateFactor() const { return learning_rate_factor_; } - virtual std::string Info() const; /// The following new virtual function returns the total dimension of diff --git a/src/nnet3/nnet-utils.cc b/src/nnet3/nnet-utils.cc index d65193d9a54..49cb16126ed 100644 --- a/src/nnet3/nnet-utils.cc +++ b/src/nnet3/nnet-utils.cc @@ -144,7 +144,7 @@ void ComputeSimpleNnetContext(const Nnet &nnet, // This will crash if the total context (left + right) is greater // than window_size. - int32 window_size = 100; + int32 window_size = 150; // by going "<= modulus" instead of "< modulus" we do one more computation // than we really need; it becomes a sanity check. for (int32 input_start = 0; input_start <= modulus; input_start++) @@ -301,6 +301,25 @@ void SetLearningRates(const Vector &learning_rates, KALDI_ASSERT(i == learning_rates.Dim()); } +void SetLearningRateFactors(const Vector &learning_rate_factors, + Nnet *nnet) { + int32 i = 0; + for (int32 c = 0; c < nnet->NumComponents(); c++) { + Component *comp = nnet->GetComponent(c); + if (comp->Properties() & kUpdatableComponent) { + // For now all updatable components inherit from class UpdatableComponent. + // If that changes in future, we will change this code. + UpdatableComponent *uc = dynamic_cast(comp); + if (uc == NULL) + KALDI_ERR << "Updatable component does not inherit from class " + "UpdatableComponent; change this code."; + KALDI_ASSERT(i < learning_rate_factors.Dim()); + uc->SetLearningRateFactor(learning_rate_factors(i++)); + } + } + KALDI_ASSERT(i == learning_rate_factors.Dim()); +} + void GetLearningRates(const Nnet &nnet, Vector *learning_rates) { learning_rates->Resize(NumUpdatableComponents(nnet)); @@ -320,6 +339,25 @@ void GetLearningRates(const Nnet &nnet, KALDI_ASSERT(i == learning_rates->Dim()); } +void GetLearningRateFactors(const Nnet &nnet, + Vector *learning_rate_factors) { + learning_rate_factors->Resize(NumUpdatableComponents(nnet)); + int32 i = 0; + for (int32 c = 0; c < nnet.NumComponents(); c++) { + const Component *comp = nnet.GetComponent(c); + if (comp->Properties() & kUpdatableComponent) { + // For now all updatable components inherit from class UpdatableComponent. + // If that changes in future, we will change this code. + const UpdatableComponent *uc = dynamic_cast(comp); + if (uc == NULL) + KALDI_ERR << "Updatable component does not inherit from class " + "UpdatableComponent; change this code."; + (*learning_rate_factors)(i++) = uc->LearningRateFactor(); + } + } + KALDI_ASSERT(i == learning_rate_factors->Dim()); +} + void ScaleNnetComponents(const Vector &scale_factors, Nnet *nnet) { int32 i = 0; @@ -351,6 +389,25 @@ void ScaleNnet(BaseFloat scale, Nnet *nnet) { } } +void ScaleSingleComponent(BaseFloat scale, Nnet *nnet, std::string component_name) { + if (scale == 1.0) return; + else if (scale == 0.0) { + SetZero(false, nnet); + } else { + for (int32 c = 0; c < nnet->NumComponents(); c++) { + Component *comp = nnet->GetComponent(c); + std::string this_component_type = nnet->GetComponent(c)->Type(); + if (this_component_type == component_name) { + if (comp->Properties() & kUpdatableComponent) + comp->Scale(scale); + else + KALDI_ERR << "component " << component_name + << "is not an updatable component."; + } + } + } +} + void AddNnetComponents(const Nnet &src, const Vector &alphas, BaseFloat scale, Nnet *dest) { if (src.NumComponents() != dest->NumComponents()) @@ -523,16 +580,6 @@ std::string NnetInfo(const Nnet &nnet) { return ostr.str(); } -void SetDropoutProportion(BaseFloat dropout_proportion, - Nnet *nnet) { - for (int32 c = 0; c < nnet->NumComponents(); c++) { - Component *comp = nnet->GetComponent(c); - DropoutComponent *dc = dynamic_cast(comp); - if (dc != NULL) - dc->SetDropoutProportion(dropout_proportion); - } -} - void FindOrphanComponents(const Nnet &nnet, std::vector *components) { int32 num_components = nnet.NumComponents(), num_nodes = nnet.NumNodes(); std::vector is_used(num_components, false); @@ -688,6 +735,29 @@ void ReadEditConfig(std::istream &edit_config_is, Nnet *nnet) { if (outputs_remaining == 0) KALDI_ERR << "All outputs were removed."; nnet->RemoveSomeNodes(nodes_to_remove); + } else if (directive == "set-dropout-proportion") { + std::string name_pattern = "*"; + // name_pattern defaults to '*' if none is given. This pattern + // matches names of components, not nodes. + config_line.GetValue("name", &name_pattern); + BaseFloat proportion = -1; + if (!config_line.GetValue("proportion", &proportion)) { + KALDI_ERR << "In edits-config, expected proportion to be set in line: " + << config_line.WholeLine(); + } + DropoutComponent *component = NULL; + int32 num_dropout_proportions_set = 0; + for (int32 c = 0; c < nnet->NumComponents(); c++) { + if (NameMatchesPattern(nnet->GetComponentName(c).c_str(), + name_pattern.c_str()) && + (component = + dynamic_cast(nnet->GetComponent(c)))) { + component->SetDropoutProportion(proportion); + num_dropout_proportions_set++; + } + } + KALDI_LOG << "Set dropout proportions for " + << num_dropout_proportions_set << " nodes."; } else { KALDI_ERR << "Directive '" << directive << "' is not currently " "supported (reading edit-config)."; diff --git a/src/nnet3/nnet-utils.h b/src/nnet3/nnet-utils.h index 1e0dcefd703..21dbc67be2a 100644 --- a/src/nnet3/nnet-utils.h +++ b/src/nnet3/nnet-utils.h @@ -127,11 +127,22 @@ void ScaleLearningRate(BaseFloat learning_rate_scale, void SetLearningRates(const Vector &learning_rates, Nnet *nnet); +/// Sets the learning rate factors for all the updatable components in +/// the neural net to the values in 'learning_rate_factors' vector +/// (one for each updatable component). +void SetLearningRateFactors( + const Vector &learning_rate_factors, + Nnet *nnet); + /// Get the learning rates for all the updatable components in the neural net /// (the output must have dim equal to the number of updatable components). void GetLearningRates(const Nnet &nnet, Vector *learning_rates); +/// Get the learning rate factors for all the updatable components in the neural net +void GetLearningRateFactors(const Nnet &nnet, + Vector *learning_rate_factors); + /// Scales the nnet parameters and stats by this scale. void ScaleNnet(BaseFloat scale, Nnet *nnet); @@ -233,6 +244,9 @@ void FindOrphanNodes(const Nnet &nnet, std::vector *nodes); remove internal nodes directly; instead you should use the command 'remove-orphans'. + set-dropout-proportion [name=] proportion= + Sets the dropout rates for any components of type DropoutComponent whose + names match the given (e.g. lstm*). defaults to "*". \endverbatim */ void ReadEditConfig(std::istream &config_file, Nnet *nnet); From b47298741f99f18d0720da20709da0d902595458 Mon Sep 17 00:00:00 2001 From: Vimal Manohar Date: Fri, 9 Dec 2016 23:54:17 -0500 Subject: [PATCH 071/213] asr_diarization: Adding dropout --- src/nnet3bin/nnet3-copy.cc | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/src/nnet3bin/nnet3-copy.cc b/src/nnet3bin/nnet3-copy.cc index e4a41933fff..ce0fb510260 100644 --- a/src/nnet3bin/nnet3-copy.cc +++ b/src/nnet3bin/nnet3-copy.cc @@ -41,8 +41,9 @@ int main(int argc, char *argv[]) { " nnet3-copy --binary=false 0.raw text.raw\n"; bool binary_write = true; + BaseFloat learning_rate = -1, - dropout = 0.0; + dropout = -1; std::string nnet_config, edits_config, edits_str; BaseFloat scale = 1.0; @@ -64,7 +65,10 @@ int main(int argc, char *argv[]) { "will be converted to newlines before parsing. E.g. " "'--edits=remove-orphans'."); po.Register("set-dropout-proportion", &dropout, "Set dropout proportion " - "in all DropoutComponent to this value."); + "in all DropoutComponent to this value. " + "This option is deprecated. Use set-dropout-proportion " + "option in edits-config. See comments in ReadEditConfig() " + "in nnet3/nnet-utils.h."); po.Register("scale", &scale, "The parameter matrices are scaled" " by the specified value."); po.Read(argc, argv); @@ -92,7 +96,10 @@ int main(int argc, char *argv[]) { ScaleNnet(scale, &nnet); if (dropout > 0) - SetDropoutProportion(dropout, &nnet); + KALDI_ERR << "--dropout option is deprecated. " + << "Use set-dropout-proportion " + << "option in edits-config. See comments in ReadEditConfig() " + << "in nnet3/nnet-utils.h."; if (!edits_config.empty()) { Input ki(edits_config); From df3319eff2fadd941cff17fb390c4676838124b3 Mon Sep 17 00:00:00 2001 From: Vimal Manohar Date: Fri, 9 Dec 2016 23:54:38 -0500 Subject: [PATCH 072/213] asr_diarization: Adding more info in nnet3-info --- src/nnet3bin/nnet3-info.cc | 27 ++++++++++++++++++++++++++- 1 file changed, 26 insertions(+), 1 deletion(-) diff --git a/src/nnet3bin/nnet3-info.cc b/src/nnet3bin/nnet3-info.cc index 6b7fb2c629e..7f8dc82b3ce 100644 --- a/src/nnet3bin/nnet3-info.cc +++ b/src/nnet3bin/nnet3-info.cc @@ -20,6 +20,7 @@ #include "base/kaldi-common.h" #include "util/common-utils.h" #include "nnet3/nnet-nnet.h" +#include "nnet3/nnet-utils.h" int main(int argc, char *argv[]) { try { @@ -36,7 +37,14 @@ int main(int argc, char *argv[]) { " nnet3-info 0.raw\n" "See also: nnet3-am-info\n"; + bool print_detailed_info = false; + bool print_learning_rates = false; + ParseOptions po(usage); + po.Register("print-detailed-info", &print_detailed_info, + "Print more detailed info"); + po.Register("print-learning-rates", &print_learning_rates, + "Print learning rates of updatable components"); po.Read(argc, argv); @@ -50,7 +58,24 @@ int main(int argc, char *argv[]) { Nnet nnet; ReadKaldiObject(raw_nnet_rxfilename, &nnet); - std::cout << nnet.Info(); + if (print_learning_rates) { + Vector learning_rates; + GetLearningRates(nnet, &learning_rates); + std::cout << "learning-rates: " + << PrintVectorPerUpdatableComponent(nnet, learning_rates) + << "\n"; + + Vector learning_rate_factors; + GetLearningRateFactors(nnet, &learning_rate_factors); + std::cout << "learning-rate-factors: " + << PrintVectorPerUpdatableComponent(nnet, learning_rate_factors) + << "\n"; + } + + if (print_detailed_info) + std::cout << NnetInfo(nnet); + else + std::cout << nnet.Info(); return 0; } catch(const std::exception &e) { From 99c88451245e72916cd0b7b59ec08a62b8a483d3 Mon Sep 17 00:00:00 2001 From: Vimal Manohar Date: Fri, 9 Dec 2016 23:56:46 -0500 Subject: [PATCH 073/213] asr_diarization: Minor bug fix in AMI run_cleanup_segmentation.sh --- egs/ami/s5b/local/run_cleanup_segmentation.sh | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/egs/ami/s5b/local/run_cleanup_segmentation.sh b/egs/ami/s5b/local/run_cleanup_segmentation.sh index e2f0b0516ce..9a947ce1fce 100755 --- a/egs/ami/s5b/local/run_cleanup_segmentation.sh +++ b/egs/ami/s5b/local/run_cleanup_segmentation.sh @@ -129,7 +129,6 @@ fi final_lm=`cat data/local/lm/final_lm` LM=$final_lm.pr1-7 - if [ $stage -le 5 ]; then graph_dir=exp/$mic/${gmm}_${cleanup_affix}/graph_$LM nj_dev=$(cat data/$mic/dev/spk2utt | wc -l) @@ -137,9 +136,9 @@ if [ $stage -le 5 ]; then $decode_cmd $graph_dir/mkgraph.log \ utils/mkgraph.sh data/lang_$LM exp/$mic/${gmm}_${cleanup_affix} $graph_dir - steps/decode_fmllr.sh --nj $nj --cmd "$decode_cmd" --config conf/decode.conf \ + steps/decode_fmllr.sh --nj $nj_dev --cmd "$decode_cmd" --config conf/decode.conf \ $graph_dir data/$mic/dev exp/$mic/${gmm}_${cleanup_affix}/decode_dev_$LM - steps/decode_fmllr.sh --nj $nj --cmd "$decode_cmd" --config conf/decode.conf \ + steps/decode_fmllr.sh --nj $nj_eval --cmd "$decode_cmd" --config conf/decode.conf \ $graph_dir data/$mic/eval exp/$mic/${gmm}_${cleanup_affix}/decode_eval_$LM fi From 26b49b1284d2c488dcb862c265c8752addadfd35 Mon Sep 17 00:00:00 2001 From: Vimal Manohar Date: Fri, 9 Dec 2016 23:57:19 -0500 Subject: [PATCH 074/213] Adding missing break in nnet-test-utils --- src/nnet3/nnet-test-utils.cc | 1 + 1 file changed, 1 insertion(+) diff --git a/src/nnet3/nnet-test-utils.cc b/src/nnet3/nnet-test-utils.cc index da519fa1cd3..7a6e476ded1 100644 --- a/src/nnet3/nnet-test-utils.cc +++ b/src/nnet3/nnet-test-utils.cc @@ -1401,6 +1401,7 @@ static void GenerateRandomComponentConfig(std::string *component_type, *component_type = "DropoutComponent"; os << "dim=" << RandInt(1, 200) << " dropout-proportion=" << RandUniform(); + break; } case 30: { *component_type = "LstmNonlinearityComponent"; From ad1c10c88e686e45041881b3a34683ba980dfbfc Mon Sep 17 00:00:00 2001 From: Vimal Manohar Date: Fri, 9 Dec 2016 23:58:57 -0500 Subject: [PATCH 075/213] asr_diarization: adding compute-fscore binary --- src/bin/Makefile | 2 +- src/bin/compute-fscore.cc | 153 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 154 insertions(+), 1 deletion(-) create mode 100644 src/bin/compute-fscore.cc diff --git a/src/bin/Makefile b/src/bin/Makefile index 3dc59fe8112..1948ba2d681 100644 --- a/src/bin/Makefile +++ b/src/bin/Makefile @@ -25,7 +25,7 @@ BINFILES = align-equal align-equal-compiled acc-tree-stats \ build-pfile-from-ali get-post-on-ali tree-info am-info \ vector-sum matrix-sum-rows est-pca sum-lda-accs sum-mllt-accs \ transform-vec align-text matrix-dim weight-pdf-post weight-matrix \ - matrix-add-offset matrix-dot-product + matrix-add-offset matrix-dot-product compute-fscore OBJFILES = diff --git a/src/bin/compute-fscore.cc b/src/bin/compute-fscore.cc new file mode 100644 index 00000000000..eb231fe361e --- /dev/null +++ b/src/bin/compute-fscore.cc @@ -0,0 +1,153 @@ +// bin/compute-fscore.cc + +// Copyright 2016 Vimal Manohar + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "base/kaldi-common.h" +#include "util/common-utils.h" + +int main(int argc, char *argv[]) { + using namespace kaldi; + typedef kaldi::int32 int32; + + try { + const char *usage = + "Compute F1-score, precision, recall etc.\n" + "Takes two alignment files and computes statistics\n" + "\n" + "Usage: compute-fscore [options] \n" + " e.g.: compute-fscore ark:data/train/text ark:hyp_text\n"; + + ParseOptions po(usage); + + std::string mode = "strict"; + std::string mask_rspecifier; + + po.Register("mode", &mode, + "Scoring mode: \"present\"|\"all\"|\"strict\":\n" + " \"present\" means score those we have transcriptions for\n" + " \"all\" means treat absent transcriptions as empty\n" + " \"strict\" means die if all in ref not also in hyp"); + po.Register("mask", &mask_rspecifier, + "Only score on frames where mask is 1"); + + po.Read(argc, argv); + + if (po.NumArgs() != 2) { + po.PrintUsage(); + exit(1); + } + + std::string ref_rspecifier = po.GetArg(1); + std::string hyp_rspecifier = po.GetArg(2); + + if (mode != "strict" && mode != "present" && mode != "all") { + KALDI_ERR << "--mode option invalid: expected \"present\"|\"all\"|\"strict\", got " + << mode; + } + + int64 num_tp = 0, num_fp = 0, num_tn = 0, num_fn = 0, num_frames = 0; + int32 num_absent_sents = 0; + + // Both text and integers are loaded as vector of strings, + SequentialInt32VectorReader ref_reader(ref_rspecifier); + RandomAccessInt32VectorReader hyp_reader(hyp_rspecifier); + RandomAccessInt32VectorReader mask_reader(mask_rspecifier); + + // Main loop, accumulate WER stats, + for (; !ref_reader.Done(); ref_reader.Next()) { + const std::string &key = ref_reader.Key(); + const std::vector &ref_ali = ref_reader.Value(); + std::vector hyp_ali; + if (!hyp_reader.HasKey(key)) { + if (mode == "strict") + KALDI_ERR << "No hypothesis for key " << key << " and strict " + "mode specifier."; + num_absent_sents++; + if (mode == "present") // do not score this one. + continue; + } else { + hyp_ali = hyp_reader.Value(key); + } + + std::vector mask_ali; + if (!mask_rspecifier.empty()) { + if (!mask_reader.HasKey(key)) { + if (mode == "strict") + KALDI_ERR << "No hypothesis for key " << key << " and strict " + "mode specifier."; + num_absent_sents++; + if (mode == "present") // do not score this one. + continue; + } else { + mask_ali = mask_reader.Value(key); + } + } + + for (int32 i = 0; i < ref_ali.size(); i++) { + if ( (i < hyp_ali.size() && hyp_ali[i] != 0 && hyp_ali[i] != 1) || + (i < ref_ali.size() && ref_ali[i] != 0 && ref_ali[i] != 1) || + (i < mask_ali.size() && mask_ali[i] != 0 && mask_ali[i] != 1) ) { + KALDI_ERR << "Expecting alignment to be 0s or 1s"; + } + + if (!mask_rspecifier.empty() && (std::abs(static_cast(ref_ali.size()) - static_cast(mask_ali.size())) > 2) ) + KALDI_ERR << "Length mismatch: mask vs ref"; + + if (!mask_rspecifier.empty() && (i > mask_ali.size() || mask_ali[i] == 0)) continue; + num_frames++; + + if (ref_ali[i] == 1 && i > hyp_ali.size()) { num_fn++; continue; } + if (ref_ali[i] == 0 && i > hyp_ali.size()) { num_tn++; continue; } + + if (ref_ali[i] == 1 && hyp_ali[i] == 1) num_tp++; + else if (ref_ali[i] == 0 && hyp_ali[i] == 1) num_fp++; + else if (ref_ali[i] == 1 && hyp_ali[i] == 0) num_fn++; + else if (ref_ali[i] == 0 && hyp_ali[i] == 0) num_tn++; + else + KALDI_ERR << "Unknown condition"; + } + } + + // Print the ouptut, + std::cout.precision(2); + std::cerr.precision(2); + + BaseFloat precision = static_cast(num_tp) / (num_tp + num_fp); + BaseFloat recall = static_cast(num_tp) / (num_tp + num_fn); + + std::cout << "F1 " << 2 * precision * recall / (precision + recall) << "\n"; + std::cout << "Precision " << precision << "\n"; + std::cout << "Recall " << recall << "\n"; + std::cout << "Specificity " + << static_cast(num_tn) / (num_tn + num_fp) << "\n"; + std::cout << "Accuracy " + << static_cast(num_tp + num_tn) / num_frames << "\n"; + + std::cerr << "TP " << num_tp << "\n"; + std::cerr << "FP " << num_fp << "\n"; + std::cerr << "TN " << num_tn << "\n"; + std::cerr << "FN " << num_fn << "\n"; + std::cerr << "Length " << num_frames << "\n"; + + return 0; + } catch(const std::exception &e) { + std::cerr << e.what(); + return -1; + } +} + From 62e18da82083331baf59e283a1c68d569a6cc44f Mon Sep 17 00:00:00 2001 From: Vimal Manohar Date: Wed, 23 Nov 2016 22:42:54 -0500 Subject: [PATCH 076/213] asr_diarization: Create make_overlapped_data_dir.py for overlapped speech detection --- .../s5/steps/data/make_overlapped_data_dir.py | 541 ++++++++++++++++++ 1 file changed, 541 insertions(+) create mode 100644 egs/wsj/s5/steps/data/make_overlapped_data_dir.py diff --git a/egs/wsj/s5/steps/data/make_overlapped_data_dir.py b/egs/wsj/s5/steps/data/make_overlapped_data_dir.py new file mode 100644 index 00000000000..86137c26e25 --- /dev/null +++ b/egs/wsj/s5/steps/data/make_overlapped_data_dir.py @@ -0,0 +1,541 @@ +#!/usr/bin/env python +# Copyright 2016 Tom Ko +# Apache 2.0 +# script to generate reverberated data + +# we're using python 3.x style print but want it to work in python 2.x, +from __future__ import print_function +import argparse, shlex, glob, math, os, random, sys, warnings, copy, imp, ast + +data_lib = imp.load_source('dml', 'steps/data/data_dir_manipulation_lib.py') + +def GetArgs(): + # we add required arguments as named arguments for readability + parser = argparse.ArgumentParser(description="Reverberate the data directory with an option " + "to add isotropic and point source noises. " + "Usage: reverberate_data_dir.py [options...] " + "E.g. reverberate_data_dir.py --rir-set-parameters rir_list " + "--foreground-snrs 20:10:15:5:0 --background-snrs 20:10:15:5:0 " + "--noise-list-file noise_list --speech-rvb-probability 1 --num-replications 2 " + "--random-seed 1 data/train data/train_rvb", + formatter_class=argparse.ArgumentDefaultsHelpFormatter) + + parser.add_argument("--rir-set-parameters", type=str, action='append', required = True, dest = "rir_set_para_array", + help="Specifies the parameters of an RIR set. " + "Supports the specification of mixture_weight and rir_list_file_name. The mixture weight is optional. " + "The default mixture weight is the probability mass remaining after adding the mixture weights " + "of all the RIR lists, uniformly divided among the RIR lists without mixture weights. " + "E.g. --rir-set-parameters '0.3, rir_list' or 'rir_list' " + "the format of the RIR list file is " + "--rir-id --room-id " + "--receiver-position-id --source-position-id " + "--rt-60 --drr location " + "E.g. --rir-id 00001 --room-id 001 --receiver-position-id 001 --source-position-id 00001 " + "--rt60 0.58 --drr -4.885 data/impulses/Room001-00001.wav") + parser.add_argument("--noise-set-parameters", type=str, action='append', default = None, dest = "noise_set_para_array", + help="Specifies the parameters of an noise set. " + "Supports the specification of mixture_weight and noise_list_file_name. The mixture weight is optional. " + "The default mixture weight is the probability mass remaining after adding the mixture weights " + "of all the noise lists, uniformly divided among the noise lists without mixture weights. " + "E.g. --noise-set-parameters '0.3, noise_list' or 'noise_list' " + "the format of the noise list file is " + "--noise-id --noise-type " + "--bg-fg-type " + "--room-linkage " + "location " + "E.g. --noise-id 001 --noise-type isotropic --rir-id 00019 iso_noise.wav") + parser.add_argument("--speech-segments-set-parameters", type=str, action='append', default = None, dest = "speech_segments_set_para_array", + help="Specifies the speech segments for overlapped speech generation") + parser.add_argument("--num-replications", type=int, dest = "num_replicas", default = 1, + help="Number of replicate to generated for the data") + parser.add_argument('--foreground-snrs', type=str, dest = "foreground_snr_string", default = '20:10:0', help='When foreground noises are being added the script will iterate through these SNRs.') + parser.add_argument('--background-snrs', type=str, dest = "background_snr_string", default = '20:10:0', help='When background noises are being added the script will iterate through these SNRs.') + parser.add_argument('--overlap-snrs', type=str, dest = "overlap_snr_string", default = "20:10:0", help='When overlapping speech segments are being added the script will iterate through these SNRs.') + parser.add_argument('--prefix', type=str, default = None, help='This prefix will modified for each reverberated copy, by adding additional affixes.') + parser.add_argument("--speech-rvb-probability", type=float, default = 1.0, + help="Probability of reverberating a speech signal, e.g. 0 <= p <= 1") + parser.add_argument("--pointsource-noise-addition-probability", type=float, default = 1.0, + help="Probability of adding point-source noises, e.g. 0 <= p <= 1") + parser.add_argument("--isotropic-noise-addition-probability", type=float, default = 1.0, + help="Probability of adding isotropic noises, e.g. 0 <= p <= 1") + parser.add_argument("--overlapped-speech-addition-probability", type=float, default = 1.0, + help="Probability of adding overlapped speech, e.g. 0 <= p <= 1") + parser.add_argument("--rir-smoothing-weight", type=float, default = 0.3, + help="Smoothing weight for the RIR probabilties, e.g. 0 <= p <= 1. If p = 0, no smoothing will be done. " + "The RIR distribution will be mixed with a uniform distribution according to the smoothing weight") + parser.add_argument("--noise-smoothing-weight", type=float, default = 0.3, + help="Smoothing weight for the noise probabilties, e.g. 0 <= p <= 1. If p = 0, no smoothing will be done. " + "The noise distribution will be mixed with a uniform distribution according to the smoothing weight") + parser.add_argument("--overlapped-speech-smoothing-weight", type=float, default = 0.3, + help="The overlapped speech distribution will be mixed with a uniform distribution according to the smoothing weight") + parser.add_argument("--max-noises-per-minute", type=int, default = 2, + help="This controls the maximum number of point-source noises that could be added to a recording according to its duration") + parser.add_argument("--max-overlapped-segments-per-minute", type=int, default = 5, + help="This controls the maximum number of overlapping segments of speech that could be added to a recording per minute") + parser.add_argument('--random-seed', type=int, default=0, help='seed to be used in the randomization of impulses and noises') + parser.add_argument("--shift-output", type=str, help="If true, the reverberated waveform will be shifted by the amount of the peak position of the RIR", + choices=['true', 'false'], default = "true") + parser.add_argument("--output-additive-noise-dir", type=str, help="Output directory corresponding to the additive noise part of the data corruption") + parser.add_argument("--output-reverb-dir", type=str, help="Output directory corresponding to the reverberated signal part of the data corruption") + + parser.add_argument("input_dir", + help="Input data directory") + parser.add_argument("output_dir", + help="Output data directory") + + print(' '.join(sys.argv)) + + args = parser.parse_args() + args = CheckArgs(args) + + return args + +def CheckArgs(args): + if not os.path.exists(args.output_dir): + os.makedirs(args.output_dir) + + if args.output_reverb_dir is not None: + if args.output_reverb_dir == "": + args.output_reverb_dir = None + + if args.output_reverb_dir is not None: + if not os.path.exists(args.output_reverb_dir): + os.makedirs(args.output_reverb_dir) + + if args.output_additive_noise_dir is not None: + if args.output_additive_noise_dir == "": + args.output_additive_noise_dir = None + + if args.output_additive_noise_dir is not None: + if not os.path.exists(args.output_additive_noise_dir): + os.makedirs(args.output_additive_noise_dir) + + ## Check arguments. + + if args.num_replicas > 1 and args.prefix is None: + args.prefix = "rvb" + warnings.warn("--prefix is set to 'rvb' as --num-replications is larger than 1.") + + if not args.num_replicas > 0: + raise Exception("--num-replications cannot be non-positive") + + if args.speech_rvb_probability < 0 or args.speech_rvb_probability > 1: + raise Exception("--speech-rvb-probability must be between 0 and 1") + + if args.pointsource_noise_addition_probability < 0 or args.pointsource_noise_addition_probability > 1: + raise Exception("--pointsource-noise-addition-probability must be between 0 and 1") + + if args.isotropic_noise_addition_probability < 0 or args.isotropic_noise_addition_probability > 1: + raise Exception("--isotropic-noise-addition-probability must be between 0 and 1") + + if args.overlapped_speech_addition_probability < 0 or args.overlapped_speech_addition_probability > 1: + raise Exception("--overlapped-speech-addition-probability must be between 0 and 1") + + if args.rir_smoothing_weight < 0 or args.rir_smoothing_weight > 1: + raise Exception("--rir-smoothing-weight must be between 0 and 1") + + if args.noise_smoothing_weight < 0 or args.noise_smoothing_weight > 1: + raise Exception("--noise-smoothing-weight must be between 0 and 1") + + if args.overlapped_speech_smoothing_weight < 0 or args.overlapped_speech_smoothing_weight > 1: + raise Exception("--overlapped-speech-smoothing-weight must be between 0 and 1") + + if args.max_noises_per_minute < 0: + raise Exception("--max-noises-per-minute cannot be negative") + + if args.max_overlapped_segments_per_minute < 0: + raise Exception("--max-overlapped-segments-per-minute cannot be negative") + + return args + +def ParseSpeechSegmentsList(speech_segments_set_para_array, smoothing_weight): + set_list = [] + for set_para in speech_segments_set_para_array: + set = lambda: None + setattr(set, "wav_scp", None) + setattr(set, "segments", None) + setattr(set, "probability", None) + parts = set_para.split(',') + if len(parts) == 3: + set.probability = float(parts[0]) + set.wav_scp = parts[1].strip() + set.segments = parts[2].strip() + else: + set.wav_scp = parts[0].strip() + set.segments = parts[1].strip() + if not os.path.isfile(set.wav_scp): + raise Exception(set.wav_scp + " not found") + if not os.path.isfile(set.segments): + raise Exception(set.segments + " not found") + set_list.append(set) + + data_lib.SmoothProbabilityDistribution(set_list) + + segments_list = [] + for segments_set in set_list: + current_segments_list = [] + + wav_dict = {} + for s in open(segments_set.wav_scp): + parts = s.strip().split() + wav_dict[parts[0]] = ' '.join(parts[1:]) + + for s in open(segments_set.segments): + parts = s.strip().split() + current_segment = argparse.Namespace() + current_segment.utt_id = parts[0] + current_segment.probability = None + + start_time = float(parts[2]) + end_time = float(parts[3]) + + current_segment.duration = (end_time - start_time) + + wav_rxfilename = wav_dict[parts[1]] + if wav_rxfilename.split()[-1] == '|': + current_segment.wav_rxfilename = "{0} sox -t wav - -t wav - trim {1} {2} |".format(wav_rxfilename, start_time, end_time - start_time) + else: + current_segment.wav_rxfilename = "sox {0} -t wav - trim {1} {2} |".format(wav_rxfilename, start_time, end_time - start_time) + + current_segments_list.append(current_segment) + + segments_list += data_lib.SmoothProbabilityDistribution(current_segments_list, smoothing_weight, segments_set.probability) + + return segments_list + +def AddOverlappedSpeech(room, # the room selected + speech_segments_list, # the speech list + overlapped_speech_addition_probability, # Probability of another speech waveform + snrs, # the SNR for adding the foreground speech + speech_dur, # duration of the recording + max_overlapped_speech_segments, # Maximum number of speech signals that can be added + overlapped_speech_descriptor # descriptor to store the information of the overlapped speech + ): + if (len(speech_segments_list) > 0 and random.random() < overlapped_speech_addition_probability + and max_overlapped_speech_segments >= 1): + for k in range(random.randint(1, max_overlapped_speech_segments)): + # pick the overlapped speech signal and the RIR to + # reverberate the overlapped speech signal + speech_segment = data_lib.PickItemWithProbability(speech_segments_list) + rir = data_lib.PickItemWithProbability(room.rir_list) + + speech_rvb_command = """wav-reverberate --impulse-response="{0}" --shift-output=true """.format(rir.rir_rspecifier) + overlapped_speech_descriptor['start_times'].append(round(random.random() * speech_dur, 2)) + overlapped_speech_descriptor['snrs'].append(snrs.next()) + overlapped_speech_descriptor['utt_ids'].append(speech_segment.utt_id) + overlapped_speech_descriptor['durations'].append(speech_segment.duration) + + if len(speech_segment.wav_rxfilename.split()) == 1: + overlapped_speech_descriptor['speech_segments'].append("{1} {0} - |".format(speech_segment.wav_rxfilename, speech_rvb_command)) + else: + overlapped_speech_descriptor['speech_segments'].append("{0} {1} - - |".format(speech_segment.wav_rxfilename, speech_rvb_command)) + +# This function randomly decides whether to reverberate, and sample a RIR if it does +# It also decides whether to add the appropriate noises +# This function return the string of options to the binary wav-reverberate +def GenerateReverberationAndOverlappedSpeechOpts( + room_dict, # the room dictionary, please refer to MakeRoomDict() for the format + pointsource_noise_list, # the point source noise list + iso_noise_dict, # the isotropic noise dictionary + foreground_snrs, # the SNR for adding the foreground noises + background_snrs, # the SNR for adding the background noises + speech_segments_list, + overlap_snrs, + speech_rvb_probability, # Probability of reverberating a speech signal + isotropic_noise_addition_probability, # Probability of adding isotropic noises + pointsource_noise_addition_probability, # Probability of adding point-source noises + overlapped_speech_addition_probability, + speech_dur, # duration of the recording + max_noises_recording, # Maximum number of point-source noises that can be added + max_overlapped_segments_recording # Maximum number of overlapped segments that can be added + ): + impulse_response_opts = "" + additive_noise_opts = "" + + noise_addition_descriptor = {'noise_io': [], + 'start_times': [], + 'snrs': []} + # Randomly select the room + # Here the room probability is a sum of the probabilities of the RIRs recorded in the room. + room = data_lib.PickItemWithProbability(room_dict) + # Randomly select the RIR in the room + speech_rir = data_lib.PickItemWithProbability(room.rir_list) + if random.random() < speech_rvb_probability: + # pick the RIR to reverberate the speech + impulse_response_opts = """--impulse-response="{0}" """.format(speech_rir.rir_rspecifier) + + rir_iso_noise_list = [] + if speech_rir.room_id in iso_noise_dict: + rir_iso_noise_list = iso_noise_dict[speech_rir.room_id] + # Add the corresponding isotropic noise associated with the selected RIR + if len(rir_iso_noise_list) > 0 and random.random() < isotropic_noise_addition_probability: + isotropic_noise = data_lib.PickItemWithProbability(rir_iso_noise_list) + # extend the isotropic noise to the length of the speech waveform + # check if it is really a pipe + if len(isotropic_noise.noise_rspecifier.split()) == 1: + noise_addition_descriptor['noise_io'].append("wav-reverberate --duration={1} {0} - |".format(isotropic_noise.noise_rspecifier, speech_dur)) + else: + noise_addition_descriptor['noise_io'].append("{0} wav-reverberate --duration={1} - - |".format(isotropic_noise.noise_rspecifier, speech_dur)) + noise_addition_descriptor['start_times'].append(0) + noise_addition_descriptor['snrs'].append(background_snrs.next()) + + data_lib.AddPointSourceNoise(room, # the room selected + pointsource_noise_list, # the point source noise list + pointsource_noise_addition_probability, # Probability of adding point-source noises + foreground_snrs, # the SNR for adding the foreground noises + background_snrs, # the SNR for adding the background noises + speech_dur, # duration of the recording + max_noises_recording, # Maximum number of point-source noises that can be added + noise_addition_descriptor # descriptor to store the information of the noise added + ) + + assert len(noise_addition_descriptor['noise_io']) == len(noise_addition_descriptor['start_times']) + assert len(noise_addition_descriptor['noise_io']) == len(noise_addition_descriptor['snrs']) + + overlapped_speech_descriptor = {'speech_segments': [], + 'start_times': [], + 'snrs': [], + 'utt_ids': [], + 'durations': [] + } + + AddOverlappedSpeech(room, + speech_segments_list, # speech segments list + overlapped_speech_addition_probability, + overlap_snrs, + speech_dur, + max_overlapped_segments_recording, + overlapped_speech_descriptor + ) + + if len(overlapped_speech_descriptor['speech_segments']) > 0: + noise_addition_descriptor['noise_io'] += overlapped_speech_descriptor['speech_segments'] + noise_addition_descriptor['start_times'] += overlapped_speech_descriptor['start_times'] + noise_addition_descriptor['snrs'] += overlapped_speech_descriptor['snrs'] + + if len(noise_addition_descriptor['noise_io']) > 0: + additive_noise_opts += "--additive-signals='{0}' ".format(','.join(noise_addition_descriptor['noise_io'])) + additive_noise_opts += "--start-times='{0}' ".format(','.join(map(lambda x:str(x), noise_addition_descriptor['start_times']))) + additive_noise_opts += "--snrs='{0}' ".format(','.join(map(lambda x:str(x), noise_addition_descriptor['snrs']))) + + return [impulse_response_opts, additive_noise_opts, + zip(overlapped_speech_descriptor['utt_ids'], [ str(x) for x in overlapped_speech_descriptor['start_times'] ], [ str(x) for x in overlapped_speech_descriptor['durations'] ])] + +# This is the main function to generate pipeline command for the corruption +# The generic command of wav-reverberate will be like: +# wav-reverberate --duration=t --impulse-response=rir.wav +# --additive-signals='noise1.wav,noise2.wav' --snrs='snr1,snr2' --start-times='s1,s2' input.wav output.wav +def GenerateReverberatedWavScpWithOverlappedSpeech( + wav_scp, # a dictionary whose values are the Kaldi-IO strings of the speech recordings + durations, # a dictionary whose values are the duration (in sec) of the speech recordings + output_dir, # output directory to write the corrupted wav.scp + room_dict, # the room dictionary, please refer to MakeRoomDict() for the format + pointsource_noise_list, # the point source noise list + iso_noise_dict, # the isotropic noise dictionary + foreground_snr_array, # the SNR for adding the foreground noises + background_snr_array, # the SNR for adding the background noises + speech_segments_list, # list of speech segments to create overlapped speech + overlap_snr_array, # the SNR for adding overlapped speech + num_replicas, # Number of replicate to generated for the data + prefix, # prefix for the id of the corrupted utterances + speech_rvb_probability, # Probability of reverberating a speech signal + shift_output, # option whether to shift the output waveform + isotropic_noise_addition_probability, # Probability of adding isotropic noises + pointsource_noise_addition_probability, # Probability of adding point-source noises + max_noises_per_minute, # maximum number of point-source noises that can be added to a recording according to its duration + overlapped_speech_addition_probability, + max_overlapped_segments_per_minute, + output_reverb_dir = None, + output_additive_noise_dir = None + ): + foreground_snrs = data_lib.list_cyclic_iterator(foreground_snr_array) + background_snrs = data_lib.list_cyclic_iterator(background_snr_array) + overlap_snrs = data_lib.list_cyclic_iterator(overlap_snr_array) + + + corrupted_wav_scp = {} + reverb_wav_scp = {} + additive_noise_wav_scp = {} + overlapped_segments_info = {} + + keys = wav_scp.keys() + keys.sort() + for i in range(1, num_replicas+1): + for recording_id in keys: + wav_original_pipe = wav_scp[recording_id] + # check if it is really a pipe + if len(wav_original_pipe.split()) == 1: + wav_original_pipe = "cat {0} |".format(wav_original_pipe) + speech_dur = durations[recording_id] + max_noises_recording = math.floor(max_noises_per_minute * speech_dur / 60) + max_overlapped_segments_recording = math.floor(max_overlapped_segments_per_minute * speech_dur / 60) + + [impulse_response_opts, + additive_noise_opts, + overlapped_speech_segments] = GenerateReverberationAndOverlappedSpeechOpts( + room_dict = room_dict, # the room dictionary, please refer to MakeRoomDict() for the format + pointsource_noise_list = pointsource_noise_list, # the point source noise list + iso_noise_dict = iso_noise_dict, # the isotropic noise dictionary + foreground_snrs = foreground_snrs, # the SNR for adding the foreground noises + background_snrs = background_snrs, # the SNR for adding the background noises + speech_segments_list = speech_segments_list, # Speech segments for creating overlapped speech + overlap_snrs = overlap_snrs, # the SNR for adding overlapped speech + speech_rvb_probability = speech_rvb_probability, # Probability of reverberating a speech signal + isotropic_noise_addition_probability = isotropic_noise_addition_probability, # Probability of adding isotropic noises + pointsource_noise_addition_probability = pointsource_noise_addition_probability, # Probability of adding point-source noises + overlapped_speech_addition_probability = overlapped_speech_addition_probability, + speech_dur = speech_dur, # duration of the recording + max_noises_recording = max_noises_recording, # Maximum number of point-source noises that can be added + max_overlapped_segments_recording = max_overlapped_segments_recording + ) + reverberate_opts = impulse_response_opts + additive_noise_opts + + new_recording_id = data_lib.GetNewId(recording_id, prefix, i) + + if reverberate_opts == "": + wav_corrupted_pipe = "{0}".format(wav_original_pipe) + else: + wav_corrupted_pipe = "{0} wav-reverberate --shift-output={1} {2} - - |".format(wav_original_pipe, shift_output, reverberate_opts) + + corrupted_wav_scp[new_recording_id] = wav_corrupted_pipe + + if output_reverb_dir is not None: + if impulse_response_opts == "": + wav_reverb_pipe = "{0}".format(wav_original_pipe) + else: + wav_reverb_pipe = "{0} wav-reverberate --shift-output={1} --reverb-out-wxfilename=- {2} - /dev/null |".format(wav_original_pipe, shift_output, reverberate_opts) + reverb_wav_scp[new_recording_id] = wav_reverb_pipe + + if output_additive_noise_dir is not None: + if additive_noise_opts != "": + wav_additive_noise_pipe = "{0} wav-reverberate --shift-output={1} --additive-noise-out-wxfilename=- {2} - /dev/null |".format(wav_original_pipe, shift_output, reverberate_opts) + additive_noise_wav_scp[new_recording_id] = wav_additive_noise_pipe + + if len(overlapped_speech_segments) > 0: + overlapped_segments_info[new_recording_id] = [ ':'.join(x) for x in overlapped_speech_segments ] + + data_lib.WriteDictToFile(corrupted_wav_scp, output_dir + "/wav.scp") + + # Write for each new recording, the utterance id of the segments and + # the start time at which they are added + data_lib.WriteDictToFile(overlapped_segments_info, output_dir + "/overlapped_segments_info.txt") + + if output_reverb_dir is not None: + data_lib.WriteDictToFile(reverb_wav_scp, output_reverb_dir + "/wav.scp") + + if output_additive_noise_dir is not None: + data_lib.WriteDictToFile(additive_noise_wav_scp, output_additive_noise_dir + "/wav.scp") + + +# This function creates multiple copies of the necessary files, e.g. utt2spk, wav.scp ... +def CreateReverberatedCopy(input_dir, + output_dir, + room_dict, # the room dictionary, please refer to MakeRoomDict() for the format + pointsource_noise_list, # the point source noise list + iso_noise_dict, # the isotropic noise dictionary + speech_segments_list, + foreground_snr_string, # the SNR for adding the foreground noises + background_snr_string, # the SNR for adding the background noises + overlap_snr_string, # the SNR for overlapped speech + num_replicas, # Number of replicate to generated for the data + prefix, # prefix for the id of the corrupted utterances + speech_rvb_probability, # Probability of reverberating a speech signal + shift_output, # option whether to shift the output waveform + isotropic_noise_addition_probability, # Probability of adding isotropic noises + pointsource_noise_addition_probability, # Probability of adding point-source noises + max_noises_per_minute, # maximum number of point-source noises that can be added to a recording according to its duration + overlapped_speech_addition_probability, + max_overlapped_segments_per_minute, + output_reverb_dir = None, + output_additive_noise_dir = None + ): + + wav_scp = data_lib.ParseFileToDict(input_dir + "/wav.scp", value_processor = lambda x: " ".join(x)) + if not os.path.isfile(input_dir + "/reco2dur"): + print("Getting the duration of the recordings..."); + read_entire_file="false" + for value in wav_scp.values(): + # we will add more checks for sox commands which modify the header as we come across these cases in our data + if "sox" in value and "speed" in value: + read_entire_file="true" + break + data_lib.RunKaldiCommand("wav-to-duration --read-entire-file={1} scp:{0}/wav.scp ark,t:{0}/reco2dur".format(input_dir, read_entire_file)) + durations = data_lib.ParseFileToDict(input_dir + "/reco2dur", value_processor = lambda x: float(x[0])) + foreground_snr_array = map(lambda x: float(x), foreground_snr_string.split(':')) + background_snr_array = map(lambda x: float(x), background_snr_string.split(':')) + overlap_snr_array = map(lambda x: float(x), overlap_snr_string.split(':')) + + GenerateReverberatedWavScpWithOverlappedSpeech( + wav_scp = wav_scp, + durations = durations, + output_dir = output_dir, + room_dict = room_dict, + pointsource_noise_list = pointsource_noise_list, + iso_noise_dict = iso_noise_dict, + foreground_snr_array = foreground_snr_array, + background_snr_array = background_snr_array, + speech_segments_list = speech_segments_list, + overlap_snr_array = overlap_snr_array, + num_replicas = num_replicas, prefix = prefix, + speech_rvb_probability = speech_rvb_probability, + shift_output = shift_output, + isotropic_noise_addition_probability = isotropic_noise_addition_probability, + pointsource_noise_addition_probability = pointsource_noise_addition_probability, + max_noises_per_minute = max_noises_per_minute, + overlapped_speech_addition_probability = overlapped_speech_addition_probability, + max_overlapped_segments_per_minute = max_overlapped_segments_per_minute, + output_reverb_dir = output_reverb_dir, + output_additive_noise_dir = output_additive_noise_dir) + + data_lib.CopyDataDirFiles(input_dir, output_dir, num_replicas, prefix) + data_lib.AddPrefixToFields(input_dir + "/reco2dur", output_dir + "/reco2dur", num_replicas, prefix, field = [0]) + + if output_reverb_dir is not None: + data_lib.CopyDataDirFiles(input_dir, output_reverb_dir, num_replicas, prefix) + data_lib.AddPrefixToFields(input_dir + "/reco2dur", output_reverb_dir + "/reco2dur", num_replicas, prefix, field = [0]) + + if output_additive_noise_dir is not None: + data_lib.CopyDataDirFiles(input_dir, output_additive_noise_dir, num_replicas, prefix) + data_lib.AddPrefixToFields(input_dir + "/reco2dur", output_additive_noise_dir + "/reco2dur", num_replicas, prefix, field = [0]) + + +def Main(): + args = GetArgs() + random.seed(args.random_seed) + rir_list = data_lib.ParseRirList(args.rir_set_para_array, args.rir_smoothing_weight) + print("Number of RIRs is {0}".format(len(rir_list))) + pointsource_noise_list = [] + iso_noise_dict = {} + if args.noise_set_para_array is not None: + pointsource_noise_list, iso_noise_dict = data_lib.ParseNoiseList(args.noise_set_para_array, args.noise_smoothing_weight) + print("Number of point-source noises is {0}".format(len(pointsource_noise_list))) + print("Number of isotropic noises is {0}".format(sum(len(iso_noise_dict[key]) for key in iso_noise_dict.keys()))) + room_dict = data_lib.MakeRoomDict(rir_list) + + speech_segments_list = ParseSpeechSegmentsList(args.speech_segments_set_para_array, args.overlapped_speech_smoothing_weight) + + CreateReverberatedCopy(input_dir = args.input_dir, + output_dir = args.output_dir, + room_dict = room_dict, + pointsource_noise_list = pointsource_noise_list, + iso_noise_dict = iso_noise_dict, + speech_segments_list = speech_segments_list, + foreground_snr_string = args.foreground_snr_string, + background_snr_string = args.background_snr_string, + overlap_snr_string = args.overlap_snr_string, + num_replicas = args.num_replicas, + prefix = args.prefix, + speech_rvb_probability = args.speech_rvb_probability, + shift_output = args.shift_output, + isotropic_noise_addition_probability = args.isotropic_noise_addition_probability, + pointsource_noise_addition_probability = args.pointsource_noise_addition_probability, + max_noises_per_minute = args.max_noises_per_minute, + overlapped_speech_addition_probability = args.overlapped_speech_addition_probability, + max_overlapped_segments_per_minute = args.max_overlapped_segments_per_minute, + output_reverb_dir = args.output_reverb_dir, + output_additive_noise_dir = args.output_additive_noise_dir) + +if __name__ == "__main__": + Main() + + From be41b741e9f034d2eda68263aea4bc1845495dfd Mon Sep 17 00:00:00 2001 From: Vimal Manohar Date: Thu, 24 Nov 2016 00:12:19 -0500 Subject: [PATCH 077/213] asr_diarization: Added do_corruption_data_dir_overlapped_speech.sh --- ...o_corruption_data_dir_overlapped_speech.sh | 270 ++++++++++++++++++ 1 file changed, 270 insertions(+) create mode 100644 egs/aspire/s5/local/segmentation/do_corruption_data_dir_overlapped_speech.sh diff --git a/egs/aspire/s5/local/segmentation/do_corruption_data_dir_overlapped_speech.sh b/egs/aspire/s5/local/segmentation/do_corruption_data_dir_overlapped_speech.sh new file mode 100644 index 00000000000..f387acb8552 --- /dev/null +++ b/egs/aspire/s5/local/segmentation/do_corruption_data_dir_overlapped_speech.sh @@ -0,0 +1,270 @@ +#!/bin/bash +set -e +set -u +set -o pipefail + +. path.sh +. cmd.sh + +num_data_reps=5 +data_dir=data/train_si284 +whole_data_dir=data/train_si284_whole + +nj=40 +reco_nj=40 + +stage=0 +corruption_stage=-10 + +pad_silence=false + +mfcc_config=conf/mfcc_hires_bp_vh.conf +energy_config=conf/log_energy.conf + +dry_run=false +corrupt_only=false +speed_perturb=true + +reco_vad_dir= +utt_vad_dir= + +max_jobs_run=20 + +overlap_snrs="5:2:1:0:-1:-2" +base_rirs=simulated + +. utils/parse_options.sh + +if [ $# -ne 0 ]; then + echo "Usage: $0" + exit 1 +fi + +rvb_opts=() +# This is the config for the system using simulated RIRs and point-source noises +rvb_opts+=(--rir-set-parameters "0.5, RIRS_NOISES/simulated_rirs/smallroom/rir_list") +rvb_opts+=(--rir-set-parameters "0.5, RIRS_NOISES/simulated_rirs/mediumroom/rir_list") +rvb_opts+=(--speech-segments-set-parameters="$data_dir/wav.scp,$data_dir/segments") + +whole_data_id=`basename ${whole_data_dir}` + +corrupted_data_id=${whole_data_id}_ovlp_corrupted +clean_data_id=${whole_data_id}_ovlp_clean +noise_data_id=${whole_data_id}_ovlp_noise + +if [ $stage -le 2 ]; then + python steps/data/make_corrupted_data_dir.py \ + "${rvb_opts[@]}" \ + --prefix="ovlp" \ + --overlap-snrs=$overlap_snrs \ + --speech-rvb-probability=1 \ + --overlapping-speech-addition-probability=1 \ + --num-replications=$num_data_reps \ + --min-overlapping-segments-per-minute=5 \ + --max-overlapping-segments-per-minute=20 \ + --output-additive-noise-dir=data/${noise_data_id} \ + --output-reverb-dir=data/${clean_data_id} \ + data/${whole_data_id} data/${corrupted_data_id} +fi + +if $dry_run; then + exit 0 +fi + +clean_data_dir=data/${clean_data_id} +corrupted_data_dir=data/${corrupted_data_id} +noise_data_dir=data/${noise_data_id} +orig_corrupted_data_dir=$corrupted_data_dir + +if $speed_perturb; then + if [ $stage -le 3 ]; then + ## Assuming whole data directories + for x in $clean_data_dir $corrupted_data_dir $noise_data_dir; do + cp $x/reco2dur $x/utt2dur + utils/data/perturb_data_dir_speed_3way.sh $x ${x}_sp + done + fi + + corrupted_data_dir=${corrupted_data_dir}_sp + clean_data_dir=${clean_data_dir}_sp + noise_data_dir=${noise_data_dir}_sp + + corrupted_data_id=${corrupted_data_id}_sp + clean_data_id=${clean_data_id}_sp + noise_data_id=${noise_data_id}_sp + + if [ $stage -le 4 ]; then + utils/data/perturb_data_dir_volume.sh --force true ${corrupted_data_dir} + utils/data/perturb_data_dir_volume.sh --force true --reco2vol ${corrupted_data_dir}/reco2vol ${clean_data_dir} + utils/data/perturb_data_dir_volume.sh --force true --reco2vol ${corrupted_data_dir}/reco2vol ${noise_data_dir} + fi +fi + +if $corrupt_only; then + echo "$0: Got corrupted data directory in ${corrupted_data_dir}" + exit 0 +fi + +mfccdir=`basename $mfcc_config` +mfccdir=${mfccdir%%.conf} + +if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $mfccdir/storage ]; then + utils/create_split_dir.pl \ + /export/b0{3,4,5,6}/$USER/kaldi-data/egs/aspire-$(date +'%m_%d_%H_%M')/s5/$mfccdir/storage $mfccdir/storage +fi + +if [ $stage -le 5 ]; then + steps/make_mfcc.sh --mfcc-config $mfcc_config \ + --cmd "$train_cmd" --nj $reco_nj \ + $corrupted_data_dir exp/make_hires_bp/${corrupted_data_id} $mfccdir +fi + +if [ $stage -le 6 ]; then + steps/make_mfcc.sh --mfcc-config $energy_config \ + --cmd "$train_cmd" --nj $reco_nj \ + $clean_data_dir exp/make_log_energy/${clean_data_id} log_energy_feats +fi + +if [ $stage -le 7 ]; then + steps/make_mfcc.sh --mfcc-config $energy_config \ + --cmd "$train_cmd" --nj $reco_nj \ + $noise_data_dir exp/make_log_energy/${noise_data_id} log_energy_feats +fi + +if [ -z "$reco_vad_dir" ]; then + echo "reco-vad-dir must be provided" + exit 1 +fi + +targets_dir=irm_targets +if [ $stage -le 8 ]; then + mkdir -p exp/make_irm_targets/${corrupted_data_id} + + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $targets_dir/storage ]; then + utils/create_split_dir.pl \ + /export/b0{3,4,5,6}/$USER/kaldi-data/egs/aspire-$(date +'%m_%d_%H_%M')/s5/$targets_dir/storage $targets_dir/storage + fi + + steps/segmentation/make_snr_targets.sh \ + --nj $nj --cmd "$train_cmd --max-jobs-run $max_jobs_run" \ + --target-type Irm --compress true --apply-exp false \ + ${clean_data_dir} ${noise_data_dir} ${corrupted_data_dir} \ + exp/make_irm_targets/${corrupted_data_id} $targets_dir +fi + +# Data dirs without speed perturbation +overlap_dir=exp/make_overlap_labels/${corrupted_data_id} +unreliable_dir=exp/make_overlap_labels/unreliable_${corrupted_data_id} +overlap_data_dir=$overlap_dir/overlap_data +unreliable_data_dir=$overlap_dir/unreliable_data + +mkdir -p $unreliable_dir + +if [ $stage -le 8 ]; then + cat $reco_vad_dir/sad_seg.scp | \ + steps/segmentation/get_reverb_scp.pl -f 1 $num_data_reps "ovlp" \ + | sort -k1,1 > ${corrupted_data_dir}/sad_seg.scp + utils/data/get_utt2num_frames.sh $corrupted_data_dir + utils/split_data.sh --per-reco ${orig_corrupted_data_dir} $reco_nj + + # Combine the VAD from the base recording and the VAD from the overlapping segments + # to create per-frame labels of the number of overlapping speech segments + # Unreliable segments are regions where no VAD labels were available for the + # overlapping segments. These can be later removed by setting deriv weights to 0. + $train_cmd JOB=1:$reco_nj $overlap_dir/log/get_overlap_seg.JOB.log \ + segmentation-init-from-overlap-info --lengths-rspecifier=ark,t:$corrupted_data_dir/utt2num_frames \ + "scp:utils/filter_scp.pl ${orig_corrupted_data_dir}/split${reco_nj}reco/JOB/utt2spk $corrupted_data_dir/sad_seg.scp |" \ + ark,t:$orig_corrupted_data_dir/overlapped_segments_info.txt \ + scp:$utt_vad_dir/sad_seg.scp ark:- ark:$unreliable_dir/unreliable_seg_speed_unperturbed.JOB.ark \| \ + segmentation-copy --keep-label=1 ark:- ark:- \| \ + segmentation-get-stats --lengths-rspecifier=ark,t:$corrupted_data_dir/utt2num_frames \ + ark:- ark:- ark:/dev/null \| \ + segmentation-init-from-ali ark:- ark:$overlap_dir/overlap_seg_speed_unperturbed.JOB.ark +fi + +if [ $stage -le 9 ]; then + mkdir -p $overlap_data_dir $unreliable_data_dir + cp $orig_corrupted_data_dir/wav.scp $overlap_data_dir + cp $orig_corrupted_data_dir/wav.scp $unreliable_data_dir + + # Create segments where there is definitely an overlap. + $train_cmd JOB=1:$reco_nj $overlap_dir/log/process_to_segments.JOB.log \ + segmentation-post-process --remove-labels=0:1 \ + ark:$overlap_dir/overlap_seg_speed_unperturbed.JOB.ark ark:- \| \ + segmentation-post-process --merge-labels=2:3:4:5:6:7:8:9:10 --merge-dst-label=1 ark:- ark:- \| \ + segmentation-to-segments ark:- ark:$overlap_data_dir/utt2spk.JOB $overlap_data_dir/segments.JOB + + $train_cmd JOB=1:$reco_nj $overlap_dir/log/get_unreliable_segments.JOB.log \ + segmentation-to-segments --single-speaker \ + ark:$unreliable_dir/unreliable_seg_speed_unperturbed.JOB.ark \ + ark:$unreliable_data_dir/utt2spk.JOB $unreliable_data_dir/segments.JOB + + for n in `seq $reco_nj`; do cat $overlap_data_dir/utt2spk.$n; done > $overlap_data_dir/utt2spk + for n in `seq $reco_nj`; do cat $overlap_data_dir/segments.$n; done > $overlap_data_dir/segments + for n in `seq $reco_nj`; do cat $unreliable_data_dir/utt2spk.$n; done > $unreliable_data_dir/utt2spk + for n in `seq $reco_nj`; do cat $unreliable_data_dir/segments.$n; done > $unreliable_data_dir/segments + + utils/fix_data_dir.sh $overlap_data_dir + utils/fix_data_dir.sh $unreliable_data_dir + + if $speed_perturb; then + utils/data/perturb_data_dir_speed_3way.sh $overlap_data_dir ${overlap_data_dir}_sp + utils/data/perturb_data_dir_speed_3way.sh $unreliable_data_dir ${unreliable_data_dir}_sp + fi +fi + +if $speed_perturb; then + overlap_data_dir=${overlap_data_dir}_sp + unreliable_data_dir=${unreliable_data_dir}_sp +fi + +if [ $stage -le 10 ]; then + utils/split_data.sh --per-reco ${overlap_data_dir} $reco_nj + + $train_cmd JOB=1:$reco_nj $overlap_dir/log/get_overlap_speech_labels.JOB.log \ + utils/data/get_reco2utt.sh ${overlap_data_dir}/split${reco_nj}reco/JOB '&&' \ + segmentation-init-from-segments --shift-to-zero=false \ + ${overlap_data_dir}/split${reco_nj}reco/JOB/segments ark:- \| \ + segmentation-combine-segments-to-recordings ark:- ark,t:${overlap_data_dir}/split${reco_nj}reco/JOB/reco2utt \ + ark:- \| \ + segmentation-to-ali --lengths-rspecifier=ark,t:${corrupted_data_dir}/utt2num_frames ark:- \ + ark,scp:overlap_labels/overlapped_speech_${corrupted_data_id}.JOB.ark,overlap_labels/overlapped_speech_${corrupted_data_id}.JOB.scp +fi + +for n in `seq $reco_nj`; do + cat overlap_labels/overlapped_speech_${corrupted_data_id}.$n.scp +done > ${corrupted_data_dir}/overlapped_speech_labels.scp + +if [ $stage -le 11 ]; then + utils/data/get_reco2utt.sh ${unreliable_data_dir} + + # First convert the unreliable segments into a recording-level segmentation. + # Initialize a segmentation from utt2num_frames and set to 0, the regions + # of unreliable segments. At this stage deriv weights is 1 for all but the + # unreliable segment regions. + # Initialize a segmentation from the VAD labels and retain only the speech segments. + # Intersect this with the deriv weights segmentation from above. At this stage + # deriv weights is 1 for only the regions where base VAD label is 1 and + # the overlapping segment is not unreliable. Convert this to deriv weights. + $train_cmd JOB=1:$reco_nj $unreliable_dir/log/get_deriv_weights.JOB.log\ + segmentation-init-from-segments --shift-to-zero=false \ + "utils/filter_scp.pl -f 2 ${overlap_data_dir}/split${reco_nj}reco/JOB/reco2utt ${unreliable_data_dir}/segments |" ark:- \| \ + segmentation-combine-segments-to-recordings ark:- "ark,t:utils/filter_scp.pl ${overlap_data_dir}/split${reco_nj}reco/JOB/reco2utt ${unreliable_data_dir}/reco2utt |" \ + ark:- \| \ + segmentation-create-subsegments --filter-label=1 --subsegment-label=0 --ignore-missing \ + "ark:utils/filter_scp.pl ${overlap_data_dir}/split${reco_nj}reco/JOB/reco2utt $corrupted_data_dir/utt2num_frames | segmentation-init-from-lengths ark,t:- ark:- |" \ + ark:- ark:- \| \ + segmentation-intersect-segments --mismatch-label=0 \ + "ark:utils/filter_scp.pl ${overlap_data_dir}/split${reco_nj}reco/JOB/reco2utt $corrupted_data_dir/sad_seg.scp | segmentation-post-process --remove-labels=0:2:3 scp:- ark:- |" \ + ark:- ark:- \| \ + segmentation-post-process --remove-labels=0 ark:- ark:- \| \ + segmentation-to-ali --lengths-rspecifier=ark,t:${corrupted_data_dir}/utt2num_frames ark:- ark,t:- \| \ + steps/segmentation/convert_ali_to_vec.pl \| copy-vector ark,t:- \ + ark,scp:$unreliable_dir/deriv_weights_for_overlapped_speech.JOB.ark,$unreliable_dir/deriv_weights_for_overlapped_speech.JOB.scp + + for n in `seq $reco_nj`; do + cat $unreliable_dir/deriv_weights_for_overlapped_speech.${n}.scp + done > $corrupted_data_dir/deriv_weights_for_overlapped_speech.scp +fi + +exit 0 From a1be1dd02795ac51b6ccd03a21fa084a553e68ba Mon Sep 17 00:00:00 2001 From: Vimal Manohar Date: Thu, 24 Nov 2016 00:12:46 -0500 Subject: [PATCH 078/213] asr_diarization: Added train_sad_ovlp{,_prob}.sh --- .../s5/local/segmentation/train_sad_ovlp.sh | 144 +++++++++++++++++ .../local/segmentation/train_sad_ovlp_prob.sh | 145 ++++++++++++++++++ 2 files changed, 289 insertions(+) create mode 100644 egs/aspire/s5/local/segmentation/train_sad_ovlp.sh create mode 100644 egs/aspire/s5/local/segmentation/train_sad_ovlp_prob.sh diff --git a/egs/aspire/s5/local/segmentation/train_sad_ovlp.sh b/egs/aspire/s5/local/segmentation/train_sad_ovlp.sh new file mode 100644 index 00000000000..2d553875db0 --- /dev/null +++ b/egs/aspire/s5/local/segmentation/train_sad_ovlp.sh @@ -0,0 +1,144 @@ +#!/bin/bash + +# this is the standard "tdnn" system, built in nnet3; it's what we use to +# call multi-splice. + +set -o pipefail +set -e +set -u + +. cmd.sh + +# At this script level we don't support not running on GPU, as it would be painfully slow. +# If you want to run without GPU you'd have to call train_tdnn.sh with --gpu false, +# --num-threads 16 and --minibatch-size 128. + +stage=0 +train_stage=-10 +get_egs_stage=-10 +egs_opts= + +splice_indexes="-3,-2,-1,0,1,2,3 -6,0 -9,0,3 0" +relu_dim=256 + +# training options +num_epochs=2 +initial_effective_lrate=0.0003 +final_effective_lrate=0.00003 +num_jobs_initial=3 +num_jobs_final=8 +remove_egs=false +max_param_change=1 +extra_egs_copy_cmd= + +num_utts_subset_valid=40 +num_utts_subset_train=40 +add_idct=true + +# target options +train_data_dir=data/train_azteec_unsad_music_whole_sp_multi_lessreverb_1k_hires + +snr_scp= +speech_feat_scp= +overlapped_speech_labels_scp= + +deriv_weights_scp= +deriv_weights_for_overlapped_speech_scp= + +egs_dir= +nj=40 +feat_type=raw +config_dir= +compute_objf_opts= + +dir= +affix=a + +. cmd.sh +. ./path.sh +. ./utils/parse_options.sh + +num_hidden_layers=`echo $splice_indexes | perl -ane 'print scalar @F'` || exit 1 +if [ -z "$dir" ]; then + dir=exp/nnet3_sad_ovlp_snr/nnet_tdnn +fi + +dir=$dir${affix:+_$affix}_n${num_hidden_layers} + +if ! cuda-compiled; then + cat < Date: Tue, 22 Nov 2016 12:44:55 -0500 Subject: [PATCH 079/213] asr_diarization: New copy-egs-overlap-detection in nnet3bin/Makefile --- src/nnet3bin/Makefile | 3 +- .../nnet3-copy-egs-overlap-detection.cc | 187 ++++++++++++++++++ 2 files changed, 189 insertions(+), 1 deletion(-) create mode 100644 src/nnet3bin/nnet3-copy-egs-overlap-detection.cc diff --git a/src/nnet3bin/Makefile b/src/nnet3bin/Makefile index aeb3dc1dc03..2a660da232c 100644 --- a/src/nnet3bin/Makefile +++ b/src/nnet3bin/Makefile @@ -17,7 +17,8 @@ BINFILES = nnet3-init nnet3-info nnet3-get-egs nnet3-copy-egs nnet3-subset-egs \ nnet3-discriminative-merge-egs nnet3-discriminative-shuffle-egs \ nnet3-discriminative-compute-objf nnet3-discriminative-train \ discriminative-get-supervision nnet3-discriminative-subset-egs \ - nnet3-discriminative-compute-from-egs nnet3-get-egs-multiple-targets + nnet3-discriminative-compute-from-egs nnet3-get-egs-multiple-targets \ + nnet3-copy-egs-overlap-detection OBJFILES = diff --git a/src/nnet3bin/nnet3-copy-egs-overlap-detection.cc b/src/nnet3bin/nnet3-copy-egs-overlap-detection.cc new file mode 100644 index 00000000000..3f180a6393e --- /dev/null +++ b/src/nnet3bin/nnet3-copy-egs-overlap-detection.cc @@ -0,0 +1,187 @@ +// nnet3bin/nnet3-copy-egs.cc + +// Copyright 2012-2015 Johns Hopkins University (author: Daniel Povey) +// 2014 Vimal Manohar + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "base/kaldi-common.h" +#include "util/common-utils.h" +#include "hmm/transition-model.h" +#include "nnet3/nnet-example.h" +#include "nnet3/nnet-example-utils.h" + +int main(int argc, char *argv[]) { + try { + using namespace kaldi; + using namespace kaldi::nnet3; + typedef kaldi::int32 int32; + typedef kaldi::int64 int64; + + const char *usage = + "Copy examples (single frames or fixed-size groups of frames) for neural\n" + "network training, possibly changing the binary mode. Supports multiple wspecifiers, in\n" + "which case it will write the examples round-robin to the outputs.\n" + "\n" + "Usage: nnet3-copy-egs [options] \n" + "\n" + "e.g.\n" + "nnet3-copy-egs ark:train.egs ark,t:text.egs\n" + "or:\n" + "nnet3-copy-egs ark:train.egs ark:1.egs\n"; + + ParseOptions po(usage); + + bool add_silence_output = true; + bool add_speech_output = true; + int32 srand_seed = 0; + + std::string keep_proportion_positive_rxfilename; + std::string keep_proportion_negative_rxfilename; + + po.Register("add-silence-output", &add_silence_output, + "Add silence output"); + po.Register("add-speech-output", &add_speech_output, + "Add speech output"); + po.Register("srand", &srand_seed, "Seed for random number generator " + "(only relevant if --keep-proportion-vec is specified"); + po.Register("keep-proportion-positive-vec", &keep_proportion_positive_rxfilename, + "If a dimension of this is <1.0, this program will " + "randomly set deriv weight 0 for this proportion of the input samples of the " + "corresponding positive examples"); + po.Register("keep-proportion-negative-vec", &keep_proportion_negative_rxfilename, + "If a dimension of this is <1.0, this program will " + "randomly set deriv weight 0 for this proportion of the input samples of the " + "corresponding negative examples"); + + Vector p_positive_vec(3); + p_positive_vec.Set(1); + if (!keep_proportion_positive_rxfilename.empty()) + ReadKaldiObject(keep_proportion_positive_rxfilename, &p_positive_vec); + + Vector p_negative_vec(3); + p_negative_vec.Set(1); + if (!keep_proportion_negative_rxfilename.empty()) + ReadKaldiObject(keep_proportion_negative_rxfilename, &p_negative_vec); + + po.Read(argc, argv); + + if (po.NumArgs() != 2) { + po.PrintUsage(); + exit(1); + } + + std::string examples_rspecifier = po.GetArg(1); + std::string examples_wspecifier = po.GetArg(2); + + SequentialNnetExampleReader example_reader(examples_rspecifier); + NnetExampleWriter example_writer(examples_wspecifier); + + int64 num_read = 0, num_written = 0; + for (; !example_reader.Done(); example_reader.Next(), num_read++) { + std::string key = example_reader.Key(); + NnetExample eg = example_reader.Value(); + + KALDI_ASSERT(eg.io.size() == 2); + NnetIo &io = eg.io[1]; + + KALDI_ASSERT(io.name == "output"); + + NnetIo silence_output(io); + silence_output.name = "output-silence"; + + NnetIo speech_output(io); + speech_output.name = "output-speech"; + + NnetIo overlap_speech_output(io); + overlap_speech_output.name = "output-overlap_speech"; + + io.features.Uncompress(); + + KALDI_ASSERT(io.features.Type() == kFullMatrix); + const Matrix &feats = io.features.GetFullMatrix(); + + typedef std::vector > SparseVec; + std::vector silence_post(feats.NumRows(), SparseVec()); + std::vector speech_post(feats.NumRows(), SparseVec()); + std::vector overlap_speech_post(feats.NumRows(), SparseVec()); + + Vector silence_deriv_weights(feats.NumRows()); + Vector speech_deriv_weights(feats.NumRows()); + Vector overlap_speech_deriv_weights(feats.NumRows()); + + for (int32 i = 0; i < feats.NumRows(); i++) { + if (feats(i,0) < 0.5) { + silence_deriv_weights(i) = WithProb(p_negative_vec(0)) ? 1.0 : 0.0; + silence_post[i].push_back(std::make_pair(0, 1)); + } else { + silence_deriv_weights(i) = WithProb(p_positive_vec(0)) ? 1.0 : 0.0; + silence_post[i].push_back(std::make_pair(1, 1)); + } + + if (feats(i,1) < 0.5) { + speech_deriv_weights(i) = WithProb(p_negative_vec(1)) ? 1.0 : 0.0; + speech_post[i].push_back(std::make_pair(0, 1)); + } else { + speech_deriv_weights(i) = WithProb(p_positive_vec(1)) ? 1.0 : 0.0; + speech_post[i].push_back(std::make_pair(1, 1)); + } + + if (feats(i,2) < 0.5) { + overlap_speech_deriv_weights(i) = WithProb(p_negative_vec(2)) ? 1.0 : 0.0; + overlap_speech_post[i].push_back(std::make_pair(0, 1)); + } else { + overlap_speech_deriv_weights(i) = WithProb(p_positive_vec(2)) ? 1.0 : 0.0; + overlap_speech_post[i].push_back(std::make_pair(1, 1)); + } + } + + SparseMatrix silence_feats(2, silence_post); + SparseMatrix speech_feats(2, speech_post); + SparseMatrix overlap_speech_feats(2, overlap_speech_post); + + silence_output.features = silence_feats; + speech_output.features = speech_feats; + overlap_speech_output.features = overlap_speech_feats; + + io = overlap_speech_output; + io.deriv_weights.MulElements(overlap_speech_deriv_weights); + + if (add_silence_output) { + silence_output.deriv_weights.MulElements(silence_deriv_weights); + eg.io.push_back(silence_output); + } + + if (add_speech_output) { + speech_output.deriv_weights.MulElements(speech_deriv_weights); + eg.io.push_back(speech_output); + } + + example_writer.Write(key, eg); + num_written++; + } + + KALDI_LOG << "Read " << num_read << " neural-network training examples, wrote " + << num_written; + return (num_written == 0 ? 1 : 0); + } catch(const std::exception &e) { + std::cerr << e.what() << '\n'; + return -1; + } +} + + + From 25a21f9c60edf3186fc7d38ca829afb395d2f153 Mon Sep 17 00:00:00 2001 From: Vimal Manohar Date: Sat, 10 Dec 2016 14:30:27 -0500 Subject: [PATCH 080/213] Modifying do_corruption_data_dir_overlapped_speech.sh --- ...o_corruption_data_dir_overlapped_speech.sh | 30 +++++++++++++------ 1 file changed, 21 insertions(+), 9 deletions(-) diff --git a/egs/aspire/s5/local/segmentation/do_corruption_data_dir_overlapped_speech.sh b/egs/aspire/s5/local/segmentation/do_corruption_data_dir_overlapped_speech.sh index f387acb8552..242dfca8170 100644 --- a/egs/aspire/s5/local/segmentation/do_corruption_data_dir_overlapped_speech.sh +++ b/egs/aspire/s5/local/segmentation/do_corruption_data_dir_overlapped_speech.sh @@ -1,23 +1,35 @@ #!/bin/bash + +# Copyright 2016 Vimal Manohar +# Apache 2.0 + set -e set -u set -o pipefail . path.sh -. cmd.sh - -num_data_reps=5 -data_dir=data/train_si284 -whole_data_dir=data/train_si284_whole - -nj=40 -reco_nj=40 stage=0 corruption_stage=-10 +corrupt_only=false + +# Data options +data_dir=data/train_si284 # Excpecting non-whole data directory +speed_perturb=true +num_data_reps=5 # Number of corrupted versions +snrs="20:10:15:5:0:-5" +foreground_snrs="20:10:15:5:0:-5" +background_snrs="20:10:15:5:0:-5" +base_rirs=simulated +# Whole-data directory corresponding to data_dir +whole_data_dir=data/train_si284_whole -pad_silence=false +# Parallel options +reco_nj=40 +nj=40 +cmd=queue.pl +# Options for feature extraction mfcc_config=conf/mfcc_hires_bp_vh.conf energy_config=conf/log_energy.conf From c7ba2080a72d253080a6cb038bceb399a7dd5633 Mon Sep 17 00:00:00 2001 From: Vimal Manohar Date: Sun, 11 Dec 2016 13:41:35 -0500 Subject: [PATCH 081/213] dropout_schedule: Changing default in dropout-schedule option --- egs/wsj/s5/steps/libs/nnet3/train/common.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/egs/wsj/s5/steps/libs/nnet3/train/common.py b/egs/wsj/s5/steps/libs/nnet3/train/common.py index f2485e36784..90ee209a092 100644 --- a/egs/wsj/s5/steps/libs/nnet3/train/common.py +++ b/egs/wsj/s5/steps/libs/nnet3/train/common.py @@ -647,7 +647,8 @@ def __init__(self): doesn't increase the effective learning rate.""") self.parser.add_argument("--trainer.dropout-schedule", type=str, - dest='dropout_schedule', default='', + dest='dropout_schedule', default=None, + action=common_lib.NullstrToNoneAction, help="""Use this to specify the dropout schedule. You specify a piecewise linear function on the domain [0,1], where 0 is the From 851e98a993d8232b99c77efc7340ae65c616d29d Mon Sep 17 00:00:00 2001 From: Vimal Manohar Date: Mon, 12 Dec 2016 23:35:14 -0500 Subject: [PATCH 082/213] Bug fix in xconfig/basic_layers.py --- egs/wsj/s5/steps/libs/nnet3/xconfig/basic_layers.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/egs/wsj/s5/steps/libs/nnet3/xconfig/basic_layers.py b/egs/wsj/s5/steps/libs/nnet3/xconfig/basic_layers.py index 24eea922968..c612af984b1 100644 --- a/egs/wsj/s5/steps/libs/nnet3/xconfig/basic_layers.py +++ b/egs/wsj/s5/steps/libs/nnet3/xconfig/basic_layers.py @@ -349,7 +349,8 @@ def set_default_configs(self): # note: self.config['input'] is a descriptor, '[-1]' means output # the most recent layer. - self.config = { 'input':'[-1]' } + self.config = {'input': '[-1]', + 'dim': -1} def check_configs(self): From 693bd140ebb2f8203a7635518dbb4c0037e78c5a Mon Sep 17 00:00:00 2001 From: Vimal Manohar Date: Mon, 12 Dec 2016 23:36:24 -0500 Subject: [PATCH 083/213] asr_diarization: Addind stats_layer to xconfigs --- egs/wsj/s5/steps/libs/nnet3/xconfig/layers.py | 1 + egs/wsj/s5/steps/libs/nnet3/xconfig/parser.py | 3 +- .../steps/libs/nnet3/xconfig/stats_layer.py | 142 ++++++++++++++++++ 3 files changed, 145 insertions(+), 1 deletion(-) create mode 100644 egs/wsj/s5/steps/libs/nnet3/xconfig/stats_layer.py diff --git a/egs/wsj/s5/steps/libs/nnet3/xconfig/layers.py b/egs/wsj/s5/steps/libs/nnet3/xconfig/layers.py index 353b9d3bba4..1092be572b4 100644 --- a/egs/wsj/s5/steps/libs/nnet3/xconfig/layers.py +++ b/egs/wsj/s5/steps/libs/nnet3/xconfig/layers.py @@ -6,3 +6,4 @@ from basic_layers import * from lstm import * from tdnn import * +from stats_layer import * diff --git a/egs/wsj/s5/steps/libs/nnet3/xconfig/parser.py b/egs/wsj/s5/steps/libs/nnet3/xconfig/parser.py index 7ccab2f6c6f..7b34481993b 100644 --- a/egs/wsj/s5/steps/libs/nnet3/xconfig/parser.py +++ b/egs/wsj/s5/steps/libs/nnet3/xconfig/parser.py @@ -29,7 +29,8 @@ 'lstmp-layer' : xlayers.XconfigLstmpLayer, 'lstmpc-layer' : xlayers.XconfigLstmpcLayer, 'fast-lstm-layer' : xlayers.XconfigFastLstmLayer, - 'fast-lstmp-layer' : xlayers.XconfigFastLstmpLayer + 'fast-lstmp-layer' : xlayers.XconfigFastLstmpLayer, + 'stats-layer': xlayers.XconfigStatsLayer } # Converts a line as parsed by ParseConfigLine() into a first diff --git a/egs/wsj/s5/steps/libs/nnet3/xconfig/stats_layer.py b/egs/wsj/s5/steps/libs/nnet3/xconfig/stats_layer.py new file mode 100644 index 00000000000..beaf7c8923a --- /dev/null +++ b/egs/wsj/s5/steps/libs/nnet3/xconfig/stats_layer.py @@ -0,0 +1,142 @@ +# Copyright 2016 Vimal Manohar +# Apache 2.0. + +""" This module contains the statistics extraction and pooling layer. +""" + +from __future__ import print_function +import re +from libs.nnet3.xconfig.utils import XconfigParserError as xparser_error +from libs.nnet3.xconfig.basic_layers import XconfigLayerBase + + +class XconfigStatsLayer(XconfigLayerBase): + """This class is for parsing lines like + stats-layer name=tdnn1-stats config=mean+stddev(-99:3:9:99) input=tdnn1 + + This adds statistics-pooling and statistics-extraction components. An + example string is 'mean(-99:3:9::99)', which means, compute the mean of + data within a window of -99 to +99, with distinct means computed every 9 + frames (we round to get the appropriate one), and with the input extracted + on multiples of 3 frames (so this will force the input to this layer to be + evaluated every 3 frames). Another example string is + 'mean+stddev(-99:3:9:99)', which will also cause the standard deviation to + be computed. + + The dimension is worked out from the input. mean and stddev add a + dimension of input_dim each to the output dimension. If counts is + specified, an additional dimension is added to the output to store log + counts. + + Parameters of the class, and their defaults: + input='[-1]' [Descriptor giving the input of the layer.] + dim=-1 [Output dimension of layer. If provided, must match the + dimension computed from input] + config='' [Required. Defines what stats must be computed.] + """ + def __init__(self, first_token, key_to_value, prev_names=None): + assert first_token in ['stats-layer'] + XconfigLayerBase.__init__(self, first_token, key_to_value, prev_names) + + def set_default_configs(self): + self.config = {'input': '[-1]', + 'dim': -1, + 'config': ''} + + def set_derived_configs(self): + config_string = self.config['config'] + if config_string == '': + raise xparser_error("config has to be non-empty", + self.str()) + m = re.search("(mean|mean\+stddev|mean\+count|mean\+stddev\+count)" + "\((-?\d+):(-?\d+):(-?\d+):(-?\d+)\)", + config_string) + if m is None: + raise xparser_error("Invalid statistic-config string: {0}".format( + config_string), self) + + self._output_stddev = (m.group(1) in ['mean+stddev', + 'mean+stddev+count']) + self._output_log_counts = (m.group(1) in ['mean+count', + 'mean+stddev+count']) + self._left_context = -int(m.group(2)) + self._input_period = int(m.group(3)) + self._stats_period = int(m.group(4)) + self._right_context = int(m.group(5)) + + output_dim = (self.descriptors['input']['dim'] + * (2 if self._output_stddev else 1) + + 1 if self._output_log_counts else 0) + + if self.config['dim'] > 0 and self.config['dim'] != output_dim: + raise xparser_error( + "Invalid dim supplied {0:d} != " + "actual output dim {1:d}".format( + self.config['dim'], output_dim)) + self.config['dim'] = output_dim + + def check_configs(self): + if not (self._left_context > 0 and self._right_context > 0 + and self._input_period > 0 and self._stats_period > 0 + and self._left_context % self._stats_period == 0 + and self._right_context % self._stats_period == 0 + and self._stats_period % self._input_period == 0): + raise xparser_error( + "Invalid configuration of statistics-extraction: {0}".format( + self.config['config']), self) + super(XconfigStatsLayer, self).check_configs() + + def _generate_config(self): + input_desc = self.descriptors['input']['final-string'] + input_dim = self.descriptors['input']['dim'] + + configs = [] + configs.append( + 'component name={name}-extraction-{lc}-{rc} ' + 'type=StatisticsExtractionComponent input-dim={dim} ' + 'input-period={input_period} output-period={output_period} ' + 'include-variance={var} '.format( + name=self.name, lc=self._left_context, rc=self._right_context, + dim=input_dim, input_period=self._input_period, + output_period=self._stats_period, + var='true' if self._output_stddev else 'false')) + configs.append( + 'component-node name={name}-extraction-{lc}-{rc} ' + 'component={name}-extraction-{lc}-{rc} input={input} '.format( + name=self.name, lc=self._left_context, rc=self._right_context, + input=input_desc)) + + stats_dim = 1 + input_dim * (2 if self._output_stddev else 1) + configs.append( + 'component name={name}-pooling-{lc}-{rc} ' + 'type=StatisticsPoolingComponent input-dim={dim} ' + 'input-period={input_period} left-context={lc} right-context={rc} ' + 'num-log-count-features={count} output-stddevs={var} '.format( + name=self.name, lc=self._left_context, rc=self._right_context, + dim=stats_dim, input_period=self._stats_period, + count=1 if self._output_log_counts else 0, + var='true' if self._output_stddev else 'false')) + configs.append( + 'component-node name={name}-pooling-{lc}-{rc} ' + 'component={name}-pooling-{lc}-{rc} ' + 'input={name}-extraction-{lc}-{rc} '.format( + name=self.name, lc=self._left_context, rc=self._right_context)) + return configs + + def output_name(self, auxiliary_output=None): + return 'Round({name}-pooling-{lc}-{rc}, {period})'.format( + name=self.name, lc=self._left_context, + rc=self._right_context, period=self._stats_period) + + def output_dim(self, auxiliary_outputs=None): + return self.config['dim'] + + def get_full_config(self): + ans = [] + config_lines = self._generate_config() + + for line in config_lines: + for config_name in ['ref', 'final']: + ans.append((config_name, line)) + + return ans From ed938f63cdd5dc19dba257de765e4773c279fb42 Mon Sep 17 00:00:00 2001 From: Vimal Manohar Date: Mon, 12 Dec 2016 23:36:48 -0500 Subject: [PATCH 084/213] asr_diarization: Making xconfigs support more general networks --- egs/wsj/s5/steps/nnet3/xconfig_to_configs.py | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/egs/wsj/s5/steps/nnet3/xconfig_to_configs.py b/egs/wsj/s5/steps/nnet3/xconfig_to_configs.py index c55dae18b19..5edd3303942 100755 --- a/egs/wsj/s5/steps/nnet3/xconfig_to_configs.py +++ b/egs/wsj/s5/steps/nnet3/xconfig_to_configs.py @@ -25,6 +25,9 @@ def get_args(): help='Filename of input xconfig file') parser.add_argument('--config-dir', required=True, help='Directory to write config files and variables') + parser.add_argument('--nnet-edits', type=str, default=None, + action=common_lib.NullstrToNoneAction, + help="Edit network before getting nnet3-info") print(' '.join(sys.argv)) @@ -187,13 +190,19 @@ def write_config_files(config_dir, all_layers): raise -def add_back_compatibility_info(config_dir): +def add_back_compatibility_info(config_dir, nnet_edits=None): """This will be removed when python script refactoring is done.""" common_lib.run_kaldi_command("nnet3-init {0}/ref.config " "{0}/ref.raw".format(config_dir)) - out, err = common_lib.run_kaldi_command("nnet3-info {0}/ref.raw | " - "head -4".format(config_dir)) + model = "{0}/ref.raw".format(config_dir) + if nnet_edits is not None: + model = """nnet3-copy --edits='{0}' {1} - |""".format(nnet_edits, + model) + + print("""nnet3-info "{0}" | head -4""".format(model), file=sys.stderr) + out, err = common_lib.run_kaldi_command("""nnet3-info "{0}" | """ + """head -4""".format(model)) # out looks like this # left-context: 7 # right-context: 0 @@ -226,7 +235,7 @@ def main(): all_layers = xparser.read_xconfig_file(args.xconfig_file) write_expanded_xconfig_files(args.config_dir, all_layers) write_config_files(args.config_dir, all_layers) - add_back_compatibility_info(args.config_dir) + add_back_compatibility_info(args.config_dir, args.nnet_edits) if __name__ == '__main__': From c6ade8c84b7f05c6ac9d6ce53b379c87722d633b Mon Sep 17 00:00:00 2001 From: Vimal Manohar Date: Mon, 12 Dec 2016 23:37:41 -0500 Subject: [PATCH 085/213] asr_diarization: Update do_corruption_data_dir.sh with better default valuesD --- .../segmentation/do_corruption_data_dir.sh | 18 +++++++----------- 1 file changed, 7 insertions(+), 11 deletions(-) diff --git a/egs/aspire/s5/local/segmentation/do_corruption_data_dir.sh b/egs/aspire/s5/local/segmentation/do_corruption_data_dir.sh index 36bf4c93306..1bfa08370e7 100755 --- a/egs/aspire/s5/local/segmentation/do_corruption_data_dir.sh +++ b/egs/aspire/s5/local/segmentation/do_corruption_data_dir.sh @@ -27,8 +27,8 @@ reco_nj=40 cmd=queue.pl # Options for feature extraction -mfcc_config=conf/mfcc_hires_bp_vh.conf -feat_suffix=hires_bp_vh +mfcc_config=conf/mfcc_hires_bp.conf +feat_suffix=hires_bp reco_vad_dir= # Output of prepare_unsad_data.sh. # If provided, the speech labels and deriv weights will be @@ -105,19 +105,15 @@ if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $mfccdir/storage ]; then fi if [ $stage -le 4 ]; then - if [ ! -z $feat_suffix ]; then - utils/copy_data_dir.sh $corrupted_data_dir ${corrupted_data_dir}_$feat_suffix - corrupted_data_dir=${corrupted_data_dir}_$feat_suffix - fi + utils/copy_data_dir.sh $corrupted_data_dir ${corrupted_data_dir}_$feat_suffix + corrupted_data_dir=${corrupted_data_dir}_$feat_suffix steps/make_mfcc.sh --mfcc-config $mfcc_config \ --cmd "$cmd" --nj $reco_nj \ - $corrupted_data_dir exp/make_${mfccdir}/${corrupted_data_id} $mfccdir + $corrupted_data_dir exp/make_${feat_suffix}/${corrupted_data_id} $mfccdir steps/compute_cmvn_stats.sh --fake \ - $corrupted_data_dir exp/make_${mfccdir}/${corrupted_data_id} $mfccdir + $corrupted_data_dir exp/make_${feat_suffix}/${corrupted_data_id} $mfccdir else - if [ ! -z $feat_suffix ]; then - corrupted_data_dir=${corrupted_data_dir}_$feat_suffix - fi + corrupted_data_dir=${corrupted_data_dir}_$feat_suffix fi if [ $stage -le 8 ]; then From d639b319e08b63049e91eb7e0f898a49941d7d68 Mon Sep 17 00:00:00 2001 From: Vimal Manohar Date: Mon, 12 Dec 2016 23:37:59 -0500 Subject: [PATCH 086/213] asr_diarization: Update do_corruption_data_dir_overlapped_speech.sh with better default values --- ...o_corruption_data_dir_overlapped_speech.sh | 60 ++++++++++--------- 1 file changed, 31 insertions(+), 29 deletions(-) diff --git a/egs/aspire/s5/local/segmentation/do_corruption_data_dir_overlapped_speech.sh b/egs/aspire/s5/local/segmentation/do_corruption_data_dir_overlapped_speech.sh index 242dfca8170..75dbce578b2 100644 --- a/egs/aspire/s5/local/segmentation/do_corruption_data_dir_overlapped_speech.sh +++ b/egs/aspire/s5/local/segmentation/do_corruption_data_dir_overlapped_speech.sh @@ -1,4 +1,4 @@ -#!/bin/bash +#! /bin/bash # Copyright 2016 Vimal Manohar # Apache 2.0 @@ -20,9 +20,10 @@ num_data_reps=5 # Number of corrupted versions snrs="20:10:15:5:0:-5" foreground_snrs="20:10:15:5:0:-5" background_snrs="20:10:15:5:0:-5" -base_rirs=simulated +overlap_snrs="5:2:1:0:-1:-2" # Whole-data directory corresponding to data_dir whole_data_dir=data/train_si284_whole +overlap_labels_dir=overlap_labels # Parallel options reco_nj=40 @@ -30,21 +31,15 @@ nj=40 cmd=queue.pl # Options for feature extraction -mfcc_config=conf/mfcc_hires_bp_vh.conf +mfcc_config=conf/mfcc_hires_bp.conf +feat_suffix=hires_bp energy_config=conf/log_energy.conf -dry_run=false -corrupt_only=false -speed_perturb=true - -reco_vad_dir= +reco_vad_dir= # Output of prepare_unsad_data.sh. + # If provided, the speech labels and deriv weights will be + # copied into the output data directory. utt_vad_dir= -max_jobs_run=20 - -overlap_snrs="5:2:1:0:-1:-2" -base_rirs=simulated - . utils/parse_options.sh if [ $# -ne 0 ]; then @@ -64,7 +59,7 @@ corrupted_data_id=${whole_data_id}_ovlp_corrupted clean_data_id=${whole_data_id}_ovlp_clean noise_data_id=${whole_data_id}_ovlp_noise -if [ $stage -le 2 ]; then +if [ $stage -le 1 ]; then python steps/data/make_corrupted_data_dir.py \ "${rvb_opts[@]}" \ --prefix="ovlp" \ @@ -89,7 +84,7 @@ noise_data_dir=data/${noise_data_id} orig_corrupted_data_dir=$corrupted_data_dir if $speed_perturb; then - if [ $stage -le 3 ]; then + if [ $stage -le 2 ]; then ## Assuming whole data directories for x in $clean_data_dir $corrupted_data_dir $noise_data_dir; do cp $x/reco2dur $x/utt2dur @@ -105,8 +100,8 @@ if $speed_perturb; then clean_data_id=${clean_data_id}_sp noise_data_id=${noise_data_id}_sp - if [ $stage -le 4 ]; then - utils/data/perturb_data_dir_volume.sh --force true ${corrupted_data_dir} + if [ $stage -le 3 ]; then + utils/data/perturb_data_dir_volume.sh --scale-low 0.03125 --scale-high 2 --force true ${corrupted_data_dir} utils/data/perturb_data_dir_volume.sh --force true --reco2vol ${corrupted_data_dir}/reco2vol ${clean_data_dir} utils/data/perturb_data_dir_volume.sh --force true --reco2vol ${corrupted_data_dir}/reco2vol ${noise_data_dir} fi @@ -125,19 +120,21 @@ if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $mfccdir/storage ]; then /export/b0{3,4,5,6}/$USER/kaldi-data/egs/aspire-$(date +'%m_%d_%H_%M')/s5/$mfccdir/storage $mfccdir/storage fi -if [ $stage -le 5 ]; then +if [ $stage -le 4 ]; then + utils/copy_data_dir.sh $corrupted_data_dir ${corrupted_data_dir}_$feat_suffix + corrupted_data_dir=${corrupted_data_dir}_$feat_suffix steps/make_mfcc.sh --mfcc-config $mfcc_config \ --cmd "$train_cmd" --nj $reco_nj \ - $corrupted_data_dir exp/make_hires_bp/${corrupted_data_id} $mfccdir + $corrupted_data_dir exp/make_${feat_suffix}/${corrupted_data_id} $mfccdir fi -if [ $stage -le 6 ]; then +if [ $stage -le 5 ]; then steps/make_mfcc.sh --mfcc-config $energy_config \ --cmd "$train_cmd" --nj $reco_nj \ $clean_data_dir exp/make_log_energy/${clean_data_id} log_energy_feats fi -if [ $stage -le 7 ]; then +if [ $stage -le 6 ]; then steps/make_mfcc.sh --mfcc-config $energy_config \ --cmd "$train_cmd" --nj $reco_nj \ $noise_data_dir exp/make_log_energy/${noise_data_id} log_energy_feats @@ -164,6 +161,11 @@ if [ $stage -le 8 ]; then exp/make_irm_targets/${corrupted_data_id} $targets_dir fi +# Combine the VAD from the base recording and the VAD from the overlapping segments +# to create per-frame labels of the number of overlapping speech segments +# Unreliable segments are regions where no VAD labels were available for the +# overlapping segments. These can be later removed by setting deriv weights to 0. + # Data dirs without speed perturbation overlap_dir=exp/make_overlap_labels/${corrupted_data_id} unreliable_dir=exp/make_overlap_labels/unreliable_${corrupted_data_id} @@ -179,10 +181,6 @@ if [ $stage -le 8 ]; then utils/data/get_utt2num_frames.sh $corrupted_data_dir utils/split_data.sh --per-reco ${orig_corrupted_data_dir} $reco_nj - # Combine the VAD from the base recording and the VAD from the overlapping segments - # to create per-frame labels of the number of overlapping speech segments - # Unreliable segments are regions where no VAD labels were available for the - # overlapping segments. These can be later removed by setting deriv weights to 0. $train_cmd JOB=1:$reco_nj $overlap_dir/log/get_overlap_seg.JOB.log \ segmentation-init-from-overlap-info --lengths-rspecifier=ark,t:$corrupted_data_dir/utt2num_frames \ "scp:utils/filter_scp.pl ${orig_corrupted_data_dir}/split${reco_nj}reco/JOB/utt2spk $corrupted_data_dir/sad_seg.scp |" \ @@ -200,6 +198,7 @@ if [ $stage -le 9 ]; then cp $orig_corrupted_data_dir/wav.scp $unreliable_data_dir # Create segments where there is definitely an overlap. + # Assume no more than 10 speakers overlap. $train_cmd JOB=1:$reco_nj $overlap_dir/log/process_to_segments.JOB.log \ segmentation-post-process --remove-labels=0:1 \ ark:$overlap_dir/overlap_seg_speed_unperturbed.JOB.ark ark:- \| \ @@ -230,6 +229,9 @@ if $speed_perturb; then unreliable_data_dir=${unreliable_data_dir}_sp fi +# make $overlap_labels_dir an absolute pathname. +overlap_labels_dir=`perl -e '($dir,$pwd)= @ARGV; if($dir!~m:^/:) { $dir = "$pwd/$dir"; } print $dir; ' $overlap_labels_dir ${PWD}` + if [ $stage -le 10 ]; then utils/split_data.sh --per-reco ${overlap_data_dir} $reco_nj @@ -240,11 +242,11 @@ if [ $stage -le 10 ]; then segmentation-combine-segments-to-recordings ark:- ark,t:${overlap_data_dir}/split${reco_nj}reco/JOB/reco2utt \ ark:- \| \ segmentation-to-ali --lengths-rspecifier=ark,t:${corrupted_data_dir}/utt2num_frames ark:- \ - ark,scp:overlap_labels/overlapped_speech_${corrupted_data_id}.JOB.ark,overlap_labels/overlapped_speech_${corrupted_data_id}.JOB.scp + ark,scp:$overlap_labels_dir/overlapped_speech_${corrupted_data_id}.JOB.ark,$overlap_labels_dir/overlapped_speech_${corrupted_data_id}.JOB.scp fi for n in `seq $reco_nj`; do - cat overlap_labels/overlapped_speech_${corrupted_data_id}.$n.scp + cat $overlap_labels_dir/overlapped_speech_${corrupted_data_id}.$n.scp done > ${corrupted_data_dir}/overlapped_speech_labels.scp if [ $stage -le 11 ]; then @@ -272,10 +274,10 @@ if [ $stage -le 11 ]; then segmentation-post-process --remove-labels=0 ark:- ark:- \| \ segmentation-to-ali --lengths-rspecifier=ark,t:${corrupted_data_dir}/utt2num_frames ark:- ark,t:- \| \ steps/segmentation/convert_ali_to_vec.pl \| copy-vector ark,t:- \ - ark,scp:$unreliable_dir/deriv_weights_for_overlapped_speech.JOB.ark,$unreliable_dir/deriv_weights_for_overlapped_speech.JOB.scp + ark,scp:$overlap_labels_dir/deriv_weights_for_overlapped_speech.JOB.ark,$overlap_labels_dir/deriv_weights_for_overlapped_speech.JOB.scp for n in `seq $reco_nj`; do - cat $unreliable_dir/deriv_weights_for_overlapped_speech.${n}.scp + cat $overlap_labels_dir/deriv_weights_for_overlapped_speech.${n}.scp done > $corrupted_data_dir/deriv_weights_for_overlapped_speech.scp fi From 8668bcfe7771506cac40775431831b7ff9871ac8 Mon Sep 17 00:00:00 2001 From: Vimal Manohar Date: Mon, 12 Dec 2016 23:39:45 -0500 Subject: [PATCH 087/213] asr_diarization: Moving ../egs/aspire/s5/local/segmentation/train_stats_sad_music.sh to tuning --- .../segmentation/train_stats_sad_music.sh | 173 +----------------- .../tuning/train_stats_sad_music_1a.sh | 172 +++++++++++++++++ 2 files changed, 173 insertions(+), 172 deletions(-) mode change 100644 => 120000 egs/aspire/s5/local/segmentation/train_stats_sad_music.sh create mode 100644 egs/aspire/s5/local/segmentation/tuning/train_stats_sad_music_1a.sh diff --git a/egs/aspire/s5/local/segmentation/train_stats_sad_music.sh b/egs/aspire/s5/local/segmentation/train_stats_sad_music.sh deleted file mode 100644 index 8242b83c747..00000000000 --- a/egs/aspire/s5/local/segmentation/train_stats_sad_music.sh +++ /dev/null @@ -1,172 +0,0 @@ -#!/bin/bash - -# This is a script to train a time-delay neural network for speech activity detection (SAD) and -# music-id using statistic pooling component for long-context information. - -set -o pipefail -set -e -set -u - -. cmd.sh - -# At this script level we don't support not running on GPU, as it would be painfully slow. -# If you want to run without GPU you'd have to call train_tdnn.sh with --gpu false, -# --num-threads 16 and --minibatch-size 128. - -stage=0 -train_stage=-10 -get_egs_stage=-10 -egs_opts= # Directly passed to get_egs_multiple_targets.py - -# TDNN options -splice_indexes="-3,-2,-1,0,1,2,3 -6,0,mean+count(-99:3:9:99) -9,0,3 0" -relu_dim=256 -chunk_width=20 # We use chunk training for training TDNN -extra_left_context=100 # Maximum left context in egs apart from TDNN's left context -extra_right_context=20 # Maximum right context in egs apart from TDNN's right context - -# We randomly select an extra {left,right} context for each job between -# min_extra_*_context and extra_*_context so that the network can get used -# to different contexts used to compute statistics. -min_extra_left_context=20 -min_extra_right_context=0 - -# training options -num_epochs=2 -initial_effective_lrate=0.0003 -final_effective_lrate=0.00003 -num_jobs_initial=3 -num_jobs_final=8 -remove_egs=false -max_param_change=0.2 # Small max-param change for small network -extra_egs_copy_cmd= # Used if you want to do some weird stuff to egs - # such as removing one of the targets - -num_utts_subset_valid=50 # "utts" is actually recording. So this is prettly small. -num_utts_subset_train=50 - -# target options -train_data_dir=data/train_azteec_whole_sp_corrupted_hires - -speech_feat_scp= -music_labels_scp= - -deriv_weights_scp= - -egs_dir= -nj=40 -feat_type=raw -config_dir= - -dir= -affix=a - -. cmd.sh -. ./path.sh -. ./utils/parse_options.sh - -num_hidden_layers=`echo $splice_indexes | perl -ane 'print scalar @F'` || exit 1 -if [ -z "$dir" ]; then - dir=exp/nnet3_stats_sad_music/nnet_tdnn -fi - -dir=$dir${affix:+_$affix}_n${num_hidden_layers} - -if ! cuda-compiled; then - cat < Date: Mon, 12 Dec 2016 23:40:27 -0500 Subject: [PATCH 088/213] asr_diarization: Bug fix in random extra contexts --- .../s5/steps/libs/nnet3/train/frame_level_objf/common.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/egs/wsj/s5/steps/libs/nnet3/train/frame_level_objf/common.py b/egs/wsj/s5/steps/libs/nnet3/train/frame_level_objf/common.py index 37dd36aa392..0b0149ece3d 100644 --- a/egs/wsj/s5/steps/libs/nnet3/train/frame_level_objf/common.py +++ b/egs/wsj/s5/steps/libs/nnet3/train/frame_level_objf/common.py @@ -66,13 +66,17 @@ def train_new_models(dir, iter, srand, num_jobs, deriv_time_opts.append("--optimization.max-deriv-time={0}".format( max_deriv_time)) - this_random = random.Random(srand) + this_random = random.Random(srand + iter) + if min_left_context is not None: left_context = this_random.randint(min_left_context, left_context) if min_right_context is not None: right_context = this_random.randint(min_right_context, right_context) + logger.info("On iteration %d, left-context=%d and right-context=%s", + iter, left_context, right_context) + context_opts = "--left-context={0} --right-context={1}".format( left_context, right_context) From 0dc172cbf8f15ce65088be936f6827238d56c763 Mon Sep 17 00:00:00 2001 From: Vimal Manohar Date: Mon, 12 Dec 2016 23:41:04 -0500 Subject: [PATCH 089/213] asr_diarization: New tuning scrits for music id --- .../tuning/train_stats_sad_music_1b.sh | 191 ++++++++++++++++++ .../tuning/train_stats_sad_music_1c.sh | 185 +++++++++++++++++ 2 files changed, 376 insertions(+) create mode 100644 egs/aspire/s5/local/segmentation/tuning/train_stats_sad_music_1b.sh create mode 100644 egs/aspire/s5/local/segmentation/tuning/train_stats_sad_music_1c.sh diff --git a/egs/aspire/s5/local/segmentation/tuning/train_stats_sad_music_1b.sh b/egs/aspire/s5/local/segmentation/tuning/train_stats_sad_music_1b.sh new file mode 100644 index 00000000000..685dd846b26 --- /dev/null +++ b/egs/aspire/s5/local/segmentation/tuning/train_stats_sad_music_1b.sh @@ -0,0 +1,191 @@ +#!/bin/bash + +# This is a script to train a time-delay neural network for speech activity detection (SAD) and +# music-id using statistic pooling component for long-context information. + +set -o pipefail +set -e +set -u + +. cmd.sh + +# At this script level we don't support not running on GPU, as it would be painfully slow. +# If you want to run without GPU you'd have to call train_tdnn.sh with --gpu false, +# --num-threads 16 and --minibatch-size 128. + +stage=0 +train_stage=-10 +get_egs_stage=-10 +egs_opts= # Directly passed to get_egs_multiple_targets.py + +# TDNN options +relu_dim=256 +chunk_width=20 # We use chunk training for training TDNN +extra_left_context=100 # Maximum left context in egs apart from TDNN's left context +extra_right_context=20 # Maximum right context in egs apart from TDNN's right context + +# We randomly select an extra {left,right} context for each job between +# min_extra_*_context and extra_*_context so that the network can get used +# to different contexts used to compute statistics. +min_extra_left_context=20 +min_extra_right_context=0 + +# training options +num_epochs=2 +initial_effective_lrate=0.0003 +final_effective_lrate=0.00003 +num_jobs_initial=3 +num_jobs_final=8 +remove_egs=false +max_param_change=0.2 # Small max-param change for small network +extra_egs_copy_cmd= # Used if you want to do some weird stuff to egs + # such as removing one of the targets + +num_utts_subset_valid=50 # "utts" is actually recording. So this is prettly small. +num_utts_subset_train=50 + +# target options +train_data_dir=data/train_azteec_whole_sp_corrupted_hires + +speech_feat_scp= +music_labels_scp= + +deriv_weights_scp= + +egs_dir= +nj=40 +feat_type=raw +config_dir= + +dir= +affix=a + +. cmd.sh +. ./path.sh +. ./utils/parse_options.sh + +num_hidden_layers=`echo $splice_indexes | perl -ane 'print scalar @F'` || exit 1 +if [ -z "$dir" ]; then + dir=exp/nnet3_stats_sad_music/nnet_tdnn +fi + +dir=$dir${affix:+_$affix}_n${num_hidden_layers} + +if ! cuda-compiled; then + cat < $dir/configs/network.xconfig + input dim=`feat-to-dim scp:$train_data_dir/feats.scp -` name=input + + # please note that it is important to have input layer with the name=input + # as the layer immediately preceding the fixed-affine-layer to enable + # the use of short notation for the descriptor + # This is disabled for now. + # fixed-affine-layer name=lda input=Append(-3,-2,-1,0,1,2,3) affine-transform-file=$dir/configs/lda.mat + # the first splicing is moved before the lda layer, so no splicing here + # relu-renorm-layer name=tdnn1 dim=625 + + relu-renorm-layer name=tdnn1 input=Append(-3,-2,-1,0,1,2,3) dim=256 + stats-layer name=tdnn2.stats config=mean+count(-99:3:9:99) + relu-renorm-layer name=tdnn2 input=Append(Offset(tdnn1, -6), tdnn1, tdnn2.stats) dim=256 + relu-renorm-layer name=tdnn3 input=Append(-9,0,3) dim=256 + relu-renorm-layer name=tdnn4 dim=256 + + output-layer name=output-speech include-log-softmax=true dim=2 input=tdnn4 + output-layer name=output-music include-log-softmax=true dim=2 input=tdnn4 +EOF + steps/nnet3/xconfig_to_configs.py --xconfig-file $dir/configs/network.xconfig \ + --config-dir $dir/configs/ +fi + +if [ -z "$egs_dir" ]; then + egs_dir=$dir/egs + if [ $stage -le 4 ]; then + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $dir/egs/storage ]; then + utils/create_split_dir.pl \ + /export/b{03,04,05,06}/$USER/kaldi-data/egs/aspire-$(date +'%m_%d_%H_%M')/s5/$dir/egs/storage $dir/egs/storage + fi + + . $dir/configs/vars + + steps/nnet3/get_egs_multiple_targets.py --cmd="$decode_cmd" \ + $egs_opts \ + --feat.dir="$train_data_dir" \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --frames-per-eg=$chunk_width \ + --left-context=$[model_left_context + extra_left_context] \ + --right-context=$[model_right_context + extra_right_context] \ + --num-utts-subset-train=$num_utts_subset_train \ + --num-utts-subset-valid=$num_utts_subset_valid \ + --samples-per-iter=20000 \ + --stage=$get_egs_stage \ + --targets-parameters="--output-name=output-speech --target-type=sparse --dim=2 --targets-scp=$speech_feat_scp --deriv-weights-scp=$deriv_weights_scp --scp2ark-cmd=\"extract-column --column-index=0 scp:- ark,t:- | steps/segmentation/quantize_vector.pl | ali-to-post ark,t:- ark:- |\" --compress=true" \ + --targets-parameters="--output-name=output-music --target-type=sparse --dim=2 --targets-scp=$music_labels_scp --scp2ark-cmd=\"ali-to-post scp:- ark:- |\" --compress=true" \ + --dir=$dir/egs + fi +fi + +if [ $stage -le 5 ]; then + steps/nnet3/train_raw_rnn.py --stage=$train_stage \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --egs.chunk-width=20 \ + --egs.dir="$egs_dir" --egs.stage=$get_egs_stage \ + --egs.chunk-left-context=$extra_left_context \ + --egs.chunk-right-context=$extra_right_context \ + ${extra_egs_copy_cmd:+--egs.extra-copy-cmd="$extra_egs_copy_cmd"} \ + --trainer.min-chunk-left-context=$min_extra_left_context \ + --trainer.min-chunk-right-context=$min_extra_right_context \ + --trainer.num-epochs=$num_epochs \ + --trainer.samples-per-iter=20000 \ + --trainer.optimization.num-jobs-initial=$num_jobs_initial \ + --trainer.optimization.num-jobs-final=$num_jobs_final \ + --trainer.optimization.initial-effective-lrate=$initial_effective_lrate \ + --trainer.optimization.final-effective-lrate=$final_effective_lrate \ + --trainer.optimization.shrink-value=1.0 \ + --trainer.rnn.num-chunk-per-minibatch=64 \ + --trainer.deriv-truncate-margin=8 \ + --trainer.max-param-change=$max_param_change \ + --cmd="$decode_cmd" --nj 40 \ + --cleanup=true \ + --cleanup.remove-egs=$remove_egs \ + --cleanup.preserve-model-interval=10 \ + --use-gpu=true \ + --use-dense-targets=false \ + --feat-dir=$train_data_dir \ + --targets-scp="$speech_feat_scp" \ + --dir=$dir || exit 1 +fi + +if [ $stage -le 6 ]; then + $train_cmd JOB=1:100 $dir/log/compute_post_output-speech.JOB.log \ + extract-column "scp:utils/split_scp.pl -j 100 \$[JOB-1] $speech_feat_scp |" ark,t:- \| \ + steps/segmentation/quantize_vector.pl \| \ + ali-to-post ark,t:- ark:- \| \ + weight-post ark:- scp:$deriv_weights_scp ark:- \| \ + post-to-feats --post-dim=2 ark:- ark:- \| \ + matrix-sum-rows ark:- ark:- \| \ + vector-sum ark:- $dir/post_output-speech.vec.JOB + eval vector-sum $dir/post_output-speech.vec.{`seq -s, 100`} $dir/post_output-speech.vec + + $train_cmd JOB=1:100 $dir/log/compute_post_output-music.JOB.log \ + ali-to-post "scp:utils/split_scp.pl -j 100 \$[JOB-1] $music_labels_scp |" ark:- \| \ + post-to-feats --post-dim=2 ark:- ark:- \| \ + matrix-sum-rows ark:- ark:- \| \ + vector-sum ark:- $dir/post_output-music.vec.JOB + eval vector-sum $dir/post_output-music.vec.{`seq -s, 100`} $dir/post_output-music.vec +fi diff --git a/egs/aspire/s5/local/segmentation/tuning/train_stats_sad_music_1c.sh b/egs/aspire/s5/local/segmentation/tuning/train_stats_sad_music_1c.sh new file mode 100644 index 00000000000..163ea6df14d --- /dev/null +++ b/egs/aspire/s5/local/segmentation/tuning/train_stats_sad_music_1c.sh @@ -0,0 +1,185 @@ +#!/bin/bash + +# This is a script to train a time-delay neural network for speech activity detection (SAD) and +# music-id using statistic pooling component for long-context information. + +set -o pipefail +set -e +set -u + +. cmd.sh + +# At this script level we don't support not running on GPU, as it would be painfully slow. +# If you want to run without GPU you'd have to call train_tdnn.sh with --gpu false, +# --num-threads 16 and --minibatch-size 128. + +stage=0 +train_stage=-10 +get_egs_stage=-10 +egs_opts= # Directly passed to get_egs_multiple_targets.py + +# TDNN options +relu_dim=256 +chunk_width=20 # We use chunk training for training TDNN +extra_left_context=100 # Maximum left context in egs apart from TDNN's left context +extra_right_context=20 # Maximum right context in egs apart from TDNN's right context + +# We randomly select an extra {left,right} context for each job between +# min_extra_*_context and extra_*_context so that the network can get used +# to different contexts used to compute statistics. +min_extra_left_context=20 +min_extra_right_context=0 + +# training options +num_epochs=2 +initial_effective_lrate=0.0003 +final_effective_lrate=0.00003 +num_jobs_initial=3 +num_jobs_final=8 +remove_egs=false +max_param_change=0.2 # Small max-param change for small network +extra_egs_copy_cmd= # Used if you want to do some weird stuff to egs + # such as removing one of the targets + +num_utts_subset_valid=50 # "utts" is actually recording. So this is prettly small. +num_utts_subset_train=50 + +# target options +train_data_dir=data/train_azteec_whole_sp_corrupted_hires + +speech_feat_scp= +music_labels_scp= + +deriv_weights_scp= + +egs_dir= +nj=40 +feat_type=raw +config_dir= + +dir= +affix=a + +. cmd.sh +. ./path.sh +. ./utils/parse_options.sh + +if [ -z "$dir" ]; then + dir=exp/nnet3_stats_sad_music/nnet_tdnn +fi + +dir=$dir${affix:+_$affix} + +if ! cuda-compiled; then + cat < $dir/configs/network.xconfig + input dim=`feat-to-dim scp:$train_data_dir/feats.scp -` name=input + output name=output-temp input=Append(-3,-2,-1,0,1,2,3) + + relu-renorm-layer name=tdnn1 input=Append(input@-3, input@-2, input@-1, input, input@1, input@2, input@3) dim=256 + stats-layer name=tdnn2_stats config=mean+count(-99:3:9:99) + relu-renorm-layer name=tdnn2 input=Append(tdnn1@-9, tdnn1@-3, tdnn1, tdnn1@3, tdnn2_stats) dim=256 + stats-layer name=tdnn3_stats config=mean+count(-108:9:27:108) + relu-renorm-layer name=tdnn3 input=Append(tdnn2@-27, tdnn2@-9, tdnn2, tdnn2@9, tdnn3_stats) dim=256 + relu-renorm-layer name=tdnn4 dim=256 + + output-layer name=output-speech include-log-softmax=true dim=2 input=tdnn4 + output-layer name=output-music include-log-softmax=true dim=2 input=tdnn4 +EOF + steps/nnet3/xconfig_to_configs.py --xconfig-file $dir/configs/network.xconfig \ + --config-dir $dir/configs/ \ + --nnet-edits="rename-node old-name=output-speech new-name=output" + + cat <> $dir/configs/vars +add_lda=false +EOF +fi + +if [ -z "$egs_dir" ]; then + egs_dir=$dir/egs + if [ $stage -le 4 ]; then + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $dir/egs/storage ]; then + utils/create_split_dir.pl \ + /export/b{03,04,05,06}/$USER/kaldi-data/egs/aspire-$(date +'%m_%d_%H_%M')/s5/$dir/egs/storage $dir/egs/storage + fi + + . $dir/configs/vars + + steps/nnet3/get_egs_multiple_targets.py --cmd="$decode_cmd" \ + $egs_opts \ + --feat.dir="$train_data_dir" \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --frames-per-eg=$chunk_width \ + --left-context=$[model_left_context + extra_left_context] \ + --right-context=$[model_right_context + extra_right_context] \ + --num-utts-subset-train=$num_utts_subset_train \ + --num-utts-subset-valid=$num_utts_subset_valid \ + --samples-per-iter=20000 \ + --stage=$get_egs_stage \ + --targets-parameters="--output-name=output-speech --target-type=sparse --dim=2 --targets-scp=$speech_feat_scp --deriv-weights-scp=$deriv_weights_scp --scp2ark-cmd=\"extract-column --column-index=0 scp:- ark,t:- | steps/segmentation/quantize_vector.pl | ali-to-post ark,t:- ark:- |\" --compress=true" \ + --targets-parameters="--output-name=output-music --target-type=sparse --dim=2 --targets-scp=$music_labels_scp --scp2ark-cmd=\"ali-to-post scp:- ark:- |\" --compress=true" \ + --dir=$dir/egs + fi +fi + +if [ $stage -le 5 ]; then + steps/nnet3/train_raw_rnn.py --stage=$train_stage \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --egs.chunk-width=20 \ + --egs.dir="$egs_dir" --egs.stage=$get_egs_stage \ + --egs.chunk-left-context=$extra_left_context \ + --egs.chunk-right-context=$extra_right_context \ + ${extra_egs_copy_cmd:+--egs.extra-copy-cmd="$extra_egs_copy_cmd"} \ + --trainer.min-chunk-left-context=$min_extra_left_context \ + --trainer.min-chunk-right-context=$min_extra_right_context \ + --trainer.num-epochs=$num_epochs \ + --trainer.samples-per-iter=20000 \ + --trainer.optimization.num-jobs-initial=$num_jobs_initial \ + --trainer.optimization.num-jobs-final=$num_jobs_final \ + --trainer.optimization.initial-effective-lrate=$initial_effective_lrate \ + --trainer.optimization.final-effective-lrate=$final_effective_lrate \ + --trainer.optimization.shrink-value=1.0 \ + --trainer.rnn.num-chunk-per-minibatch=64 \ + --trainer.deriv-truncate-margin=8 \ + --trainer.max-param-change=$max_param_change \ + --cmd="$decode_cmd" --nj 40 \ + --cleanup=true \ + --cleanup.remove-egs=$remove_egs \ + --cleanup.preserve-model-interval=10 \ + --use-gpu=true \ + --use-dense-targets=false \ + --feat-dir=$train_data_dir \ + --targets-scp="$speech_feat_scp" \ + --dir=$dir || exit 1 +fi + +if [ $stage -le 6 ]; then + $train_cmd JOB=1:100 $dir/log/compute_post_output-speech.JOB.log \ + extract-column "scp:utils/split_scp.pl -j 100 \$[JOB-1] $speech_feat_scp |" ark,t:- \| \ + steps/segmentation/quantize_vector.pl \| \ + ali-to-post ark,t:- ark:- \| \ + weight-post ark:- scp:$deriv_weights_scp ark:- \| \ + post-to-feats --post-dim=2 ark:- ark:- \| \ + matrix-sum-rows ark:- ark:- \| \ + vector-sum ark:- $dir/post_output-speech.vec.JOB + eval vector-sum $dir/post_output-speech.vec.{`seq -s, 100`} $dir/post_output-speech.vec + + $train_cmd JOB=1:100 $dir/log/compute_post_output-music.JOB.log \ + ali-to-post "scp:utils/split_scp.pl -j 100 \$[JOB-1] $music_labels_scp |" ark:- \| \ + post-to-feats --post-dim=2 ark:- ark:- \| \ + matrix-sum-rows ark:- ark:- \| \ + vector-sum ark:- $dir/post_output-music.vec.JOB + eval vector-sum $dir/post_output-music.vec.{`seq -s, 100`} $dir/post_output-music.vec +fi From 869b6694066c1ca7f8f06b99acbbf117284bbd9e Mon Sep 17 00:00:00 2001 From: Vimal Manohar Date: Mon, 12 Dec 2016 23:42:28 -0500 Subject: [PATCH 090/213] asr_diarization: New overlap detection with stats script --- .../tuning/train_stats_sad_overlap_1a.sh | 203 ++++++++++++++++++ 1 file changed, 203 insertions(+) create mode 100644 egs/aspire/s5/local/segmentation/tuning/train_stats_sad_overlap_1a.sh diff --git a/egs/aspire/s5/local/segmentation/tuning/train_stats_sad_overlap_1a.sh b/egs/aspire/s5/local/segmentation/tuning/train_stats_sad_overlap_1a.sh new file mode 100644 index 00000000000..425f8230418 --- /dev/null +++ b/egs/aspire/s5/local/segmentation/tuning/train_stats_sad_overlap_1a.sh @@ -0,0 +1,203 @@ +#!/bin/bash + +# This is a script to train a time-delay neural network for overlapped speech activity detection +# using statistic pooling component for long-context information. + +set -o pipefail +set -e +set -u + +. cmd.sh + +# At this script level we don't support not running on GPU, as it would be painfully slow. +# If you want to run without GPU you'd have to call train_tdnn.sh with --gpu false, +# --num-threads 16 and --minibatch-size 128. + +stage=0 +train_stage=-10 +get_egs_stage=-10 +egs_opts= # Directly passed to get_egs_multiple_targets.py + +# TDNN options +relu_dim=256 +chunk_width=20 # We use chunk training for training TDNN +extra_left_context=100 # Maximum left context in egs apart from TDNN's left context +extra_right_context=20 # Maximum right context in egs apart from TDNN's right context + +# We randomly select an extra {left,right} context for each job between +# min_extra_*_context and extra_*_context so that the network can get used +# to different contexts used to compute statistics. +min_extra_left_context=20 +min_extra_right_context=0 + +# training options +num_epochs=1 +initial_effective_lrate=0.0003 +final_effective_lrate=0.00003 +num_jobs_initial=3 +num_jobs_final=8 +remove_egs=false +max_param_change=0.2 # Small max-param change for small network +extra_egs_copy_cmd= # Used if you want to do some weird stuff to egs + # such as removing one of the targets + +num_utts_subset_valid=50 # "utts" is actually recording. So this is prettly small. +num_utts_subset_train=50 + +# target options +train_data_dir=data/train_azteec_whole_sp_corrupted_hires + +snr_scp= +speech_feat_scp= +overlapped_speech_labels_scp= + +deriv_weights_scp= +deriv_weights_for_overlapped_speech_scp= + +train_data_dir=data/train_aztec_small_unsad_whole_all_corrupted_sp_hires_bp +speech_feat_scp=data/train_aztec_unsad_whole_corrupted_sp_hires_bp_2400/speech_feat.scp +deriv_weights_scp=data/train_aztec_unsad_whole_corrupted_sp_hires_bp_2400/deriv_weights.scp + +snr_scp=data/train_aztec_unsad_whole_corrupted_sp_hires_bp_2400/irm_targets.scp +deriv_weights_for_irm_scp=data/train_aztec_unsad_whole_corrupted_sp_hires_bp_2400/deriv_weights_manual_seg.scp + +deriv_weights_for_overlapped_speech_scp= +overlapped_speech_labels_scp= + +#extra_left_context=79 +#extra_right_context=11 + +egs_dir= +nj=40 +feat_type=raw +config_dir= + +dir= +affix=a + +. cmd.sh +. ./path.sh +. ./utils/parse_options.sh + +if [ -z "$dir" ]; then + dir=exp/nnet3_stats_sad_ovlp_snr/nnet_tdnn +fi + +dir=$dir${affix:+_$affix} + +if ! cuda-compiled; then + cat < $dir/configs/network.xconfig + input dim=`feat-to-dim scp:$train_data_dir/feats.scp -` name=input + output name=output-temp input=Append(-3,-2,-1,0,1,2,3) + + relu-renorm-layer name=tdnn1 input=Append(input@-3, input@-2, input@-1, input, input@1, input@2, input@3) dim=256 + stats-layer name=tdnn2_stats config=mean+count(-99:3:9:99) + relu-renorm-layer name=tdnn2 input=Append(tdnn1@-6, tdnn1, tdnn2_stats) dim=256 + relu-renorm-layer name=tdnn3 input=Append(-9,0,3) dim=256 + relu-renorm-layer name=tdnn4 dim=256 + + output-layer name=output-speech include-log-softmax=true dim=2 + output-layer name=output-snr include-log-softmax=false dim=$num_snr_bins objective=quadratic + output-layer name=output-overlapped_speech include-log-softmax=true dim=2 +EOF + steps/nnet3/xconfig_to_configs.py --xconfig-file $dir/configs/network.xconfig \ + --config-dir $dir/configs/ \ + --nnet-edits="rename-node old-name=output-speech new-name=output" + + cat <> $dir/configs/vars +add_lda=false +EOF +fi + +if [ -z "$egs_dir" ]; then + egs_dir=$dir/egs + if [ $stage -le 4 ]; then + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $dir/egs/storage ]; then + utils/create_split_dir.pl \ + /export/b{03,04,05,06}/$USER/kaldi-data/egs/aspire-$(date +'%m_%d_%H_%M')/s5/$dir/egs/storage $dir/egs/storage + fi + + . $dir/configs/vars + + steps/nnet3/get_egs_multiple_targets.py --cmd="$decode_cmd" \ + $egs_opts \ + --feat.dir="$train_data_dir" \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --frames-per-eg=$chunk_width \ + --left-context=$[model_left_context + extra_left_context] \ + --right-context=$[model_right_context + extra_right_context] \ + --num-utts-subset-train=$num_utts_subset_train \ + --num-utts-subset-valid=$num_utts_subset_valid \ + --samples-per-iter=20000 \ + --stage=$get_egs_stage \ + --targets-parameters="--output-name=output-snr --target-type=dense --targets-scp=$snr_scp --deriv-weights-scp=$deriv_weights_scp" \ + --targets-parameters="--output-name=output-speech --target-type=sparse --dim=2 --targets-scp=$speech_feat_scp --deriv-weights-scp=$deriv_weights_scp --scp2ark-cmd=\"extract-column --column-index=0 scp:- ark,t:- | steps/segmentation/quantize_vector.pl | ali-to-post ark,t:- ark:- |\"" \ + --targets-parameters="--output-name=output-overlapped_speech --target-type=sparse --dim=2 --targets-scp=$overlapped_speech_labels_scp --deriv-weights-scp=$deriv_weights_for_overlapped_speech_scp --scp2ark-cmd=\"ali-to-post scp:- ark:- |\"" \ + --dir=$dir/egs + fi +fi + +if [ $stage -le 5 ]; then + steps/nnet3/train_raw_rnn.py --stage=$train_stage \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --egs.chunk-width=20 \ + --egs.dir="$egs_dir" --egs.stage=$get_egs_stage \ + --egs.chunk-left-context=$extra_left_context \ + --egs.chunk-right-context=$extra_right_context \ + ${extra_egs_copy_cmd:+--egs.extra-copy-cmd="$extra_egs_copy_cmd"} \ + --trainer.min-chunk-left-context=$min_extra_left_context \ + --trainer.min-chunk-right-context=$min_extra_right_context \ + --trainer.num-epochs=$num_epochs \ + --trainer.samples-per-iter=20000 \ + --trainer.optimization.num-jobs-initial=$num_jobs_initial \ + --trainer.optimization.num-jobs-final=$num_jobs_final \ + --trainer.optimization.initial-effective-lrate=$initial_effective_lrate \ + --trainer.optimization.final-effective-lrate=$final_effective_lrate \ + --trainer.optimization.shrink-value=1.0 \ + --trainer.rnn.num-chunk-per-minibatch=64 \ + --trainer.deriv-truncate-margin=8 \ + --trainer.max-param-change=$max_param_change \ + --cmd="$decode_cmd" --nj 40 \ + --cleanup=true \ + --cleanup.remove-egs=$remove_egs \ + --cleanup.preserve-model-interval=10 \ + --use-gpu=true \ + --use-dense-targets=false \ + --feat-dir=$train_data_dir \ + --targets-scp="$speech_feat_scp" \ + --dir=$dir || exit 1 +fi + +if [ $stage -le 6 ]; then + $train_cmd JOB=1:100 $dir/log/compute_post_output-speech.JOB.log \ + extract-column "scp:utils/split_scp.pl -j 100 \$[JOB-1] $speech_feat_scp |" ark,t:- \| \ + steps/segmentation/quantize_vector.pl \| \ + ali-to-post ark,t:- ark:- \| \ + weight-post ark:- scp:$deriv_weights_scp ark:- \| \ + post-to-feats --post-dim=2 ark:- ark:- \| \ + matrix-sum-rows ark:- ark:- \| \ + vector-sum ark:- $dir/post_output-speech.vec.JOB + eval vector-sum $dir/post_output-speech.vec.{`seq -s, 100`} $dir/post_output-speech.vec + + $train_cmd JOB=1:100 $dir/log/compute_post_output-overlapped_speech\ + ali-to-post "scp:utils/split_scp.pl -j 100 \$[JOB-1] $overlapped_speech_scp |" ark:- \| \ + post-to-feats --post-dim=2 ark:- ark:- \| \ + matrix-sum-rows ark:- ark:- \| \ + vector-sum ark:- $dir/post_output-overlapped_speech.vec.JOB + eval vector-sum $dir/post_output-overlapped_speech.vec.{`seq -s, 100`} $dir/post_output-overlapped_speech.vec +fi From 3b6b460bfc4dea6526910c7bd956c3999a522619 Mon Sep 17 00:00:00 2001 From: Vimal Manohar Date: Mon, 12 Dec 2016 23:43:50 -0500 Subject: [PATCH 091/213] asr_diarization: remove junk from ami path.sh --- egs/aspire/s5/path.sh | 3 --- 1 file changed, 3 deletions(-) diff --git a/egs/aspire/s5/path.sh b/egs/aspire/s5/path.sh index 5c0d3a92f19..7fb6d91c543 100755 --- a/egs/aspire/s5/path.sh +++ b/egs/aspire/s5/path.sh @@ -2,8 +2,5 @@ export KALDI_ROOT=`pwd`/../../.. export PATH=$PWD/utils/:$KALDI_ROOT/tools/openfst/bin:$PWD:$PATH [ ! -f $KALDI_ROOT/tools/config/common_path.sh ] && echo >&2 "The standard file $KALDI_ROOT/tools/config/common_path.sh is not present -> Exit!" && exit 1 . $KALDI_ROOT/tools/config/common_path.sh -export PATH=/home/vmanoha1/kaldi-raw-signal/src/segmenterbin:$PATH -export PATH=$KALDI_ROOT/tools/sph2pipe_v2.5:$PATH export PATH=$KALDI_ROOT/tools/sctk/bin:$PATH -export PYTHONPATH=steps:${PYTHONPATH} export LC_ALL=C From 7cb6d56122d06fb7c6f387ec2c0c63ace118e952 Mon Sep 17 00:00:00 2001 From: Vimal Manohar Date: Tue, 13 Dec 2016 23:51:02 -0500 Subject: [PATCH 092/213] asr_diarization: Adding segmentation-init-from-additive-signals-info --- ...ntation-init-from-additive-signals-info.cc | 164 ++++++++++++++++++ 1 file changed, 164 insertions(+) create mode 100644 src/segmenterbin/segmentation-init-from-additive-signals-info.cc diff --git a/src/segmenterbin/segmentation-init-from-additive-signals-info.cc b/src/segmenterbin/segmentation-init-from-additive-signals-info.cc new file mode 100644 index 00000000000..139048ac17b --- /dev/null +++ b/src/segmenterbin/segmentation-init-from-additive-signals-info.cc @@ -0,0 +1,164 @@ +// segmenterbin/segmentation-init-from-overlap-info.cc + +// Copyright 2015-16 Vimal Manohar (Johns Hopkins University) + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "base/kaldi-common.h" +#include "util/common-utils.h" +#include "segmenter/segmentation-utils.h" + +int main(int argc, char *argv[]) { + try { + using namespace kaldi; + using namespace segmenter; + + const char *usage = + "Convert overlapping segments information into segmentation\n" + "\n" + "Usage: segmentation-init-from-additive-signals-info [options] " + " \n" + " e.g.: segmentation-init-from-additive-signals-info --additive-signals-segmentation-rspecifier=ark:utt_segmentation.ark " + "ark:reco_segmentation.ark ark,t:overlapped_segments_info.txt ark:-\n"; + + BaseFloat frame_shift = 0.01; + std::string lengths_rspecifier; + std::string additive_signals_segmentation_rspecifier; + std::string unreliable_segmentation_wspecifier; + + ParseOptions po(usage); + + po.Register("frame-shift", &frame_shift, "Frame shift in seconds"); + po.Register("lengths-rspecifier", &lengths_rspecifier, + "Archive of lengths for recordings; if provided, will be " + "used to truncate the output segmentation."); + po.Register("additive-signals-segmentation-rspecifier", + &additive_signals_segmentation_rspecifier, + "Archive of segmentation of the additive signal which will used " + "instead of an all 1 segmentation"); + po.Register("unreliable-segmentation-wspecifier", + &unreliable_segmentation_wspecifier, + "Applicable when additive-signals-segmentation-rspecifier is " + "provided and some utterances in it are missing"); + + po.Read(argc, argv); + + if (po.NumArgs() != 3) { + po.PrintUsage(); + exit(1); + } + + std::string reco_segmentation_rspecifier = po.GetArg(1), + additive_signals_info_rspecifier = po.GetArg(2), + segmentation_wspecifier = po.GetArg(3); + + SequentialSegmentationReader reco_segmentation_reader(reco_segmentation_rspecifier); + RandomAccessTokenVectorReader additive_signals_info_reader(additive_signals_info_rspecifier); + SegmentationWriter writer(segmentation_wspecifier); + + RandomAccessSegmentationReader additive_signals_segmentation_reader(additive_signals_segmentation_rspecifier); + SegmentationWriter unreliable_writer(unreliable_segmentation_wspecifier); + + RandomAccessInt32Reader lengths_reader(lengths_rspecifier); + + int32 num_done = 0, num_err = 0, num_missing = 0; + + for (; !reco_segmentation_reader.Done(); reco_segmentation_reader.Next()) { + const std::string &key = reco_segmentation_reader.Key(); + + if (!additive_signals_info_reader.HasKey(key)) { + KALDI_WARN << "Could not find additive_signals_info for key " << key; + num_missing++; + continue; + } + const std::vector &additive_signals_info = additive_signals_info_reader.Value(key); + + Segmentation segmentation(reco_segmentation_reader.Value()); + Segmentation unreliable_segmentation; + + for (size_t i = 0; i < additive_signals_info.size(); i++) { + std::vector parts; + SplitStringToVector(additive_signals_info[i], ",:", false, &parts); + + if (parts.size() != 3) { + KALDI_ERR << "Invalid format of overlap info " << additive_signals_info[i] + << "for key " << key << " in " << additive_signals_info_rspecifier; + } + const std::string &utt_id = parts[0]; + double start_time; + double duration; + ConvertStringToReal(parts[1], &start_time); + ConvertStringToReal(parts[2], &duration); + + int32 start_frame = round(start_time / frame_shift); + + if (!additive_signals_segmentation_reader.HasKey(utt_id)) { + KALDI_WARN << "Could not find utterance " << utt_id << " in " + << "segmentation " << additive_signals_segmentation_rspecifier; + if (duration < 0) { + KALDI_ERR << "duration < 0 for utt_id " << utt_id << " in " + << "additive_signals_info " << additive_signals_info_rspecifier + << "; additive-signals-segmentation must be provided in such a case"; + } + num_err++; + unreliable_segmentation.EmplaceBack(start_frame, start_frame + duration - 1, 0); + continue; // Treated as non-overlapping even though there + // is overlap + } + + InsertFromSegmentation(additive_signals_segmentation_reader.Value(utt_id), + start_frame, false, &segmentation); + } + + Sort(&segmentation); + if (!lengths_rspecifier.empty()) { + if (!lengths_reader.HasKey(key)) { + KALDI_WARN << "Could not find length for the recording " << key + << "in " << lengths_rspecifier; + continue; + } + TruncateToLength(lengths_reader.Value(key), &segmentation); + } + writer.Write(key, segmentation); + + if (!unreliable_segmentation_wspecifier.empty()) { + Sort(&unreliable_segmentation); + if (!lengths_rspecifier.empty()) { + if (!lengths_reader.HasKey(key)) { + KALDI_WARN << "Could not find length for the recording " << key + << "in " << lengths_rspecifier; + continue; + } + TruncateToLength(lengths_reader.Value(key), &unreliable_segmentation); + } + unreliable_writer.Write(key, unreliable_segmentation); + } + + num_done++; + } + + KALDI_LOG << "Successfully processed " << num_done << " recordings " + << " in additive signals info; failed for " << num_missing + << "; could not get segmentation for " << num_err; + + return (num_done > (num_missing/ 2) ? 0 : 1); + + } catch(const std::exception &e) { + std::cerr << e.what(); + return -1; + } +} + From 8f1ee41a9064aa3f24a9b74a03f8ef5dbef42d09 Mon Sep 17 00:00:00 2001 From: Vimal Manohar Date: Wed, 14 Dec 2016 00:07:00 -0500 Subject: [PATCH 093/213] asr_diarization: Update make_overlapped_data_dir.py and data_dir_Manipulation_lib --- .../steps/data/data_dir_manipulation_lib.py | 12 +- .../s5/steps/data/make_overlapped_data_dir.py | 252 +++++++++++------- 2 files changed, 159 insertions(+), 105 deletions(-) diff --git a/egs/wsj/s5/steps/data/data_dir_manipulation_lib.py b/egs/wsj/s5/steps/data/data_dir_manipulation_lib.py index 7f1a5f74fe2..26fb17324dc 100644 --- a/egs/wsj/s5/steps/data/data_dir_manipulation_lib.py +++ b/egs/wsj/s5/steps/data/data_dir_manipulation_lib.py @@ -168,24 +168,24 @@ def CopyDataDirFiles(input_dir, output_dir, num_replicas, include_original, pref if not os.path.isfile(output_dir + "/wav.scp"): raise Exception("CopyDataDirFiles function expects output_dir to contain wav.scp already") - AddPrefixToFields(input_dir + "/utt2spk", output_dir + "/utt2spk", num_replicas, include_original, prefix, field = [0,1]) + AddPrefixToFields(input_dir + "/utt2spk", output_dir + "/utt2spk", num_replicas, include_original=include_original, prefix=prefix, field = [0,1]) RunKaldiCommand("utils/utt2spk_to_spk2utt.pl <{output_dir}/utt2spk >{output_dir}/spk2utt" .format(output_dir = output_dir)) if os.path.isfile(input_dir + "/utt2uniq"): - AddPrefixToFields(input_dir + "/utt2uniq", output_dir + "/utt2uniq", num_replicas, include_original, prefix, field =[0]) + AddPrefixToFields(input_dir + "/utt2uniq", output_dir + "/utt2uniq", num_replicas, include_original=include_original, prefix=prefix, field =[0]) else: # Create the utt2uniq file CreateCorruptedUtt2uniq(input_dir, output_dir, num_replicas, include_original, prefix) if os.path.isfile(input_dir + "/text"): - AddPrefixToFields(input_dir + "/text", output_dir + "/text", num_replicas, prefix, include_original, field =[0]) + AddPrefixToFields(input_dir + "/text", output_dir + "/text", num_replicas, include_original=include_original, prefix=prefix, field =[0]) if os.path.isfile(input_dir + "/segments"): - AddPrefixToFields(input_dir + "/segments", output_dir + "/segments", num_replicas, prefix, include_original, field = [0,1]) + AddPrefixToFields(input_dir + "/segments", output_dir + "/segments", num_replicas, prefix=prefix, include_original=include_original, field = [0,1]) if os.path.isfile(input_dir + "/reco2file_and_channel"): - AddPrefixToFields(input_dir + "/reco2file_and_channel", output_dir + "/reco2file_and_channel", num_replicas, include_original, prefix, field = [0,1]) + AddPrefixToFields(input_dir + "/reco2file_and_channel", output_dir + "/reco2file_and_channel", num_replicas, include_original=include_original, prefix=prefix, field = [0,1]) - AddPrefixToFields(input_dir + "/reco2dur", output_dir + "/reco2dur", num_replicas, include_original, prefix, field = [0]) + AddPrefixToFields(input_dir + "/reco2dur", output_dir + "/reco2dur", num_replicas, include_original=include_original, prefix=prefix, field = [0]) RunKaldiCommand("utils/validate_data_dir.sh --no-feats {output_dir}" .format(output_dir = output_dir)) diff --git a/egs/wsj/s5/steps/data/make_overlapped_data_dir.py b/egs/wsj/s5/steps/data/make_overlapped_data_dir.py index 86137c26e25..e4bf85f9af7 100644 --- a/egs/wsj/s5/steps/data/make_overlapped_data_dir.py +++ b/egs/wsj/s5/steps/data/make_overlapped_data_dir.py @@ -9,6 +9,9 @@ data_lib = imp.load_source('dml', 'steps/data/data_dir_manipulation_lib.py') +sys.path.insert(0, 'steps') +import libs.common as common_lib + def GetArgs(): # we add required arguments as named arguments for readability parser = argparse.ArgumentParser(description="Reverberate the data directory with an option " @@ -32,7 +35,8 @@ def GetArgs(): "--rt-60 --drr location " "E.g. --rir-id 00001 --room-id 001 --receiver-position-id 001 --source-position-id 00001 " "--rt60 0.58 --drr -4.885 data/impulses/Room001-00001.wav") - parser.add_argument("--noise-set-parameters", type=str, action='append', default = None, dest = "noise_set_para_array", + parser.add_argument("--noise-set-parameters", type=str, action='append', + default = None, dest = "noise_set_para_array", help="Specifies the parameters of an noise set. " "Supports the specification of mixture_weight and noise_list_file_name. The mixture weight is optional. " "The default mixture weight is the probability mass remaining after adding the mixture weights " @@ -44,39 +48,56 @@ def GetArgs(): "--room-linkage " "location " "E.g. --noise-id 001 --noise-type isotropic --rir-id 00019 iso_noise.wav") - parser.add_argument("--speech-segments-set-parameters", type=str, action='append', default = None, dest = "speech_segments_set_para_array", - help="Specifies the speech segments for overlapped speech generation") + parser.add_argument("--speech-segments-set-parameters", type=str, action='append', + default = None, dest = "speech_segments_set_para_array", + help="Specifies the speech segments for overlapped speech generation.\n" + "Format: [], wav_scp, segments_list\n"); parser.add_argument("--num-replications", type=int, dest = "num_replicas", default = 1, help="Number of replicate to generated for the data") - parser.add_argument('--foreground-snrs', type=str, dest = "foreground_snr_string", default = '20:10:0', help='When foreground noises are being added the script will iterate through these SNRs.') - parser.add_argument('--background-snrs', type=str, dest = "background_snr_string", default = '20:10:0', help='When background noises are being added the script will iterate through these SNRs.') - parser.add_argument('--overlap-snrs', type=str, dest = "overlap_snr_string", default = "20:10:0", help='When overlapping speech segments are being added the script will iterate through these SNRs.') - parser.add_argument('--prefix', type=str, default = None, help='This prefix will modified for each reverberated copy, by adding additional affixes.') + parser.add_argument('--foreground-snrs', type=str, dest = "foreground_snr_string", + default = '20:10:0', + help='When foreground noises are being added the script will iterate through these SNRs.') + parser.add_argument('--background-snrs', type=str, dest = "background_snr_string", + default = '20:10:0', + help='When background noises are being added the script will iterate through these SNRs.') + parser.add_argument('--overlap-snrs', type=str, dest = "overlap_snr_string", + default = "20:10:0", + help='When overlapping speech segments are being added the script will iterate through these SNRs.') + parser.add_argument('--prefix', type=str, default = None, + help='This prefix will modified for each reverberated copy, by adding additional affixes.') parser.add_argument("--speech-rvb-probability", type=float, default = 1.0, help="Probability of reverberating a speech signal, e.g. 0 <= p <= 1") parser.add_argument("--pointsource-noise-addition-probability", type=float, default = 1.0, help="Probability of adding point-source noises, e.g. 0 <= p <= 1") parser.add_argument("--isotropic-noise-addition-probability", type=float, default = 1.0, help="Probability of adding isotropic noises, e.g. 0 <= p <= 1") - parser.add_argument("--overlapped-speech-addition-probability", type=float, default = 1.0, - help="Probability of adding overlapped speech, e.g. 0 <= p <= 1") + parser.add_argument("--overlapping-speech-addition-probability", type=float, default = 1.0, + help="Probability of adding overlapping speech, e.g. 0 <= p <= 1") parser.add_argument("--rir-smoothing-weight", type=float, default = 0.3, help="Smoothing weight for the RIR probabilties, e.g. 0 <= p <= 1. If p = 0, no smoothing will be done. " "The RIR distribution will be mixed with a uniform distribution according to the smoothing weight") parser.add_argument("--noise-smoothing-weight", type=float, default = 0.3, help="Smoothing weight for the noise probabilties, e.g. 0 <= p <= 1. If p = 0, no smoothing will be done. " "The noise distribution will be mixed with a uniform distribution according to the smoothing weight") - parser.add_argument("--overlapped-speech-smoothing-weight", type=float, default = 0.3, - help="The overlapped speech distribution will be mixed with a uniform distribution according to the smoothing weight") + parser.add_argument("--overlapping-speech-smoothing-weight", type=float, default = 0.3, + help="The overlapping speech distribution will be mixed with a uniform distribution according to the smoothing weight") parser.add_argument("--max-noises-per-minute", type=int, default = 2, help="This controls the maximum number of point-source noises that could be added to a recording according to its duration") - parser.add_argument("--max-overlapped-segments-per-minute", type=int, default = 5, + parser.add_argument("--min-overlapping-segments-per-minute", type=int, default = 1, + help="This controls the minimum number of overlapping segments of speech that could be added to a recording per minute") + parser.add_argument("--max-overlapping-segments-per-minute", type=int, default = 5, help="This controls the maximum number of overlapping segments of speech that could be added to a recording per minute") - parser.add_argument('--random-seed', type=int, default=0, help='seed to be used in the randomization of impulses and noises') - parser.add_argument("--shift-output", type=str, help="If true, the reverberated waveform will be shifted by the amount of the peak position of the RIR", - choices=['true', 'false'], default = "true") - parser.add_argument("--output-additive-noise-dir", type=str, help="Output directory corresponding to the additive noise part of the data corruption") - parser.add_argument("--output-reverb-dir", type=str, help="Output directory corresponding to the reverberated signal part of the data corruption") + parser.add_argument('--random-seed', type=int, default=0, + help='seed to be used in the randomization of impulses and noises') + parser.add_argument("--shift-output", type=str, + help="If true, the reverberated waveform will be shifted by the amount of the peak position of the RIR", + choices=['true', 'false'], default = "true") + parser.add_argument("--output-additive-noise-dir", type=str, + action = common_train_lib.NullstrToNoneAction, default = None, + help="Output directory corresponding to the additive noise part of the data corruption") + parser.add_argument("--output-reverb-dir", type=str, + action = common_train_lib.NullstrToNoneAction, default = None, + help="Output directory corresponding to the reverberated signal part of the data corruption") parser.add_argument("input_dir", help="Input data directory") @@ -128,8 +149,8 @@ def CheckArgs(args): if args.isotropic_noise_addition_probability < 0 or args.isotropic_noise_addition_probability > 1: raise Exception("--isotropic-noise-addition-probability must be between 0 and 1") - if args.overlapped_speech_addition_probability < 0 or args.overlapped_speech_addition_probability > 1: - raise Exception("--overlapped-speech-addition-probability must be between 0 and 1") + if args.overlapping_speech_addition_probability < 0 or args.overlapping_speech_addition_probability > 1: + raise Exception("--overlapping-speech-addition-probability must be between 0 and 1") if args.rir_smoothing_weight < 0 or args.rir_smoothing_weight > 1: raise Exception("--rir-smoothing-weight must be between 0 and 1") @@ -137,14 +158,17 @@ def CheckArgs(args): if args.noise_smoothing_weight < 0 or args.noise_smoothing_weight > 1: raise Exception("--noise-smoothing-weight must be between 0 and 1") - if args.overlapped_speech_smoothing_weight < 0 or args.overlapped_speech_smoothing_weight > 1: - raise Exception("--overlapped-speech-smoothing-weight must be between 0 and 1") + if args.overlapping_speech_smoothing_weight < 0 or args.overlapping_speech_smoothing_weight > 1: + raise Exception("--overlapping-speech-smoothing-weight must be between 0 and 1") if args.max_noises_per_minute < 0: raise Exception("--max-noises-per-minute cannot be negative") - if args.max_overlapped_segments_per_minute < 0: - raise Exception("--max-overlapped-segments-per-minute cannot be negative") + if args.min_overlapping_segments_per_minute < 0: + raise Exception("--min-overlapping-segments-per-minute cannot be negative") + + if args.max_overlapping_segments_per_minute < 0: + raise Exception("--max-overlapping-segments-per-minute cannot be negative") return args @@ -203,32 +227,33 @@ def ParseSpeechSegmentsList(speech_segments_set_para_array, smoothing_weight): return segments_list -def AddOverlappedSpeech(room, # the room selected - speech_segments_list, # the speech list - overlapped_speech_addition_probability, # Probability of another speech waveform - snrs, # the SNR for adding the foreground speech - speech_dur, # duration of the recording - max_overlapped_speech_segments, # Maximum number of speech signals that can be added - overlapped_speech_descriptor # descriptor to store the information of the overlapped speech +def AddOverlappingSpeech(room, # the room selected + speech_segments_list, # the speech list + overlapping_speech_addition_probability, # Probability of another speech waveform + snrs, # the SNR for adding the foreground speech + speech_dur, # duration of the recording + min_overlapping_speech_segments, # Minimum number of speech signals that can be added + max_overlapping_speech_segments, # Maximum number of speech signals that can be added + overlapping_speech_descriptor # descriptor to store the information of the overlapping speech ): - if (len(speech_segments_list) > 0 and random.random() < overlapped_speech_addition_probability - and max_overlapped_speech_segments >= 1): - for k in range(random.randint(1, max_overlapped_speech_segments)): - # pick the overlapped speech signal and the RIR to - # reverberate the overlapped speech signal + if (len(speech_segments_list) > 0 and random.random() < overlapping_speech_addition_probability + and max_overlapping_speech_segments >= 1): + for k in range(random.randint(min_overlapping_speech_segments, max_overlapping_speech_segments)): + # pick the overlapping_speech speech signal and the RIR to + # reverberate the overlapping_speech speech signal speech_segment = data_lib.PickItemWithProbability(speech_segments_list) rir = data_lib.PickItemWithProbability(room.rir_list) speech_rvb_command = """wav-reverberate --impulse-response="{0}" --shift-output=true """.format(rir.rir_rspecifier) - overlapped_speech_descriptor['start_times'].append(round(random.random() * speech_dur, 2)) - overlapped_speech_descriptor['snrs'].append(snrs.next()) - overlapped_speech_descriptor['utt_ids'].append(speech_segment.utt_id) - overlapped_speech_descriptor['durations'].append(speech_segment.duration) + overlapping_speech_descriptor['start_times'].append(round(random.random() * speech_dur, 2)) + overlapping_speech_descriptor['snrs'].append(snrs.next()) + overlapping_speech_descriptor['utt_ids'].append(speech_segment.utt_id) + overlapping_speech_descriptor['durations'].append(speech_segment.duration) if len(speech_segment.wav_rxfilename.split()) == 1: - overlapped_speech_descriptor['speech_segments'].append("{1} {0} - |".format(speech_segment.wav_rxfilename, speech_rvb_command)) + overlapping_speech_descriptor['speech_segments'].append("{1} {0} - |".format(speech_segment.wav_rxfilename, speech_rvb_command)) else: - overlapped_speech_descriptor['speech_segments'].append("{0} {1} - - |".format(speech_segment.wav_rxfilename, speech_rvb_command)) + overlapping_speech_descriptor['speech_segments'].append("{0} {1} - - |".format(speech_segment.wav_rxfilename, speech_rvb_command)) # This function randomly decides whether to reverberate, and sample a RIR if it does # It also decides whether to add the appropriate noises @@ -244,17 +269,21 @@ def GenerateReverberationAndOverlappedSpeechOpts( speech_rvb_probability, # Probability of reverberating a speech signal isotropic_noise_addition_probability, # Probability of adding isotropic noises pointsource_noise_addition_probability, # Probability of adding point-source noises - overlapped_speech_addition_probability, + overlapping_speech_addition_probability, # Probability of adding overlapping speech segments speech_dur, # duration of the recording max_noises_recording, # Maximum number of point-source noises that can be added - max_overlapped_segments_recording # Maximum number of overlapped segments that can be added + min_overlapping_segments_recording, # Minimum number of overlapping segments that can be added + max_overlapping_segments_recording # Maximum number of overlapping segments that can be added ): impulse_response_opts = "" - additive_noise_opts = "" noise_addition_descriptor = {'noise_io': [], 'start_times': [], - 'snrs': []} + 'snrs': [], + 'noise_ids': [], + 'durations': [] + } + # Randomly select the room # Here the room probability is a sum of the probabilities of the RIRs recorded in the room. room = data_lib.PickItemWithProbability(room_dict) @@ -278,6 +307,8 @@ def GenerateReverberationAndOverlappedSpeechOpts( noise_addition_descriptor['noise_io'].append("{0} wav-reverberate --duration={1} - - |".format(isotropic_noise.noise_rspecifier, speech_dur)) noise_addition_descriptor['start_times'].append(0) noise_addition_descriptor['snrs'].append(background_snrs.next()) + noise_addition_descriptor['noise_ids'].append(isotropic_noise.noise_id) + noise_addition_descriptor['durations'].append(speech_dur) data_lib.AddPointSourceNoise(room, # the room selected pointsource_noise_list, # the point source noise list @@ -291,35 +322,29 @@ def GenerateReverberationAndOverlappedSpeechOpts( assert len(noise_addition_descriptor['noise_io']) == len(noise_addition_descriptor['start_times']) assert len(noise_addition_descriptor['noise_io']) == len(noise_addition_descriptor['snrs']) - - overlapped_speech_descriptor = {'speech_segments': [], - 'start_times': [], - 'snrs': [], - 'utt_ids': [], - 'durations': [] - } - - AddOverlappedSpeech(room, - speech_segments_list, # speech segments list - overlapped_speech_addition_probability, - overlap_snrs, - speech_dur, - max_overlapped_segments_recording, - overlapped_speech_descriptor + assert len(noise_addition_descriptor['noise_io']) == len(noise_addition_descriptor['utt_ids']) + assert len(noise_addition_descriptor['noise_io']) == len(noise_addition_descriptor['durations']) + + overlapping_speech_descriptor = {'speech_segments': [], + 'start_times': [], + 'snrs': [], + 'utt_ids': [], + 'durations': [] + } + + print ("Adding overlapping speech...") + AddOverlappingSpeech(room, + speech_segments_list, # speech segments list + overlapping_speech_addition_probability, + overlap_snrs, + speech_dur, + min_overlapping_segments_recording, + max_overlapping_segments_recording, + overlapping_speech_descriptor ) - if len(overlapped_speech_descriptor['speech_segments']) > 0: - noise_addition_descriptor['noise_io'] += overlapped_speech_descriptor['speech_segments'] - noise_addition_descriptor['start_times'] += overlapped_speech_descriptor['start_times'] - noise_addition_descriptor['snrs'] += overlapped_speech_descriptor['snrs'] - - if len(noise_addition_descriptor['noise_io']) > 0: - additive_noise_opts += "--additive-signals='{0}' ".format(','.join(noise_addition_descriptor['noise_io'])) - additive_noise_opts += "--start-times='{0}' ".format(','.join(map(lambda x:str(x), noise_addition_descriptor['start_times']))) - additive_noise_opts += "--snrs='{0}' ".format(','.join(map(lambda x:str(x), noise_addition_descriptor['snrs']))) - - return [impulse_response_opts, additive_noise_opts, - zip(overlapped_speech_descriptor['utt_ids'], [ str(x) for x in overlapped_speech_descriptor['start_times'] ], [ str(x) for x in overlapped_speech_descriptor['durations'] ])] + return [impulse_response_opts, noise_addition_descriptor, + overlapping_speech_descriptor] # This is the main function to generate pipeline command for the corruption # The generic command of wav-reverberate will be like: @@ -335,7 +360,7 @@ def GenerateReverberatedWavScpWithOverlappedSpeech( foreground_snr_array, # the SNR for adding the foreground noises background_snr_array, # the SNR for adding the background noises speech_segments_list, # list of speech segments to create overlapped speech - overlap_snr_array, # the SNR for adding overlapped speech + overlap_snr_array, # the SNR for adding overlapping speech num_replicas, # Number of replicate to generated for the data prefix, # prefix for the id of the corrupted utterances speech_rvb_probability, # Probability of reverberating a speech signal @@ -343,20 +368,20 @@ def GenerateReverberatedWavScpWithOverlappedSpeech( isotropic_noise_addition_probability, # Probability of adding isotropic noises pointsource_noise_addition_probability, # Probability of adding point-source noises max_noises_per_minute, # maximum number of point-source noises that can be added to a recording according to its duration - overlapped_speech_addition_probability, - max_overlapped_segments_per_minute, + overlapping_speech_addition_probability, + min_overlapping_segments_per_minute, + max_overlapping_segments_per_minute, output_reverb_dir = None, - output_additive_noise_dir = None + output_additive_noise_dir = None, ): foreground_snrs = data_lib.list_cyclic_iterator(foreground_snr_array) background_snrs = data_lib.list_cyclic_iterator(background_snr_array) overlap_snrs = data_lib.list_cyclic_iterator(overlap_snr_array) - corrupted_wav_scp = {} reverb_wav_scp = {} additive_noise_wav_scp = {} - overlapped_segments_info = {} + overlapping_segments_info = {} keys = wav_scp.keys() keys.sort() @@ -368,26 +393,48 @@ def GenerateReverberatedWavScpWithOverlappedSpeech( wav_original_pipe = "cat {0} |".format(wav_original_pipe) speech_dur = durations[recording_id] max_noises_recording = math.floor(max_noises_per_minute * speech_dur / 60) - max_overlapped_segments_recording = math.floor(max_overlapped_segments_per_minute * speech_dur / 60) + min_overlapping_segments_recording = max(math.floor(min_overlapping_segments_per_minute * speech_dur / 60), 1) + max_overlapping_segments_recording = math.floor(max_overlapping_segments_per_minute * speech_dur / 60) - [impulse_response_opts, - additive_noise_opts, - overlapped_speech_segments] = GenerateReverberationAndOverlappedSpeechOpts( + [impulse_response_opts, noise_addition_descriptor, + overlapping_speech_descriptor] = GenerateReverberationAndOverlappedSpeechOpts( room_dict = room_dict, # the room dictionary, please refer to MakeRoomDict() for the format pointsource_noise_list = pointsource_noise_list, # the point source noise list iso_noise_dict = iso_noise_dict, # the isotropic noise dictionary foreground_snrs = foreground_snrs, # the SNR for adding the foreground noises background_snrs = background_snrs, # the SNR for adding the background noises speech_segments_list = speech_segments_list, # Speech segments for creating overlapped speech - overlap_snrs = overlap_snrs, # the SNR for adding overlapped speech + overlap_snrs = overlap_snrs, # the SNR for adding overlapping speech speech_rvb_probability = speech_rvb_probability, # Probability of reverberating a speech signal isotropic_noise_addition_probability = isotropic_noise_addition_probability, # Probability of adding isotropic noises pointsource_noise_addition_probability = pointsource_noise_addition_probability, # Probability of adding point-source noises - overlapped_speech_addition_probability = overlapped_speech_addition_probability, + overlapping_speech_addition_probability = overlapping_speech_addition_probability, speech_dur = speech_dur, # duration of the recording max_noises_recording = max_noises_recording, # Maximum number of point-source noises that can be added - max_overlapped_segments_recording = max_overlapped_segments_recording + min_overlapping_segments_recording = min_overlapping_segments_recording, + max_overlapping_segments_recording = max_overlapping_segments_recording ) + + additive_noise_opts = "" + + if (len(noise_addition_descriptor['noise_io']) > 0 or + len(overlapping_speech_descriptor['speech_segments']) > 0): + additive_noise_opts += ("--additive-signals='{0}' " + .format(',' + .join(noise_addition_descriptor['noise_io'] + + overlapping_speech_descriptor['speech_segments'])) + ) + additive_noise_opts += ("--start-times='{0}' " + .format(',' + .join(map(lambda x:str(x), noise_addition_descriptor['start_times'] + + overlapping_speech_descriptor['start_times']))) + ) + additive_noise_opts += ("--snrs='{0}' " + .format(',' + .join(map(lambda x:str(x), noise_addition_descriptor['snrs'] + + overlapping_speech_descriptor['snrs']))) + ) + reverberate_opts = impulse_response_opts + additive_noise_opts new_recording_id = data_lib.GetNewId(recording_id, prefix, i) @@ -411,14 +458,19 @@ def GenerateReverberatedWavScpWithOverlappedSpeech( wav_additive_noise_pipe = "{0} wav-reverberate --shift-output={1} --additive-noise-out-wxfilename=- {2} - /dev/null |".format(wav_original_pipe, shift_output, reverberate_opts) additive_noise_wav_scp[new_recording_id] = wav_additive_noise_pipe - if len(overlapped_speech_segments) > 0: - overlapped_segments_info[new_recording_id] = [ ':'.join(x) for x in overlapped_speech_segments ] + if len(overlapping_speech_descriptor['speech_segments']) > 0: + overlapping_segments_info[new_recording_id] = [ + ':'.join(x) + for x in zip(overlapping_speech_descriptor['utt_ids'], + [ str(x) for x in overlapping_speech_descriptor['start_times'] ], + [ str(x) for x in overlapping_speech_descriptor['durations'] ]) + ] data_lib.WriteDictToFile(corrupted_wav_scp, output_dir + "/wav.scp") - # Write for each new recording, the utterance id of the segments and - # the start time at which they are added - data_lib.WriteDictToFile(overlapped_segments_info, output_dir + "/overlapped_segments_info.txt") + # Write for each new recording, the id, start time and durations + # of the overlapping segments + data_lib.WriteDictToFile(overlapping_segments_info, output_dir + "/overlapped_segments_info.txt") if output_reverb_dir is not None: data_lib.WriteDictToFile(reverb_wav_scp, output_reverb_dir + "/wav.scp") @@ -426,7 +478,6 @@ def GenerateReverberatedWavScpWithOverlappedSpeech( if output_additive_noise_dir is not None: data_lib.WriteDictToFile(additive_noise_wav_scp, output_additive_noise_dir + "/wav.scp") - # This function creates multiple copies of the necessary files, e.g. utt2spk, wav.scp ... def CreateReverberatedCopy(input_dir, output_dir, @@ -436,7 +487,7 @@ def CreateReverberatedCopy(input_dir, speech_segments_list, foreground_snr_string, # the SNR for adding the foreground noises background_snr_string, # the SNR for adding the background noises - overlap_snr_string, # the SNR for overlapped speech + overlap_snr_string, # the SNR for overlapping speech num_replicas, # Number of replicate to generated for the data prefix, # prefix for the id of the corrupted utterances speech_rvb_probability, # Probability of reverberating a speech signal @@ -444,8 +495,9 @@ def CreateReverberatedCopy(input_dir, isotropic_noise_addition_probability, # Probability of adding isotropic noises pointsource_noise_addition_probability, # Probability of adding point-source noises max_noises_per_minute, # maximum number of point-source noises that can be added to a recording according to its duration - overlapped_speech_addition_probability, - max_overlapped_segments_per_minute, + overlapping_speech_addition_probability, + min_overlapping_segments_per_minute, + max_overlapping_segments_per_minute, output_reverb_dir = None, output_additive_noise_dir = None ): @@ -482,8 +534,9 @@ def CreateReverberatedCopy(input_dir, isotropic_noise_addition_probability = isotropic_noise_addition_probability, pointsource_noise_addition_probability = pointsource_noise_addition_probability, max_noises_per_minute = max_noises_per_minute, - overlapped_speech_addition_probability = overlapped_speech_addition_probability, - max_overlapped_segments_per_minute = max_overlapped_segments_per_minute, + overlapping_speech_addition_probability = overlapping_speech_addition_probability, + min_overlapping_segments_per_minute = min_overlapping_segments_per_minute, + max_overlapping_segments_per_minute = max_overlapping_segments_per_minute, output_reverb_dir = output_reverb_dir, output_additive_noise_dir = output_additive_noise_dir) @@ -512,7 +565,7 @@ def Main(): print("Number of isotropic noises is {0}".format(sum(len(iso_noise_dict[key]) for key in iso_noise_dict.keys()))) room_dict = data_lib.MakeRoomDict(rir_list) - speech_segments_list = ParseSpeechSegmentsList(args.speech_segments_set_para_array, args.overlapped_speech_smoothing_weight) + speech_segments_list = ParseSpeechSegmentsList(args.speech_segments_set_para_array, args.overlapping_speech_smoothing_weight) CreateReverberatedCopy(input_dir = args.input_dir, output_dir = args.output_dir, @@ -530,8 +583,9 @@ def Main(): isotropic_noise_addition_probability = args.isotropic_noise_addition_probability, pointsource_noise_addition_probability = args.pointsource_noise_addition_probability, max_noises_per_minute = args.max_noises_per_minute, - overlapped_speech_addition_probability = args.overlapped_speech_addition_probability, - max_overlapped_segments_per_minute = args.max_overlapped_segments_per_minute, + overlapping_speech_addition_probability = args.overlapping_speech_addition_probability, + min_overlapping_segments_per_minute = args.min_overlapping_segments_per_minute, + max_overlapping_segments_per_minute = args.max_overlapping_segments_per_minute, output_reverb_dir = args.output_reverb_dir, output_additive_noise_dir = args.output_additive_noise_dir) From 7b3723da82631d76b7ddc001f7f234d226576db5 Mon Sep 17 00:00:00 2001 From: Vimal Manohar Date: Wed, 14 Dec 2016 00:07:19 -0500 Subject: [PATCH 094/213] asr_diarization: Update reverberate_data_dir.py --- egs/wsj/s5/steps/data/reverberate_data_dir.py | 20 +++++++------------ 1 file changed, 7 insertions(+), 13 deletions(-) diff --git a/egs/wsj/s5/steps/data/reverberate_data_dir.py b/egs/wsj/s5/steps/data/reverberate_data_dir.py index 0080bdba5f0..9a71126dde3 100755 --- a/egs/wsj/s5/steps/data/reverberate_data_dir.py +++ b/egs/wsj/s5/steps/data/reverberate_data_dir.py @@ -7,7 +7,7 @@ from __future__ import print_function import argparse, glob, math, os, random, sys, warnings, copy, imp, ast -data_lib = imp.load_source('dml', 'steps/data/data_dir_manipulation_lib.py') +import data_dir_manipulation_lib as data_lib def GetArgs(): # we add required arguments as named arguments for readability @@ -71,8 +71,12 @@ def GetArgs(): "the RIRs/noises will be resampled to the rate of the source data.") parser.add_argument("--include-original-data", type=str, help="If true, the output data includes one copy of the original data", choices=['true', 'false'], default = "false") - parser.add_argument("--output-additive-noise-dir", type=str, help="Output directory corresponding to the additive noise part of the data corruption") - parser.add_argument("--output-reverb-dir", type=str, help="Output directory corresponding to the reverberated signal part of the data corruption") + parser.add_argument("--output-additive-noise-dir", type=str, + action = common_lib.NullstrToNoneAction, default = None, + help="Output directory corresponding to the additive noise part of the data corruption") + parser.add_argument("--output-reverb-dir", type=str, + action = common_lib.NullstrToNoneAction, default = None, + help="Output directory corresponding to the reverberated signal part of the data corruption") parser.add_argument("input_dir", help="Input data directory") @@ -97,18 +101,10 @@ def CheckArgs(args): args.prefix = "rvb" warnings.warn("--prefix is set to 'rvb' as more than one copy of data is generated") - if args.output_reverb_dir is not None: - if args.output_reverb_dir == "": - args.output_reverb_dir = None - if args.output_reverb_dir is not None: if not os.path.exists(args.output_reverb_dir): os.makedirs(args.output_reverb_dir) - if args.output_additive_noise_dir is not None: - if args.output_additive_noise_dir == "": - args.output_additive_noise_dir = None - if args.output_additive_noise_dir is not None: if not os.path.exists(args.output_additive_noise_dir): os.makedirs(args.output_additive_noise_dir) @@ -346,5 +342,3 @@ def Main(): if __name__ == "__main__": Main() - - From b9328f700fc353de71d0c95362c99110982c4924 Mon Sep 17 00:00:00 2001 From: Vimal Manohar Date: Wed, 14 Dec 2016 00:07:36 -0500 Subject: [PATCH 095/213] asr_diarization: Better way of checking vol perturbation --- .../s5/utils/data/perturb_data_dir_volume.sh | 22 ++++++++++++++++++- 1 file changed, 21 insertions(+), 1 deletion(-) diff --git a/egs/wsj/s5/utils/data/perturb_data_dir_volume.sh b/egs/wsj/s5/utils/data/perturb_data_dir_volume.sh index 185c7abf426..ee3c281bdbb 100755 --- a/egs/wsj/s5/utils/data/perturb_data_dir_volume.sh +++ b/egs/wsj/s5/utils/data/perturb_data_dir_volume.sh @@ -30,7 +30,27 @@ if [ ! -f $data/wav.scp ]; then exit 1 fi -if ! $force && grep -q "sox --vol" $data/wav.scp; then +volume_perturb_done=`head -n100 $data/wav.scp | python -c " +import sys, re +for line in sys.stdin.readlines(): + if len(line.strip()) == 0: + continue + # Handle three cases of rxfilenames appropriately; 'input piped command', 'file offset' and 'filename' + parts = line.strip().split() + if line.strip()[-1] == '|': + if re.search('sox --vol', ' '.join(parts[-11:])): + print 'true' + sys.exit(0) + elif re.search(':[0-9]+$', line.strip()) is not None: + continue + else: + if ' '.join(parts[1:3]) == 'sox --vol': + print 'true' + sys.exit(0) +print 'false' +"` || exit 1 + +if $volume_perturb_done; then echo "$0: It looks like the data was already volume perturbed. Not doing anything." exit 0 fi From 1ac065d241d367bbdbdcff0cc6364e734a59cc8d Mon Sep 17 00:00:00 2001 From: Vimal Manohar Date: Wed, 14 Dec 2016 00:08:14 -0500 Subject: [PATCH 096/213] asr_diarization: Updated train_stats_sad_overlap_1a.sh --- .../segmentation/tuning/train_stats_sad_overlap_1a.sh | 8 ++++---- egs/wsj/s5/utils/data/get_reco2dur.sh | 10 +++++++++- 2 files changed, 13 insertions(+), 5 deletions(-) diff --git a/egs/aspire/s5/local/segmentation/tuning/train_stats_sad_overlap_1a.sh b/egs/aspire/s5/local/segmentation/tuning/train_stats_sad_overlap_1a.sh index 425f8230418..aae1fd995e0 100644 --- a/egs/aspire/s5/local/segmentation/tuning/train_stats_sad_overlap_1a.sh +++ b/egs/aspire/s5/local/segmentation/tuning/train_stats_sad_overlap_1a.sh @@ -54,12 +54,12 @@ overlapped_speech_labels_scp= deriv_weights_scp= deriv_weights_for_overlapped_speech_scp= -train_data_dir=data/train_aztec_small_unsad_whole_all_corrupted_sp_hires_bp -speech_feat_scp=data/train_aztec_unsad_whole_corrupted_sp_hires_bp_2400/speech_feat.scp +train_data_dir=data/train_aztec_small_unsad_whole_sad_ovlp_corrupted_sp +speech_feat_scp=data/train_aztec_unsad_whole_corrupted_sp_hires_bp/speech_feat.scp deriv_weights_scp=data/train_aztec_unsad_whole_corrupted_sp_hires_bp_2400/deriv_weights.scp -snr_scp=data/train_aztec_unsad_whole_corrupted_sp_hires_bp_2400/irm_targets.scp -deriv_weights_for_irm_scp=data/train_aztec_unsad_whole_corrupted_sp_hires_bp_2400/deriv_weights_manual_seg.scp +snr_scp=data/train_aztec_unsad_whole_corrupted_sp_hires_bp/irm_targets.scp +deriv_weights_for_irm_scp=data/train_aztec_unsad_whole_corrupted_sp_hires_bp/deriv_weights_manual_seg.scp deriv_weights_for_overlapped_speech_scp= overlapped_speech_labels_scp= diff --git a/egs/wsj/s5/utils/data/get_reco2dur.sh b/egs/wsj/s5/utils/data/get_reco2dur.sh index 7d2ccb71769..5e925fc3e75 100755 --- a/egs/wsj/s5/utils/data/get_reco2dur.sh +++ b/egs/wsj/s5/utils/data/get_reco2dur.sh @@ -11,6 +11,8 @@ # files in entirely.) frame_shift=0.01 +cmd=run.pl +nj=4 . utils/parse_options.sh . ./path.sh @@ -74,11 +76,17 @@ else echo "... perturb_data_dir_speed_3way.sh." fi - if ! wav-to-duration --read-entire-file=$read_entire_file scp:$data/wav.scp ark,t:$data/reco2dur 2>&1 | grep -v 'nonzero return status'; then + utils/split_data.sh $data $nj + if ! $cmd JOB=1:$nj $data/log/get_wav_duration.JOB.log wav-to-duration --read-entire-file=$read_entire_file scp:$data/split$nj/JOB/wav.scp ark,t:$data/split$nj/JOB/reco2dur 2>&1; then echo "$0: there was a problem getting the durations; moving $data/reco2dur to $data/.backup/" mkdir -p $data/.backup/ mv $data/reco2dur $data/.backup/ + exit 1 fi + + for n in `seq $nj`; do + cat $data/split$nj/$n/reco2dur + done > $data/reco2dur fi echo "$0: computed $data/reco2dur" From 7c6e40a8fe020f7dbbecfd2f5dcc3c48035dedb8 Mon Sep 17 00:00:00 2001 From: Vimal Manohar Date: Wed, 14 Dec 2016 00:08:45 -0500 Subject: [PATCH 097/213] asr_diarization: New version of corruption_data_dir for overlapped_speech that works on non-whole dirs --- ...o_corruption_data_dir_overlapped_speech.sh | 79 ++++++++++--------- 1 file changed, 40 insertions(+), 39 deletions(-) mode change 100644 => 100755 egs/aspire/s5/local/segmentation/do_corruption_data_dir_overlapped_speech.sh diff --git a/egs/aspire/s5/local/segmentation/do_corruption_data_dir_overlapped_speech.sh b/egs/aspire/s5/local/segmentation/do_corruption_data_dir_overlapped_speech.sh old mode 100644 new mode 100755 index 75dbce578b2..4d532be4353 --- a/egs/aspire/s5/local/segmentation/do_corruption_data_dir_overlapped_speech.sh +++ b/egs/aspire/s5/local/segmentation/do_corruption_data_dir_overlapped_speech.sh @@ -21,12 +21,9 @@ snrs="20:10:15:5:0:-5" foreground_snrs="20:10:15:5:0:-5" background_snrs="20:10:15:5:0:-5" overlap_snrs="5:2:1:0:-1:-2" -# Whole-data directory corresponding to data_dir -whole_data_dir=data/train_si284_whole overlap_labels_dir=overlap_labels # Parallel options -reco_nj=40 nj=40 cmd=queue.pl @@ -35,9 +32,6 @@ mfcc_config=conf/mfcc_hires_bp.conf feat_suffix=hires_bp energy_config=conf/log_energy.conf -reco_vad_dir= # Output of prepare_unsad_data.sh. - # If provided, the speech labels and deriv weights will be - # copied into the output data directory. utt_vad_dir= . utils/parse_options.sh @@ -53,11 +47,20 @@ rvb_opts+=(--rir-set-parameters "0.5, RIRS_NOISES/simulated_rirs/smallroom/rir_l rvb_opts+=(--rir-set-parameters "0.5, RIRS_NOISES/simulated_rirs/mediumroom/rir_list") rvb_opts+=(--speech-segments-set-parameters="$data_dir/wav.scp,$data_dir/segments") -whole_data_id=`basename ${whole_data_dir}` +if [ $stage -le 0 ]; then + steps/segmentation/get_data_dir_with_segmented_wav.py \ + $data_dir ${data_dir}_seg +fi + +data_dir=${data_dir}_seg + +data_id=`basename ${data_dir}` -corrupted_data_id=${whole_data_id}_ovlp_corrupted -clean_data_id=${whole_data_id}_ovlp_clean -noise_data_id=${whole_data_id}_ovlp_noise +corrupted_data_id=${data_id}_ovlp_corrupted +clean_data_id=${data_id}_ovlp_clean +noise_data_id=${data_id}_ovlp_noise + +utils/data/get_reco2dur.sh --cmd $cmd --nj 40 $data_dir if [ $stage -le 1 ]; then python steps/data/make_corrupted_data_dir.py \ @@ -67,15 +70,11 @@ if [ $stage -le 1 ]; then --speech-rvb-probability=1 \ --overlapping-speech-addition-probability=1 \ --num-replications=$num_data_reps \ - --min-overlapping-segments-per-minute=5 \ - --max-overlapping-segments-per-minute=20 \ + --min-overlapping-segments-per-minute=1 \ + --max-overlapping-segments-per-minute=1 \ --output-additive-noise-dir=data/${noise_data_id} \ --output-reverb-dir=data/${clean_data_id} \ - data/${whole_data_id} data/${corrupted_data_id} -fi - -if $dry_run; then - exit 0 + ${data_dir} data/${corrupted_data_id} fi clean_data_dir=data/${clean_data_id} @@ -85,9 +84,7 @@ orig_corrupted_data_dir=$corrupted_data_dir if $speed_perturb; then if [ $stage -le 2 ]; then - ## Assuming whole data directories for x in $clean_data_dir $corrupted_data_dir $noise_data_dir; do - cp $x/reco2dur $x/utt2dur utils/data/perturb_data_dir_speed_3way.sh $x ${x}_sp done fi @@ -101,9 +98,9 @@ if $speed_perturb; then noise_data_id=${noise_data_id}_sp if [ $stage -le 3 ]; then - utils/data/perturb_data_dir_volume.sh --scale-low 0.03125 --scale-high 2 --force true ${corrupted_data_dir} - utils/data/perturb_data_dir_volume.sh --force true --reco2vol ${corrupted_data_dir}/reco2vol ${clean_data_dir} - utils/data/perturb_data_dir_volume.sh --force true --reco2vol ${corrupted_data_dir}/reco2vol ${noise_data_dir} + utils/data/perturb_data_dir_volume.sh --scale-low 0.03125 --scale-high 2 ${corrupted_data_dir} + utils/data/perturb_data_dir_volume.sh --reco2vol ${corrupted_data_dir}/reco2vol ${clean_data_dir} + utils/data/perturb_data_dir_volume.sh --reco2vol ${corrupted_data_dir}/reco2vol ${noise_data_dir} fi fi @@ -124,19 +121,20 @@ if [ $stage -le 4 ]; then utils/copy_data_dir.sh $corrupted_data_dir ${corrupted_data_dir}_$feat_suffix corrupted_data_dir=${corrupted_data_dir}_$feat_suffix steps/make_mfcc.sh --mfcc-config $mfcc_config \ - --cmd "$train_cmd" --nj $reco_nj \ + --cmd "$cmd" --nj $nj \ $corrupted_data_dir exp/make_${feat_suffix}/${corrupted_data_id} $mfccdir fi +if false; then if [ $stage -le 5 ]; then steps/make_mfcc.sh --mfcc-config $energy_config \ - --cmd "$train_cmd" --nj $reco_nj \ + --cmd "$cmd" --nj $nj \ $clean_data_dir exp/make_log_energy/${clean_data_id} log_energy_feats fi if [ $stage -le 6 ]; then steps/make_mfcc.sh --mfcc-config $energy_config \ - --cmd "$train_cmd" --nj $reco_nj \ + --cmd "$cmd" --nj $nj \ $noise_data_dir exp/make_log_energy/${noise_data_id} log_energy_feats fi @@ -155,11 +153,12 @@ if [ $stage -le 8 ]; then fi steps/segmentation/make_snr_targets.sh \ - --nj $nj --cmd "$train_cmd --max-jobs-run $max_jobs_run" \ + --nj $nj --cmd "$cmd --max-jobs-run $max_jobs_run" \ --target-type Irm --compress true --apply-exp false \ ${clean_data_dir} ${noise_data_dir} ${corrupted_data_dir} \ exp/make_irm_targets/${corrupted_data_id} $targets_dir fi +fi # Combine the VAD from the base recording and the VAD from the overlapping segments # to create per-frame labels of the number of overlapping speech segments @@ -175,15 +174,15 @@ unreliable_data_dir=$overlap_dir/unreliable_data mkdir -p $unreliable_dir if [ $stage -le 8 ]; then - cat $reco_vad_dir/sad_seg.scp | \ + cat $utt_vad_dir/sad_seg.scp | \ steps/segmentation/get_reverb_scp.pl -f 1 $num_data_reps "ovlp" \ | sort -k1,1 > ${corrupted_data_dir}/sad_seg.scp utils/data/get_utt2num_frames.sh $corrupted_data_dir - utils/split_data.sh --per-reco ${orig_corrupted_data_dir} $reco_nj + utils/split_data.sh ${orig_corrupted_data_dir} $nj - $train_cmd JOB=1:$reco_nj $overlap_dir/log/get_overlap_seg.JOB.log \ + $cmd JOB=1:$nj $overlap_dir/log/get_overlap_seg.JOB.log \ segmentation-init-from-overlap-info --lengths-rspecifier=ark,t:$corrupted_data_dir/utt2num_frames \ - "scp:utils/filter_scp.pl ${orig_corrupted_data_dir}/split${reco_nj}reco/JOB/utt2spk $corrupted_data_dir/sad_seg.scp |" \ + "scp:utils/filter_scp.pl ${orig_corrupted_data_dir}/split${nj}/JOB/utt2spk $corrupted_data_dir/sad_seg.scp |" \ ark,t:$orig_corrupted_data_dir/overlapped_segments_info.txt \ scp:$utt_vad_dir/sad_seg.scp ark:- ark:$unreliable_dir/unreliable_seg_speed_unperturbed.JOB.ark \| \ segmentation-copy --keep-label=1 ark:- ark:- \| \ @@ -192,6 +191,8 @@ if [ $stage -le 8 ]; then segmentation-init-from-ali ark:- ark:$overlap_dir/overlap_seg_speed_unperturbed.JOB.ark fi +exit 1 + if [ $stage -le 9 ]; then mkdir -p $overlap_data_dir $unreliable_data_dir cp $orig_corrupted_data_dir/wav.scp $overlap_data_dir @@ -199,21 +200,21 @@ if [ $stage -le 9 ]; then # Create segments where there is definitely an overlap. # Assume no more than 10 speakers overlap. - $train_cmd JOB=1:$reco_nj $overlap_dir/log/process_to_segments.JOB.log \ + $cmd JOB=1:$nj $overlap_dir/log/process_to_segments.JOB.log \ segmentation-post-process --remove-labels=0:1 \ ark:$overlap_dir/overlap_seg_speed_unperturbed.JOB.ark ark:- \| \ segmentation-post-process --merge-labels=2:3:4:5:6:7:8:9:10 --merge-dst-label=1 ark:- ark:- \| \ segmentation-to-segments ark:- ark:$overlap_data_dir/utt2spk.JOB $overlap_data_dir/segments.JOB - $train_cmd JOB=1:$reco_nj $overlap_dir/log/get_unreliable_segments.JOB.log \ + $cmd JOB=1:$nj $overlap_dir/log/get_unreliable_segments.JOB.log \ segmentation-to-segments --single-speaker \ ark:$unreliable_dir/unreliable_seg_speed_unperturbed.JOB.ark \ ark:$unreliable_data_dir/utt2spk.JOB $unreliable_data_dir/segments.JOB - for n in `seq $reco_nj`; do cat $overlap_data_dir/utt2spk.$n; done > $overlap_data_dir/utt2spk - for n in `seq $reco_nj`; do cat $overlap_data_dir/segments.$n; done > $overlap_data_dir/segments - for n in `seq $reco_nj`; do cat $unreliable_data_dir/utt2spk.$n; done > $unreliable_data_dir/utt2spk - for n in `seq $reco_nj`; do cat $unreliable_data_dir/segments.$n; done > $unreliable_data_dir/segments + for n in `seq $nj`; do cat $overlap_data_dir/utt2spk.$n; done > $overlap_data_dir/utt2spk + for n in `seq $nj`; do cat $overlap_data_dir/segments.$n; done > $overlap_data_dir/segments + for n in `seq $nj`; do cat $unreliable_data_dir/utt2spk.$n; done > $unreliable_data_dir/utt2spk + for n in `seq $nj`; do cat $unreliable_data_dir/segments.$n; done > $unreliable_data_dir/segments utils/fix_data_dir.sh $overlap_data_dir utils/fix_data_dir.sh $unreliable_data_dir @@ -233,9 +234,9 @@ fi overlap_labels_dir=`perl -e '($dir,$pwd)= @ARGV; if($dir!~m:^/:) { $dir = "$pwd/$dir"; } print $dir; ' $overlap_labels_dir ${PWD}` if [ $stage -le 10 ]; then - utils/split_data.sh --per-reco ${overlap_data_dir} $reco_nj + utils/split_data.sh ${overlap_data_dir} $nj - $train_cmd JOB=1:$reco_nj $overlap_dir/log/get_overlap_speech_labels.JOB.log \ + $cmd JOB=1:$nj $overlap_dir/log/get_overlap_speech_labels.JOB.log \ utils/data/get_reco2utt.sh ${overlap_data_dir}/split${reco_nj}reco/JOB '&&' \ segmentation-init-from-segments --shift-to-zero=false \ ${overlap_data_dir}/split${reco_nj}reco/JOB/segments ark:- \| \ @@ -260,7 +261,7 @@ if [ $stage -le 11 ]; then # Intersect this with the deriv weights segmentation from above. At this stage # deriv weights is 1 for only the regions where base VAD label is 1 and # the overlapping segment is not unreliable. Convert this to deriv weights. - $train_cmd JOB=1:$reco_nj $unreliable_dir/log/get_deriv_weights.JOB.log\ + $cmd JOB=1:$reco_nj $unreliable_dir/log/get_deriv_weights.JOB.log\ segmentation-init-from-segments --shift-to-zero=false \ "utils/filter_scp.pl -f 2 ${overlap_data_dir}/split${reco_nj}reco/JOB/reco2utt ${unreliable_data_dir}/segments |" ark:- \| \ segmentation-combine-segments-to-recordings ark:- "ark,t:utils/filter_scp.pl ${overlap_data_dir}/split${reco_nj}reco/JOB/reco2utt ${unreliable_data_dir}/reco2utt |" \ From ac7e7166eb5b7f0327339b451c5bfd9ba154d396 Mon Sep 17 00:00:00 2001 From: Vimal Manohar Date: Wed, 14 Dec 2016 00:09:02 -0500 Subject: [PATCH 098/213] asr_diarization: Updated AMI segmentation recipe --- egs/aspire/s5/local/segmentation/run_segmentation_ami.sh | 2 ++ 1 file changed, 2 insertions(+) diff --git a/egs/aspire/s5/local/segmentation/run_segmentation_ami.sh b/egs/aspire/s5/local/segmentation/run_segmentation_ami.sh index 46ebf013b82..98ff4210780 100755 --- a/egs/aspire/s5/local/segmentation/run_segmentation_ami.sh +++ b/egs/aspire/s5/local/segmentation/run_segmentation_ami.sh @@ -75,6 +75,7 @@ cat $src_dir/data/sdm1/dev/reco2file_and_channel | \ utils/apply_map.pl -f 3 $dir/channel_map > $dir/reco2file_and_channel if [ $stage -le 5 ]; then + # Reference RTTM where SPEECH frames are obtainted by combining IHM VAD alignments $train_cmd $dir/log/get_ref_rttm.log \ segmentation-combine-segments scp:$dir/sad_seg.scp \ "ark:segmentation-init-from-segments --shift-to-zero=false $src_dir/data/sdm1/dev_ihmdata/segments ark:- |" \ @@ -87,6 +88,7 @@ if [ $stage -le 5 ]; then fi if [ $stage -le 6 ]; then + # Get an UEM which evaluates only on the manual segments. $train_cmd $dir/log/get_uem.log \ segmentation-init-from-segments --shift-to-zero=false $src_dir/data/sdm1/dev/segments ark:- \| \ segmentation-combine-segments-to-recordings ark:- ark,t:$src_dir/data/sdm1/dev/reco2utt ark:- \| \ From bd499c8557fd0ffb517ad75e782edbf37c5402fe Mon Sep 17 00:00:00 2001 From: Vimal Manohar Date: Thu, 15 Dec 2016 19:41:38 -0500 Subject: [PATCH 099/213] asr_diarization: Restructuring do_segmentation_data_dir.sh --- .../segmentation/do_segmentation_data_dir.sh | 47 +++++++++++-------- 1 file changed, 27 insertions(+), 20 deletions(-) diff --git a/egs/wsj/s5/steps/segmentation/do_segmentation_data_dir.sh b/egs/wsj/s5/steps/segmentation/do_segmentation_data_dir.sh index 663655eef77..9feb421ccd3 100755 --- a/egs/wsj/s5/steps/segmentation/do_segmentation_data_dir.sh +++ b/egs/wsj/s5/steps/segmentation/do_segmentation_data_dir.sh @@ -42,15 +42,17 @@ echo $* . utils/parse_options.sh -if [ $# -ne 3 ]; then - echo "Usage: $0 " - echo " e.g.: $0 ~/workspace/egs/ami/s5b/data/sdm1/dev data/ami_sdm1_dev exp/nnet3_sad_snr/nnet_tdnn_j_n4" +if [ $# -ne 4 ]; then + echo "Usage: $0 " + echo " e.g.: $0 ~/workspace/egs/ami/s5b/data/sdm1/dev exp/nnet3_sad_snr/nnet_tdnn_j_n4 mfcc_hires_bp data/ami_sdm1_dev" exit 1 fi -src_data_dir=$1 -data_dir=$2 -sad_nnet_dir=$3 +src_data_dir=$1 # The input data directory that needs to be segmented. + # Any segments in that will be ignored. +sad_nnet_dir=$2 # The SAD neural network +mfcc_dir=$3 # The directory to store the features +data_dir=$4 # The output data directory will be ${data_dir}_seg affix=${affix:+_$affix} feat_affix=${feat_affix:+_$feat_affix} @@ -62,8 +64,10 @@ seg_dir=${sad_nnet_dir}/${segmentation_name}${affix}_${data_id}_whole${feat_affi export PATH="$KALDI_ROOT/tools/sph2pipe_v2.5/:$PATH" [ ! -z `which sph2pipe` ] +whole_data_dir=${sad_dir}/${data_id}_whole + if [ $stage -le 0 ]; then - utils/data/convert_data_dir_to_whole.sh $src_data_dir ${data_dir}_whole + utils/data/convert_data_dir_to_whole.sh $src_data_dir ${whole_data_dir} if $do_downsampling; then freq=`cat $mfcc_config | perl -pe 's/\s*#.*//g' | grep "sample-frequency=" | awk -F'=' '{if (NF == 0) print 16000; else print $2}'` @@ -76,18 +80,19 @@ for line in sys.stdin.readlines(): out_line = line.strip() + ' $sox -t wav - -r $freq -c 1 -b 16 -t wav - downsample |' else: out_line = 'cat {0} {1} | $sox -t wav - -r $freq -c 1 -b 16 -t wav - downsample |'.format(splits[0], ' '.join(splits[1:])) - print (out_line)" > ${data_dir}_whole/wav.scp + print (out_line)" > ${whole_data_dir}/wav.scp fi - utils/copy_data_dir.sh ${data_dir}_whole ${data_dir}_whole${feat_affix}_hires + utils/copy_data_dir.sh ${whole_data_dir} ${whole_data_dir}${feat_affix}_hires fi -test_data_dir=${data_dir}_whole${feat_affix}_hires +test_data_dir=${whole_data_dir}${feat_affix}_hires if [ $stage -le 1 ]; then steps/make_mfcc.sh --mfcc-config $mfcc_config --nj $reco_nj --cmd "$train_cmd" \ - ${data_dir}_whole${feat_affix}_hires exp/make_hires/${data_id}_whole${feat_affix} mfcc_hires - steps/compute_cmvn_stats.sh ${data_dir}_whole${feat_affix}_hires exp/make_hires/${data_id}_whole${feat_affix} mfcc_hires + ${whole_data_dir}${feat_affix}_hires exp/make_hires/${data_id}_whole${feat_affix} $mfcc_dir + steps/compute_cmvn_stats.sh ${whole_data_dir}${feat_affix}_hires exp/make_hires/${data_id}_whole${feat_affix} $mfcc_dir + utils/fix_data_dir.sh ${whole_data_dir}${feat_affix}_hires fi post_vec=$sad_nnet_dir/post_${output_name}.vec @@ -114,21 +119,23 @@ if [ $stage -le 3 ]; then --min-silence-duration $min_silence_duration \ --min-speech-duration $min_speech_duration \ --segmentation-config $segmentation_config --cmd "$train_cmd" \ - ${test_data_dir} $sad_dir $seg_dir $seg_dir/${data_id}_seg + ${test_data_dir} $sad_dir $seg_dir ${data_dir}_seg fi # Subsegment data directory if [ $stage -le 4 ]; then - rm $seg_dir/${data_id}_seg/feats.scp || true + rm ${data_dir}_seg/feats.scp || true utils/data/get_reco2num_frames.sh ${test_data_dir} - awk '{print $1" "$2}' ${seg_dir}/${data_id}_seg/segments | \ + awk '{print $1" "$2}' ${data_dir}_seg/segments | \ utils/apply_map.pl -f 2 ${test_data_dir}/reco2num_frames > \ - $seg_dir/${data_id}_seg/utt2max_frames + ${data_dir}_seg/utt2max_frames frame_shift_info=`cat $mfcc_config | steps/segmentation/get_frame_shift_info_from_config.pl` utils/data/get_subsegment_feats.sh ${test_data_dir}/feats.scp \ - $frame_shift_info $seg_dir/${data_id}_seg/segments | \ - utils/data/fix_subsegmented_feats.pl ${seg_dir}/${data_id}_seg/utt2max_frames > \ - $seg_dir/${data_id}_seg/feats.scp - steps/compute_cmvn_stats.sh --fake $seg_dir/${data_id}_seg + $frame_shift_info ${data_dir}_seg/segments | \ + utils/data/fix_subsegmented_feats.pl ${data_dir}_seg/utt2max_frames > \ + ${data_dir}_seg/feats.scp + steps/compute_cmvn_stats.sh --fake ${data_dir}_seg + + utils/fix_data_dir.sh ${data_dir}_seg fi From eaa31a4bcdaefc65a242def1a7e59d325505f88d Mon Sep 17 00:00:00 2001 From: Vimal Manohar Date: Thu, 15 Dec 2016 19:44:37 -0500 Subject: [PATCH 100/213] asr_diarization: SAD on Aspire --- .../s5/local/multi_condition/get_ctm.sh | 6 +- .../nnet3/prep_test_aspire_segmentation.sh | 230 ++++++++++++++++++ egs/aspire/s5/local/score_aspire.sh | 9 +- 3 files changed, 238 insertions(+), 7 deletions(-) create mode 100755 egs/aspire/s5/local/nnet3/prep_test_aspire_segmentation.sh diff --git a/egs/aspire/s5/local/multi_condition/get_ctm.sh b/egs/aspire/s5/local/multi_condition/get_ctm.sh index f67a1191544..6fc87fec7b0 100755 --- a/egs/aspire/s5/local/multi_condition/get_ctm.sh +++ b/egs/aspire/s5/local/multi_condition/get_ctm.sh @@ -7,7 +7,7 @@ decode_mbr=true filter_ctm_command=cp glm= stm= -window=10 +resolve_overlaps=true overlap=5 [ -f ./path.sh ] && . ./path.sh . parse_options.sh || exit 1; @@ -62,7 +62,9 @@ lattice-align-words-lexicon --output-error-lats=true --output-if-empty=true --ma lattice-to-ctm-conf $frame_shift_opt --decode-mbr=$decode_mbr ark:- $decode_dir/score_$LMWT/penalty_$wip/ctm.overlapping || exit 1; # combine the segment-wise ctm files, while resolving overlaps -python local/multi_condition/resolve_ctm_overlaps.py --overlap $overlap --window-length $window $data_dir/utt2spk $decode_dir/score_$LMWT/penalty_$wip/ctm.overlapping $decode_dir/score_$LMWT/penalty_$wip/ctm.merged || exit 1; +if $resolve_overlaps; then + steps/resolve_ctm_overlaps.py $data_dir/segments $decode_dir/score_$LMWT/penalty_$wip/ctm.overlapping $decode_dir/score_$LMWT/penalty_$wip/ctm.merged || exit 1; +fi merged_ctm=$decode_dir/score_$LMWT/penalty_$wip/ctm.merged cat $merged_ctm | utils/int2sym.pl -f 5 $lang/words.txt | \ diff --git a/egs/aspire/s5/local/nnet3/prep_test_aspire_segmentation.sh b/egs/aspire/s5/local/nnet3/prep_test_aspire_segmentation.sh new file mode 100755 index 00000000000..5f38f6de51f --- /dev/null +++ b/egs/aspire/s5/local/nnet3/prep_test_aspire_segmentation.sh @@ -0,0 +1,230 @@ +#!/bin/bash + +# Copyright Johns Hopkins University (Author: Daniel Povey, Vijayaditya Peddinti) 2016. Apache 2.0. +# This script generates the ctm files for dev_aspire, test_aspire and eval_aspire +# for scoring with ASpIRE scoring server. +# It also provides the WER for dev_aspire data. + +set -e +set -o pipefail +set -u + +# general opts +iter=final +stage=0 +decode_num_jobs=30 +num_jobs=30 +affix= + +# ivector opts +max_count=75 # parameter for extract_ivectors.sh +sub_speaker_frames=6000 +ivector_scale=0.75 +filter_ctm=true +weights_file= +silence_weight=0.00001 + +# decode opts +pass2_decode_opts="--min-active 1000" +lattice_beam=8 +extra_left_context=0 # change for (B)LSTM +extra_right_context=0 # change for BLSTM +frames_per_chunk=50 # change for (B)LSTM +acwt=0.1 # important to change this when using chain models +post_decode_acwt=1.0 # important to change this when using chain models + +. ./cmd.sh +[ -f ./path.sh ] && . ./path.sh +. utils/parse_options.sh || exit 1; + +if [ $# -ne 5 ]; then + echo "Usage: $0 [options] " + echo " Options:" + echo " --stage (0|1|2) # start scoring script from part-way through." + echo "e.g.:" + echo "$0 dev_aspire data/lang exp/tri5a/graph_pp exp/nnet3/tdnn" + exit 1; +fi + +data_set=$1 +sad_nnet_dir=$2 +lang=$3 # data/lang +graph=$4 #exp/tri5a/graph_pp +dir=$5 # exp/nnet3/tdnn + +model_affix=`basename $dir` +ivector_dir=exp/nnet3 +ivector_affix=${affix:+_$affix}_chain_${model_affix}_iter$iter +affix=_${affix}_iter${iter} +act_data_set=${data_set} # we will modify the data dir, when segmenting it + # so we will keep track of original data dirfor the glm and stm files + +if [[ "$data_set" =~ "test_aspire" ]]; then + out_file=single_dev_test${affix}_$model_affix.ctm +elif [[ "$data_set" =~ "eval_aspire" ]]; then + out_file=single_eval${affix}_$model_affix.ctm +elif [[ "$data_set" =~ "dev_aspire" ]]; then + # we will just decode the directory without oracle segments file + # as we would like to operate in the actual evaluation condition + out_file=single_dev${affix}_${model_affix}.ctm +else + exit 1 +fi + +if [ $stage -le 1 ]; then + steps/segmentation/do_segmentation_data_dir.sh --reco-nj $num_jobs \ + --mfcc-config conf/mfcc_hires_bp.conf --feat-affix bp \ + --do-downsampling false --extra-left-context 100 --extra-right-context 20 \ + --output-name output-speech --frame-subsampling-factor 6 \ + data/${data_set} $sad_nnet_dir mfcc_hires_bp data/${data_set} + # Output will be in data/${data_set}_seg +fi + +# uniform segmentation script would have created this dataset +# so update that script if you plan to change this variable +segmented_data_set=${data_set}_seg + +if [ $stage -le 2 ]; then + mfccdir=mfcc_reverb + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $mfccdir/storage ]; then + date=$(date +'%m_%d_%H_%M') + utils/create_split_dir.pl /export/b0{1,2,3,4}/$USER/kaldi-data/egs/aspire-$date/s5/$mfccdir/storage $mfccdir/storage + fi + + utils/copy_data_dir.sh data/${segmented_data_set} data/${segmented_data_set}_hires + steps/make_mfcc.sh --nj 30 --cmd "$train_cmd" \ + --mfcc-config conf/mfcc_hires.conf data/${segmented_data_set}_hires \ + exp/make_reverb_hires/${segmented_data_set} $mfccdir + steps/compute_cmvn_stats.sh data/${segmented_data_set}_hires \ + exp/make_reverb_hires/${segmented_data_set} $mfccdir + utils/fix_data_dir.sh data/${segmented_data_set}_hires + utils/validate_data_dir.sh --no-text data/${segmented_data_set}_hires +fi + +decode_dir=$dir/decode_${segmented_data_set}${affix}_pp +false && { +if [ $stage -le 2 ]; then + echo "Extracting i-vectors, stage 1" + steps/online/nnet2/extract_ivectors_online.sh --cmd "$train_cmd" --nj 20 \ + --max-count $max_count \ + data/${segmented_data_set}_hires $ivector_dir/extractor \ + $ivector_dir/ivectors_${segmented_data_set}${ivector_affix}_stage1; + # float comparisons are hard in bash + if [ `bc <<< "$ivector_scale != 1"` -eq 1 ]; then + ivector_scale_affix=_scale$ivector_scale + else + ivector_scale_affix= + fi + + if [ ! -z "$ivector_scale_affix" ]; then + echo "$0: Scaling iVectors, stage 1" + srcdir=$ivector_dir/ivectors_${segmented_data_set}${ivector_affix}_stage1 + outdir=$ivector_dir/ivectors_${segmented_data_set}${ivector_affix}${ivector_scale_affix}_stage1 + mkdir -p $outdir + copy-matrix --scale=$ivector_scale scp:$srcdir/ivector_online.scp ark:- | \ + copy-feats --compress=true ark:- ark,scp:$outdir/ivector_online.ark,$outdir/ivector_online.scp; + cp $srcdir/ivector_period $outdir/ivector_period + fi +fi + +# generate the lattices +if [ $stage -le 3 ]; then + echo "Generating lattices, stage 1" + steps/nnet3/decode.sh --nj $decode_num_jobs --cmd "$decode_cmd" --config conf/decode.config \ + --acwt $acwt --post-decode-acwt $post_decode_acwt \ + --extra-left-context $extra_left_context \ + --extra-right-context $extra_right_context \ + --frames-per-chunk "$frames_per_chunk" \ + --online-ivector-dir $ivector_dir/ivectors_${segmented_data_set}${ivector_affix}${ivector_scale_affix}_stage1 \ + --skip-scoring true --iter $iter \ + $graph data/${segmented_data_set}_hires ${decode_dir}_stage1; +fi + +if [ $stage -le 4 ]; then + if $filter_ctm; then + if [ ! -z $weights_file ]; then + echo "$0: Using provided vad weights file $weights_file" + ivector_extractor_input=$weights_file + else + echo "$0 : Generating vad weights file" + ivector_extractor_input=${decode_dir}_stage1/weights${affix}.gz + local/extract_vad_weights.sh --cmd "$decode_cmd" --iter $iter \ + data/${segmented_data_set}_hires $lang \ + ${decode_dir}_stage1 $ivector_extractor_input + fi + else + # just use all the frames + ivector_extractor_input=${decode_dir}_stage1 + fi +fi + +if [ $stage -le 5 ]; then + echo "Extracting i-vectors, stage 2 with input $ivector_extractor_input" + # this does offline decoding, except we estimate the iVectors per + # speaker, excluding silence (based on alignments from a DNN decoding), with a + # different script. This is just to demonstrate that script. + # the --sub-speaker-frames is optional; if provided, it will divide each speaker + # up into "sub-speakers" of at least that many frames... can be useful if + # acoustic conditions drift over time within the speaker's data. + steps/online/nnet2/extract_ivectors.sh --cmd "$train_cmd" --nj 20 \ + --silence-weight $silence_weight \ + --sub-speaker-frames $sub_speaker_frames --max-count $max_count \ + data/${segmented_data_set}_hires $lang $ivector_dir/extractor \ + $ivector_extractor_input $ivector_dir/ivectors_${segmented_data_set}${ivector_affix}; +fi +} + +if [ $stage -le 5 ]; then + echo "Extracting i-vectors, stage 2" + # this does offline decoding, except we estimate the iVectors per + # speaker, excluding silence (based on alignments from a DNN decoding), with a + # different script. This is just to demonstrate that script. + # the --sub-speaker-frames is optional; if provided, it will divide each speaker + # up into "sub-speakers" of at least that many frames... can be useful if + # acoustic conditions drift over time within the speaker's data. + steps/online/nnet2/extract_ivectors.sh --cmd "$train_cmd" --nj 20 \ + --sub-speaker-frames $sub_speaker_frames --max-count $max_count \ + data/${segmented_data_set}_hires $lang $ivector_dir/extractor \ + $ivector_dir/ivectors_${segmented_data_set}${ivector_affix}; +fi + +if [ $stage -le 6 ]; then + echo "Generating lattices, stage 2 with --acwt $acwt" + rm -f ${decode_dir}_tg/.error + steps/nnet3/decode.sh --nj $decode_num_jobs --cmd "$decode_cmd" --config conf/decode.config $pass2_decode_opts \ + --acwt $acwt --post-decode-acwt $post_decode_acwt \ + --extra-left-context $extra_left_context \ + --extra-right-context $extra_right_context \ + --frames-per-chunk "$frames_per_chunk" \ + --skip-scoring true --iter $iter --lattice-beam $lattice_beam \ + --online-ivector-dir $ivector_dir/ivectors_${segmented_data_set}${ivector_affix} \ + $graph data/${segmented_data_set}_hires ${decode_dir}_tg || touch ${decode_dir}_tg/.error + [ -f ${decode_dir}_tg/.error ] && echo "$0: Error decoding" && exit 1; +fi + +if [ $stage -le 7 ]; then + echo "Rescoring lattices" + steps/lmrescore_const_arpa.sh --cmd "$decode_cmd" \ + --skip-scoring true \ + ${lang}_pp_test{,_fg} data/${segmented_data_set}_hires \ + ${decode_dir}_{tg,fg}; +fi + +decode_dir=${decode_dir}_fg + +if [ $stage -le 8 ]; then + local/score_aspire.sh --cmd "$decode_cmd" \ + --min-lmwt 1 --max-lmwt 20 \ + --word-ins-penalties "0.0,0.25,0.5,0.75,1.0" \ + --ctm-beam 6 \ + --iter $iter \ + --decode-mbr true \ + --resolve-overlaps false \ + --tune-hyper true \ + $lang $decode_dir $act_data_set $segmented_data_set $out_file +fi + +# Two-pass decoding baseline +# %WER 27.8 | 2120 27217 | 78.2 13.6 8.2 6.0 27.8 75.9 | -0.613 | exp/chain/tdnn_7b/decode_dev_aspire_whole_uniformsegmented_win10_over5_v6_200jobs_iterfinal_pp_fg/score_9/penalty_0.0/ctm.filt.filt.sys +# Using automatic segmentation +# %WER 28.2 | 2120 27214 | 76.5 12.4 11.1 4.7 28.2 75.2 | -0.522 | exp/chain/tdnn_7b/decode_dev_aspire_seg_v7_n_stddev_iterfinal_pp_fg/score_10/penalty_0.0/ctm.filt.filt.sys diff --git a/egs/aspire/s5/local/score_aspire.sh b/egs/aspire/s5/local/score_aspire.sh index 3e35b6d3dae..9c08a6c85d1 100755 --- a/egs/aspire/s5/local/score_aspire.sh +++ b/egs/aspire/s5/local/score_aspire.sh @@ -14,10 +14,9 @@ word_ins_penalties=0.0,0.25,0.5,0.75,1.0 default_wip=0.0 ctm_beam=6 decode_mbr=true -window=30 -overlap=5 cmd=run.pl stage=1 +resolve_overlaps=true tune_hyper=true # if true: # if the data set is "dev_aspire" we check for the # best lmwt and word_insertion_penalty, @@ -89,7 +88,7 @@ if $tune_hyper ; then # or use the default values if [ $stage -le 1 ]; then - if [ "$act_data_set" == "dev_aspire" ]; then + if [[ "$act_data_set" =~ "dev_aspire" ]]; then wip_string=$(echo $word_ins_penalties | sed 's/,/ /g') temp_wips=($wip_string) $cmd WIP=1:${#temp_wips[@]} $decode_dir/scoring/log/score.wip.WIP.log \ @@ -98,8 +97,8 @@ if $tune_hyper ; then echo \$wip \&\& \ $cmd LMWT=$min_lmwt:$max_lmwt $decode_dir/scoring/log/score.LMWT.\$wip.log \ local/multi_condition/get_ctm.sh --filter-ctm-command "$filter_ctm_command" \ - --window $window --overlap $overlap \ --beam $ctm_beam --decode-mbr $decode_mbr \ + --resolve-overlaps $resolve_overlaps \ --glm data/${act_data_set}/glm --stm data/${act_data_set}/stm \ LMWT \$wip $lang data/${segmented_data_set}_hires $model $decode_dir || exit 1; @@ -124,7 +123,7 @@ wipfile.close() fi - if [ "$act_data_set" == "test_aspire" ] || [ "$act_data_set" == "eval_aspire" ]; then + if [[ "$act_data_set" =~ "test_aspire" ]] || [[ "$act_data_set" =~ "eval_aspire" ]]; then # check for the best values from dev_aspire decodes dev_decode_dir=$(echo $decode_dir|sed "s/test_aspire/dev_aspire_whole/g; s/eval_aspire/dev_aspire_whole/g") if [ -f $dev_decode_dir/scoring/bestLMWT ]; then From c9a8da1b43590251164ad83a07f5d6fe3f7c983b Mon Sep 17 00:00:00 2001 From: Vimal Manohar Date: Thu, 15 Dec 2016 19:45:19 -0500 Subject: [PATCH 101/213] asr_diaization: Changes to run_segmentation_ami based on restructuring --- .../s5/local/segmentation/run_segmentation_ami.sh | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/egs/aspire/s5/local/segmentation/run_segmentation_ami.sh b/egs/aspire/s5/local/segmentation/run_segmentation_ami.sh index 98ff4210780..f9374aaf55a 100755 --- a/egs/aspire/s5/local/segmentation/run_segmentation_ami.sh +++ b/egs/aspire/s5/local/segmentation/run_segmentation_ami.sh @@ -12,6 +12,8 @@ set -u stage=-1 nnet_dir=exp/nnet3_sad_snr/nnet_tdnn_k_n4 +extra_left_context=100 +extra_right_context=20 . utils/parse_options.sh @@ -99,23 +101,24 @@ if [ $stage -le 6 ]; then rttmSmooth.pl -s 0 \| awk '{ print $2" "$3" "$4" "$5+$4 }' '>' $dir/uem fi -hyp_dir=$nnet_dir/segmentation_ami_sdm1_dev_whole_bp +hyp_dir=${nnet_dir}/segmentation_ami_sdm1_dev_whole_bp/ami_sdm1_dev if [ $stage -le 7 ]; then steps/segmentation/do_segmentation_data_dir.sh --reco-nj 18 \ --mfcc-config conf/mfcc_hires_bp.conf --feat-affix bp --do-downsampling true \ - --extra-left-context 100 --extra-right-context 20 \ + --extra-left-context $extra_left_context --extra-right-context $extra_right_context \ --output-name output-speech --frame-subsampling-factor 6 \ - $src_dir/data/sdm1/dev data/ami_sdm1_dev $nnet_dir + $src_dir/data/sdm1/dev $nnet_dir mfcc_hires_bp $hyp_dir fi +hyp_dir=${hyp_dir}_seg if [ $stage -le 8 ]; then utils/data/get_reco2utt.sh $src_dir/data/sdm1/dev_ihmdata steps/segmentation/convert_utt2spk_and_segments_to_rttm.py \ - $hyp_dir/ami_sdm1_dev_seg/utt2spk \ - $hyp_dir/ami_sdm1_dev_seg/segments \ + $hyp_dir/utt2spk \ + $hyp_dir/segments \ $dir/reco2file_and_channel \ /dev/stdout | spkr2sad.pl > $hyp_dir/sys.rttm fi From 5d0b82808287e53f927a2c1948dde93874c0ab9a Mon Sep 17 00:00:00 2001 From: Vimal Manohar Date: Thu, 15 Dec 2016 19:46:01 -0500 Subject: [PATCH 102/213] Bug fix in basic_layers.py --- egs/wsj/s5/steps/libs/nnet3/xconfig/basic_layers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/wsj/s5/steps/libs/nnet3/xconfig/basic_layers.py b/egs/wsj/s5/steps/libs/nnet3/xconfig/basic_layers.py index c612af984b1..38ff36622ec 100644 --- a/egs/wsj/s5/steps/libs/nnet3/xconfig/basic_layers.py +++ b/egs/wsj/s5/steps/libs/nnet3/xconfig/basic_layers.py @@ -458,7 +458,7 @@ def check_configs(self): "".format(self.config['dim']), self.str()) if self.config['objective-type'] != 'linear' and \ - self.config['objective_type'] != 'quadratic': + self.config['objective-type'] != 'quadratic': raise xparser_error("In output-layer, objective-type has" " invalid value {0}" "".format(self.config['objective-type']), From 93fe5b3399ff7feb7ad7787605294abf8cc3d1cb Mon Sep 17 00:00:00 2001 From: Vimal Manohar Date: Thu, 15 Dec 2016 19:46:17 -0500 Subject: [PATCH 103/213] asr_diarization: Minor fix in get_egs_multiple_targets.py --- egs/wsj/s5/steps/nnet3/get_egs_multiple_targets.py | 1 + 1 file changed, 1 insertion(+) diff --git a/egs/wsj/s5/steps/nnet3/get_egs_multiple_targets.py b/egs/wsj/s5/steps/nnet3/get_egs_multiple_targets.py index 16e1f98a019..fa8a68f5c64 100755 --- a/egs/wsj/s5/steps/nnet3/get_egs_multiple_targets.py +++ b/egs/wsj/s5/steps/nnet3/get_egs_multiple_targets.py @@ -14,6 +14,7 @@ import math import glob +sys.path.insert(0, 'steps') import libs.data as data_lib import libs.common as common_lib From 87511da98aceb588420f68aa254e4494a8953af7 Mon Sep 17 00:00:00 2001 From: Vimal Manohar Date: Thu, 15 Dec 2016 19:47:04 -0500 Subject: [PATCH 104/213] asr_diarization: Change the way do_corruption_data_dir_overlapped_speech.sh works to non-whole dirs --- ...o_corruption_data_dir_overlapped_speech.sh | 172 +++---------- .../prepare_unsad_overlapped_speech_data.sh | 236 ++++++++++++++++++ 2 files changed, 266 insertions(+), 142 deletions(-) create mode 100755 egs/aspire/s5/local/segmentation/prepare_unsad_overlapped_speech_data.sh diff --git a/egs/aspire/s5/local/segmentation/do_corruption_data_dir_overlapped_speech.sh b/egs/aspire/s5/local/segmentation/do_corruption_data_dir_overlapped_speech.sh index 4d532be4353..aa1d9adc3e9 100755 --- a/egs/aspire/s5/local/segmentation/do_corruption_data_dir_overlapped_speech.sh +++ b/egs/aspire/s5/local/segmentation/do_corruption_data_dir_overlapped_speech.sh @@ -15,7 +15,6 @@ corrupt_only=false # Data options data_dir=data/train_si284 # Excpecting non-whole data directory -speed_perturb=true num_data_reps=5 # Number of corrupted versions snrs="20:10:15:5:0:-5" foreground_snrs="20:10:15:5:0:-5" @@ -80,9 +79,9 @@ fi clean_data_dir=data/${clean_data_id} corrupted_data_dir=data/${corrupted_data_id} noise_data_dir=data/${noise_data_id} -orig_corrupted_data_dir=$corrupted_data_dir +orig_corrupted_data_dir=data/${corrupted_data_id} -if $speed_perturb; then +if false; then if [ $stage -le 2 ]; then for x in $clean_data_dir $corrupted_data_dir $noise_data_dir; do utils/data/perturb_data_dir_speed_3way.sh $x ${x}_sp @@ -96,12 +95,12 @@ if $speed_perturb; then corrupted_data_id=${corrupted_data_id}_sp clean_data_id=${clean_data_id}_sp noise_data_id=${noise_data_id}_sp +fi - if [ $stage -le 3 ]; then - utils/data/perturb_data_dir_volume.sh --scale-low 0.03125 --scale-high 2 ${corrupted_data_dir} - utils/data/perturb_data_dir_volume.sh --reco2vol ${corrupted_data_dir}/reco2vol ${clean_data_dir} - utils/data/perturb_data_dir_volume.sh --reco2vol ${corrupted_data_dir}/reco2vol ${noise_data_dir} - fi +if [ $stage -le 3 ]; then + utils/data/perturb_data_dir_volume.sh --scale-low 0.03125 --scale-high 2 ${corrupted_data_dir} + utils/data/perturb_data_dir_volume.sh --reco2vol ${corrupted_data_dir}/reco2vol ${clean_data_dir} + utils/data/perturb_data_dir_volume.sh --reco2vol ${corrupted_data_dir}/reco2vol ${noise_data_dir} fi if $corrupt_only; then @@ -123,24 +122,32 @@ if [ $stage -le 4 ]; then steps/make_mfcc.sh --mfcc-config $mfcc_config \ --cmd "$cmd" --nj $nj \ $corrupted_data_dir exp/make_${feat_suffix}/${corrupted_data_id} $mfccdir +else + corrupted_data_dir=${corrupted_data_dir}_$feat_suffix fi -if false; then +exit 0 + if [ $stage -le 5 ]; then - steps/make_mfcc.sh --mfcc-config $energy_config \ + # clean here is the reverberated first-speaker signal + utils/copy_data_dir.sh $clean_data_dir ${clean_data_dir}_$feat_suffix + clean_data_dir=${clean_data_dir}_$feat_suffix + steps/make_mfcc.sh --mfcc-config $mfcc_config \ --cmd "$cmd" --nj $nj \ - $clean_data_dir exp/make_log_energy/${clean_data_id} log_energy_feats + $clean_data_dir exp/make_${feat_suffix}/${clean_data_id} $mfccdir +else + clean_data_dir=${clean_data_dir}_$feat_suffix fi if [ $stage -le 6 ]; then - steps/make_mfcc.sh --mfcc-config $energy_config \ + # noise here is the reverberated second-speaker signal + utils/copy_data_dir.sh $noise_data_dir ${noise_data_dir}_$feat_suffix + noise_data_dir=${noise_data_dir}_$feat_suffix + steps/make_mfcc.sh --mfcc-config $mfcc_config \ --cmd "$cmd" --nj $nj \ - $noise_data_dir exp/make_log_energy/${noise_data_id} log_energy_feats -fi - -if [ -z "$reco_vad_dir" ]; then - echo "reco-vad-dir must be provided" - exit 1 + $noise_data_dir exp/make_${feat_suffix}/${noise_data_id} $mfccdir +else + noise_data_dir=${noise_data_dir}_$feat_suffix fi targets_dir=irm_targets @@ -152,134 +159,15 @@ if [ $stage -le 8 ]; then /export/b0{3,4,5,6}/$USER/kaldi-data/egs/aspire-$(date +'%m_%d_%H_%M')/s5/$targets_dir/storage $targets_dir/storage fi + # Get SNR targets only for the overlapped speech labels. steps/segmentation/make_snr_targets.sh \ --nj $nj --cmd "$cmd --max-jobs-run $max_jobs_run" \ - --target-type Irm --compress true --apply-exp false \ + --target-type Irm --compress false --apply-exp true \ + --ali-rspecifier "ark,s,cs:cat ${corrupted_data_dir}/sad_seg.scp | segmentation-to-ali --lengths-rspecifier=ark,t:${corrupted_data_dir}/utt2num_frames scp:- ark:- |" \ + overlapped_speech_labels.scp \ + --silence-phones 0 \ ${clean_data_dir} ${noise_data_dir} ${corrupted_data_dir} \ exp/make_irm_targets/${corrupted_data_id} $targets_dir fi -fi - -# Combine the VAD from the base recording and the VAD from the overlapping segments -# to create per-frame labels of the number of overlapping speech segments -# Unreliable segments are regions where no VAD labels were available for the -# overlapping segments. These can be later removed by setting deriv weights to 0. - -# Data dirs without speed perturbation -overlap_dir=exp/make_overlap_labels/${corrupted_data_id} -unreliable_dir=exp/make_overlap_labels/unreliable_${corrupted_data_id} -overlap_data_dir=$overlap_dir/overlap_data -unreliable_data_dir=$overlap_dir/unreliable_data - -mkdir -p $unreliable_dir - -if [ $stage -le 8 ]; then - cat $utt_vad_dir/sad_seg.scp | \ - steps/segmentation/get_reverb_scp.pl -f 1 $num_data_reps "ovlp" \ - | sort -k1,1 > ${corrupted_data_dir}/sad_seg.scp - utils/data/get_utt2num_frames.sh $corrupted_data_dir - utils/split_data.sh ${orig_corrupted_data_dir} $nj - - $cmd JOB=1:$nj $overlap_dir/log/get_overlap_seg.JOB.log \ - segmentation-init-from-overlap-info --lengths-rspecifier=ark,t:$corrupted_data_dir/utt2num_frames \ - "scp:utils/filter_scp.pl ${orig_corrupted_data_dir}/split${nj}/JOB/utt2spk $corrupted_data_dir/sad_seg.scp |" \ - ark,t:$orig_corrupted_data_dir/overlapped_segments_info.txt \ - scp:$utt_vad_dir/sad_seg.scp ark:- ark:$unreliable_dir/unreliable_seg_speed_unperturbed.JOB.ark \| \ - segmentation-copy --keep-label=1 ark:- ark:- \| \ - segmentation-get-stats --lengths-rspecifier=ark,t:$corrupted_data_dir/utt2num_frames \ - ark:- ark:- ark:/dev/null \| \ - segmentation-init-from-ali ark:- ark:$overlap_dir/overlap_seg_speed_unperturbed.JOB.ark -fi - -exit 1 - -if [ $stage -le 9 ]; then - mkdir -p $overlap_data_dir $unreliable_data_dir - cp $orig_corrupted_data_dir/wav.scp $overlap_data_dir - cp $orig_corrupted_data_dir/wav.scp $unreliable_data_dir - - # Create segments where there is definitely an overlap. - # Assume no more than 10 speakers overlap. - $cmd JOB=1:$nj $overlap_dir/log/process_to_segments.JOB.log \ - segmentation-post-process --remove-labels=0:1 \ - ark:$overlap_dir/overlap_seg_speed_unperturbed.JOB.ark ark:- \| \ - segmentation-post-process --merge-labels=2:3:4:5:6:7:8:9:10 --merge-dst-label=1 ark:- ark:- \| \ - segmentation-to-segments ark:- ark:$overlap_data_dir/utt2spk.JOB $overlap_data_dir/segments.JOB - - $cmd JOB=1:$nj $overlap_dir/log/get_unreliable_segments.JOB.log \ - segmentation-to-segments --single-speaker \ - ark:$unreliable_dir/unreliable_seg_speed_unperturbed.JOB.ark \ - ark:$unreliable_data_dir/utt2spk.JOB $unreliable_data_dir/segments.JOB - - for n in `seq $nj`; do cat $overlap_data_dir/utt2spk.$n; done > $overlap_data_dir/utt2spk - for n in `seq $nj`; do cat $overlap_data_dir/segments.$n; done > $overlap_data_dir/segments - for n in `seq $nj`; do cat $unreliable_data_dir/utt2spk.$n; done > $unreliable_data_dir/utt2spk - for n in `seq $nj`; do cat $unreliable_data_dir/segments.$n; done > $unreliable_data_dir/segments - - utils/fix_data_dir.sh $overlap_data_dir - utils/fix_data_dir.sh $unreliable_data_dir - - if $speed_perturb; then - utils/data/perturb_data_dir_speed_3way.sh $overlap_data_dir ${overlap_data_dir}_sp - utils/data/perturb_data_dir_speed_3way.sh $unreliable_data_dir ${unreliable_data_dir}_sp - fi -fi - -if $speed_perturb; then - overlap_data_dir=${overlap_data_dir}_sp - unreliable_data_dir=${unreliable_data_dir}_sp -fi - -# make $overlap_labels_dir an absolute pathname. -overlap_labels_dir=`perl -e '($dir,$pwd)= @ARGV; if($dir!~m:^/:) { $dir = "$pwd/$dir"; } print $dir; ' $overlap_labels_dir ${PWD}` - -if [ $stage -le 10 ]; then - utils/split_data.sh ${overlap_data_dir} $nj - - $cmd JOB=1:$nj $overlap_dir/log/get_overlap_speech_labels.JOB.log \ - utils/data/get_reco2utt.sh ${overlap_data_dir}/split${reco_nj}reco/JOB '&&' \ - segmentation-init-from-segments --shift-to-zero=false \ - ${overlap_data_dir}/split${reco_nj}reco/JOB/segments ark:- \| \ - segmentation-combine-segments-to-recordings ark:- ark,t:${overlap_data_dir}/split${reco_nj}reco/JOB/reco2utt \ - ark:- \| \ - segmentation-to-ali --lengths-rspecifier=ark,t:${corrupted_data_dir}/utt2num_frames ark:- \ - ark,scp:$overlap_labels_dir/overlapped_speech_${corrupted_data_id}.JOB.ark,$overlap_labels_dir/overlapped_speech_${corrupted_data_id}.JOB.scp -fi - -for n in `seq $reco_nj`; do - cat $overlap_labels_dir/overlapped_speech_${corrupted_data_id}.$n.scp -done > ${corrupted_data_dir}/overlapped_speech_labels.scp - -if [ $stage -le 11 ]; then - utils/data/get_reco2utt.sh ${unreliable_data_dir} - - # First convert the unreliable segments into a recording-level segmentation. - # Initialize a segmentation from utt2num_frames and set to 0, the regions - # of unreliable segments. At this stage deriv weights is 1 for all but the - # unreliable segment regions. - # Initialize a segmentation from the VAD labels and retain only the speech segments. - # Intersect this with the deriv weights segmentation from above. At this stage - # deriv weights is 1 for only the regions where base VAD label is 1 and - # the overlapping segment is not unreliable. Convert this to deriv weights. - $cmd JOB=1:$reco_nj $unreliable_dir/log/get_deriv_weights.JOB.log\ - segmentation-init-from-segments --shift-to-zero=false \ - "utils/filter_scp.pl -f 2 ${overlap_data_dir}/split${reco_nj}reco/JOB/reco2utt ${unreliable_data_dir}/segments |" ark:- \| \ - segmentation-combine-segments-to-recordings ark:- "ark,t:utils/filter_scp.pl ${overlap_data_dir}/split${reco_nj}reco/JOB/reco2utt ${unreliable_data_dir}/reco2utt |" \ - ark:- \| \ - segmentation-create-subsegments --filter-label=1 --subsegment-label=0 --ignore-missing \ - "ark:utils/filter_scp.pl ${overlap_data_dir}/split${reco_nj}reco/JOB/reco2utt $corrupted_data_dir/utt2num_frames | segmentation-init-from-lengths ark,t:- ark:- |" \ - ark:- ark:- \| \ - segmentation-intersect-segments --mismatch-label=0 \ - "ark:utils/filter_scp.pl ${overlap_data_dir}/split${reco_nj}reco/JOB/reco2utt $corrupted_data_dir/sad_seg.scp | segmentation-post-process --remove-labels=0:2:3 scp:- ark:- |" \ - ark:- ark:- \| \ - segmentation-post-process --remove-labels=0 ark:- ark:- \| \ - segmentation-to-ali --lengths-rspecifier=ark,t:${corrupted_data_dir}/utt2num_frames ark:- ark,t:- \| \ - steps/segmentation/convert_ali_to_vec.pl \| copy-vector ark,t:- \ - ark,scp:$overlap_labels_dir/deriv_weights_for_overlapped_speech.JOB.ark,$overlap_labels_dir/deriv_weights_for_overlapped_speech.JOB.scp - - for n in `seq $reco_nj`; do - cat $overlap_labels_dir/deriv_weights_for_overlapped_speech.${n}.scp - done > $corrupted_data_dir/deriv_weights_for_overlapped_speech.scp -fi exit 0 diff --git a/egs/aspire/s5/local/segmentation/prepare_unsad_overlapped_speech_data.sh b/egs/aspire/s5/local/segmentation/prepare_unsad_overlapped_speech_data.sh new file mode 100755 index 00000000000..36eb4de2afe --- /dev/null +++ b/egs/aspire/s5/local/segmentation/prepare_unsad_overlapped_speech_data.sh @@ -0,0 +1,236 @@ +#! /bin/bash + +# Copyright 2016 Vimal Manohar +# Apache 2.0 + +set -e +set -u +set -o pipefail + +. path.sh + +num_data_reps=5 +nj=40 +cmd=queue.pl +stage=-1 + +. utils/parse_options.sh + +if [ $# -ne 5 ]; then + echo "Usage: $0 " + echo " e.g.: $0 data/fisher_train_100k_sp_75k_hires_bp data/fisher_train_100k_sp_75k/overlapped_segments_info.txt exp/unsad/make_unsad_fisher_train_100k_sp/tri4_ali_fisher_train_100k_sp_vad_fisher_train_100k_sp exp/unsad/make_overlap_labels/fisher_train_100k_sp_75k overlap_labels" + exit 1 +fi + +corrupted_data_dir=$1 +orig_corrupted_data_dir=$2 +utt_vad_dir=$3 +tmpdir=$4 +overlap_labels_dir=$5 + +overlapped_segments_info=$orig_corrupted_data_dir/overlapped_segments_info.txt +corrupted_data_id=`basename $orig_corrupted_data_dir` + +for f in $corrupted_data_dir/feats.scp $overlapped_segments_info $utt_vad_dir/sad_seg.scp; do + [ ! -f $f ] && echo "Could not find file $f" && exit 1 +done + +overlap_dir=$tmpdir/make_overlap_labels_${corrupted_data_id} +unreliable_dir=$tmpdir/unreliable_${corrupted_data_id} + +mkdir -p $unreliable_dir + +# make $overlap_labels_dir an absolute pathname. +overlap_labels_dir=`perl -e '($dir,$pwd)= @ARGV; if($dir!~m:^/:) { $dir = "$pwd/$dir"; } print $dir; ' $overlap_labels_dir ${PWD}` + +# Combine the VAD from the base recording and the VAD from the overlapping segments +# to create per-frame labels of the number of overlapping speech segments +# Unreliable segments are regions where no VAD labels were available for the +# overlapping segments. These can be later removed by setting deriv weights to 0. + +if [ $stage -le 1 ]; then + for n in `seq $num_data_reps`; do + cat $utt_vad_dir/sad_seg.scp | \ + awk -v n=$n '{print "ovlp"n"_"$0}' + done | sort -k1,1 > ${corrupted_data_dir}/sad_seg.scp + utils/data/get_utt2num_frames.sh $corrupted_data_dir + utils/split_data.sh ${corrupted_data_dir} $nj + + $cmd JOB=1:$nj $overlap_dir/log/get_overlap_seg.JOB.log \ + segmentation-init-from-additive-signals-info --lengths-rspecifier=ark,t:$corrupted_data_dir/utt2num_frames \ + --additive-signals-segmentation-rspecifier=scp:$utt_vad_dir/sad_seg.scp \ + --unreliable-segmentation-wspecifier="ark:| gzip -c > $unreliable_dir/unreliable_seg.JOB.gz" \ + "scp:utils/filter_scp.pl ${corrupted_data_dir}/split${nj}/JOB/utt2spk $corrupted_data_dir/sad_seg.scp |" \ + ark,t:$orig_corrupted_data_dir/overlapped_segments_info.txt ark:- \| \ + segmentation-copy --keep-label=1 ark:- ark:- \| \ + segmentation-get-stats --lengths-rspecifier=ark,t:$corrupted_data_dir/utt2num_frames \ + ark:- ark:- ark:/dev/null \| \ + segmentation-init-from-ali ark:- "ark:| gzip -c > $overlap_dir/overlap_seg.JOB.gz" +fi + +if [ $stage -le 2 ]; then + $cmd JOB=1:$nj $overlap_dir/log/get_overlapped_speech_labels.JOB.log \ + gunzip -c $overlap_dir/overlap_seg.JOB.gz \| \ + segmentation-post-process --remove-labels=0:1 ark:- ark:- \| \ + segmentation-post-process --merge-labels=2:3:4:5:6:7:8:9:10 --merge-dst-label=1 ark:- ark:- \| \ + segmentation-to-ali --lengths-rspecifier=ark,t:${corrupted_data_dir}/utt2num_frames ark:- \ + ark,scp:$overlap_labels_dir/overlapped_speech_labels_${corrupted_data_id}.JOB.ark,$overlap_labels_dir/overlapped_speech_labels_${corrupted_data_id}.JOB.scp + + for n in `seq $nj`; do + cat $overlap_labels_dir/overlapped_speech_labels_${corrupted_data_id}.$n.scp + done > ${corrupted_data_dir}/overlapped_speech_labels.scp +fi + +if [ $stage -le 3 ]; then + # First convert the unreliable segments into a segmentation. + # Initialize a segmentation from utt2num_frames and set to 0, the regions + # of unreliable segments. At this stage deriv weights is 1 for all but the + # unreliable segment regions. + # Initialize a segmentation from the overlap labels and retain regions where + # there is speech from at least one speaker. + # Intersect this with the deriv weights segmentation from above. + # At this stage deriv weights is 1 for only the regions where there is + # at least one speaker and the the overlapping segment is not unreliable. + # Convert this to deriv weights. + $cmd JOB=1:$nj $unreliable_dir/log/get_deriv_weights.JOB.log \ + utils/filter_scp.pl $corrupted_data_dir/split$nj/JOB/utt2spk $corrupted_data_dir/utt2num_frames \| \ + segmentation-init-from-lengths ark,t:- ark:- \| \ + segmentation-create-subsegments --filter-label=1 --subsegment-label=0 --ignore-missing \ + ark:- "ark,s,cs:gunzip -c $unreliable_dir/unreliable_seg.JOB.gz | segmentation-to-segments ark:- - | segmentation-init-from-segments - ark:- |" ark:- \| \ + segmentation-intersect-segments --mismatch-label=0 \ + "ark:gunzip -c $overlap_dir/overlap_seg.JOB.gz | segmentation-post-process --remove-labels=0 --merge-labels=1:2:3:4:5:6:7:8:9:10 --merge-dst-label=1 ark:- ark:- |" \ + ark,s,cs:- ark:- \| segmentation-post-process --remove-labels=0 ark:- ark:- \| \ + segmentation-to-ali --lengths-rspecifier=ark,t:${corrupted_data_dir}/utt2num_frames ark:- ark,t:- \| \ + steps/segmentation/convert_ali_to_vec.pl \| copy-vector ark,t:- \ + ark,scp:$overlap_labels_dir/deriv_weights_for_overlapped_speech_${corrupted_data_id}.JOB.ark,$overlap_labels_dir/deriv_weights_for_overlapped_speech_${corrupted_data_id}.JOB.scp + + for n in `seq $nj`; do + cat $overlap_labels_dir/deriv_weights_for_overlapped_speech_$corrupted_data_id.${n}.scp + done > $corrupted_data_dir/deriv_weights_for_overlapped_speech.scp +fi + +if [ $stage -le 4 ]; then + # Get only first speaker labels as speech_feat as we are not sure of the energy levels of the other speaker. + $cmd JOB=1:$nj $overlap_dir/log/get_first_speaker_labels.JOB.log \ + gunzip -c $overlap_dir/overlap_seg.JOB.gz \| \ + segmentation-post-process --remove-labels=0 --merge-labels=1:2:3:4:5:6:7:8:9:10 --merge-dst-label=1 ark:- ark:- \| \ + segmentation-to-ali --lengths-rspecifier=ark,t:${corrupted_data_dir}/utt2num_frames ark:- ark,t:- \| \ + steps/segmentation/convert_ali_to_vec.pl \| \ + vector-to-feat ark:- \ + ark,scp:$overlap_labels_dir/speech_feat_${corrupted_data_id}.JOB.ark,$overlap_labels_dir/speech_feat_${corrupted_data_id}.JOB.scp + + for n in `seq $nj`; do + cat $overlap_labels_dir/speech_feat_${corrupted_data_id}.$n.scp + done > ${corrupted_data_dir}/speech_feat.scp +fi + +if [ $stage -le 5 ]; then + $cmd JOB=1:$nj $unreliable_dir/log/get_deriv_weights.JOB.log \ + utils/filter_scp.pl $corrupted_data_dir/split$nj/JOB/utt2spk $corrupted_data_dir/utt2num_frames \| \ + segmentation-init-from-lengths ark,t:- ark:- \| \ + segmentation-create-subsegments --filter-label=1 --subsegment-label=0 --ignore-missing \ + ark:- "ark,s,cs:gunzip -c $unreliable_dir/unreliable_seg.JOB.gz | segmentation-to-segments ark:- - | segmentation-init-from-segments - ark:- |" ark:- \| \ + segmentation-post-process --remove-labels=0 ark:- ark:- \| \ + segmentation-to-ali --lengths-rspecifier=ark,t:${corrupted_data_dir}/utt2num_frames ark:- ark,t:- \| \ + steps/segmentation/convert_ali_to_vec.pl \| copy-vector ark,t:- \ + ark,scp:$overlap_labels_dir/deriv_weights_${corrupted_data_id}.JOB.ark,$overlap_labels_dir/deriv_weights_${corrupted_data_id}.JOB.scp + + for n in `seq $nj`; do + cat $overlap_labels_dir/deriv_weights_$corrupted_data_id.${n}.scp + done > $corrupted_data_dir/deriv_weights.scp +fi + +exit 0 + +####exit 1 +#### +####if [ $stage -le 9 ]; then +#### mkdir -p $overlap_data_dir $unreliable_data_dir +#### cp $orig_corrupted_data_dir/wav.scp $overlap_data_dir +#### cp $orig_corrupted_data_dir/wav.scp $unreliable_data_dir +#### +#### # Create segments where there is definitely an overlap. +#### # Assume no more than 10 speakers overlap. +#### $cmd JOB=1:$nj $overlap_dir/log/process_to_segments.JOB.log \ +#### segmentation-post-process --remove-labels=0:1 \ +#### ark:$overlap_dir/overlap_seg_speed_unperturbed.JOB.ark ark:- \| \ +#### segmentation-post-process --merge-labels=2:3:4:5:6:7:8:9:10 --merge-dst-label=1 ark:- ark:- \| \ +#### segmentation-to-segments ark:- ark:$overlap_data_dir/utt2spk.JOB $overlap_data_dir/segments.JOB +#### +#### $cmd JOB=1:$nj $overlap_dir/log/get_unreliable_segments.JOB.log \ +#### segmentation-to-segments --single-speaker \ +#### ark:$unreliable_dir/unreliable_seg_speed_unperturbed.JOB.ark \ +#### ark:$unreliable_data_dir/utt2spk.JOB $unreliable_data_dir/segments.JOB +#### +#### for n in `seq $nj`; do cat $overlap_data_dir/utt2spk.$n; done > $overlap_data_dir/utt2spk +#### for n in `seq $nj`; do cat $overlap_data_dir/segments.$n; done > $overlap_data_dir/segments +#### for n in `seq $nj`; do cat $unreliable_data_dir/utt2spk.$n; done > $unreliable_data_dir/utt2spk +#### for n in `seq $nj`; do cat $unreliable_data_dir/segments.$n; done > $unreliable_data_dir/segments +#### +#### utils/fix_data_dir.sh $overlap_data_dir +#### utils/fix_data_dir.sh $unreliable_data_dir +#### +#### if $speed_perturb; then +#### utils/data/perturb_data_dir_speed_3way.sh $overlap_data_dir ${overlap_data_dir}_sp +#### utils/data/perturb_data_dir_speed_3way.sh $unreliable_data_dir ${unreliable_data_dir}_sp +#### fi +####fi +#### +####if $speed_perturb; then +#### overlap_data_dir=${overlap_data_dir}_sp +#### unreliable_data_dir=${unreliable_data_dir}_sp +####fi +#### +##### make $overlap_labels_dir an absolute pathname. +####overlap_labels_dir=`perl -e '($dir,$pwd)= @ARGV; if($dir!~m:^/:) { $dir = "$pwd/$dir"; } print $dir; ' $overlap_labels_dir ${PWD}` +#### +####if [ $stage -le 10 ]; then +#### utils/split_data.sh ${overlap_data_dir} $nj +#### +#### $cmd JOB=1:$nj $overlap_dir/log/get_overlap_speech_labels.JOB.log \ +#### utils/data/get_reco2utt.sh ${overlap_data_dir}/split${reco_nj}reco/JOB '&&' \ +#### segmentation-init-from-segments --shift-to-zero=false \ +#### ${overlap_data_dir}/split${reco_nj}reco/JOB/segments ark:- \| \ +#### segmentation-combine-segments-to-recordings ark:- ark,t:${overlap_data_dir}/split${reco_nj}reco/JOB/reco2utt \ +#### ark:- \| \ +#### segmentation-to-ali --lengths-rspecifier=ark,t:${corrupted_data_dir}/utt2num_frames ark:- \ +#### ark,scp:$overlap_labels_dir/overlapped_speech_${corrupted_data_id}.JOB.ark,$overlap_labels_dir/overlapped_speech_${corrupted_data_id}.JOB.scp +####fi +#### +####for n in `seq $reco_nj`; do +#### cat $overlap_labels_dir/overlapped_speech_${corrupted_data_id}.$n.scp +####done > ${corrupted_data_dir}/overlapped_speech_labels.scp +#### +####if [ $stage -le 11 ]; then +#### utils/data/get_reco2utt.sh ${unreliable_data_dir} +#### +#### # First convert the unreliable segments into a recording-level segmentation. +#### # Initialize a segmentation from utt2num_frames and set to 0, the regions +#### # of unreliable segments. At this stage deriv weights is 1 for all but the +#### # unreliable segment regions. +#### # Initialize a segmentation from the VAD labels and retain only the speech segments. +#### # Intersect this with the deriv weights segmentation from above. At this stage +#### # deriv weights is 1 for only the regions where base VAD label is 1 and +#### # the overlapping segment is not unreliable. Convert this to deriv weights. +#### $cmd JOB=1:$reco_nj $unreliable_dir/log/get_deriv_weights.JOB.log\ +#### segmentation-init-from-segments --shift-to-zero=false \ +#### "utils/filter_scp.pl -f 2 ${overlap_data_dir}/split${reco_nj}reco/JOB/reco2utt ${unreliable_data_dir}/segments |" ark:- \| \ +#### segmentation-combine-segments-to-recordings ark:- "ark,t:utils/filter_scp.pl ${overlap_data_dir}/split${reco_nj}reco/JOB/reco2utt ${unreliable_data_dir}/reco2utt |" \ +#### ark:- \| \ +#### segmentation-create-subsegments --filter-label=1 --subsegment-label=0 --ignore-missing \ +#### "ark:utils/filter_scp.pl ${overlap_data_dir}/split${reco_nj}reco/JOB/reco2utt $corrupted_data_dir/utt2num_frames | segmentation-init-from-lengths ark,t:- ark:- |" \ +#### ark:- ark:- \| \ +#### segmentation-intersect-segments --mismatch-label=0 \ +#### "ark:utils/filter_scp.pl ${overlap_data_dir}/split${reco_nj}reco/JOB/reco2utt $corrupted_data_dir/sad_seg.scp | segmentation-post-process --remove-labels=0:2:3 scp:- ark:- |" \ +#### ark:- ark:- \| \ +#### segmentation-post-process --remove-labels=0 ark:- ark:- \| \ +#### segmentation-to-ali --lengths-rspecifier=ark,t:${corrupted_data_dir}/utt2num_frames ark:- ark,t:- \| \ +#### steps/segmentation/convert_ali_to_vec.pl \| copy-vector ark,t:- \ +#### ark,scp:$overlap_labels_dir/deriv_weights_for_overlapped_speech.JOB.ark,$overlap_labels_dir/deriv_weights_for_overlapped_speech.JOB.scp +#### +#### for n in `seq $reco_nj`; do +#### cat $overlap_labels_dir/deriv_weights_for_overlapped_speech.${n}.scp +#### done > $corrupted_data_dir/deriv_weights_for_overlapped_speech.scp +####fi +#### +####exit 0 From 3368166cf2af1f6fadcc9899e67b717385d575b7 Mon Sep 17 00:00:00 2001 From: Vimal Manohar Date: Sun, 18 Dec 2016 01:39:54 -0500 Subject: [PATCH 105/213] Bug fix in nnet3 training Conflicts: egs/wsj/s5/steps/libs/nnet3/train/frame_level_objf/common.py --- .../steps/libs/nnet3/train/frame_level_objf/common.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/egs/wsj/s5/steps/libs/nnet3/train/frame_level_objf/common.py b/egs/wsj/s5/steps/libs/nnet3/train/frame_level_objf/common.py index 0b0149ece3d..b8a28d2e2bf 100644 --- a/egs/wsj/s5/steps/libs/nnet3/train/frame_level_objf/common.py +++ b/egs/wsj/s5/steps/libs/nnet3/train/frame_level_objf/common.py @@ -98,6 +98,12 @@ def train_new_models(dir, iter, srand, num_jobs, cache_write_opt = "--write-cache={dir}/cache.{iter}".format( dir=dir, iter=iter+1) + minibatch_opts = "--minibatch-size={0}".format(minibatch_size) + + if chunk_level_training: + minibatch_opts = "{0} --measure-output-frames=false".format( + minibatch_size) + process_handle = common_lib.run_job( """{command} {train_queue_opt} {dir}/log/train.{iter}.{job}.log \ nnet3-train {parallel_train_opts} {cache_read_opt} \ @@ -109,8 +115,7 @@ def train_new_models(dir, iter, srand, num_jobs, """ark:{egs_dir}/egs.{archive_index}.ark ark:- |{extra_egs_copy_cmd}""" """nnet3-shuffle-egs --buffer-size={shuffle_buffer_size} """ """--srand={srand} ark:- ark:- | """ - """nnet3-merge-egs --minibatch-size={minibatch_size} """ - """--measure-output-frames=false """ + """nnet3-merge-egs {minibatch_opts} """ """--discard-partial-minibatches=true ark:- ark:- |" \ {dir}/{next_iter}.{job}.raw""".format( command=run_opts.command, @@ -124,12 +129,12 @@ def train_new_models(dir, iter, srand, num_jobs, frame_opts=("" if chunk_level_training else "--frame={0}".format(frame)), + minibatch_opts=minibatch_opts, momentum=momentum, max_param_change=max_param_change, deriv_time_opts=" ".join(deriv_time_opts), raw_model=raw_model_string, context_opts=context_opts, egs_dir=egs_dir, archive_index=archive_index, shuffle_buffer_size=shuffle_buffer_size, - minibatch_size=minibatch_size, extra_egs_copy_cmd=extra_egs_copy_cmd), wait=False) From 56f087bdee4602c5c38bbbda8ac2e442d2b90ffb Mon Sep 17 00:00:00 2001 From: Vimal Manohar Date: Mon, 19 Dec 2016 16:47:06 -0500 Subject: [PATCH 106/213] asr_diarization: Adding multilingual egs --- .../allocate_multilingual_examples.py | 282 ++++++++++++++++++ .../s5/steps/nnet3/multilingual/get_egs.sh | 130 ++++++++ 2 files changed, 412 insertions(+) create mode 100644 egs/wsj/s5/steps/nnet3/multilingual/allocate_multilingual_examples.py create mode 100755 egs/wsj/s5/steps/nnet3/multilingual/get_egs.sh diff --git a/egs/wsj/s5/steps/nnet3/multilingual/allocate_multilingual_examples.py b/egs/wsj/s5/steps/nnet3/multilingual/allocate_multilingual_examples.py new file mode 100644 index 00000000000..cba804b1a66 --- /dev/null +++ b/egs/wsj/s5/steps/nnet3/multilingual/allocate_multilingual_examples.py @@ -0,0 +1,282 @@ +#!/usr/bin/env python + +# This script generates egs.Archive.scp and ranges.* used for generating egs.Archive.scp +# for multilingual setup. +# Also this script generates outputs.*.scp and weight.*.scp, where each line +# corresponds to language-id and weight for the same example in egs.*.scp. +# weight.*.scp used to scale the output's posterior during training. +# ranges.*.scp is generated w.r.t frequency distribution of remaining examples +# in each language. +# +# You call this script as (e.g.) +# +# allocate_multilingual_examples.py [opts] num-of-languages example-scp-lists multilingual-egs-dir +# +# allocate_multilingual_examples.py --num-jobs 10 --samples-per-iter 10000 --minibatch-size 512 +# --lang2weight exp/multi/lang2weight 2 "exp/lang1/egs.scp exp/lang2/egs.scp" +# exp/multi/egs +# +# This script outputs specific ranges.* files to the temp directory (exp/multi/egs/temp) +# that will enable you to creat egs.*.scp files for multilingual training. +# exp/multi/egs/temp/ranges.* contains something like the following: +# e.g. +# lang1 0 0 256 +# lang2 1 256 256 +# +# where each line can be interpreted as follows: +# +# +# note that is the zero-based line number in egs.scp for +# that language. +# num-examples is multiple of actual minibatch-size. +# +# +# egs.1.scp is generated using ranges.1.scp as following: +# "num_examples" consecutive examples starting from line "local-scp-line" from +# egs.scp file for language "source-lang" is copied to egs.1.scp. +# +# + +from __future__ import print_function +import re, os, argparse, sys, math, warnings, random, io, imp + +import logging + +sys.path.insert(0, 'steps') +import libs.common as common_lib + +logger = logging.getLogger('libs') +logger.setLevel(logging.INFO) +handler = logging.StreamHandler() +handler.setLevel(logging.INFO) +formatter = logging.Formatter("%(asctime)s [%(filename)s:%(lineno)s - " + "%(funcName)s - %(levelname)s ] %(message)s") +handler.setFormatter(formatter) +logger.addHandler(handler) + + +def GetArgs(): + + parser = argparse.ArgumentParser(description="Writes ranges.*, outputs.* and weights.* files " + "in preparation for dumping egs for multilingual training.", + epilog="Called by steps/nnet3/multilingual/get_egs.sh") + parser.add_argument("--samples-per-iter", type=int, default=40000, + help="The target number of egs in each archive of egs, " + "(prior to merging egs). "); + parser.add_argument("--num-jobs", type=int, default=20, + help="This can be used for better randomness in distributing languages across archives." + ", where egs.job.archive.scp generated randomly and examples are combined " + " across all jobs as eg.archive.scp.") + parser.add_argument("--random-lang", type=str, action=common_lib.StrToBoolAction, + help="If true, the lang-id in ranges.* selected" + " w.r.t frequency distribution of remaining examples in each language," + " otherwise it is selected sequentially.", + default=True, choices = ["false", "true"]) + parser.add_argument("--max-archives", type=int, default=1000, + help="max number of archives used to generate egs.*.scp"); + parser.add_argument("--seed", type=int, default=1, + help="Seed for random number generator") + + parser.add_argument("--minibatch-size", type=int, default=512, + help="The minibatch size used to generate scp files per job. " + "It should be multiple of actual minibatch size."); + + parser.add_argument("--prefix", type=str, default="", + help="Adds a prefix to the range files. This is used to distinguish between the train " + "and diagnostic files.") + + parser.add_argument("--lang2weight", type=str, + help="lang2weight file contains the weight per language to scale output posterior for that language.(format is: " + " )"); +# now the positional arguments + parser.add_argument("num_langs", type=int, + help="num of languages used in multilingual training setup."); + parser.add_argument("egs_scp_lists", type=str, + help="list of egs.scp files per input language." + "e.g. exp/lang1/egs/egs.scp exp/lang2/egs/egs.scp"); + + parser.add_argument("egs_dir", + help="Name of egs directory e.g. exp/multilingual_a/egs"); + + + print(' '.join(sys.argv)) + + args = parser.parse_args() + + return args + + +# Returns a random language number w.r.t +# amount of examples in each language. +# It works based on sampling from a +# discrete distribution, where it returns i +# with prob(i) as (num_egs in lang(i)/ tot_egs). +# tot_egs is sum of lang_len. +def RandomLang(lang_len, tot_egs, random_selection): + assert(tot_egs > 0) + rand_int = random.randint(0, tot_egs - 1) + count = 0 + for l in range(len(lang_len)): + if random_selection: + if rand_int > count and rand_int <= (count + lang_len[l]): + rand_lang = l + break + else: + count += lang_len[l] + else: + if (lang_len[l] > 0): + rand_lang = l + break + assert(rand_lang >= 0 and rand_lang < len(lang_len)) + return rand_lang + +# Read lang2weight file and return lang2weight array +# where lang2weight[i] is weight for language i. +def ReadLang2weight(lang2w_file): + f = open(lang2w_file, "r"); + if f is None: + raise Exception("Error opening lang2weight file " + str(lang2w_file)) + lang2w = [] + for line in f: + a = line.split() + if len(a) != 2: + raise Exception("bad line in lang2weight file " + line) + lang2w.append(int(a[1])) + f.close() + return lang2w + +# struct to keep archives correspond to each job +class ArchiveToJob(): + def __init__(self, job_id, archives_for_job): + self.job_id = job_id + self.archives = archives_for_job + +def Main(): + args = GetArgs() + random.seed(args.seed) + num_langs = args.num_langs + rand_select = args.random_lang + + # read egs.scp for input languages + scp_lists = args.egs_scp_lists.split(); + assert(len(scp_lists) == num_langs); + + scp_files = [open(scp_lists[lang], 'r') for lang in range(num_langs)] + + # computes lang2len, where lang2len[i] shows number of + # examples for language i. + lang2len = [0] * num_langs + for lang in range(num_langs): + lang2len[lang] = sum(1 for line in open(scp_lists[lang])) + logger.info("Number of examples for language {0} is {1}".format(lang, lang2len[lang])) + + # If weights are not provided, the scaling weights + # are one. + if args.lang2weight is None: + lang2weight = [ 1.0 ] * num_langs + else: + lang2weight = ReadLang2Len(args.lang2weight) + assert(len(lang2weight) == num_langs) + + if not os.path.exists(args.egs_dir + "/temp"): + os.makedirs(args.egs_dir + "/temp") + + num_lang_file = open(args.egs_dir + "/info/" + args.prefix + "num_lang", "w"); + print("{0}".format(num_langs), file = num_lang_file) + + + # Each element of all_egs (one per num_archive * num_jobs) is + # an array of 3-tuples (lang-id, local-start-egs-line, num-egs) + all_egs = [] + lang_len = lang2len[:] + tot_num_egs = sum(lang2len[i] for i in range(len(lang2len))) # total num of egs in all languages + num_archives = max(1, min(args.max_archives, tot_num_egs / args.samples_per_iter)) + + + num_arch_file = open(args.egs_dir + "/info/" + args.prefix + "num_archives", "w"); + print("{0}".format(num_archives), file = num_arch_file) + num_arch_file.close() + + this_num_egs_per_archive = tot_num_egs / (num_archives * args.num_jobs) # num of egs per archive + for job_index in range(args.num_jobs): + for archive_index in range(num_archives): + # Temporary scp.job_index.archive_index files to store egs.scp correspond to each archive. + logger.debug("Processing archive {0} for job {1}".format(archive_index + 1, job_index + 1)) + archfile = open(args.egs_dir + "/temp/" + args.prefix + "scp." + str(job_index + 1) + "." + str(archive_index + 1), "w") + + this_egs = [] # this will be array of 2-tuples (lang-id start-frame num-frames) + + num_egs = 0 + while num_egs <= this_num_egs_per_archive: + rem_egs = sum(lang_len[i] for i in range(len(lang_len))) + if rem_egs > 0: + lang_id = RandomLang(lang_len, rem_egs, rand_select) + start_egs = lang2len[lang_id] - lang_len[lang_id] + this_egs.append((lang_id, start_egs, args.minibatch_size)) + for scpline in range(args.minibatch_size): + print("{0} {1}".format(scp_files[lang_id].readline().splitlines()[0], lang_id), file = archfile) + + lang_len[lang_id] = lang_len[lang_id] - args.minibatch_size + num_egs = num_egs + args.minibatch_size; + # If the num of remaining egs in each lang is less than minibatch_size, + # they are discarded. + if lang_len[lang_id] < args.minibatch_size: + lang_len[lang_id] = 0 + logger.debug("Run out of data for language {0}".format(lang_id)) + else: + logger.debug("Run out of data for all languages.") + break + all_egs.append(this_egs) + archfile.close() + + # combine examples across all jobs correspond to each archive. + for archive in range(num_archives): + logger.debug("Processing archive {0} by combining all jobs.".format(archive + 1)) + this_ranges = [] + f = open(args.egs_dir + "/temp/" + args.prefix + "ranges." + str(archive + 1), "w") + o = open(args.egs_dir + "/" + args.prefix + "output." + str(archive + 1), "w") + w = open(args.egs_dir + "/" + args.prefix + "weight." + str(archive + 1), "w") + scp_per_archive_file = open(args.egs_dir + "/" + args.prefix + "egs." + str(archive + 1), "w") + + # check files befor writing. + if f is None: + raise Exception("Error opening file " + args.egs_dir + "/temp/" + args.prefix + "ranges." + str(job + 1)) + if o is None: + raise Exception("Error opening file " + args.egs_dir + "/" + args.prefix + "output." + str(job + 1)) + if w is None: + raise Exception("Error opening file " + args.egs_dir + "/" + args.prefix + "weight." + str(job + 1)) + if scp_per_archive_file is None: + raise Exception("Error opening file " + args.egs_dir + "/" + args.prefix + "egs." + str(archive + 1), "w") + + for job in range(args.num_jobs): + # combine egs.job.archive.scp across all jobs. + scp = args.egs_dir + "/temp/" + args.prefix + "scp." + str(job + 1) + "." + str(archive + 1) + with open(scp, "r") as scpfile: + for line in scpfile: + try: + scp_line = line.splitlines()[0].split() + print("{0} {1}".format(scp_line[0], scp_line[1]), file=scp_per_archive_file) + print("{0} output-{1}".format(scp_line[0], scp_line[2]), file=o) + print("{0} {1}".format(scp_line[0], lang2weight[int(scp_line[2])]), file=w) + except Exception: + logger.error("Failed processing line %s in scp %s", line, + scpfile.name) + raise + os.remove(scp) + + # combine ranges.* across all jobs for archive + for (lang_id, start_eg_line, num_egs) in all_egs[num_archives * job + archive]: + this_ranges.append((lang_id, start_eg_line, num_egs)) + + # write ranges.archive + for (lang_id, start_eg_line, num_egs) in this_ranges: + print("{0} {1} {2}".format(lang_id, start_eg_line, num_egs), file=f) + + scp_per_archive_file.close() + f.close() + o.close() + w.close() + print("allocate_multilingual_examples.py finished generating " + args.prefix + "egs.*.scp and " + args.prefix + "ranges.* and " + args.prefix + "output.*" + args.prefix + "weight.* files") + +if __name__ == "__main__": + Main() diff --git a/egs/wsj/s5/steps/nnet3/multilingual/get_egs.sh b/egs/wsj/s5/steps/nnet3/multilingual/get_egs.sh new file mode 100755 index 00000000000..aa9a911ffb2 --- /dev/null +++ b/egs/wsj/s5/steps/nnet3/multilingual/get_egs.sh @@ -0,0 +1,130 @@ +#!/bin/bash +# +# This script uses separate input egs directory for each language as input, +# to generate egs.*.scp files in multilingual egs directory +# where the scp line points to the original archive for each egs directory. +# $megs/egs.*.scp is randomized w.r.t language id. +# +# Also this script generates egs.JOB.scp, output.JOB.scp and weight.JOB.scp, +# where output file contains language-id for each example +# and weight file contains weights for scaling output posterior +# for each example w.r.t input language. +# + +set -e +set -o pipefail +set -u + +# Begin configuration section. +cmd=run.pl +minibatch_size=512 # multiple of minibatch used during training. +num_jobs=10 # This can be set to max number of jobs to run in parallel; + # Helps for better randomness across languages + # per archive. +samples_per_iter=400000 # this is the target number of egs in each archive of egs + # (prior to merging egs). We probably should have called + # it egs_per_iter. This is just a guideline; it will pick + # a number that divides the number of samples in the + # entire data. +stage=0 + +echo "$0 $@" # Print the command line for logging + +if [ -f path.sh ]; then . ./path.sh; fi +. parse_options.sh || exit 1; + +num_langs=$1 +shift 1 +args=("$@") +megs_dir=${args[-1]} # multilingual directory +mkdir -p $megs_dir +mkdir -p $megs_dir/info + +if [ ${#args[@]} != $[$num_langs+1] ]; then + echo "$0: Number of input example dirs provided is not compatible with num_langs $num_langs." + echo "Usage:$0 [opts] ... " + echo "Usage:$0 [opts] 2 exp/lang1/egs exp/lang2/egs exp/multi/egs" + exit 1; +fi + +required_files="egs.scp combine.egs.scp train_diagnostic.egs.scp valid_diagnostic.egs.scp" +train_scp_list= +train_diagnostic_scp_list= +valid_diagnostic_scp_list= +combine_scp_list= + +# copy paramters from $egs_dir[0]/info +# into multilingual dir egs_dir/info + +params_to_check="feat_dim ivector_dim left_context right_context frames_per_eg" +for param in $params_to_check; do + cat ${args[0]}/info/$param > $megs_dir/info/$param || exit 1; +done + +for lang in $(seq 0 $[$num_langs-1]);do + multi_egs_dir[$lang]=${args[$lang]} + echo "arg[$lang] = ${args[$lang]}" + for f in $required_files; do + if [ ! -f ${multi_egs_dir[$lang]}/$f ]; then + echo "$0: no such a file ${multi_egs_dir[$lang]}/$f." && exit 1; + fi + done + train_scp_list="$train_scp_list ${args[$lang]}/egs.scp" + train_diagnostic_scp_list="$train_diagnostic_scp_list ${args[$lang]}/train_diagnostic.egs.scp" + valid_diagnostic_scp_list="$valid_diagnostic_scp_list ${args[$lang]}/valid_diagnostic.egs.scp" + combine_scp_list="$combine_scp_list ${args[$lang]}/combine.egs.scp" + + # check parameter dimension to be the same in all egs dirs + for f in $params_to_check; do + f1=`cat $megs_dir/info/$param`; + f2=`cat ${multi_egs_dir[$lang]}/info/$f`; + if [ $f1 != $f1 ]; then + echo "$0: mismatch in dimension for $f parameter in ${multi_egs_dir[$lang]}." + exit 1; + fi + done +done + +if [ $stage -le 0 ]; then + echo "$0: allocating multilingual examples for training." + # Generate egs.*.scp for multilingual setup. + $cmd $megs_dir/log/allocate_multilingual_examples_train.log \ + python steps/nnet3/multilingual/allocate_multilingual_examples.py \ + --minibatch-size $minibatch_size \ + --samples-per-iter $samples_per_iter \ + $num_langs "$train_scp_list" $megs_dir || exit 1; +fi + +if [ $stage -le 1 ]; then + echo "$0: combine combine.egs.scp examples from all langs in $megs_dir/combine.egs.scp." + # Generate combine.egs.scp for multilingual setup. + $cmd $megs_dir/log/allocate_multilingual_examples_combine.log \ + python steps/nnet3/multilingual/allocate_multilingual_examples.py \ + --random-lang false \ + --max-archives 1 --num-jobs 1 \ + --minibatch-size $minibatch_size \ + --prefix "combine." \ + $num_langs "$combine_scp_list" $megs_dir || exit 1; + + echo "$0: combine train_diagnostic.egs.scp examples from all langs in $megs_dir/train_diagnostic.egs.scp." + # Generate train_diagnostic.egs.scp for multilingual setup. + $cmd $megs_dir/log/allocate_multilingual_examples_train_diagnostic.log \ + python steps/nnet3/multilingual/allocate_multilingual_examples.py \ + --random-lang false \ + --max-archives 1 --num-jobs 1 \ + --minibatch-size $minibatch_size \ + --prefix "train_diagnostic." \ + $num_langs "$train_diagnostic_scp_list" $megs_dir || exit 1; + + + echo "$0: combine valid_diagnostic.egs.scp examples from all langs in $megs_dir/valid_diagnostic.egs.scp." + # Generate valid_diagnostic.egs.scp for multilingual setup. + $cmd $megs_dir/log/allocate_multilingual_examples_valid_diagnostic.log \ + python steps/nnet3/multilingual/allocate_multilingual_examples.py \ + --random-lang false --max-archives 1 --num-jobs 1\ + --minibatch-size $minibatch_size \ + --prefix "valid_diagnostic." \ + $num_langs "$valid_diagnostic_scp_list" $megs_dir || exit 1; + +fi + From 34e34e9ac35fe750ff5af0cdf2d891e2a9b561b0 Mon Sep 17 00:00:00 2001 From: Vimal Manohar Date: Mon, 19 Dec 2016 18:49:44 -0500 Subject: [PATCH 107/213] asr_diarization: Add fake targets to get-egs-multiple-targets --- .../nnet3-get-egs-multiple-targets.cc | 33 +++++++++++++++++-- 1 file changed, 31 insertions(+), 2 deletions(-) diff --git a/src/nnet3bin/nnet3-get-egs-multiple-targets.cc b/src/nnet3bin/nnet3-get-egs-multiple-targets.cc index 49f0dde4af7..2c5fb364309 100644 --- a/src/nnet3bin/nnet3-get-egs-multiple-targets.cc +++ b/src/nnet3bin/nnet3-get-egs-multiple-targets.cc @@ -236,6 +236,7 @@ int main(int argc, char *argv[]) { bool compress_input = true; + bool add_fake_targets= true; int32 input_compress_format = 0; int32 left_context = 0, right_context = 0, num_frames = 1, length_tolerance = 2; @@ -247,6 +248,9 @@ int main(int argc, char *argv[]) { std::string output_names_str; ParseOptions po(usage); + po.Register("add-fake-targets", &add_fake_targets, + "Add fake targets so that " + "all the egs contain the same number of outputs"); po.Register("compress-input", &compress_input, "If true, write egs in " "compressed format."); po.Register("input-compress-format", &input_compress_format, "Format for " @@ -298,7 +302,6 @@ int main(int argc, char *argv[]) { std::vector sparse_targets_readers(num_outputs, static_cast(NULL)); - std::vector compress_targets(1, true); std::vector compress_targets_vector; @@ -360,7 +363,11 @@ int main(int argc, char *argv[]) { std::vector targets_rspecifiers(num_outputs); std::vector deriv_weights_rspecifiers(num_outputs); - + + std::vector > fake_dense_targets(num_outputs); + std::vector > fake_deriv_weights(num_outputs); + std::vector fake_sparse_targets(num_outputs); + for (int32 n = 0; n < num_outputs; n++) { const std::string &targets_rspecifier = po.GetArg(2*n + 2); const std::string &deriv_weights_rspecifier = po.GetArg(2*n + 3); @@ -428,6 +435,16 @@ int main(int argc, char *argv[]) { KALDI_WARN << "No dense targets matrix for key " << key << " in " << "rspecifier " << targets_rspecifiers[n] << " for output " << output_names[n]; + + if (add_fake_targets) { + fake_dense_targets[n].Resize(feats.NumRows(), -output_dims[n]); + dense_targets[n] = &(fake_dense_targets[n]); + + fake_deriv_weights[n].Resize(feats.NumRows()); + deriv_weights[n] = &(fake_deriv_weights[n]); + + num_outputs_found++; + } continue; } const MatrixBase *target_matrix = &(dense_targets_readers[n]->Value(key)); @@ -446,6 +463,12 @@ int main(int argc, char *argv[]) { KALDI_WARN << "No sparse target matrix for key " << key << " in " << "rspecifier " << targets_rspecifiers[n] << " for output " << output_names[n]; + + if (add_fake_targets) { + fake_sparse_targets[n].resize(feats.NumRows()); + sparse_targets[n] = &(fake_sparse_targets[n]); + num_outputs_found++; + } continue; } const Posterior *posterior = &(sparse_targets_readers[n]->Value(key)); @@ -499,6 +522,12 @@ int main(int argc, char *argv[]) { continue; } + if (add_fake_targets && num_outputs_found != output_names.size()) { + KALDI_WARN << "Not all outputs found for key " << key; + num_err++; + continue; + } + ProcessFile(feats, ivector_feats, output_names, output_dims, dense_targets, sparse_targets, deriv_weights, key, From 2d4eeeb7f1a465a9c7ca4c0f053ff58bb4ae7ae9 Mon Sep 17 00:00:00 2001 From: Vimal Manohar Date: Mon, 19 Dec 2016 19:58:29 -0500 Subject: [PATCH 108/213] asr_diarization: Support scaling of nnet3 egs feats --- src/matrix/sparse-matrix.cc | 28 ++++++++++ src/matrix/sparse-matrix.h | 6 +++ src/nnet3bin/nnet3-copy-egs.cc | 93 ++++++++++++++++++++++++++++++++-- 3 files changed, 124 insertions(+), 3 deletions(-) diff --git a/src/matrix/sparse-matrix.cc b/src/matrix/sparse-matrix.cc index 777819ed677..c5bc868f48e 100644 --- a/src/matrix/sparse-matrix.cc +++ b/src/matrix/sparse-matrix.cc @@ -281,6 +281,14 @@ void SparseVector::Resize(MatrixIndexT dim, dim_ = dim; } +template +void SparseVector::Scale(BaseFloat scale) { + typename std::vector >::iterator it = pairs_.begin(); + for (; it != pairs_.end(); ++it) { + it->second *= scale; + } +} + template MatrixIndexT SparseMatrix::NumRows() const { return rows_.size(); @@ -574,6 +582,14 @@ void SparseMatrix::Resize(MatrixIndexT num_rows, rows_[row].Resize(num_cols, kCopyData); } } + +template +void SparseMatrix::Scale(BaseFloat scale) { + for (typename std::vector >::iterator it = rows_.begin(); + it != rows_.end(); ++it) { + it->Scale(scale); + } +} template void SparseMatrix::AppendSparseMatrixRows( @@ -1053,6 +1069,18 @@ void GeneralMatrix::AddToMat(BaseFloat alpha, MatrixBase *mat, } } + +void GeneralMatrix::Scale(BaseFloat scale) { + if(Type() == kCompressedMatrix) + Uncompress(); + if (Type() == kFullMatrix) { + mat_.Scale(scale); + } else if (Type() == kSparseMatrix) { + smat_.Scale(scale); + } +} + + template Real SparseVector::Max(int32 *index_out) const { KALDI_ASSERT(dim_ > 0 && pairs_.size() <= static_cast(dim_)); diff --git a/src/matrix/sparse-matrix.h b/src/matrix/sparse-matrix.h index 88619da3034..8ad62e0ac51 100644 --- a/src/matrix/sparse-matrix.h +++ b/src/matrix/sparse-matrix.h @@ -98,6 +98,8 @@ class SparseVector { /// Resizes to this dimension. resize_type == kUndefined /// behaves the same as kSetZero. void Resize(MatrixIndexT dim, MatrixResizeType resize_type = kSetZero); + + void Scale(BaseFloat scale); void Write(std::ostream &os, bool binary) const; @@ -196,6 +198,8 @@ class SparseMatrix { void Resize(MatrixIndexT rows, MatrixIndexT cols, MatrixResizeType resize_type = kSetZero); + void Scale(BaseFloat scale); + // Use the Matrix::CopyFromSmat() function to copy from this to Matrix. Also // see Matrix::AddSmat(). There is not very extensive functionality for // SparseMat just yet (e.g. no matrix multiply); we will add things as needed @@ -286,6 +290,8 @@ class GeneralMatrix { void AddToMat(BaseFloat alpha, CuMatrixBase *cu_mat, MatrixTransposeType trans = kNoTrans) const; + void Scale(BaseFloat alpha); + /// Assignment from regular matrix. GeneralMatrix &operator= (const MatrixBase &mat); diff --git a/src/nnet3bin/nnet3-copy-egs.cc b/src/nnet3bin/nnet3-copy-egs.cc index 2702ae5fae9..5189ee4046f 100644 --- a/src/nnet3bin/nnet3-copy-egs.cc +++ b/src/nnet3bin/nnet3-copy-egs.cc @@ -28,6 +28,34 @@ namespace kaldi { namespace nnet3 { +// rename io-name of eg w.r.t io_names list e.g. input/input-1,output/output-1 +// 'input' is renamed to input-1 and 'output' renamed to output-1. +void RenameIoNames(const std::string &io_names, + NnetExample *eg_modified) { + std::vector separated_io_names; + SplitStringToVector(io_names, ",", true, &separated_io_names); + int32 num_modified_io = separated_io_names.size(), + io_size = eg_modified->io.size(); + std::vector orig_io_list; + for (int32 io_ind = 0; io_ind < io_size; io_ind++) + orig_io_list.push_back(eg_modified->io[io_ind].name); + + for (int32 ind = 0; ind < num_modified_io; ind++) { + std::vector rename_io_name; + SplitStringToVector(separated_io_names[ind], "/", true, &rename_io_name); + // find the io in eg with specific name and rename it to new name. + + int32 rename_io_ind = + std::find(orig_io_list.begin(), orig_io_list.end(), rename_io_name[0]) - + orig_io_list.begin(); + + if (rename_io_ind >= io_size) + KALDI_ERR << "No io-node with name " << rename_io_name[0] + << "exists in eg."; + eg_modified->io[rename_io_ind].name = rename_io_name[1]; + } +} + bool KeepOutputs(const std::vector &keep_outputs, NnetExample *eg) { std::vector io_new; @@ -330,6 +358,8 @@ int main(int argc, char *argv[]) { // you can set frame to a number to select a single frame with a particular // offset, or to 'random' to select a random single frame. std::string frame_str; + std::string weight_str; + std::string output_str; ParseOptions po(usage); po.Register("random", &random, "If true, will write frames to output " @@ -357,6 +387,16 @@ int main(int argc, char *argv[]) { po.Register("remove-zero-deriv-outputs", &remove_zero_deriv_outputs, "Remove outputs that do not contribute to the objective " "because of zero deriv-weights"); + po.Register("weights", &weight_str, + "Rspecifier maps the output posterior to each example" + "If provided, the supervision weight for output is scaled." + " Scaling supervision weight is the same as scaling to the derivative during training " + " in case of linear objective." + "The default is one, which means we are not applying per-example weights."); + po.Register("outputs", &output_str, + "Rspecifier maps example old output-name to new output-name in example." + " If provided, the NnetIo with name 'output' in each example " + " is renamed to new output name."); po.Read(argc, argv); @@ -370,6 +410,8 @@ int main(int argc, char *argv[]) { std::string examples_rspecifier = po.GetArg(1); SequentialNnetExampleReader example_reader(examples_rspecifier); + RandomAccessTokenReader output_reader(output_str); + RandomAccessBaseFloatReader egs_weight_reader(weight_str); int32 num_outputs = po.NumArgs() - 1; std::vector example_writers(num_outputs); @@ -382,7 +424,7 @@ int main(int argc, char *argv[]) { std::sort(keep_outputs.begin(), keep_outputs.end()); } - int64 num_read = 0, num_written = 0; + int64 num_read = 0, num_written = 0, num_err = 0; for (; !example_reader.Done(); example_reader.Next(), num_read++) { // count is normally 1; could be 0, or possibly >1. int32 count = GetCount(keep_proportion); @@ -399,16 +441,61 @@ int main(int argc, char *argv[]) { frame_shift == 0) { if (remove_zero_deriv_outputs) if (!RemoveZeroDerivOutputs(&eg)) continue; + if (!weight_str.empty()) { + if (!egs_weight_reader.HasKey(key)) { + KALDI_WARN << "No weight for example key " << key; + num_err++; + continue; + } + BaseFloat weight = egs_weight_reader.Value(key); + for (int32 i = 0; i < eg.io.size(); i++) + if (eg.io[i].name.find("output") != std::string::npos) + eg.io[i].features.Scale(weight); + } + if (!output_str.empty()) { + if (!output_reader.HasKey(key)) { + KALDI_WARN << "No new output-name for example key " << key; + num_err++; + continue; + } + std::string new_output_name = output_reader.Value(key); + // rename output io name to $new_output_name. + std::string rename_io_names = "output/" + new_output_name; + RenameIoNames(rename_io_names, &eg); + } example_writers[index]->Write(key, eg); num_written++; } else { // the --frame option or context options were set. NnetExample eg_modified; if (SelectFromExample(eg, frame_str, left_context, right_context, frame_shift, &eg_modified)) { - // this branch of the if statement will almost always be taken (should only - // not be taken for shorter-than-normal egs from the end of a file. if (remove_zero_deriv_outputs) if (!RemoveZeroDerivOutputs(&eg_modified)) continue; + if (!weight_str.empty()) { + // scale the supervision weight for egs + if (!egs_weight_reader.HasKey(key)) { + KALDI_WARN << "No weight for example key " << key; + num_err++; + continue; + } + int32 weight = egs_weight_reader.Value(key); + for (int32 i = 0; i < eg_modified.io.size(); i++) + if (eg_modified.io[i].name.find("output") != std::string::npos) + eg_modified.io[i].features.Scale(weight); + } + if (!output_str.empty()) { + if (!output_reader.HasKey(key)) { + KALDI_WARN << "No new output-name for example key " << key; + num_err++; + continue; + } + std::string new_output_name = output_reader.Value(key); + // rename output io name to $new_output_name. + std::string rename_io_names = "output/" + new_output_name; + RenameIoNames(rename_io_names, &eg_modified); + } + // this branch of the if statement will almost always be taken (should only + // not be taken for shorter-than-normal egs from the end of a file. example_writers[index]->Write(key, eg_modified); num_written++; } From c1799f1be12294884b84bbdf5f71e0e5ca40c285 Mon Sep 17 00:00:00 2001 From: Vimal Manohar Date: Mon, 19 Dec 2016 19:59:17 -0500 Subject: [PATCH 109/213] asr_diarization: Fix bugs and restructure multiple egs targets source --- src/nnet3/nnet-example-utils.cc | 21 +- .../nnet3-get-egs-multiple-targets.cc | 208 +++++++----------- 2 files changed, 96 insertions(+), 133 deletions(-) diff --git a/src/nnet3/nnet-example-utils.cc b/src/nnet3/nnet-example-utils.cc index 548fb842385..2d9a01550b9 100644 --- a/src/nnet3/nnet-example-utils.cc +++ b/src/nnet3/nnet-example-utils.cc @@ -159,21 +159,22 @@ static void MergeIo(const std::vector &src, } Vector &this_deriv_weights = merged_eg->io[f].deriv_weights; - if (output_deriv_weights[f][0]->Dim() > 0) { - this_deriv_weights.Resize( - merged_eg->io[f].indexes.size(), kUndefined); - KALDI_ASSERT(this_deriv_weights.Dim() == - merged_eg->io[f].features.NumRows()); + this_deriv_weights.Resize( + merged_eg->io[f].indexes.size(), kUndefined); + this_deriv_weights.Set(1.0); + KALDI_ASSERT(this_deriv_weights.Dim() == + merged_eg->io[f].features.NumRows()); - std::vector const*>::const_iterator - it = output_deriv_weights[f].begin(), - end = output_deriv_weights[f].end(); + std::vector const*>::const_iterator + it = output_deriv_weights[f].begin(), + end = output_deriv_weights[f].end(); - for (int32 i = 0, cur_offset = 0; it != end; ++it, i++) { + for (int32 i = 0, cur_offset = 0; it != end; ++it, i++) { + if((*it)->Dim() > 0) { KALDI_ASSERT((*it)->Dim() == output_lists[f][i]->NumRows()); this_deriv_weights.Range(cur_offset, (*it)->Dim()).CopyFromVec(**it); - cur_offset += (*it)->Dim(); } + cur_offset += output_lists[f][i]->NumRows(); } } } diff --git a/src/nnet3bin/nnet3-get-egs-multiple-targets.cc b/src/nnet3bin/nnet3-get-egs-multiple-targets.cc index 2c5fb364309..63ebce5ab0e 100644 --- a/src/nnet3bin/nnet3-get-egs-multiple-targets.cc +++ b/src/nnet3bin/nnet3-get-egs-multiple-targets.cc @@ -44,26 +44,28 @@ bool ToBool(std::string str) { return false; // never reached } -static void ProcessFile(const MatrixBase &feats, - const MatrixBase *ivector_feats, - const std::vector &output_names, - const std::vector &output_dims, - const std::vector* > &dense_target_matrices, - const std::vector &posteriors, - const std::vector* > &deriv_weights, - const std::string &utt_id, - bool compress_input, - int32 input_compress_format, - const std::vector &compress_targets, - const std::vector &targets_compress_formats, - int32 left_context, - int32 right_context, - int32 frames_per_eg, - std::vector *num_frames_written, - std::vector *num_egs_written, - NnetExampleWriter *example_writer) { +static void ProcessFile( + const MatrixBase &feats, + const MatrixBase *ivector_feats, + const std::vector &output_names, + const std::vector &output_dims, + const std::vector* > &dense_target_matrices, + const std::vector &posteriors, + const std::vector* > &deriv_weights, + const std::string &utt_id, + bool compress_input, + int32 input_compress_format, + const std::vector &compress_targets, + const std::vector &targets_compress_formats, + int32 left_context, + int32 right_context, + int32 frames_per_eg, + std::vector *num_frames_written, + std::vector *num_egs_written, + NnetExampleWriter *example_writer) { + KALDI_ASSERT(output_names.size() > 0); - //KALDI_ASSERT(feats.NumRows() == static_cast(targets.NumRows())); + for (int32 t = 0; t < feats.NumRows(); t += frames_per_eg) { int32 tot_frames = left_context + frames_per_eg + right_context; @@ -113,16 +115,15 @@ static void ProcessFile(const MatrixBase &feats, // At the end of the file, we pad with the last frame repeated // so that all examples have the same structure (prevents the need // for recompilations). - int32 actual_frames_per_eg = std::min(std::min(frames_per_eg, - feats.NumRows() - t), deriv_weights[n]->Dim() - t); + int32 actual_frames_per_eg = std::min( + std::min(frames_per_eg, feats.NumRows() - t), + deriv_weights[n]->Dim() - t); this_deriv_weights.Resize(frames_per_eg); int32 frames_to_copy = std::min(t + actual_frames_per_eg, deriv_weights[n]->Dim()) - t; - this_deriv_weights.Range(0, frames_to_copy).CopyFromVec(deriv_weights[n]->Range(t, frames_to_copy)); - if (this_deriv_weights.Sum() == 0) { - continue; // Ignore frames that have frame weights 0 - } + this_deriv_weights.Range(0, frames_to_copy).CopyFromVec( + deriv_weights[n]->Range(t, frames_to_copy)); } if (dense_target_matrices[n]) { @@ -133,8 +134,9 @@ static void ProcessFile(const MatrixBase &feats, // At the end of the file, we pad with the last frame repeated // so that all examples have the same structure (prevents the need // for recompilations). - int32 actual_frames_per_eg = std::min(std::min(frames_per_eg, - feats.NumRows() - t), targets.NumRows() - t); + int32 actual_frames_per_eg = std::min( + std::min(frames_per_eg, feats.NumRows() - t), + targets.NumRows() - t); for (int32 i = 0; i < actual_frames_per_eg; i++) { // Copy the i^th row of the target matrix from the (t+i)^th row of the @@ -150,12 +152,14 @@ static void ProcessFile(const MatrixBase &feats, // input targets matrix KALDI_ASSERT(t + actual_frames_per_eg - 1 == targets.NumRows() - 1); SubVector this_target_dest(targets_dest, i); - SubVector this_target_src(targets, t+actual_frames_per_eg-1); + SubVector this_target_src(targets, + t + actual_frames_per_eg - 1); this_target_dest.CopyFromVec(this_target_src); } if (deriv_weights[n]) { - eg.io.push_back(NnetIo(output_names[n], this_deriv_weights, 0, targets_dest)); + eg.io.push_back(NnetIo(output_names[n], this_deriv_weights, + 0, targets_dest)); } else { eg.io.push_back(NnetIo(output_names[n], 0, targets_dest)); } @@ -166,8 +170,9 @@ static void ProcessFile(const MatrixBase &feats, // At the end of the file, we pad with the last frame repeated // so that all examples have the same structure (prevents the need // for recompilations). - int32 actual_frames_per_eg = std::min(std::min(frames_per_eg, - feats.NumRows() - t), static_cast(pdf_post.size()) - t); + int32 actual_frames_per_eg = std::min( + std::min(frames_per_eg, feats.NumRows() - t), + static_cast(pdf_post.size()) - t); Posterior labels(frames_per_eg); for (int32 i = 0; i < actual_frames_per_eg; i++) @@ -175,7 +180,8 @@ static void ProcessFile(const MatrixBase &feats, // remaining posteriors for frames are empty. if (deriv_weights[n]) { - eg.io.push_back(NnetIo(output_names[n], this_deriv_weights, output_dims[n], 0, labels)); + eg.io.push_back(NnetIo(output_names[n], this_deriv_weights, + output_dims[n], 0, labels)); } else { eg.io.push_back(NnetIo(output_names[n], output_dims[n], 0, labels)); } @@ -185,11 +191,13 @@ static void ProcessFile(const MatrixBase &feats, eg.io.back().Compress(targets_compress_formats[n]); num_outputs_added++; - (*num_frames_written)[n] += frames_per_eg; // Actually actual_frames_per_eg, but that depends on the different output. For simplification, frames_per_eg is used. + // Actually actual_frames_per_eg, but that depends on the different + // output. For simplification, frames_per_eg is used. + (*num_frames_written)[n] += frames_per_eg; (*num_egs_written)[n] += 1; } - if (num_outputs_added == 0) continue; + if (num_outputs_added != output_names.size()) continue; std::ostringstream os; os << utt_id << "-" << t; @@ -236,7 +244,6 @@ int main(int argc, char *argv[]) { bool compress_input = true; - bool add_fake_targets= true; int32 input_compress_format = 0; int32 left_context = 0, right_context = 0, num_frames = 1, length_tolerance = 2; @@ -248,17 +255,14 @@ int main(int argc, char *argv[]) { std::string output_names_str; ParseOptions po(usage); - po.Register("add-fake-targets", &add_fake_targets, - "Add fake targets so that " - "all the egs contain the same number of outputs"); po.Register("compress-input", &compress_input, "If true, write egs in " "compressed format."); po.Register("input-compress-format", &input_compress_format, "Format for " "compressing input feats e.g. Use 2 for compressing wave"); po.Register("compress-targets", &compress_targets_str, "CSL of whether " "targets must be compressed for each of the outputs"); - po.Register("targets-compress-formats", &targets_compress_formats_str, "Format for " - "compressing all feats in general"); + po.Register("targets-compress-formats", &targets_compress_formats_str, + "Format for compressing all feats in general"); po.Register("left-context", &left_context, "Number of frames of left " "context the neural net requires."); po.Register("right-context", &right_context, "Number of frames of right " @@ -271,11 +275,6 @@ int main(int argc, char *argv[]) { "difference in num-frames between feat and ivector matrices"); po.Register("output-dims", &output_dims_str, "CSL of output node dims"); po.Register("output-names", &output_names_str, "CSL of output node names"); - //po.Register("deriv-weights-rspecifiers", &deriv_weights_rspecifiers_str, - // "CSL of per-frame weights (only binary - 0 or 1) that specifies " - // "whether a frame's gradient must be backpropagated or not. " - // "Not specifying this is equivalent to specifying a vector of " - // "all 1s."); po.Read(argc, argv); @@ -295,12 +294,12 @@ int main(int argc, char *argv[]) { int32 num_outputs = (po.NumArgs() - 2) / 2; KALDI_ASSERT(num_outputs > 0); - std::vector deriv_weights_readers(num_outputs, - static_cast(NULL)); - std::vector dense_targets_readers(num_outputs, - static_cast(NULL)); - std::vector sparse_targets_readers(num_outputs, - static_cast(NULL)); + std::vector deriv_weights_readers( + num_outputs, static_cast(NULL)); + std::vector dense_targets_readers( + num_outputs, static_cast(NULL)); + std::vector sparse_targets_readers( + num_outputs, static_cast(NULL)); std::vector compress_targets(1, true); std::vector compress_targets_vector; @@ -338,7 +337,8 @@ int main(int argc, char *argv[]) { } if (targets_compress_formats.size() != num_outputs) { - KALDI_ERR << "Mismatch in length of targets-compress-formats and num-outputs; " + KALDI_ERR << "Mismatch in length of targets-compress-formats " + << " and num-outputs; " << targets_compress_formats.size() << " vs " << num_outputs; } @@ -349,25 +349,9 @@ int main(int argc, char *argv[]) { std::vector output_names(num_outputs); SplitStringToVector(output_names_str, ":,", true, &output_names); - //std::vector deriv_weights_rspecifiers; - //if (!deriv_weights_rspecifiers_str.empty()) { - // std::vector parts; - // SplitStringToVector(deriv_weights_rspecifiers_str, ":,", - // false, &deriv_weights_rspecifiers); - - // if (deriv_weights_rspecifiers.size() != num_outputs) { - // KALDI_ERR << "Expecting the number of deriv-weights-rspecifiers to " - // << "be equal to the number of outputs"; - // } - //} - std::vector targets_rspecifiers(num_outputs); std::vector deriv_weights_rspecifiers(num_outputs); - std::vector > fake_dense_targets(num_outputs); - std::vector > fake_deriv_weights(num_outputs); - std::vector fake_sparse_targets(num_outputs); - for (int32 n = 0; n < num_outputs; n++) { const std::string &targets_rspecifier = po.GetArg(2*n + 2); const std::string &deriv_weights_rspecifier = po.GetArg(2*n + 3); @@ -376,19 +360,24 @@ int main(int argc, char *argv[]) { deriv_weights_rspecifiers[n] = deriv_weights_rspecifier; if (output_dims[n] >= 0) { - sparse_targets_readers[n] = new RandomAccessPosteriorReader(targets_rspecifier); + sparse_targets_readers[n] = new RandomAccessPosteriorReader( + targets_rspecifier); } else { - dense_targets_readers[n] = new RandomAccessBaseFloatMatrixReader(targets_rspecifier); + dense_targets_readers[n] = new RandomAccessBaseFloatMatrixReader( + targets_rspecifier); } if (!deriv_weights_rspecifier.empty()) - deriv_weights_readers[n] = new RandomAccessBaseFloatVectorReader(deriv_weights_rspecifier); + deriv_weights_readers[n] = new RandomAccessBaseFloatVectorReader( + deriv_weights_rspecifier); KALDI_LOG << "output-name=" << output_names[n] << " target-dim=" << output_dims[n] << " targets-rspecifier=\"" << targets_rspecifiers[n] << "\"" - << " deriv-weights-rspecifier=\"" << deriv_weights_rspecifiers[n] << "\"" - << " compress-target=" << (compress_targets[n] ? "true" : "false") + << " deriv-weights-rspecifier=\"" + << deriv_weights_rspecifiers[n] << "\"" + << " compress-target=" + << (compress_targets[n] ? "true" : "false") << " target-compress-format=" << targets_compress_formats[n]; } @@ -405,7 +394,6 @@ int main(int argc, char *argv[]) { if (!ivector_rspecifier.empty()) { if (!ivector_reader.HasKey(key)) { KALDI_WARN << "No iVectors for utterance " << key; - num_err++; continue; } else { // this address will be valid until we call HasKey() or Value() @@ -424,9 +412,12 @@ int main(int argc, char *argv[]) { continue; } - std::vector* > dense_targets(num_outputs, static_cast* >(NULL)); - std::vector sparse_targets(num_outputs, static_cast(NULL)); - std::vector* > deriv_weights(num_outputs, static_cast* >(NULL)); + std::vector* > dense_targets( + num_outputs, static_cast* >(NULL)); + std::vector sparse_targets( + num_outputs, static_cast(NULL)); + std::vector* > deriv_weights( + num_outputs, static_cast* >(NULL)); int32 num_outputs_found = 0; for (int32 n = 0; n < num_outputs; n++) { @@ -435,26 +426,16 @@ int main(int argc, char *argv[]) { KALDI_WARN << "No dense targets matrix for key " << key << " in " << "rspecifier " << targets_rspecifiers[n] << " for output " << output_names[n]; - - if (add_fake_targets) { - fake_dense_targets[n].Resize(feats.NumRows(), -output_dims[n]); - dense_targets[n] = &(fake_dense_targets[n]); - - fake_deriv_weights[n].Resize(feats.NumRows()); - deriv_weights[n] = &(fake_deriv_weights[n]); - - num_outputs_found++; - } - continue; + break; } - const MatrixBase *target_matrix = &(dense_targets_readers[n]->Value(key)); + const MatrixBase *target_matrix = &( + dense_targets_readers[n]->Value(key)); if ((target_matrix->NumRows() - feats.NumRows()) > length_tolerance) { KALDI_WARN << "Length difference between feats " << feats.NumRows() << " and target matrix " << target_matrix->NumRows() << "exceeds tolerance " << length_tolerance; - num_err++; - continue; + break; } dense_targets[n] = target_matrix; @@ -463,22 +444,16 @@ int main(int argc, char *argv[]) { KALDI_WARN << "No sparse target matrix for key " << key << " in " << "rspecifier " << targets_rspecifiers[n] << " for output " << output_names[n]; - - if (add_fake_targets) { - fake_sparse_targets[n].resize(feats.NumRows()); - sparse_targets[n] = &(fake_sparse_targets[n]); - num_outputs_found++; - } - continue; + break; } const Posterior *posterior = &(sparse_targets_readers[n]->Value(key)); - if (abs(static_cast(posterior->size()) - feats.NumRows()) > length_tolerance + if (abs(static_cast(posterior->size()) - feats.NumRows()) + > length_tolerance || posterior->size() < feats.NumRows()) { KALDI_WARN << "Posterior has wrong size " << posterior->size() << " versus " << feats.NumRows(); - num_err++; - continue; + break; } sparse_targets[n] = posterior; @@ -489,10 +464,7 @@ int main(int argc, char *argv[]) { KALDI_WARN << "No deriv weights for key " << key << " in " << "rspecifier " << deriv_weights_rspecifiers[n] << " for output " << output_names[n]; - num_err++; - sparse_targets[n] = NULL; - dense_targets[n] = NULL; - continue; + break; } else { // this address will be valid until we call HasKey() or Value() // again. @@ -500,29 +472,20 @@ int main(int argc, char *argv[]) { } } - if (deriv_weights[n] && - (abs(feats.NumRows() - deriv_weights[n]->Dim()) > length_tolerance - || deriv_weights[n]->Dim() == 0)) { + if (deriv_weights[n] + && (abs(feats.NumRows() - deriv_weights[n]->Dim()) + > length_tolerance + || deriv_weights[n]->Dim() == 0)) { KALDI_WARN << "Length difference between feats " << feats.NumRows() << " and deriv weights " << deriv_weights[n]->Dim() << " exceeds tolerance " << length_tolerance; - num_err++; - sparse_targets[n] = NULL; - dense_targets[n] = NULL; - deriv_weights[n] = NULL; - continue; + break; } num_outputs_found++; } - if (num_outputs_found == 0) { - KALDI_WARN << "No output found for key " << key; - num_err++; - continue; - } - - if (add_fake_targets && num_outputs_found != output_names.size()) { + if (num_outputs_found != num_outputs) { KALDI_WARN << "Not all outputs found for key " << key; num_err++; continue; @@ -553,7 +516,8 @@ int main(int argc, char *argv[]) { KALDI_LOG << "Finished generating examples, " << "successfully processed " << num_done - << " feature files, wrote at most " << max_num_egs_written << " examples, " + << " feature files, wrote at most " << max_num_egs_written + << " examples, " << " with at most " << max_num_frames_written << " egs in total; " << num_err << " files had errors."; @@ -563,5 +527,3 @@ int main(int argc, char *argv[]) { return -1; } } - - From cfb71e500eaec2b7796277ac5945b537e9bd6777 Mon Sep 17 00:00:00 2001 From: Vimal Manohar Date: Mon, 19 Dec 2016 20:00:07 -0500 Subject: [PATCH 110/213] asr_diarization: Minor fixes to get_egs_multiple_targets --- .../steps/nnet3/get_egs_multiple_targets.py | 183 +++++++++++++----- 1 file changed, 132 insertions(+), 51 deletions(-) diff --git a/egs/wsj/s5/steps/nnet3/get_egs_multiple_targets.py b/egs/wsj/s5/steps/nnet3/get_egs_multiple_targets.py index fa8a68f5c64..72b0cb4edd3 100755 --- a/egs/wsj/s5/steps/nnet3/get_egs_multiple_targets.py +++ b/egs/wsj/s5/steps/nnet3/get_egs_multiple_targets.py @@ -22,7 +22,7 @@ logger.setLevel(logging.INFO) handler = logging.StreamHandler() handler.setLevel(logging.INFO) -formatter = logging.Formatter("%(asctime)s [%(filename)s:%(lineno)s - " +formatter = logging.Formatter("%(asctime)s [%(pathname)s:%(lineno)s - " "%(funcName)s - %(levelname)s ] %(message)s") handler.setFormatter(formatter) logger.addHandler(handler) @@ -133,6 +133,9 @@ def get_args(): parser.add_argument("--srand", type=int, default=0, help="Rand seed for nnet3-copy-egs and " "nnet3-shuffle-egs") + parser.add_argument("--generate-egs-scp", type=str, + default=False, action=common_lib.StrToBoolAction, + help="Generate scp files in addition to archives") parser.add_argument("--targets-parameters", type=str, action='append', required=True, dest='targets_para_array', @@ -186,9 +189,10 @@ def check_for_required_files(feat_dir, targets_scps, online_ivector_dir=None): '{0}/cmvn.scp'.format(feat_dir)] if online_ivector_dir is not None: required_files.append('{0}/ivector_online.scp'.format( - online_ivector_dir)) + online_ivector_dir)) required_files.append('{0}/ivector_period'.format( - online_ivector_dir)) + online_ivector_dir)) + required_files.extend(targets_scps) for file in required_files: if not os.path.isfile(file): @@ -229,9 +233,9 @@ def parse_targets_parameters_array(para_array): if not os.path.isfile(t.targets_scp): raise Exception("Expected {0} to exist.".format(t.targets_scp)) - if (t.target_type == "dense"): + if t.target_type == "dense": dim = common_lib.get_feat_dim_from_scp(t.targets_scp) - if (t.dim != -1 and t.dim != dim): + if t.dim != -1 and t.dim != dim: raise Exception('Mismatch in --dim provided and feat dim for ' 'file {0}; {1} vs {2}'.format(t.targets_scp, t.dim, dim)) @@ -272,7 +276,10 @@ def sample_utts(feat_dir, num_utts_subset, min_duration, exclude_list=None): for utt in utts2add: sampled_utts.append(utt) - index = index + 1 + else: + logger.info("Skipping utterance %s of length %f", + utt2uniq[utt2durs[index][0]], utt2durs[index][1]) + index = index + 1 num_trials = num_trials + 1 if exclude_list is not None: assert(len(set(exclude_list).intersection(sampled_utts)) == 0) @@ -311,13 +318,13 @@ def get_feat_ivector_strings(dir, feat_dir, split_feat_dir, "{dir}/valid_uttlist {sdir}/JOB/feats.scp | " "apply-cmvn {cmvn} --utt2spk=ark:{sdir}/JOB/utt2spk " "scp:{sdir}/JOB/cmvn.scp scp:- ark:- |".format( - dir=dir, sdir=split_feat_dir, - cmvn=cmvn_opt_string)) + dir=dir, sdir=split_feat_dir, + cmvn=cmvn_opt_string)) valid_feats = ("ark,s,cs:utils/filter_scp.pl {dir}/valid_uttlist " "{fdir}/feats.scp | " "apply-cmvn {cmvn} --utt2spk=ark:{fdir}/utt2spk " "scp:{fdir}/cmvn.scp scp:- ark:- |".format( - dir=dir, fdir=feat_dir, cmvn=cmvn_opt_string)) + dir=dir, fdir=feat_dir, cmvn=cmvn_opt_string)) train_subset_feats = ("ark,s,cs:utils/filter_scp.pl " "{dir}/train_subset_uttlist {fdir}/feats.scp | " "apply-cmvn {cmvn} --utt2spk=ark:{fdir}/utt2spk " @@ -470,7 +477,24 @@ def generate_valid_train_subset_egs(dir, targets_parameters, num_train_egs_combine, num_valid_egs_combine, num_egs_diagnostic, cmd, - num_jobs=1): + num_jobs=1, + generate_egs_scp=False): + + if generate_egs_scp: + valid_combine_output = ("ark,scp:{0}/valid_combine.egs," + "{0}/valid_combine.egs.scp".format(dir)) + valid_diagnostic_output = ("ark,scp:{0}/valid_diagnostic.egs," + "{0}/valid_diagnostic.egs.scp".format(dir)) + train_combine_output = ("ark,scp:{0}/train_combine.egs," + "{0}/train_combine.egs.scp".format(dir)) + train_diagnostic_output = ("ark,scp:{0}/train_diagnostic.egs," + "{0}/train_diagnostic.egs.scp".format(dir)) + else: + valid_combine_output = "ark:{0}/valid_combine.egs".format(dir) + valid_diagnostic_output = "ark:{0}/valid_diagnostic.egs".format(dir) + train_combine_output = "ark:{0}/train_combine.egs".format(dir) + train_diagnostic_output = "ark:{0}/train_diagnostic.egs".format(dir) + wait_pids = [] logger.info("Creating validation and train subset examples.") @@ -481,7 +505,8 @@ def generate_valid_train_subset_egs(dir, targets_parameters, valid_pid = common_lib.run_kaldi_command( """{cmd} JOB=1:{nj} {dir}/log/create_valid_subset.JOB.log \ nnet3-get-egs-multiple-targets {v_iv_opt} {v_egs_opt} "{v_feats}" \ - {targets} ark:{dir}/valid_all.JOB.egs""".format( + {targets} ark,scp:{dir}/valid_all.JOB.egs,""" + """{dir}/valid_all.JOB.egs.scp""".format( cmd=cmd, nj=num_jobs, dir=dir, v_egs_opt=egs_opts['valid_egs_opts'], v_iv_opt=feat_ivector_strings['valid_ivector_opts'], @@ -495,7 +520,8 @@ def generate_valid_train_subset_egs(dir, targets_parameters, train_pid = common_lib.run_kaldi_command( """{cmd} JOB=1:{nj} {dir}/log/create_train_subset.JOB.log \ nnet3-get-egs-multiple-targets {t_iv_opt} {v_egs_opt} "{t_feats}" \ - {targets} ark:{dir}/train_subset_all.JOB.egs""".format( + {targets} ark,scp:{dir}/train_subset_all.JOB.egs,""" + """{dir}/train_subset_all.JOB.egs.scp""".format( cmd=cmd, nj=num_jobs, dir=dir, v_egs_opt=egs_opts['valid_egs_opts'], t_iv_opt=feat_ivector_strings['train_subset_ivector_opts'], @@ -514,50 +540,56 @@ def generate_valid_train_subset_egs(dir, targets_parameters, if pid.returncode != 0: raise Exception(stderr) - valid_egs_all = ' '.join(['{dir}/valid_all.{n}.egs'.format(dir=dir, n=n) - for n in range(1, num_jobs + 1)]) - train_subset_egs_all = ' '.join(['{dir}/train_subset_all.{n}.egs'.format( - dir=dir, n=n) - for n in range(1, num_jobs + 1)]) + valid_egs_all = ' '.join( + ['{dir}/valid_all.{n}.egs.scp'.format(dir=dir, n=n) + for n in range(1, num_jobs + 1)]) + train_subset_egs_all = ' '.join( + ['{dir}/train_subset_all.{n}.egs.scp'.format(dir=dir, n=n) + for n in range(1, num_jobs + 1)]) wait_pids = [] logger.info("... Getting subsets of validation examples for diagnostics " " and combination.") pid = common_lib.run_kaldi_command( """{cmd} {dir}/log/create_valid_subset_combine.log \ - cat {valid_egs_all} \| nnet3-subset-egs --n={nve_combine} ark:- \ - ark:{dir}/valid_combine.egs""".format( + cat {valid_egs_all} \| nnet3-subset-egs --n={nve_combine} \ + scp:- {valid_combine_output}""".format( cmd=cmd, dir=dir, valid_egs_all=valid_egs_all, - nve_combine=num_valid_egs_combine), + nve_combine=num_valid_egs_combine, + valid_combine_output=valid_combine_output), wait=False) wait_pids.append(pid) pid = common_lib.run_kaldi_command( """{cmd} {dir}/log/create_valid_subset_diagnostic.log \ - cat {valid_egs_all} \| nnet3-subset-egs --n={ne_diagnostic} ark:- \ - ark:{dir}/valid_diagnostic.egs""".format( + cat {valid_egs_all} \| nnet3-subset-egs --n={ne_diagnostic} \ + scp:- {valid_diagnostic_output}""".format( cmd=cmd, dir=dir, valid_egs_all=valid_egs_all, - ne_diagnostic=num_egs_diagnostic), + ne_diagnostic=num_egs_diagnostic, + valid_diagnostic_output=valid_diagnostic_output), wait=False) wait_pids.append(pid) pid = common_lib.run_kaldi_command( """{cmd} {dir}/log/create_train_subset_combine.log \ cat {train_subset_egs_all} \| \ - nnet3-subset-egs --n={nte_combine} ark:- \ - ark:{dir}/train_combine.egs""".format( + nnet3-subset-egs --n={nte_combine} \ + scp:- {train_combine_output}""".format( cmd=cmd, dir=dir, train_subset_egs_all=train_subset_egs_all, - nte_combine=num_train_egs_combine), + nte_combine=num_train_egs_combine, + train_combine_output=train_combine_output), wait=False) wait_pids.append(pid) pid = common_lib.run_kaldi_command( """{cmd} {dir}/log/create_train_subset_diagnostic.log \ cat {train_subset_egs_all} \| \ - nnet3-subset-egs --n={ne_diagnostic} ark:- \ - ark:{dir}/train_diagnostic.egs""".format( + nnet3-subset-egs --n={ne_diagnostic} \ + scp:- {train_diagnostic_output}""".format( cmd=cmd, dir=dir, train_subset_egs_all=train_subset_egs_all, - ne_diagnostic=num_egs_diagnostic), wait=False) + ne_diagnostic=num_egs_diagnostic, + train_diagnostic_output=train_diagnostic_output), + wait=False) wait_pids.append(pid) for pid in wait_pids: @@ -569,6 +601,14 @@ def generate_valid_train_subset_egs(dir, targets_parameters, """cat {dir}/valid_combine.egs {dir}/train_combine.egs > \ {dir}/combine.egs""".format(dir=dir)) + if generate_egs_scp: + common_lib.run_kaldi_command( + """cat {dir}/valid_combine.egs.scp {dir}/train_combine.egs.scp > \ + {dir}/combine.egs.scp""".format(dir=dir)) + common_lib.run_kaldi_command( + "rm {dir}/valid_combine.egs.scp {dir}/train_combine.egs.scp" + "".format(dir=dir)) + # perform checks for file_name in ('{0}/combine.egs {0}/train_diagnostic.egs ' '{0}/valid_diagnostic.egs'.format(dir).split()): @@ -577,6 +617,7 @@ def generate_valid_train_subset_egs(dir, targets_parameters, # clean-up for x in ('{0}/valid_all.*.egs {0}/train_subset_all.*.egs ' + '{0}/valid_all.*.egs.scp {0}/train_subset_all.*.egs.scp ' '{0}/train_combine.egs ' '{0}/valid_combine.egs'.format(dir).split()): for file_name in glob.glob(x): @@ -591,7 +632,8 @@ def generate_training_examples_internal(dir, targets_parameters, feat_dir, samples_per_iter, cmd, srand=0, reduce_frames_per_eg=True, only_shuffle=False, - dry_run=False): + dry_run=False, + generate_egs_scp=False): # The examples will go round-robin to egs_list. Note: we omit the # 'normalization.fst' argument while creating temporary egs: the phase of @@ -605,8 +647,8 @@ def generate_training_examples_internal(dir, targets_parameters, feat_dir, num_archives = (num_frames) / (frames_per_eg * samples_per_iter) + 1 reduced = False - while (reduce_frames_per_eg and frames_per_eg > 1 and - num_frames / ((frames_per_eg-1)*samples_per_iter) == 0): + while (reduce_frames_per_eg and frames_per_eg > 1 + and num_frames / ((frames_per_eg-1)*samples_per_iter) == 0): frames_per_eg -= 1 num_archives = 1 reduced = True @@ -652,9 +694,9 @@ def generate_training_examples_internal(dir, targets_parameters, feat_dir, for y in range(1, num_jobs + 1)]) split_feat_dir = "{0}/split{1}".format(feat_dir, num_jobs) - egs_list = ' '.join(['ark:{dir}/egs_orig.JOB.{ark_num}.ark'.format( - dir=dir, ark_num=x) - for x in range(1, num_archives_intermediate + 1)]) + egs_list = ' '.join( + ['ark:{dir}/egs_orig.JOB.{ark_num}.ark'.format(dir=dir, ark_num=x) + for x in range(1, num_archives_intermediate + 1)]) if not only_shuffle: common_lib.run_kaldi_command( @@ -678,20 +720,43 @@ def generate_training_examples_internal(dir, targets_parameters, feat_dir, if archives_multiple == 1: # there are no intermediate archives so just shuffle egs across # jobs and dump them into a single output + + if generate_egs_scp: + output_archive = ("ark,scp:{dir}/egs.JOB.ark," + "{dir}/egs.JOB.scp".format(dir=dir)) + else: + output_archive = "ark:{dir}/egs.JOB.ark".format(dir=dir) + common_lib.run_kaldi_command( """{cmd} --max-jobs-run {msjr} JOB=1:{nai} \ {dir}/log/shuffle.JOB.log \ nnet3-shuffle-egs --srand=$[JOB+{srand}] \ - "ark:cat {egs_list}|" ark:{dir}/egs.JOB.ark""".format( + "ark:cat {egs_list}|" {output_archive}""".format( cmd=cmd, msjr=num_jobs, nai=num_archives_intermediate, srand=srand, - dir=dir, egs_list=egs_list)) + dir=dir, egs_list=egs_list, + output_archive=output_archive)) + + if generate_egs_scp: + out_egs_handle = open("{0}/egs.scp".format(dir), 'w') + for i in range(1, num_archives_intermediate + 1): + for line in open("{0}/egs.{1}.scp".format(dir, i)): + print (line, file=out_egs_handle) + out_egs_handle.close() else: # there are intermediate archives so we shuffle egs across jobs # and split them into archives_multiple output archives - output_archives = ' '.join(["ark:{dir}/egs.JOB.{ark_num}.ark".format( - dir=dir, ark_num=x) - for x in range(1, archives_multiple + 1)]) + if generate_egs_scp: + output_archives = ' '.join( + ["ark,scp:{dir}/egs.JOB.{ark_num}.ark," + "{dir}/egs.JOB.{ark_num}.scp".format( + dir=dir, ark_num=x) + for x in range(1, archives_multiple + 1)]) + else: + output_archives = ' '.join( + ["ark:{dir}/egs.JOB.{ark_num}.ark".format( + dir=dir, ark_num=x) + for x in range(1, archives_multiple + 1)]) # archives were created as egs.x.y.ark # linking them to egs.i.ark format which is expected by the training # scripts @@ -712,6 +777,14 @@ def generate_training_examples_internal(dir, targets_parameters, feat_dir, nai=num_archives_intermediate, srand=srand, dir=dir, egs_list=egs_list, oarks=output_archives)) + if generate_egs_scp: + out_egs_handle = open("{0}/egs.scp".format(dir), 'w') + for i in range(1, num_archives_intermediate + 1): + for j in range(1, archives_multiple + 1): + for line in open("{0}/egs.{1}.{2}.scp".format(dir, i, j)): + print (line, file=out_egs_handle) + out_egs_handle.close() + cleanup(dir, archives_multiple) return {'num_frames': num_frames, 'num_archives': num_archives, @@ -744,7 +817,8 @@ def generate_training_examples(dir, targets_parameters, feat_dir, feat_ivector_strings, egs_opts, frame_shift, frames_per_eg, samples_per_iter, cmd, num_jobs, srand=0, - only_shuffle=False, dry_run=False): + only_shuffle=False, dry_run=False, + generate_egs_scp=False): # generate the training options string with the given chunk_width train_egs_opts = egs_opts['train_egs_opts'] @@ -769,7 +843,8 @@ def generate_training_examples(dir, targets_parameters, feat_dir, samples_per_iter=samples_per_iter, cmd=cmd, srand=srand, only_shuffle=only_shuffle, - dry_run=dry_run) + dry_run=dry_run, + generate_egs_scp=generate_egs_scp) return info @@ -792,13 +867,15 @@ def generate_egs(egs_dir, feat_dir, targets_para_array, cmvn_opts=None, apply_cmvn_sliding=False, compress_input=True, input_compress_format=0, - num_utts_subset=300, + num_utts_subset_train=300, + num_utts_subset_valid=300, num_train_egs_combine=1000, num_valid_egs_combine=0, num_egs_diagnostic=4000, samples_per_iter=400000, num_jobs=6, - srand=0): + srand=0, + generate_egs_scp=False): for directory in '{0}/log {0}/info'.format(egs_dir).split(): create_directory(directory) @@ -817,9 +894,9 @@ def generate_egs(egs_dir, feat_dir, targets_para_array, frame_shift = data_lib.get_frame_shift(feat_dir) min_duration = frames_per_eg * frame_shift - valid_utts = sample_utts(feat_dir, num_utts_subset, min_duration)[0] - train_subset_utts = sample_utts(feat_dir, num_utts_subset, min_duration, - exclude_list=valid_utts)[0] + valid_utts = sample_utts(feat_dir, num_utts_subset_valid, min_duration)[0] + train_subset_utts = sample_utts(feat_dir, num_utts_subset_train, + min_duration, exclude_list=valid_utts)[0] train_utts, train_utts_durs = sample_utts(feat_dir, None, -1, exclude_list=valid_utts) @@ -857,7 +934,8 @@ def generate_egs(egs_dir, feat_dir, targets_para_array, num_valid_egs_combine=num_valid_egs_combine, num_egs_diagnostic=num_egs_diagnostic, cmd=cmd, - num_jobs=num_jobs) + num_jobs=num_jobs, + generate_egs_scp=generate_egs_scp) logger.info("Generating training examples on disk.") info = generate_training_examples( @@ -873,7 +951,8 @@ def generate_egs(egs_dir, feat_dir, targets_para_array, num_jobs=num_jobs, srand=srand, only_shuffle=True if stage > 3 else False, - dry_run=True if stage > 4 else False) + dry_run=True if stage > 4 else False, + generate_egs_scp=generate_egs_scp) info['feat_dim'] = feat_ivector_strings['feat_dim'] info['ivector_dim'] = feat_ivector_strings['ivector_dim'] @@ -898,13 +977,15 @@ def main(): apply_cmvn_sliding=args.apply_cmvn_sliding, compress_input=args.compress_input, input_compress_format=args.input_compress_format, - num_utts_subset=args.num_utts_subset, + num_utts_subset_train=args.num_utts_subset_train, + num_utts_subset_valid=args.num_utts_subset_valid, num_train_egs_combine=args.num_train_egs_combine, num_valid_egs_combine=args.num_valid_egs_combine, num_egs_diagnostic=args.num_egs_diagnostic, samples_per_iter=args.samples_per_iter, num_jobs=args.num_jobs, - srand=args.srand) + srand=args.srand, + generate_egs_scp=args.generate_egs_scp) if __name__ == "__main__": From 54cc83675e91555c8a099a293f00ce1d14165190 Mon Sep 17 00:00:00 2001 From: Vimal Manohar Date: Mon, 19 Dec 2016 20:00:55 -0500 Subject: [PATCH 111/213] asr_diarization: Add objective-scale to xconfig output --- .../steps/libs/nnet3/xconfig/basic_layers.py | 51 +++++++++++++++---- 1 file changed, 41 insertions(+), 10 deletions(-) diff --git a/egs/wsj/s5/steps/libs/nnet3/xconfig/basic_layers.py b/egs/wsj/s5/steps/libs/nnet3/xconfig/basic_layers.py index 38ff36622ec..f74137da48b 100644 --- a/egs/wsj/s5/steps/libs/nnet3/xconfig/basic_layers.py +++ b/egs/wsj/s5/steps/libs/nnet3/xconfig/basic_layers.py @@ -448,7 +448,8 @@ def set_default_configs(self): 'max-change' : 1.5, 'param-stddev' : 0.0, 'bias-stddev' : 0.0, - 'output-delay' : 0 + 'output-delay' : 0, + 'objective-scale': 1.0 } def check_configs(self): @@ -513,6 +514,7 @@ def get_full_config(self): bias_stddev = self.config['bias-stddev'] output_delay = self.config['output-delay'] max_change = self.config['max-change'] + objective_scale = self.config['objective-scale'] # note: ref.config is used only for getting the left-context and # right-context of the network; @@ -553,6 +555,18 @@ def get_full_config(self): ans.append((config_name, line)) cur_node = '{0}.fixed-scale'.format(self.name) + if objective_scale != 1.0: + line = ('component name={0}.objective-scale' + ' type=ScaleGradientComponent scale={1} dim={2}' + ''.format(self.name, objective_scale, output_dim)) + ans.append((config_name, line)) + + line = ('component-node name={0}.objective-scale' + ' component={0}.objective-scale input={1}' + ''.format(self.name, cur_node)) + ans.append((config_name, line)) + cur_node = '{0}.objective-scale'.format(self.name) + if include_log_softmax: line = ('component name={0}.log-softmax' ' type=LogSoftmaxComponent dim={1}' @@ -611,7 +625,24 @@ def set_default_configs(self): 'max-change' : 0.75, 'self-repair-scale' : 1.0e-05, 'target-rms' : 1.0, - 'ng-affine-options' : ''} + 'ng-affine-options' : '', + 'add-log-stddev' : False } + + def set_derived_configs(self): + output_dim = self.config['dim'] + # If not set, the output-dim defaults to the input-dim. + if output_dim <= 0: + self.config['dim'] = self.descriptors['input']['dim'] + + if self.config['add-log-stddev']: + split_layer_name = self.layer_type.split('-') + assert split_layer_name[-1] == 'layer' + nonlinearities = split_layer_name[:-1] + + for nonlinearity in nonlinearities: + if nonlinearity == "renorm": + output_dim += 1 + self.config['output-dim'] = output_dim def check_configs(self): if self.config['dim'] < 0: @@ -633,12 +664,7 @@ def output_name(self, auxiliary_output=None): return '{0}.{1}'.format(self.name, last_nonlinearity) def output_dim(self, auxiliary_output = None): - output_dim = self.config['dim'] - # If not set, the output-dim defaults to the input-dim. - if output_dim <= 0: - output_dim = self.descriptors['input']['dim'] - return output_dim - + return self.config['output-dim'] def get_full_config(self): ans = [] @@ -668,11 +694,13 @@ def _generate_config(self): return self._add_components(input_desc, input_dim, nonlinearities) def _add_components(self, input_desc, input_dim, nonlinearities): - output_dim = self.output_dim() + output_dim = self.config['dim'] self_repair_scale = self.config['self-repair-scale'] target_rms = self.config['target-rms'] max_change = self.config['max-change'] ng_opt_str = self.config['ng-affine-options'] + add_log_stddev = ("true" if self.config['add-log-stddev'] + else "false") configs = [] # First the affine node. @@ -718,8 +746,11 @@ def _add_components(self, input_desc, input_dim, nonlinearities): line = ('component name={0}.{1}' ' type=NormalizeComponent dim={2}' ' target-rms={3}' + ' add-log-stddev={4}' ''.format(self.name, nonlinearity, output_dim, - target_rms)) + target_rms, add_log_stddev)) + if self.config['add-log-stddev']: + output_dim += 1 else: raise xparser_error("Unknown nonlinearity type:" From eb04aabe61a8fa12f80ce22dd2dea2f95450577a Mon Sep 17 00:00:00 2001 From: Vimal Manohar Date: Mon, 19 Dec 2016 20:01:39 -0500 Subject: [PATCH 112/213] asr_diarization: Support multitask egs at script level --- egs/wsj/s5/steps/libs/nnet3/train/common.py | 12 ++ .../nnet3/train/frame_level_objf/common.py | 111 +++++++++++++----- egs/wsj/s5/steps/nnet3/train_raw_rnn.py | 4 +- 3 files changed, 99 insertions(+), 28 deletions(-) diff --git a/egs/wsj/s5/steps/libs/nnet3/train/common.py b/egs/wsj/s5/steps/libs/nnet3/train/common.py index 90ee209a092..7ae44cdffae 100644 --- a/egs/wsj/s5/steps/libs/nnet3/train/common.py +++ b/egs/wsj/s5/steps/libs/nnet3/train/common.py @@ -559,6 +559,18 @@ def __init__(self): action=common_lib.NullstrToNoneAction, help="""String to provide options directly to steps/nnet3/get_egs.sh script""") + self.parser.add_argument("--egs.use-multitask-egs", type=str, + dest='use_multitask_egs', + default=True, choices=["true", "false"], + action=common_lib.StrToBoolAction, + help="""Use mutlitask egs created using + allocate_multilingual_egs.py.""") + self.parser.add_argument("--egs.rename-multitask-outputs", type=str, + dest='rename_multitask_outputs', + default=True, choices=["true", "false"], + action=common_lib.StrToBoolAction, + help="""Rename multitask outputs created using + allocate_multilingual_egs.py.""") # trainer options self.parser.add_argument("--trainer.srand", type=int, dest='srand', diff --git a/egs/wsj/s5/steps/libs/nnet3/train/frame_level_objf/common.py b/egs/wsj/s5/steps/libs/nnet3/train/frame_level_objf/common.py index b8a28d2e2bf..508445e331e 100644 --- a/egs/wsj/s5/steps/libs/nnet3/train/frame_level_objf/common.py +++ b/egs/wsj/s5/steps/libs/nnet3/train/frame_level_objf/common.py @@ -33,7 +33,8 @@ def train_new_models(dir, iter, srand, num_jobs, frames_per_eg=-1, min_deriv_time=None, max_deriv_time=None, min_left_context=None, min_right_context=None, - extra_egs_copy_cmd=""): + extra_egs_copy_cmd="", use_multitask_egs=False, + rename_multitask_outputs=False): """ Called from train_one_iteration(), this model does one iteration of training with 'num_jobs' jobs, and writes files like exp/tdnn_a/24.{1,2,3,..}.raw @@ -98,11 +99,49 @@ def train_new_models(dir, iter, srand, num_jobs, cache_write_opt = "--write-cache={dir}/cache.{iter}".format( dir=dir, iter=iter+1) - minibatch_opts = "--minibatch-size={0}".format(minibatch_size) - - if chunk_level_training: - minibatch_opts = "{0} --measure-output-frames=false".format( - minibatch_size) + if use_multitask_egs: + output_rename_opt = "" + if rename_multitask_outputs: + output_rename_opt = ( + "--output=ark:{egs_dir}" + "/output.{archive_index}".format( + egs_dir=egs_dir, archive_index=archive_index)) + egs_rspecifier = ( + "ark,bg:nnet3-copy-egs {frame_opts} {context_opts} " + "{output_rename_opt} " + "--weights=ark:{egs_dir}/weight.{archive_index} " + "scp:{egs_dir}/egs.{archive_index} ark:- | " + "{extra_egs_copy_cmd}" + "nnet3-merge-egs --minibatch-size={minibatch_size} " + "--measure-output-frames=false " + "--discard-partial-minibatches=true ark:- ark:- | " + "nnet3-shuffle-egs --buffer-size={shuffle_buffer_size} " + "ark:- ark:- |".format( + frame_opts=("" if chunk_level_training + else "--frame={0}".format(frame)), + context_opts=context_opts, egs_dir=egs_dir, + output_rename_opt=output_rename_opt, + archive_index=archive_index, + shuffle_buffer_size=shuffle_buffer_size, + extra_egs_copy_cmd=extra_egs_copy_cmd, + minibatch_size=minibatch_size)) + else: + egs_rspecifier = ( + "ark,bg:nnet3-copy-egs {frame_opts} {context_opts} " + "ark:{egs_dir}/egs.{archive_index}.ark ark:- |" + "{extra_egs_copy_cmd}" + "nnet3-shuffle-egs --buffer-size={shuffle_buffer_size} " + "--srand={srand} ark:- ark:- | " + "nnet3-merge-egs --minibatch-size={minibatch_size} " + "--measure-output-frames=false " + "--discard-partial-minibatches=true ark:- ark:- |".format( + frame_opts=("" if chunk_level_training + else "--frame={0}".format(frame)), + context_opts=context_opts, egs_dir=egs_dir, + archive_index=archive_index, + shuffle_buffer_size=shuffle_buffer_size, + extra_egs_copy_cmd=extra_egs_copy_cmd, + minibatch_size=minibatch_size)) process_handle = common_lib.run_job( """{command} {train_queue_opt} {dir}/log/train.{iter}.{job}.log \ @@ -111,12 +150,7 @@ def train_new_models(dir, iter, srand, num_jobs, --momentum={momentum} \ --max-param-change={max_param_change} \ {deriv_time_opts} "{raw_model}" \ - "ark,bg:nnet3-copy-egs {frame_opts} {context_opts} """ - """ark:{egs_dir}/egs.{archive_index}.ark ark:- |{extra_egs_copy_cmd}""" - """nnet3-shuffle-egs --buffer-size={shuffle_buffer_size} """ - """--srand={srand} ark:- ark:- | """ - """nnet3-merge-egs {minibatch_opts} """ - """--discard-partial-minibatches=true ark:- ark:- |" \ + "{egs_rspecifier}" \ {dir}/{next_iter}.{job}.raw""".format( command=run_opts.command, train_queue_opt=run_opts.train_queue_opt, @@ -126,17 +160,10 @@ def train_new_models(dir, iter, srand, num_jobs, parallel_train_opts=run_opts.parallel_train_opts, cache_read_opt=cache_read_opt, cache_write_opt=cache_write_opt, - frame_opts=("" - if chunk_level_training - else "--frame={0}".format(frame)), - minibatch_opts=minibatch_opts, momentum=momentum, max_param_change=max_param_change, deriv_time_opts=" ".join(deriv_time_opts), - raw_model=raw_model_string, context_opts=context_opts, - egs_dir=egs_dir, archive_index=archive_index, - shuffle_buffer_size=shuffle_buffer_size, - extra_egs_copy_cmd=extra_egs_copy_cmd), - wait=False) + raw_model=raw_model_string, + egs_rspecifier=egs_rspecifier), wait=False) processes.append(process_handle) @@ -166,7 +193,8 @@ def train_one_iteration(dir, iter, srand, egs_dir, shrinkage_value=1.0, dropout_proportions=None, get_raw_nnet_from_am=True, background_process_handler=None, - extra_egs_copy_cmd=""): + extra_egs_copy_cmd="", use_multitask_egs=False, + rename_multitask_outputs=False): """ Called from steps/nnet3/train_*.py scripts for one iteration of neural network training @@ -309,7 +337,9 @@ def train_one_iteration(dir, iter, srand, egs_dir, max_deriv_time=max_deriv_time, min_left_context=min_left_context, min_right_context=min_right_context, - extra_egs_copy_cmd=extra_egs_copy_cmd) + extra_egs_copy_cmd=extra_egs_copy_cmd, + use_multitask_egs=use_multitask_egs, + rename_multitask_outputs=rename_multitask_outputs) [models_to_average, best_model] = common_train_lib.get_successful_models( num_jobs, '{0}/log/train.{1}.%.log'.format(dir, iter)) @@ -419,15 +449,22 @@ def compute_train_cv_probabilities(dir, iter, egs_dir, left_context, context_opts = "--left-context={lc} --right-context={rc}".format( lc=left_context, rc=right_context) + if os.path.isfile("{0}/valid_diagnostic.egs".format(egs_dir)): + valid_diagnostic_egs = "ark:{0}/valid_diagnostic.egs".format(egs_dir) + else: + valid_diagnostic_egs = "scp:{0}/valid_diagnostic.egs.1".format( + egs_dir) + common_lib.run_job( """ {command} {dir}/log/compute_prob_valid.{iter}.log \ nnet3-compute-prob "{model}" \ "ark,bg:nnet3-copy-egs {context_opts} \ - ark:{egs_dir}/valid_diagnostic.egs ark:- |{extra_egs_copy_cmd} \ + {egs_rspecifier} ark:- |{extra_egs_copy_cmd} \ nnet3-merge-egs --minibatch-size={mb_size} ark:- \ ark:- |" """.format(command=run_opts.command, dir=dir, iter=iter, + egs_rspecifier=valid_diagnostic_egs, context_opts=context_opts, mb_size=mb_size, model=model, @@ -435,15 +472,22 @@ def compute_train_cv_probabilities(dir, iter, egs_dir, left_context, extra_egs_copy_cmd=extra_egs_copy_cmd), wait=wait, background_process_handler=background_process_handler) + if os.path.isfile("{0}/train_diagnostic.egs".format(egs_dir)): + train_diagnostic_egs = "ark:{0}/train_diagnostic.egs".format(egs_dir) + else: + train_diagnostic_egs = "scp:{0}/train_diagnostic.egs.1".format( + egs_dir) + common_lib.run_job( """{command} {dir}/log/compute_prob_train.{iter}.log \ nnet3-compute-prob "{model}" \ "ark,bg:nnet3-copy-egs {context_opts} \ - ark:{egs_dir}/train_diagnostic.egs ark:- |{extra_egs_copy_cmd} \ + {egs_rspecifier} ark:- | {extra_egs_copy_cmd} \ nnet3-merge-egs --minibatch-size={mb_size} ark:- \ ark:- |" """.format(command=run_opts.command, dir=dir, iter=iter, + egs_rspecifier=train_diagnostic_egs, context_opts=context_opts, mb_size=mb_size, model=model, @@ -468,16 +512,23 @@ def compute_progress(dir, iter, egs_dir, left_context, right_context, context_opts = "--left-context={lc} --right-context={rc}".format( lc=left_context, rc=right_context) + if os.path.isfile("{0}/train_diagnostic.egs".format(egs_dir)): + train_diagnostic_egs = "ark:{0}/train_diagnostic.egs".format(egs_dir) + else: + train_diagnostic_egs = "scp:{0}/train_diagnostic.egs.1".format( + egs_dir) + common_lib.run_job( """{command} {dir}/log/progress.{iter}.log \ nnet3-info "{model}" '&&' \ nnet3-show-progress --use-gpu=no "{prev_model}" "{model}" \ "ark,bg:nnet3-copy-egs {context_opts} \ - ark:{egs_dir}/train_diagnostic.egs ark:- |{extra_egs_copy_cmd} \ + {egs_rspecifier} ark:- |{extra_egs_copy_cmd} \ nnet3-merge-egs --minibatch-size={mb_size} ark:- \ ark:- |" """.format(command=run_opts.command, dir=dir, iter=iter, + egs_rspecifier=train_diagnostic_egs, model=model, context_opts=context_opts, mb_size=mb_size, @@ -532,19 +583,25 @@ def combine_models(dir, num_iters, models_to_combine, egs_dir, context_opts = "--left-context={lc} --right-context={rc}".format( lc=left_context, rc=right_context) + if os.path.isfile("{0}/combine.egs".format(egs_dir)): + combine_egs = "ark:{0}/combine.egs".format(egs_dir) + else: + combine_egs = "scp:{0}/combine.egs.1".format(egs_dir) + common_lib.run_job( """{command} {combine_queue_opt} {dir}/log/combine.log \ nnet3-combine --num-iters=40 \ --enforce-sum-to-one=true --enforce-positive-weights=true \ --verbose=3 {raw_models} \ "ark,bg:nnet3-copy-egs {context_opts} \ - ark:{egs_dir}/combine.egs ark:- |{extra_egs_copy_cmd} \ + {egs_rspecifier} ark:- |{extra_egs_copy_cmd} \ nnet3-merge-egs --measure-output-frames=false \ --minibatch-size={mbsize} ark:- ark:- |" \ "{out_model}" """.format(command=run_opts.command, combine_queue_opt=run_opts.combine_queue_opt, dir=dir, raw_models=" ".join(raw_model_strings), + egs_rspecifier=combine_egs, context_opts=context_opts, mbsize=mbsize, out_model=out_model, diff --git a/egs/wsj/s5/steps/nnet3/train_raw_rnn.py b/egs/wsj/s5/steps/nnet3/train_raw_rnn.py index 4a2424e54f5..2bea66dbcbf 100755 --- a/egs/wsj/s5/steps/nnet3/train_raw_rnn.py +++ b/egs/wsj/s5/steps/nnet3/train_raw_rnn.py @@ -466,7 +466,9 @@ def learning_rate(iter, current_num_jobs, num_archives_processed): run_opts=run_opts, get_raw_nnet_from_am=False, background_process_handler=background_process_handler, - extra_egs_copy_cmd=args.extra_egs_copy_cmd) + extra_egs_copy_cmd=args.extra_egs_copy_cmd, + use_multitask_egs=args.use_multitask_egs, + rename_multitask_outputs=args.rename_multitask_outputs) if args.cleanup: # do a clean up everythin but the last 2 models, under certain From 8174b3ce5ec5371c7f491f84aa3ac5b3f3fbfc52 Mon Sep 17 00:00:00 2001 From: Vimal Manohar Date: Mon, 19 Dec 2016 20:02:19 -0500 Subject: [PATCH 113/213] asr_diarization: Support multi output training diagnostics correctly --- src/nnet3/nnet-chain-training.cc | 4 ++-- src/nnet3/nnet-training.cc | 5 ++--- src/nnet3/nnet-training.h | 4 ++-- 3 files changed, 6 insertions(+), 7 deletions(-) diff --git a/src/nnet3/nnet-chain-training.cc b/src/nnet3/nnet-chain-training.cc index d9d43006601..91dc0d8ec19 100644 --- a/src/nnet3/nnet-chain-training.cc +++ b/src/nnet3/nnet-chain-training.cc @@ -137,14 +137,14 @@ void NnetChainTrainer::ProcessOutputs(const NnetChainExample &eg, computer->AcceptOutputDeriv(sup.name, &nnet_output_deriv); objf_info_[sup.name].UpdateStats(sup.name, opts_.nnet_config.print_interval, - num_minibatches_processed_++, tot_weight, tot_objf, tot_l2_term); - + if (use_xent) { xent_deriv.Scale(opts_.chain_config.xent_regularize); computer->AcceptOutputDeriv(xent_name, &xent_deriv); } } + num_minibatches_processed_++; } void NnetChainTrainer::UpdateParamsWithMaxChange() { diff --git a/src/nnet3/nnet-training.cc b/src/nnet3/nnet-training.cc index bdbe244a648..803f3570b1a 100644 --- a/src/nnet3/nnet-training.cc +++ b/src/nnet3/nnet-training.cc @@ -95,10 +95,10 @@ void NnetTrainer::ProcessOutputs(const NnetExample &eg, supply_deriv, computer, &tot_weight, &tot_objf, deriv_weights); objf_info_[io.name].UpdateStats(io.name, config_.print_interval, - num_minibatches_processed_++, tot_weight, tot_objf); } } + num_minibatches_processed_++; } void NnetTrainer::UpdateParamsWithMaxChange() { @@ -226,11 +226,10 @@ void NnetTrainer::PrintMaxChangeStats() const { void ObjectiveFunctionInfo::UpdateStats( const std::string &output_name, int32 minibatches_per_phase, - int32 minibatch_counter, BaseFloat this_minibatch_weight, BaseFloat this_minibatch_tot_objf, BaseFloat this_minibatch_tot_aux_objf) { - int32 phase = minibatch_counter / minibatches_per_phase; + int32 phase = num_minibatches++ / minibatches_per_phase; if (phase != current_phase) { KALDI_ASSERT(phase == current_phase + 1); // or doesn't really make sense. PrintStatsForThisPhase(output_name, minibatches_per_phase); diff --git a/src/nnet3/nnet-training.h b/src/nnet3/nnet-training.h index 7b22bc75211..fefaf9ea122 100644 --- a/src/nnet3/nnet-training.h +++ b/src/nnet3/nnet-training.h @@ -98,6 +98,7 @@ struct NnetTrainerOptions { // Also see struct AccuracyInfo, in nnet-diagnostics.h. struct ObjectiveFunctionInfo { int32 current_phase; + int32 num_minibatches; double tot_weight; double tot_objf; @@ -110,7 +111,7 @@ struct ObjectiveFunctionInfo { double tot_aux_objf_this_phase; ObjectiveFunctionInfo(): - current_phase(0), + current_phase(0), num_minibatches(0), tot_weight(0.0), tot_objf(0.0), tot_aux_objf(0.0), tot_weight_this_phase(0.0), tot_objf_this_phase(0.0), tot_aux_objf_this_phase(0.0) { } @@ -121,7 +122,6 @@ struct ObjectiveFunctionInfo { // control how frequently we print logging messages. void UpdateStats(const std::string &output_name, int32 minibatches_per_phase, - int32 minibatch_counter, BaseFloat this_minibatch_weight, BaseFloat this_minibatch_tot_objf, BaseFloat this_minibatch_tot_aux_objf = 0.0); From 59c9a2d3ef7f89f1776be939063f56f38dce3a93 Mon Sep 17 00:00:00 2001 From: Vimal Manohar Date: Mon, 19 Dec 2016 20:02:51 -0500 Subject: [PATCH 114/213] asr_diarization: Add data to libs __init__ --- egs/wsj/s5/steps/libs/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/wsj/s5/steps/libs/__init__.py b/egs/wsj/s5/steps/libs/__init__.py index 013c95d0b3f..8f3540643c8 100644 --- a/egs/wsj/s5/steps/libs/__init__.py +++ b/egs/wsj/s5/steps/libs/__init__.py @@ -8,4 +8,4 @@ import common -__all__ = ["common"] +__all__ = ["common", "data"] From 38f2515ba275e3836f8af328600cfce3a51894ee Mon Sep 17 00:00:00 2001 From: Vimal Manohar Date: Mon, 19 Dec 2016 20:03:21 -0500 Subject: [PATCH 115/213] asr_diarization: Adding new overlapped speech recipe --- .../tuning/train_stats_sad_overlap_1a.sh | 51 ++-- .../tuning/train_stats_sad_overlap_1b.sh | 239 ++++++++++++++++++ 2 files changed, 266 insertions(+), 24 deletions(-) create mode 100644 egs/aspire/s5/local/segmentation/tuning/train_stats_sad_overlap_1b.sh diff --git a/egs/aspire/s5/local/segmentation/tuning/train_stats_sad_overlap_1a.sh b/egs/aspire/s5/local/segmentation/tuning/train_stats_sad_overlap_1a.sh index aae1fd995e0..c8a7c887fef 100644 --- a/egs/aspire/s5/local/segmentation/tuning/train_stats_sad_overlap_1a.sh +++ b/egs/aspire/s5/local/segmentation/tuning/train_stats_sad_overlap_1a.sh @@ -20,7 +20,7 @@ egs_opts= # Directly passed to get_egs_multiple_targets.py # TDNN options relu_dim=256 -chunk_width=20 # We use chunk training for training TDNN +chunk_width=40 # We use chunk training for training TDNN extra_left_context=100 # Maximum left context in egs apart from TDNN's left context extra_right_context=20 # Maximum right context in egs apart from TDNN's right context @@ -41,28 +41,23 @@ max_param_change=0.2 # Small max-param change for small network extra_egs_copy_cmd= # Used if you want to do some weird stuff to egs # such as removing one of the targets -num_utts_subset_valid=50 # "utts" is actually recording. So this is prettly small. -num_utts_subset_train=50 - # target options -train_data_dir=data/train_azteec_whole_sp_corrupted_hires - -snr_scp= -speech_feat_scp= -overlapped_speech_labels_scp= - -deriv_weights_scp= -deriv_weights_for_overlapped_speech_scp= +train_data_dir=data/train_aztec_small_unsad_a +speech_feat_scp=data/train_aztec_small_unsad_a/speech_feat.scp +deriv_weights_scp=data/train_aztec_small_unsad_a/deriv_weights.scp -train_data_dir=data/train_aztec_small_unsad_whole_sad_ovlp_corrupted_sp -speech_feat_scp=data/train_aztec_unsad_whole_corrupted_sp_hires_bp/speech_feat.scp -deriv_weights_scp=data/train_aztec_unsad_whole_corrupted_sp_hires_bp_2400/deriv_weights.scp +#train_data_dir=data/train_aztec_small_unsad_whole_sad_ovlp_corrupted_sp +#speech_feat_scp=data/train_aztec_unsad_whole_corrupted_sp_hires_bp/speech_feat.scp +#deriv_weights_scp=data/train_aztec_unsad_whole_corrupted_sp_hires_bp_2400/deriv_weights.scp +#data/train_aztec_small_unsad_whole_all_corrupted_sp_hires_bp +# Only for SAD snr_scp=data/train_aztec_unsad_whole_corrupted_sp_hires_bp/irm_targets.scp deriv_weights_for_irm_scp=data/train_aztec_unsad_whole_corrupted_sp_hires_bp/deriv_weights_manual_seg.scp -deriv_weights_for_overlapped_speech_scp= -overlapped_speech_labels_scp= +# Only for overlapped speech detection +deriv_weights_for_overlapped_speech_scp=data/train_aztec_unsad_seg_ovlp_corrupted_hires_bp/deriv_weights_for_overlapped_speech.scp +overlapped_speech_labels_scp=data/train_aztec_unsad_seg_ovlp_corrupted_hires_bp/overlapped_speech_labels.scp #extra_left_context=79 #extra_right_context=11 @@ -79,6 +74,10 @@ affix=a . ./path.sh . ./utils/parse_options.sh +num_utts=`cat $train_data_dir/utt2spk | wc -l` +num_utts_subset_valid=`perl -e '$n=int($ARGV[0] * 0.005); print ($n > 4000 ? 4000 : $n)' $num_utts` +num_utts_subset_train=`perl -e '$n=int($ARGV[0] * 0.005); print ($n > 4000 ? 4000 : $n)' $num_utts` + if [ -z "$dir" ]; then dir=exp/nnet3_stats_sad_ovlp_snr/nnet_tdnn fi @@ -109,11 +108,15 @@ if [ $stage -le 3 ]; then stats-layer name=tdnn2_stats config=mean+count(-99:3:9:99) relu-renorm-layer name=tdnn2 input=Append(tdnn1@-6, tdnn1, tdnn2_stats) dim=256 relu-renorm-layer name=tdnn3 input=Append(-9,0,3) dim=256 - relu-renorm-layer name=tdnn4 dim=256 - output-layer name=output-speech include-log-softmax=true dim=2 - output-layer name=output-snr include-log-softmax=false dim=$num_snr_bins objective=quadratic - output-layer name=output-overlapped_speech include-log-softmax=true dim=2 + relu-renorm-layer name=pre-final-speech dim=256 input=tdnn3 + output-layer name=output-speech include-log-softmax=true dim=2 objective-scale=`perl -e 'print (1.0/6)'` + + relu-renorm-layer name=pre-final-snr dim=256 input=tdnn3 + output-layer name=output-snr include-log-softmax=false dim=$num_snr_bins objective-type=quadratic objective-scale=`perl -e "print 1.0/$num_snr_bins"` + + relu-renorm-layer name=pre-final-overlapped_speech dim=256 input=tdnn3 + output-layer name=output-overlapped_speech include-log-softmax=true dim=2 EOF steps/nnet3/xconfig_to_configs.py --xconfig-file $dir/configs/network.xconfig \ --config-dir $dir/configs/ \ @@ -145,7 +148,7 @@ if [ -z "$egs_dir" ]; then --num-utts-subset-valid=$num_utts_subset_valid \ --samples-per-iter=20000 \ --stage=$get_egs_stage \ - --targets-parameters="--output-name=output-snr --target-type=dense --targets-scp=$snr_scp --deriv-weights-scp=$deriv_weights_scp" \ + --targets-parameters="--output-name=output-snr --target-type=dense --targets-scp=$snr_scp --deriv-weights-scp=$deriv_weights_for_irm_scp" \ --targets-parameters="--output-name=output-speech --target-type=sparse --dim=2 --targets-scp=$speech_feat_scp --deriv-weights-scp=$deriv_weights_scp --scp2ark-cmd=\"extract-column --column-index=0 scp:- ark,t:- | steps/segmentation/quantize_vector.pl | ali-to-post ark,t:- ark:- |\"" \ --targets-parameters="--output-name=output-overlapped_speech --target-type=sparse --dim=2 --targets-scp=$overlapped_speech_labels_scp --deriv-weights-scp=$deriv_weights_for_overlapped_speech_scp --scp2ark-cmd=\"ali-to-post scp:- ark:- |\"" \ --dir=$dir/egs @@ -155,7 +158,7 @@ fi if [ $stage -le 5 ]; then steps/nnet3/train_raw_rnn.py --stage=$train_stage \ --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ - --egs.chunk-width=20 \ + --egs.chunk-width=$chunk_width \ --egs.dir="$egs_dir" --egs.stage=$get_egs_stage \ --egs.chunk-left-context=$extra_left_context \ --egs.chunk-right-context=$extra_right_context \ @@ -169,7 +172,7 @@ if [ $stage -le 5 ]; then --trainer.optimization.initial-effective-lrate=$initial_effective_lrate \ --trainer.optimization.final-effective-lrate=$final_effective_lrate \ --trainer.optimization.shrink-value=1.0 \ - --trainer.rnn.num-chunk-per-minibatch=64 \ + --trainer.rnn.num-chunk-per-minibatch=128 \ --trainer.deriv-truncate-margin=8 \ --trainer.max-param-change=$max_param_change \ --cmd="$decode_cmd" --nj 40 \ diff --git a/egs/aspire/s5/local/segmentation/tuning/train_stats_sad_overlap_1b.sh b/egs/aspire/s5/local/segmentation/tuning/train_stats_sad_overlap_1b.sh new file mode 100644 index 00000000000..888c25295d6 --- /dev/null +++ b/egs/aspire/s5/local/segmentation/tuning/train_stats_sad_overlap_1b.sh @@ -0,0 +1,239 @@ +#!/bin/bash + +# This is a script to train a time-delay neural network for overlapped speech activity detection +# using statistic pooling component for long-context information. + +set -o pipefail +set -e +set -u + +. cmd.sh + +# At this script level we don't support not running on GPU, as it would be painfully slow. +# If you want to run without GPU you'd have to call train_tdnn.sh with --gpu false, +# --num-threads 16 and --minibatch-size 128. + +stage=0 +train_stage=-10 +get_egs_stage=-10 +egs_opts= # Directly passed to get_egs_multiple_targets.py + +# TDNN options +relu_dim=256 +chunk_width=40 # We use chunk training for training TDNN +extra_left_context=100 # Maximum left context in egs apart from TDNN's left context +extra_right_context=20 # Maximum right context in egs apart from TDNN's right context + +# We randomly select an extra {left,right} context for each job between +# min_extra_*_context and extra_*_context so that the network can get used +# to different contexts used to compute statistics. +min_extra_left_context=20 +min_extra_right_context=0 + +# training options +num_epochs=1 +initial_effective_lrate=0.0003 +final_effective_lrate=0.00003 +num_jobs_initial=3 +num_jobs_final=8 +remove_egs=false +max_param_change=0.2 # Small max-param change for small network +extra_egs_copy_cmd= # Used if you want to do some weird stuff to egs + # such as removing one of the targets + +sad_data_dir=data/train_aztec_unsad_whole_corrupted_sp_hires_bp_2400 +ovlp_data_dir=data/train_aztec_unsad_seg_ovlp_corrupted_hires_bp + +#extra_left_context=79 +#extra_right_context=11 + +egs_dir= +nj=40 +feat_type=raw +config_dir= + +dir= +affix=b + +. cmd.sh +. ./path.sh +. ./utils/parse_options.sh + +num_utts=`cat $sad_data_dir/utt2spk $ovlp_data_dir/utt2spk | wc -l` +num_utts_subset_valid=`perl -e '$n=int($ARGV[0] * 0.005); print ($n > 4000 ? 4000 : $n)' $num_utts` +num_utts_subset_train=`perl -e '$n=int($ARGV[0] * 0.005); print ($n > 4000 ? 4000 : $n)' $num_utts` + +if [ -z "$dir" ]; then + dir=exp/nnet3_stats_sad_ovlp_snr/nnet_tdnn +fi + +dir=$dir${affix:+_$affix} + +if ! cuda-compiled; then + cat < $dir/configs/network.xconfig + input dim=`feat-to-dim scp:$sad_data_dir/feats.scp -` name=input + output name=output-temp input=Append(-3,-2,-1,0,1,2,3) + + relu-renorm-layer name=tdnn1 input=Append(input@-3, input@-2, input@-1, input, input@1, input@2, input@3) dim=256 + stats-layer name=tdnn2_stats config=mean+count(-99:3:9:99) + relu-renorm-layer name=tdnn2 input=Append(tdnn1@-6, tdnn1, tdnn2_stats) dim=256 + relu-renorm-layer name=tdnn3 input=Append(-9,0,3) dim=256 + + relu-renorm-layer name=pre-final-speech dim=256 input=tdnn3 + output-layer name=output-speech include-log-softmax=true dim=2 objective-scale=`perl -e "print ($num_frames_ovlp / $num_frames_sad) ** 0.25"` + + relu-renorm-layer name=pre-final-snr dim=256 input=tdnn3 + output-layer name=output-snr include-log-softmax=false dim=$num_snr_bins objective-type=quadratic objective-scale=`perl -e "print (($num_frames_ovlp / $num_frames_sad) ** 0.25) / $num_snr_bins"` + + relu-renorm-layer name=pre-final-overlapped_speech dim=256 input=tdnn3 + output-layer name=output-overlapped_speech include-log-softmax=true dim=2 +EOF + steps/nnet3/xconfig_to_configs.py --xconfig-file $dir/configs/network.xconfig \ + --config-dir $dir/configs/ \ + --nnet-edits="rename-node old-name=output-speech new-name=output" + + cat <> $dir/configs/vars +add_lda=false +EOF +fi + +samples_per_iter=`perl -e "print int(400000 / $chunk_width)"` + +if [ -z "$egs_dir" ]; then + egs_dir=$dir/egs_multi + if [ $stage -le 2 ]; then + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $dir/egs_speech/storage ]; then + utils/create_split_dir.pl \ + /export/b{03,04,05,06}/$USER/kaldi-data/egs_speech/aspire-$(date +'%m_%d_%H_%M')/s5/$dir/egs_speech/storage $dir/egs_speech/storage + fi + + . $dir/configs/vars + + steps/nnet3/get_egs_multiple_targets.py --cmd="$decode_cmd" \ + $egs_opts \ + --feat.dir="$sad_data_dir" \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --frames-per-eg=$chunk_width \ + --left-context=$[model_left_context + extra_left_context] \ + --right-context=$[model_right_context + extra_right_context] \ + --num-utts-subset-train=$num_utts_subset_train \ + --num-utts-subset-valid=$num_utts_subset_valid \ + --samples-per-iter=$samples_per_iter \ + --stage=$get_egs_stage \ + --targets-parameters="--output-name=output-snr --target-type=dense --targets-scp=$sad_data_dir/irm_targets.scp --deriv-weights-scp=$sad_data_dir/deriv_weights_manual_seg.scp" \ + --targets-parameters="--output-name=output-speech --target-type=sparse --dim=2 --targets-scp=$sad_data_dir/speech_feat.scp --deriv-weights-scp=$sad_data_dir/deriv_weights.scp --scp2ark-cmd=\"extract-column --column-index=0 scp:- ark,t:- | steps/segmentation/quantize_vector.pl | ali-to-post ark,t:- ark:- |\" --compress=true" \ + --generate-egs-scp=true \ + --dir=$dir/egs_speech + fi + + if [ $stage -le 3 ]; then + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $dir/egs_ovlp/storage ]; then + utils/create_split_dir.pl \ + /export/b{03,04,05,06}/$USER/kaldi-data/egs/aspire-$(date +'%m_%d_%H_%M')/s5/$dir/egs_music/storage $dir/egs_music/storage + fi + + . $dir/configs/vars + + steps/nnet3/get_egs_multiple_targets.py --cmd="$decode_cmd" \ + $egs_opts \ + --feat.dir="$ovlp_data_dir" \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --frames-per-eg=$chunk_width \ + --left-context=$[model_left_context + extra_left_context] \ + --right-context=$[model_right_context + extra_right_context] \ + --num-utts-subset-train=$num_utts_subset_train \ + --num-utts-subset-valid=$num_utts_subset_valid \ + --samples-per-iter=$samples_per_iter \ + --stage=$get_egs_stage \ + --targets-parameters="--output-name=output-speech --target-type=sparse --dim=2 --targets-scp=$ovlp_data_dir/speech_feat.scp --deriv-weights-scp=$ovlp_data_dir/deriv_weights.scp --scp2ark-cmd=\"extract-column --column-index=0 scp:- ark,t:- | steps/segmentation/quantize_vector.pl | ali-to-post ark,t:- ark:- |\"" \ + --targets-parameters="--output-name=output-overlapped_speech --target-type=sparse --dim=2 --targets-scp=$ovlp_data_dir/overlapped_speech_labels.scp --deriv-weights-scp=$ovlp_data_dir/deriv_weights_for_overlapped_speech.scp --scp2ark-cmd=\"ali-to-post scp:- ark:- |\"" \ + --generate-egs-scp=true \ + --dir=$dir/egs_ovlp + fi + + if [ $stage -le 4 ]; then + # num_chunk_per_minibatch is multiplied by 4 to allow a buffer to use + # the same egs with a different num_chunk_per_minibatch + steps/nnet3/multilingual/get_egs.sh \ + --minibatch-size $[chunk_width * num_chunk_per_minibatch * 4] \ + --samples-per-iter $samples_per_iter \ + 2 $dir/egs_speech $dir/egs_ovlp $dir/egs_multi + fi +fi + +if [ $stage -le 5 ]; then + steps/nnet3/train_raw_rnn.py --stage=$train_stage \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --egs.chunk-width=$chunk_width \ + --egs.dir="$egs_dir" --egs.stage=$get_egs_stage \ + --egs.chunk-left-context=$extra_left_context \ + --egs.chunk-right-context=$extra_right_context \ + --egs.use-multitask-egs=true --egs.rename-multitask-outputs=false \ + ${extra_egs_copy_cmd:+--egs.extra-copy-cmd="$extra_egs_copy_cmd"} \ + --trainer.min-chunk-left-context=$min_extra_left_context \ + --trainer.min-chunk-right-context=$min_extra_right_context \ + --trainer.num-epochs=$num_epochs \ + --trainer.samples-per-iter=20000 \ + --trainer.optimization.num-jobs-initial=$num_jobs_initial \ + --trainer.optimization.num-jobs-final=$num_jobs_final \ + --trainer.optimization.initial-effective-lrate=$initial_effective_lrate \ + --trainer.optimization.final-effective-lrate=$final_effective_lrate \ + --trainer.optimization.shrink-value=1.0 \ + --trainer.rnn.num-chunk-per-minibatch=$num_chunk_per_minibatch \ + --trainer.deriv-truncate-margin=8 \ + --trainer.max-param-change=$max_param_change \ + --cmd="$decode_cmd" --nj 40 \ + --cleanup=true \ + --cleanup.remove-egs=$remove_egs \ + --cleanup.preserve-model-interval=10 \ + --use-gpu=true \ + --use-dense-targets=false \ + --feat-dir=$sad_data_dir \ + --targets-scp="$speech_feat_scp" \ + --dir=$dir || exit 1 +fi + From 6c9efb6f54aaae4fb4d0274d5c0f3c0e6b4d011f Mon Sep 17 00:00:00 2001 From: Vimal Manohar Date: Mon, 19 Dec 2016 20:04:16 -0500 Subject: [PATCH 116/213] asr_diarization: Add iter option to run_segmentation_ami --- egs/aspire/s5/local/segmentation/run_segmentation_ami.sh | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/egs/aspire/s5/local/segmentation/run_segmentation_ami.sh b/egs/aspire/s5/local/segmentation/run_segmentation_ami.sh index f9374aaf55a..4b98eec9f43 100755 --- a/egs/aspire/s5/local/segmentation/run_segmentation_ami.sh +++ b/egs/aspire/s5/local/segmentation/run_segmentation_ami.sh @@ -14,6 +14,7 @@ stage=-1 nnet_dir=exp/nnet3_sad_snr/nnet_tdnn_k_n4 extra_left_context=100 extra_right_context=20 +iter=final . utils/parse_options.sh @@ -107,7 +108,7 @@ if [ $stage -le 7 ]; then steps/segmentation/do_segmentation_data_dir.sh --reco-nj 18 \ --mfcc-config conf/mfcc_hires_bp.conf --feat-affix bp --do-downsampling true \ --extra-left-context $extra_left_context --extra-right-context $extra_right_context \ - --output-name output-speech --frame-subsampling-factor 6 \ + --output-name output-speech --frame-subsampling-factor 6 --iter $iter \ $src_dir/data/sdm1/dev $nnet_dir mfcc_hires_bp $hyp_dir fi From 7de8a831d1dc397e4c5ec6e299fe29d140bc2b27 Mon Sep 17 00:00:00 2001 From: Vimal Manohar Date: Mon, 19 Dec 2016 20:04:30 -0500 Subject: [PATCH 117/213] asr_diarization: Add iter to aspire segmentation --- egs/aspire/s5/local/nnet3/prep_test_aspire_segmentation.sh | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/egs/aspire/s5/local/nnet3/prep_test_aspire_segmentation.sh b/egs/aspire/s5/local/nnet3/prep_test_aspire_segmentation.sh index 5f38f6de51f..e7f70c0c07f 100755 --- a/egs/aspire/s5/local/nnet3/prep_test_aspire_segmentation.sh +++ b/egs/aspire/s5/local/nnet3/prep_test_aspire_segmentation.sh @@ -16,6 +16,8 @@ decode_num_jobs=30 num_jobs=30 affix= +sad_iter=final + # ivector opts max_count=75 # parameter for extract_ivectors.sh sub_speaker_frames=6000 @@ -73,7 +75,7 @@ fi if [ $stage -le 1 ]; then steps/segmentation/do_segmentation_data_dir.sh --reco-nj $num_jobs \ - --mfcc-config conf/mfcc_hires_bp.conf --feat-affix bp \ + --mfcc-config conf/mfcc_hires_bp.conf --feat-affix bp --iter $sad_iter \ --do-downsampling false --extra-left-context 100 --extra-right-context 20 \ --output-name output-speech --frame-subsampling-factor 6 \ data/${data_set} $sad_nnet_dir mfcc_hires_bp data/${data_set} From 885d17ec728ba891600fb87799088126b2b3110b Mon Sep 17 00:00:00 2001 From: Vimal Manohar Date: Mon, 19 Dec 2016 20:05:18 -0500 Subject: [PATCH 118/213] asr_diarization: Optional resolve_ctm overlaps in multicondition get_ctm aspire --- egs/aspire/s5/local/multi_condition/get_ctm.sh | 2 ++ 1 file changed, 2 insertions(+) diff --git a/egs/aspire/s5/local/multi_condition/get_ctm.sh b/egs/aspire/s5/local/multi_condition/get_ctm.sh index 6fc87fec7b0..67c2c0bd87b 100755 --- a/egs/aspire/s5/local/multi_condition/get_ctm.sh +++ b/egs/aspire/s5/local/multi_condition/get_ctm.sh @@ -64,6 +64,8 @@ lattice-to-ctm-conf $frame_shift_opt --decode-mbr=$decode_mbr ark:- $decode_dir/ # combine the segment-wise ctm files, while resolving overlaps if $resolve_overlaps; then steps/resolve_ctm_overlaps.py $data_dir/segments $decode_dir/score_$LMWT/penalty_$wip/ctm.overlapping $decode_dir/score_$LMWT/penalty_$wip/ctm.merged || exit 1; +else + cp $decode_dir/score_$LMWT/penalty_$wip/ctm.overlapping $decode_dir/score_$LMWT/penalty_$wip/ctm.merged || exit 1; fi merged_ctm=$decode_dir/score_$LMWT/penalty_$wip/ctm.merged From 58d62ab0e9e136c7bdf0a2500447c52f9e73edd9 Mon Sep 17 00:00:00 2001 From: Vimal Manohar Date: Mon, 19 Dec 2016 20:05:43 -0500 Subject: [PATCH 119/213] asr_diarization: Adding tuning scripts for music and SAD --- .../tuning/train_stats_sad_music_1d.sh | 184 ++++++++++++++ .../tuning/train_stats_sad_music_1e.sh | 229 ++++++++++++++++++ 2 files changed, 413 insertions(+) create mode 100644 egs/aspire/s5/local/segmentation/tuning/train_stats_sad_music_1d.sh create mode 100644 egs/aspire/s5/local/segmentation/tuning/train_stats_sad_music_1e.sh diff --git a/egs/aspire/s5/local/segmentation/tuning/train_stats_sad_music_1d.sh b/egs/aspire/s5/local/segmentation/tuning/train_stats_sad_music_1d.sh new file mode 100644 index 00000000000..a013fcc49a7 --- /dev/null +++ b/egs/aspire/s5/local/segmentation/tuning/train_stats_sad_music_1d.sh @@ -0,0 +1,184 @@ +#!/bin/bash + +# This is a script to train a time-delay neural network for speech activity detection (SAD) and +# music-id using statistic pooling component for long-context information. + +set -o pipefail +set -e +set -u + +. cmd.sh + +# At this script level we don't support not running on GPU, as it would be painfully slow. +# If you want to run without GPU you'd have to call train_tdnn.sh with --gpu false, +# --num-threads 16 and --minibatch-size 128. + +stage=0 +train_stage=-10 +get_egs_stage=-10 +egs_opts= # Directly passed to get_egs_multiple_targets.py + +# TDNN options +relu_dim=256 +chunk_width=20 # We use chunk training for training TDNN +extra_left_context=100 # Maximum left context in egs apart from TDNN's left context +extra_right_context=20 # Maximum right context in egs apart from TDNN's right context + +# We randomly select an extra {left,right} context for each job between +# min_extra_*_context and extra_*_context so that the network can get used +# to different contexts used to compute statistics. +min_extra_left_context=20 +min_extra_right_context=0 + +# training options +num_epochs=2 +initial_effective_lrate=0.0003 +final_effective_lrate=0.00003 +num_jobs_initial=3 +num_jobs_final=8 +remove_egs=false +max_param_change=0.2 # Small max-param change for small network +extra_egs_copy_cmd= # Used if you want to do some weird stuff to egs + # such as removing one of the targets + +num_utts_subset_valid=50 # "utts" is actually recording. So this is prettly small. +num_utts_subset_train=50 + +# target options +train_data_dir=data/train_azteec_whole_sp_corrupted_hires + +speech_feat_scp= +music_labels_scp= + +deriv_weights_scp= + +egs_dir= +nj=40 +feat_type=raw +config_dir= + +dir= +affix=a + +. cmd.sh +. ./path.sh +. ./utils/parse_options.sh + +if [ -z "$dir" ]; then + dir=exp/nnet3_stats_sad_music/nnet_tdnn +fi + +dir=$dir${affix:+_$affix} + +if ! cuda-compiled; then + cat < $dir/configs/network.xconfig + input dim=`feat-to-dim scp:$train_data_dir/feats.scp -` name=input + output name=output-temp input=Append(-3,-2,-1,0,1,2,3) + + relu-renorm-layer name=tdnn1 input=Append(input@-3, input@-2, input@-1, input, input@1, input@2, input@3) dim=256 + stats-layer name=tdnn2_stats config=mean+count(-99:3:9:99) + relu-renorm-layer name=tdnn2 input=Append(tdnn1@-9, tdnn1@-3, tdnn1, tdnn1@3, tdnn2_stats) dim=256 + stats-layer name=tdnn3_stats config=mean+count(-108:9:27:108) + relu-renorm-layer name=tdnn3 input=Append(tdnn2@-27, tdnn2@-9, tdnn2, tdnn2@9, tdnn3_stats) dim=256 + + output-layer name=output-speech include-log-softmax=true dim=2 input=tdnn3 + output-layer name=output-music include-log-softmax=true dim=2 input=tdnn3 +EOF + steps/nnet3/xconfig_to_configs.py --xconfig-file $dir/configs/network.xconfig \ + --config-dir $dir/configs/ \ + --nnet-edits="rename-node old-name=output-speech new-name=output" + + cat <> $dir/configs/vars +add_lda=false +EOF +fi + +if [ -z "$egs_dir" ]; then + egs_dir=$dir/egs + if [ $stage -le 4 ]; then + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $dir/egs/storage ]; then + utils/create_split_dir.pl \ + /export/b{03,04,05,06}/$USER/kaldi-data/egs/aspire-$(date +'%m_%d_%H_%M')/s5/$dir/egs/storage $dir/egs/storage + fi + + . $dir/configs/vars + + steps/nnet3/get_egs_multiple_targets.py --cmd="$decode_cmd" \ + $egs_opts \ + --feat.dir="$train_data_dir" \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --frames-per-eg=$chunk_width \ + --left-context=$[model_left_context + extra_left_context] \ + --right-context=$[model_right_context + extra_right_context] \ + --num-utts-subset-train=$num_utts_subset_train \ + --num-utts-subset-valid=$num_utts_subset_valid \ + --samples-per-iter=20000 \ + --stage=$get_egs_stage \ + --targets-parameters="--output-name=output-speech --target-type=sparse --dim=2 --targets-scp=$speech_feat_scp --deriv-weights-scp=$deriv_weights_scp --scp2ark-cmd=\"extract-column --column-index=0 scp:- ark,t:- | steps/segmentation/quantize_vector.pl | ali-to-post ark,t:- ark:- |\" --compress=true" \ + --targets-parameters="--output-name=output-music --target-type=sparse --dim=2 --targets-scp=$music_labels_scp --scp2ark-cmd=\"ali-to-post scp:- ark:- |\" --compress=true" \ + --dir=$dir/egs + fi +fi + +if [ $stage -le 5 ]; then + steps/nnet3/train_raw_rnn.py --stage=$train_stage \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --egs.chunk-width=20 \ + --egs.dir="$egs_dir" --egs.stage=$get_egs_stage \ + --egs.chunk-left-context=$extra_left_context \ + --egs.chunk-right-context=$extra_right_context \ + ${extra_egs_copy_cmd:+--egs.extra-copy-cmd="$extra_egs_copy_cmd"} \ + --trainer.min-chunk-left-context=$min_extra_left_context \ + --trainer.min-chunk-right-context=$min_extra_right_context \ + --trainer.num-epochs=$num_epochs \ + --trainer.samples-per-iter=20000 \ + --trainer.optimization.num-jobs-initial=$num_jobs_initial \ + --trainer.optimization.num-jobs-final=$num_jobs_final \ + --trainer.optimization.initial-effective-lrate=$initial_effective_lrate \ + --trainer.optimization.final-effective-lrate=$final_effective_lrate \ + --trainer.optimization.shrink-value=1.0 \ + --trainer.rnn.num-chunk-per-minibatch=64 \ + --trainer.deriv-truncate-margin=8 \ + --trainer.max-param-change=$max_param_change \ + --cmd="$decode_cmd" --nj 40 \ + --cleanup=true \ + --cleanup.remove-egs=$remove_egs \ + --cleanup.preserve-model-interval=10 \ + --use-gpu=true \ + --use-dense-targets=false \ + --feat-dir=$train_data_dir \ + --targets-scp="$speech_feat_scp" \ + --dir=$dir || exit 1 +fi + +if [ $stage -le 6 ]; then + $train_cmd JOB=1:100 $dir/log/compute_post_output-speech.JOB.log \ + extract-column "scp:utils/split_scp.pl -j 100 \$[JOB-1] $speech_feat_scp |" ark,t:- \| \ + steps/segmentation/quantize_vector.pl \| \ + ali-to-post ark,t:- ark:- \| \ + weight-post ark:- scp:$deriv_weights_scp ark:- \| \ + post-to-feats --post-dim=2 ark:- ark:- \| \ + matrix-sum-rows ark:- ark:- \| \ + vector-sum ark:- $dir/post_output-speech.vec.JOB + eval vector-sum $dir/post_output-speech.vec.{`seq -s, 100`} $dir/post_output-speech.vec + + $train_cmd JOB=1:100 $dir/log/compute_post_output-music.JOB.log \ + ali-to-post "scp:utils/split_scp.pl -j 100 \$[JOB-1] $music_labels_scp |" ark:- \| \ + post-to-feats --post-dim=2 ark:- ark:- \| \ + matrix-sum-rows ark:- ark:- \| \ + vector-sum ark:- $dir/post_output-music.vec.JOB + eval vector-sum $dir/post_output-music.vec.{`seq -s, 100`} $dir/post_output-music.vec +fi diff --git a/egs/aspire/s5/local/segmentation/tuning/train_stats_sad_music_1e.sh b/egs/aspire/s5/local/segmentation/tuning/train_stats_sad_music_1e.sh new file mode 100644 index 00000000000..703865b8ad5 --- /dev/null +++ b/egs/aspire/s5/local/segmentation/tuning/train_stats_sad_music_1e.sh @@ -0,0 +1,229 @@ +#!/bin/bash + +# This is a script to train a time-delay neural network for speech activity detection (SAD) and +# music-id using statistic pooling component for long-context information. +# This script is same as 1d, but add add-log-stddev to norm layers. + +set -o pipefail +set -e +set -u + +. cmd.sh + +# At this script level we don't support not running on GPU, as it would be painfully slow. +# If you want to run without GPU you'd have to call train_tdnn.sh with --gpu false, +# --num-threads 16 and --minibatch-size 128. + +stage=0 +train_stage=-10 +get_egs_stage=-10 +egs_opts= # Directly passed to get_egs_multiple_targets.py + +# TDNN options +relu_dim=256 +chunk_width=20 # We use chunk training for training TDNN +num_chunk_per_minibatch=64 + +extra_left_context=79 # Maximum left context in egs apart from TDNN's left context +extra_right_context=11 # Maximum right context in egs apart from TDNN's right context + +# We randomly select an extra {left,right} context for each job between +# min_extra_*_context and extra_*_context so that the network can get used +# to different contexts used to compute statistics. +min_extra_left_context=79 +min_extra_right_context=11 + +# training options +num_epochs=2 +initial_effective_lrate=0.0003 +final_effective_lrate=0.00003 +num_jobs_initial=3 +num_jobs_final=8 +remove_egs=false +max_param_change=0.2 # Small max-param change for small network +extra_egs_copy_cmd= # Used if you want to do some weird stuff to egs + # such as removing one of the targets + +num_utts_subset_valid=50 # "utts" is actually recording. So this is prettly small. +num_utts_subset_train=50 + +# target options +train_data_dir=data/train_aztec_small_unsad_whole_all_corrupted_sp_hires_bp + +speech_feat_scp=data/train_aztec_unsad_whole_corrupted_sp_hires_bp_2400/speech_feat.scp +deriv_weights_scp=data/train_aztec_unsad_whole_corrupted_sp_hires_bp_2400/deriv_weights.scp +music_labels_scp=data/train_aztec_small_unsad_whole_music_corrupted_sp_hires_bp/music_labels.scp + +extra_egs_copy_cmd="nnet3-copy-egs --keep-outputs=output-speech,output-music ark:- ark:- |" + +egs_dir= +nj=40 +feat_type=raw +config_dir= + +dir= +affix=a + +. cmd.sh +. ./path.sh +. ./utils/parse_options.sh + +if [ -z "$dir" ]; then + dir=exp/nnet3_stats_sad_music/nnet_tdnn +fi + +dir=$dir${affix:+_$affix} + +if ! cuda-compiled; then + cat < $dir/configs/network.xconfig + input dim=`feat-to-dim scp:$train_data_dir/feats.scp -` name=input + output name=output-temp input=Append(-3,-2,-1,0,1,2,3) + + relu-renorm-layer name=tdnn1 input=Append(input@-3, input@-2, input@-1, input, input@1, input@2, input@3) dim=256 add-log-stddev=true + stats-layer name=tdnn2_stats config=mean+count(-99:3:9:99) + relu-renorm-layer name=tdnn2 input=Append(tdnn1@-9, tdnn1@-3, tdnn1, tdnn1@3, tdnn2_stats) dim=256 add-log-stddev=true + stats-layer name=tdnn3_stats config=mean+count(-108:9:27:108) + relu-renorm-layer name=tdnn3 input=Append(tdnn2@-27, tdnn2@-9, tdnn2, tdnn2@9, tdnn3_stats) dim=256 add-log-stddev=true + + output-layer name=output-speech include-log-softmax=true dim=2 input=tdnn3 + output-layer name=output-music include-log-softmax=true dim=2 input=tdnn3 +EOF + steps/nnet3/xconfig_to_configs.py --xconfig-file $dir/configs/network.xconfig \ + --config-dir $dir/configs/ \ + --nnet-edits="rename-node old-name=output-speech new-name=output" + + cat <> $dir/configs/vars +add_lda=false +EOF +fi + +samples_per_iter=`perl -e "print int(400000 / $chunk_width)"` +speech_data_dir=$dir/`basename $train_data_dir`_speech +music_data_dir=$dir/`basename $train_data_dir`_music + +if [ -z "$egs_dir" ]; then + egs_dir=$dir/egs_multi + if [ $stage -le 2 ]; then + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $dir/egs_speech/storage ]; then + utils/create_split_dir.pl \ + /export/b{03,04,05,06}/$USER/kaldi-data/egs/aspire-$(date +'%m_%d_%H_%M')/s5/$dir/egs_speech/storage $dir/egs_speech/storage + fi + + + . $dir/configs/vars + + utils/subset_data_dir.sh --utt-list $speech_feat_scp ${train_data_dir} $dir/`basename ${train_data_dir}`_speech + + steps/nnet3/get_egs_multiple_targets.py --cmd="$decode_cmd" \ + $egs_opts \ + --feat.dir="$speech_data_dir" \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --frames-per-eg=$chunk_width \ + --left-context=$[model_left_context + extra_left_context] \ + --right-context=$[model_right_context + extra_right_context] \ + --num-utts-subset-train=$num_utts_subset_train \ + --num-utts-subset-valid=$num_utts_subset_valid \ + --samples-per-iter=$samples_per_iter \ + --stage=$get_egs_stage \ + --targets-parameters="--output-name=output-speech --target-type=sparse --dim=2 --targets-scp=$speech_feat_scp --deriv-weights-scp=$deriv_weights_scp --scp2ark-cmd=\"extract-column --column-index=0 scp:- ark,t:- | steps/segmentation/quantize_vector.pl | ali-to-post ark,t:- ark:- |\" --compress=true" \ + --generate-egs-scp=true \ + --dir=$dir/egs_speech + fi + + if [ $stage -le 3 ]; then + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $dir/egs_music/storage ]; then + utils/create_split_dir.pl \ + /export/b{03,04,05,06}/$USER/kaldi-data/egs/aspire-$(date +'%m_%d_%H_%M')/s5/$dir/egs_music/storage $dir/egs_music/storage + fi + + . $dir/configs/vars + utils/subset_data_dir.sh --utt-list $music_labels_scp ${train_data_dir} $dir/`basename ${train_data_dir}`_music + + steps/nnet3/get_egs_multiple_targets.py --cmd="$decode_cmd" \ + $egs_opts \ + --feat.dir="$music_data_dir" \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --frames-per-eg=$chunk_width \ + --left-context=$[model_left_context + extra_left_context] \ + --right-context=$[model_right_context + extra_right_context] \ + --num-utts-subset-train=$num_utts_subset_train \ + --num-utts-subset-valid=$num_utts_subset_valid \ + --samples-per-iter=$samples_per_iter \ + --stage=$get_egs_stage \ + --targets-parameters="--output-name=output-music --target-type=sparse --dim=2 --targets-scp=$music_labels_scp --scp2ark-cmd=\"ali-to-post scp:- ark:- |\" --compress=true" \ + --generate-egs-scp=true \ + --dir=$dir/egs_music + fi + + if [ $stage -le 4 ]; then + steps/nnet3/multilingual/get_egs.sh \ + --minibatch-size $[chunk_width * num_chunk_per_minibatch] \ + --samples-per-iter $samples_per_iter \ + 2 $dir/egs_speech $dir/egs_music $dir/egs_multi + fi +fi + +if [ $stage -le 5 ]; then + steps/nnet3/train_raw_rnn.py --stage=$train_stage \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --egs.chunk-width=$chunk_width \ + --egs.dir="$egs_dir" --egs.stage=$get_egs_stage \ + --egs.chunk-left-context=$extra_left_context \ + --egs.chunk-right-context=$extra_right_context \ + --egs.use-multitask-egs=true --egs.rename-multitask-outputs=false \ + ${extra_egs_copy_cmd:+--egs.extra-copy-cmd="$extra_egs_copy_cmd"} \ + --trainer.min-chunk-left-context=$min_extra_left_context \ + --trainer.min-chunk-right-context=$min_extra_right_context \ + --trainer.num-epochs=$num_epochs \ + --trainer.samples-per-iter=20000 \ + --trainer.optimization.num-jobs-initial=$num_jobs_initial \ + --trainer.optimization.num-jobs-final=$num_jobs_final \ + --trainer.optimization.initial-effective-lrate=$initial_effective_lrate \ + --trainer.optimization.final-effective-lrate=$final_effective_lrate \ + --trainer.optimization.shrink-value=1.0 \ + --trainer.rnn.num-chunk-per-minibatch=$num_chunk_per_minibatch \ + --trainer.deriv-truncate-margin=8 \ + --trainer.max-param-change=$max_param_change \ + --cmd="$decode_cmd" --nj 40 \ + --cleanup=true \ + --cleanup.remove-egs=$remove_egs \ + --cleanup.preserve-model-interval=10 \ + --use-gpu=true \ + --use-dense-targets=false \ + --feat-dir=$train_data_dir \ + --targets-scp="$speech_feat_scp" \ + --dir=$dir || exit 1 +fi + +if [ $stage -le 6 ]; then + $train_cmd JOB=1:100 $dir/log/compute_post_output-speech.JOB.log \ + extract-column "scp:utils/split_scp.pl -j 100 \$[JOB-1] $speech_feat_scp |" ark,t:- \| \ + steps/segmentation/quantize_vector.pl \| \ + ali-to-post ark,t:- ark:- \| \ + weight-post ark:- scp:$deriv_weights_scp ark:- \| \ + post-to-feats --post-dim=2 ark:- ark:- \| \ + matrix-sum-rows ark:- ark:- \| \ + vector-sum ark:- $dir/post_output-speech.vec.JOB + eval vector-sum $dir/post_output-speech.vec.{`seq -s, 100`} $dir/post_output-speech.vec + + $train_cmd JOB=1:100 $dir/log/compute_post_output-music.JOB.log \ + ali-to-post "scp:utils/split_scp.pl -j 100 \$[JOB-1] $music_labels_scp |" ark:- \| \ + post-to-feats --post-dim=2 ark:- ark:- \| \ + matrix-sum-rows ark:- ark:- \| \ + vector-sum ark:- $dir/post_output-music.vec.JOB + eval vector-sum $dir/post_output-music.vec.{`seq -s, 100`} $dir/post_output-music.vec +fi + From 61d6f1ecd7befc6c0cfb36ad14588287583c81cb Mon Sep 17 00:00:00 2001 From: Vimal Manohar Date: Mon, 2 Jan 2017 18:30:53 -0500 Subject: [PATCH 120/213] segmentation: Modify segmentation codes --- src/segmenter/segment.h | 6 + src/segmenter/segmentation-utils.cc | 42 +++--- src/segmenter/segmentation-utils.h | 4 +- src/segmenter/segmentation.cc | 56 +++++--- src/segmenterbin/Makefile | 3 +- .../class-counts-per-frame-to-labels.cc | 115 +++++++++++++++ .../segmentation-combine-segments.cc | 17 ++- src/segmenterbin/segmentation-copy.cc | 54 +++++-- src/segmenterbin/segmentation-get-stats.cc | 34 +++-- ...ntation-init-from-additive-signals-info.cc | 48 +++---- .../segmentation-init-from-ali.cc | 8 +- .../segmentation-merge-recordings.cc | 3 +- src/segmenterbin/segmentation-to-rttm.cc | 133 +++++++++++------- 13 files changed, 371 insertions(+), 152 deletions(-) create mode 100644 src/segmenterbin/class-counts-per-frame-to-labels.cc diff --git a/src/segmenter/segment.h b/src/segmenter/segment.h index 1657affc875..b54b5367c73 100644 --- a/src/segmenter/segment.h +++ b/src/segmenter/segment.h @@ -48,6 +48,12 @@ struct Segment { static size_t SizeInBytes() { return (sizeof(int32) + sizeof(int32) + sizeof(int32)); } + + void Reset() { + start_frame = -1; + end_frame = -1; + class_id = -1; + } }; /** diff --git a/src/segmenter/segmentation-utils.cc b/src/segmenter/segmentation-utils.cc index 3adc178d66d..c69d7ff3397 100644 --- a/src/segmenter/segmentation-utils.cc +++ b/src/segmenter/segmentation-utils.cc @@ -46,14 +46,26 @@ void MergeLabels(const std::vector &merge_labels, void RelabelSegmentsUsingMap(const unordered_map &label_map, Segmentation *segmentation) { + int32 default_label = -1; + unordered_map::const_iterator it = label_map.find(-1); + if (it != label_map.end()) { + default_label = it->second; + KALDI_ASSERT(default_label != -1); + } + for (SegmentList::iterator it = segmentation->Begin(); it != segmentation->End(); ++it) { unordered_map::const_iterator map_it = label_map.find( it->Label()); - if (map_it == label_map.end()) - KALDI_ERR << "Could not find label " << it->Label() << " in label map."; - - it->SetLabel(map_it->second); + if (map_it == label_map.end()) { + if (default_label == -1) + KALDI_ERR << "Could not find label " << it->Label() + << " in label map."; + else + it->SetLabel(default_label); + } else { + it->SetLabel(map_it->second); + } } } @@ -294,7 +306,7 @@ void IntersectSegmentationAndAlignment(const Segmentation &in_segmentation, it != in_segmentation.End(); ++it) { Segmentation filter_segmentation; InsertFromAlignment(alignment, it->start_frame, - std::min(it->end_frame + 1, + std::min(it->end_frame + 1, static_cast(alignment.size())), 0, &filter_segmentation, NULL); @@ -444,7 +456,7 @@ void WidenSegments(int32 label, int32 length, Segmentation *segmentation) { // overlaps the current segment. So remove the current segment. it = segmentation->Erase(it); // So that we can increment in the for loop - --it; // TODO(Vimal): This is buggy. + --it; // TODO(Vimal): This is buggy. } else if (prev_it->end_frame >= it->start_frame) { // The extended previous segment in Line (1) reduces the length of // this segment. @@ -539,7 +551,7 @@ bool ConvertToAlignment(const Segmentation &segmentation, for (; it != segmentation.End(); ++it) { if (length != -1 && it->end_frame >= length + tolerance) { KALDI_WARN << "End frame (" << it->end_frame << ") " - << ">= length (" << length + << ">= length (" << length << ") + tolerance (" << tolerance << ")." << "Conversion failed."; return false; @@ -565,7 +577,7 @@ int32 InsertFromAlignment(const std::vector &alignment, int32 start, int32 end, int32 start_time_offset, Segmentation *segmentation, - std::vector *frame_counts_per_class) { + std::map *frame_counts_per_class) { KALDI_ASSERT(segmentation); if (end <= start) return 0; // nothing to insert @@ -593,12 +605,8 @@ int32 InsertFromAlignment(const std::vector &alignment, i-1 + start_time_offset, state); num_segments++; - if (frame_counts_per_class && state > 0) { - if (frame_counts_per_class->size() <= state) { - frame_counts_per_class->resize(state + 1, 0); - } + if (frame_counts_per_class) (*frame_counts_per_class)[state] += i - start_frame; - } } start_frame = i; state = alignment[i]; @@ -609,12 +617,8 @@ int32 InsertFromAlignment(const std::vector &alignment, segmentation->EmplaceBack(start_frame + start_time_offset, end-1 + start_time_offset, state); num_segments++; - if (frame_counts_per_class && state > 0) { - if (frame_counts_per_class->size() <= state) { - frame_counts_per_class->resize(state + 1, 0); - } + if (frame_counts_per_class) (*frame_counts_per_class)[state] += end - start_frame; - } #ifdef KALDI_PARANOID segmentation->Check(); @@ -637,7 +641,7 @@ int32 InsertFromSegmentation( for (SegmentList::const_iterator it = in_segmentation.Begin(); it != in_segmentation.End(); ++it) { out_segmentation->EmplaceBack(it->start_frame + start_time_offset, - it->end_frame + start_time_offset, + it->end_frame + start_time_offset, it->Label()); num_segments++; if (frame_counts_per_class) { diff --git a/src/segmenter/segmentation-utils.h b/src/segmenter/segmentation-utils.h index 9401722ccb7..30136ab0a5a 100644 --- a/src/segmenter/segmentation-utils.h +++ b/src/segmenter/segmentation-utils.h @@ -265,7 +265,7 @@ int32 InsertFromAlignment(const std::vector &alignment, int32 start, int32 end, int32 start_time_offset, Segmentation *segmentation, - std::vector *frame_counts_per_class = NULL); + std::map *frame_counts_per_class = NULL); /** * Insert segments from in_segmentation, but shift them by @@ -291,7 +291,7 @@ void ExtendSegmentation(const Segmentation &in_segmentation, bool sort, /** * This function is used to get per-frame count of number of classes. * The output is in the format of a vector of maps. - * class_counts_per_frame: A pointer to a vector of maps use to get the output. + * class_counts_per_frame: A pointer to a vector of maps used to get the output. * The size of the vector is the number of frames. * For each frame, there is a map from the "class_id" * to the number of segments where the label the diff --git a/src/segmenter/segmentation.cc b/src/segmenter/segmentation.cc index fb83ed5476b..01f8b0e8057 100644 --- a/src/segmenter/segmentation.cc +++ b/src/segmenter/segmentation.cc @@ -85,22 +85,36 @@ void Segmentation::Read(std::istream &is, bool binary) { } dim_ = segmentssz; } else { - if (int c = is.peek() != static_cast('[')) { - KALDI_ERR << "Segmentation::Read: expected to see [, saw " - << static_cast(c) << ", at file position " << is.tellg(); + Segment seg; + while (1) { + int i = is.peek(); + if (i == -1) { + KALDI_ERR << "Unexpected EOF"; + } else if (static_cast(i) == '\n') { + if (seg.start_frame != -1) { + KALDI_ERR << "No semicolon before newline (wrong format)"; + } else { + is.get(); + break; + } + } else if (std::isspace(i)) { + is.get(); + } else if (static_cast(i) == ';') { + if (seg.start_frame != -1) { + segments_.push_back(seg); + dim_++; + seg.Reset(); + } else { + is.get(); + KALDI_ASSERT(static_cast(is.peek()) == '\n'); + is.get(); + break; + } + is.get(); + } else { + seg.Read(is, false); + } } - is.get(); // consume the '[' - is >> std::ws; - while (is.peek() != static_cast(']')) { - KALDI_ASSERT(!is.eof()); - Segment seg; - seg.Read(is, binary); - segments_.push_back(seg); - dim_++; - is >> std::ws; - } - is.get(); - KALDI_ASSERT(!is.eof()); } #ifdef KALDI_PARANOID Check(); @@ -126,12 +140,14 @@ void Segmentation::Write(std::ostream &os, bool binary) const { it->Write(os, binary); } } else { - os << "[ "; + if (Dim() == 0) { + os << ";"; + } for (; it != End(); ++it) { it->Write(os, binary); - os << std::endl; + os << "; "; } - os << "]" << std::endl; + os << std::endl; } } @@ -175,8 +191,8 @@ void Segmentation::GenRandomSegmentation(int32 max_length, int32 st = 0; int32 end = 0; - while (st > max_length) { - int32 segment_length = RandInt(0, max_segment_length); + while (st < max_length) { + int32 segment_length = RandInt(1, max_segment_length); end = st + segment_length - 1; diff --git a/src/segmenterbin/Makefile b/src/segmenterbin/Makefile index 1f0efe71181..22a74e70551 100644 --- a/src/segmenterbin/Makefile +++ b/src/segmenterbin/Makefile @@ -16,7 +16,8 @@ BINFILES = segmentation-copy segmentation-get-stats \ segmentation-combine-segments-to-recordings \ segmentation-create-overlapped-subsegments \ segmentation-intersect-segments \ - segmentation-init-from-additive-signals-info #\ + segmentation-init-from-additive-signals-info \ + class-counts-per-frame-to-labels#\ gmm-acc-pdf-stats-segmentation \ gmm-est-segmentation gmm-update-segmentation \ segmentation-init-from-diarization \ diff --git a/src/segmenterbin/class-counts-per-frame-to-labels.cc b/src/segmenterbin/class-counts-per-frame-to-labels.cc new file mode 100644 index 00000000000..85676794e95 --- /dev/null +++ b/src/segmenterbin/class-counts-per-frame-to-labels.cc @@ -0,0 +1,115 @@ +// segmenterbin/class-counts-per-frame-to-labels.cc + +// Copyright 2016 Vimal Manohar + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "base/kaldi-common.h" +#include "util/common-utils.h" +#include "hmm/posterior.h" + + +int main(int argc, char *argv[]) { + try { + using namespace kaldi; + + const char *usage = + "Converts class-counts-per-frame in the format of vectors of vectors of " + "integers into labels for overlapping SAD.\n" + "If there is a junk-label in the classes in the frame, then the label " + "for the frame is set to the junk-label no matter what other labels " + "are present.\n" + "If there is only a 0 (silence) in the classes in the frame, then the " + "label for the frame is set to 0.\n" + "If there is only one non-zero non-junk class, then the label is set " + "to 1.\n" + "Otherwise, the label is set to 2 (overlapping speakers)\n" + "\n" + "Usage: class-counts-per-frame-to-labels [options] " + " \n"; + + int32 junk_label = -1; + ParseOptions po(usage); + + po.Register("junk-label", &junk_label, + "The label used for segments that are junk. If a frame has " + "a junk label, it will be considered junk segment, no matter " + "what other labels the frame contains. Also frames with no " + "classes seen are labeled junk."); + + po.Read(argc, argv); + + if (po.NumArgs() != 2) { + po.PrintUsage(); + exit(1); + } + + std::string in_fn = po.GetArg(1), + out_fn = po.GetArg(2); + + int num_done = 0; + Int32VectorWriter writer(out_fn); + SequentialPosteriorReader reader(in_fn); + for (; !reader.Done(); reader.Next(), num_done++) { + const Posterior &class_counts_per_frame = reader.Value(); + std::vector labels(class_counts_per_frame.size(), junk_label); + + for (size_t i = 0; i < class_counts_per_frame.size(); i++) { + const std::vector > &class_counts = + class_counts_per_frame[i]; + + if (class_counts.size() == 0) { + labels[i] = junk_label; + } else { + bool silence_found = false; + std::vector >::const_iterator it = + class_counts.begin(); + int32 class_counts_in_frame = 0; + for (; it != class_counts.end(); ++it) { + KALDI_ASSERT(it->second > 0); + if (it->first == 0) { + silence_found = true; + } else { + class_counts_in_frame += static_cast(it->second); + if (it->first == junk_label) { + labels[i] = junk_label; + break; + } + } + } + + if (class_counts_in_frame == 0) { + KALDI_ASSERT(silence_found); + labels[i] = 0; + } else if (class_counts_in_frame == 1) { + labels[i] = 1; + } else { + labels[i] = 2; + } + } + } + writer.Write(reader.Key(), labels); + } + KALDI_LOG << "Copied " << num_done << " items."; + return (num_done != 0 ? 0 : 1); + } catch(const std::exception &e) { + std::cerr << e.what(); + return -1; + } +} + + + diff --git a/src/segmenterbin/segmentation-combine-segments.cc b/src/segmenterbin/segmentation-combine-segments.cc index 7034a8a1734..09b789a0921 100644 --- a/src/segmenterbin/segmentation-combine-segments.cc +++ b/src/segmenterbin/segmentation-combine-segments.cc @@ -44,8 +44,16 @@ int main(int argc, char *argv[]) { "segmentation-merge, segmentatin-merge-recordings, " "segmentation-post-process --merge-adjacent-segments\n"; + bool include_missing = false; + ParseOptions po(usage); + po.Register("include-missing-utt-level-segmentations", &include_missing, + "If true, then the segmentations missing in " + "utt-level-segmentation-rspecifier is included in the " + "final output with the label taken from the " + "kaldi-segments-segmentation-rspecifier"); + po.Read(argc, argv); if (po.NumArgs() != 4) { @@ -96,12 +104,17 @@ int main(int argc, char *argv[]) { if (!utt_segmentation_reader.HasKey(*it)) { KALDI_WARN << "Could not find utterance " << *it << " in " << "segmentation " << utt_segmentation_rspecifier; - num_err++; + if (!include_missing) { + num_err++; + } else { + out_segmentation.PushBack(segment); + num_segments++; + } continue; } + const Segmentation &utt_segmentation = utt_segmentation_reader.Value(*it); - num_segments += InsertFromSegmentation(utt_segmentation, segment.start_frame, false, &out_segmentation, NULL); diff --git a/src/segmenterbin/segmentation-copy.cc b/src/segmenterbin/segmentation-copy.cc index 26d0f47682d..e3384170805 100644 --- a/src/segmenterbin/segmentation-copy.cc +++ b/src/segmenterbin/segmentation-copy.cc @@ -31,9 +31,8 @@ int main(int argc, char *argv[]) { "Copy segmentation or archives of segmentation.\n" "If label-map is supplied, then apply the mapping to the labels \n" "when copying.\n" - "If utt2label-rspecifier is supplied, then ignore the \n" - "original labels, and map all the segments of an utterance using \n" - "the supplied utt2label map.\n" + "If utt2label-map-rspecifier is supplied, then an utterance-specific " + "mapping is applied on the original labels\n" "\n" "Usage: segmentation-copy [options] " "\n" @@ -44,7 +43,7 @@ int main(int argc, char *argv[]) { " e.g.: segmentation-copy --binary=false foo -\n"; bool binary = true; - std::string label_map_rxfilename, utt2label_rspecifier; + std::string label_map_rxfilename, utt2label_map_rspecifier; std::string include_rxfilename, exclude_rxfilename; int32 keep_label = -1; BaseFloat frame_subsampling_factor = 1; @@ -58,8 +57,13 @@ int main(int argc, char *argv[]) { "File with mapping from old to new labels"); po.Register("frame-subsampling-factor", &frame_subsampling_factor, "Change frame rate by this factor"); - po.Register("utt2label-rspecifier", &utt2label_rspecifier, - "Mapping for each utterance to an integer label"); + po.Register("utt2label-map-rspecifier", &utt2label_map_rspecifier, + "Utterance-specific mapping from old to new labels. " + "The first column is the utterance id. The next columns are " + "pairs :. If is -1, then " + "that represents the default label map. i.e. Any old label " + "for which the mapping is not defined, will be mapped to the " + "label corresponding to old-label -1."); po.Register("keep-label", &keep_label, "If supplied, only segments of this label are written out"); po.Register("include", &include_rxfilename, @@ -162,8 +166,8 @@ int main(int argc, char *argv[]) { ScaleFrameShift(frame_subsampling_factor, &segmentation); } - if (!utt2label_rspecifier.empty()) - KALDI_ERR << "It makes no sense to specify utt2label-rspecifier " + if (!utt2label_map_rspecifier.empty()) + KALDI_ERR << "It makes no sense to specify utt2label-map-rspecifier " << "when not reading segmentation archives."; Output ko(segmentation_out_fn, binary); @@ -172,7 +176,8 @@ int main(int argc, char *argv[]) { KALDI_LOG << "Copied segmentation to " << segmentation_out_fn; return 0; } else { - RandomAccessInt32Reader utt2label_reader(utt2label_rspecifier); + RandomAccessTokenVectorReader utt2label_map_reader( + utt2label_map_rspecifier); SegmentationWriter writer(segmentation_out_fn); SequentialSegmentationReader reader(segmentation_in_fn); @@ -190,24 +195,43 @@ int main(int argc, char *argv[]) { if (label_map_rxfilename.empty() && frame_subsampling_factor == 1.0 && - utt2label_rspecifier.empty() && + utt2label_map_rspecifier.empty() && keep_label == -1) { writer.Write(key, reader.Value()); } else { Segmentation segmentation = reader.Value(); if (!label_map_rxfilename.empty()) RelabelSegmentsUsingMap(label_map, &segmentation); - if (!utt2label_rspecifier.empty()) { - if (!utt2label_reader.HasKey(key)) { + + if (!utt2label_map_rspecifier.empty()) { + if (!utt2label_map_reader.HasKey(key)) { KALDI_WARN << "Utterance " << key - << " not found in utt2label map " - << utt2label_rspecifier; + << " not found in utt2label_map " + << utt2label_map_rspecifier; num_err++; continue; } - RelabelAllSegments(utt2label_reader.Value(key), &segmentation); + unordered_map utt_label_map; + + const std::vector &utt_label_map_vec = + utt2label_map_reader.Value(key); + std::vector::const_iterator it = + utt_label_map_vec.begin(); + + for (; it != utt_label_map_vec.end(); ++it) { + std::vector vec; + SplitStringToFloats(*it, ":", false, &vec); + if (vec.size() != 2) { + KALDI_ERR << "Invalid utt-label-map " << *it; + } + utt_label_map[static_cast(vec[0])] = + static_cast(vec[1]); + } + + RelabelSegmentsUsingMap(utt_label_map, &segmentation); } + if (keep_label != -1) KeepSegments(keep_label, &segmentation); diff --git a/src/segmenterbin/segmentation-get-stats.cc b/src/segmenterbin/segmentation-get-stats.cc index b25d6913f06..1e39bafec44 100644 --- a/src/segmenterbin/segmentation-get-stats.cc +++ b/src/segmenterbin/segmentation-get-stats.cc @@ -17,7 +17,9 @@ // See the Apache 2 License for the specific language governing permissions and // limitations under the License. +#include #include "base/kaldi-common.h" +#include "hmm/posterior.h" #include "util/common-utils.h" #include "segmenter/segmentation-utils.h" @@ -33,9 +35,10 @@ int main(int argc, char *argv[]) { " num-classes: Number of distinct classes common to this frame\n" "\n" "Usage: segmentation-get-stats [options] " - " \n" + " " + "\n" " e.g.: segmentation-get-stats ark:1.seg ark:/dev/null " - "ark:num_classes.ark\n"; + "ark:num_classes.ark ark:/dev/null\n"; ParseOptions po(usage); @@ -51,20 +54,23 @@ int main(int argc, char *argv[]) { po.Read(argc, argv); - if (po.NumArgs() != 3) { + if (po.NumArgs() != 4) { po.PrintUsage(); exit(1); } std::string segmentation_rspecifier = po.GetArg(1), num_overlaps_wspecifier = po.GetArg(2), - num_classes_wspecifier = po.GetArg(3); + num_classes_wspecifier = po.GetArg(3), + class_counts_per_frame_wspecifier = po.GetArg(4); int64 num_done = 0, num_err = 0; SequentialSegmentationReader reader(segmentation_rspecifier); Int32VectorWriter num_overlaps_writer(num_overlaps_wspecifier); Int32VectorWriter num_classes_writer(num_classes_wspecifier); + PosteriorWriter class_counts_per_frame_writer( + class_counts_per_frame_wspecifier); RandomAccessInt32Reader lengths_reader(lengths_rspecifier); @@ -82,34 +88,42 @@ int main(int argc, char *argv[]) { length = lengths_reader.Value(key); } - std::vector > class_counts_per_frame; + std::vector > class_counts_map_per_frame; if (!GetClassCountsPerFrame(segmentation, length, length_tolerance, - &class_counts_per_frame)) { + &class_counts_map_per_frame)) { KALDI_WARN << "Failed getting stats for key " << key; num_err++; continue; } if (length == -1) - length = class_counts_per_frame.size(); + length = class_counts_map_per_frame.size(); std::vector num_classes_per_frame(length, 0); std::vector num_overlaps_per_frame(length, 0); + Posterior class_counts_per_frame(length, + std::vector >()); - for (int32 i = 0; i < class_counts_per_frame.size(); i++) { - std::map &class_counts = class_counts_per_frame[i]; + for (int32 i = 0; i < class_counts_map_per_frame.size(); i++) { + std::map &class_counts = class_counts_map_per_frame[i]; for (std::map::const_iterator it = class_counts.begin(); it != class_counts.end(); ++it) { - if (it->second > 0) + if (it->second > 0) { num_classes_per_frame[i]++; + class_counts_per_frame[i].push_back( + std::make_pair(it->first, it->second)); + } num_overlaps_per_frame[i] += it->second; } + std::sort(class_counts_per_frame[i].begin(), + class_counts_per_frame[i].end()); } num_classes_writer.Write(key, num_classes_per_frame); num_overlaps_writer.Write(key, num_overlaps_per_frame); + class_counts_per_frame_writer.Write(key, class_counts_per_frame); num_done++; } diff --git a/src/segmenterbin/segmentation-init-from-additive-signals-info.cc b/src/segmenterbin/segmentation-init-from-additive-signals-info.cc index 139048ac17b..ccddb4c2b60 100644 --- a/src/segmenterbin/segmentation-init-from-additive-signals-info.cc +++ b/src/segmenterbin/segmentation-init-from-additive-signals-info.cc @@ -35,9 +35,9 @@ int main(int argc, char *argv[]) { "ark:reco_segmentation.ark ark,t:overlapped_segments_info.txt ark:-\n"; BaseFloat frame_shift = 0.01; + int32 junk_label = -1; std::string lengths_rspecifier; std::string additive_signals_segmentation_rspecifier; - std::string unreliable_segmentation_wspecifier; ParseOptions po(usage); @@ -49,10 +49,9 @@ int main(int argc, char *argv[]) { &additive_signals_segmentation_rspecifier, "Archive of segmentation of the additive signal which will used " "instead of an all 1 segmentation"); - po.Register("unreliable-segmentation-wspecifier", - &unreliable_segmentation_wspecifier, - "Applicable when additive-signals-segmentation-rspecifier is " - "provided and some utterances in it are missing"); + po.Register("junk-label", &junk_label, + "If specified, then unreliable regions are labeled with this " + "label"); po.Read(argc, argv); @@ -70,7 +69,6 @@ int main(int argc, char *argv[]) { SegmentationWriter writer(segmentation_wspecifier); RandomAccessSegmentationReader additive_signals_segmentation_reader(additive_signals_segmentation_rspecifier); - SegmentationWriter unreliable_writer(unreliable_segmentation_wspecifier); RandomAccessInt32Reader lengths_reader(lengths_rspecifier); @@ -84,18 +82,20 @@ int main(int argc, char *argv[]) { num_missing++; continue; } - const std::vector &additive_signals_info = additive_signals_info_reader.Value(key); + const std::vector &additive_signals_info = + additive_signals_info_reader.Value(key); Segmentation segmentation(reco_segmentation_reader.Value()); - Segmentation unreliable_segmentation; for (size_t i = 0; i < additive_signals_info.size(); i++) { std::vector parts; SplitStringToVector(additive_signals_info[i], ",:", false, &parts); if (parts.size() != 3) { - KALDI_ERR << "Invalid format of overlap info " << additive_signals_info[i] - << "for key " << key << " in " << additive_signals_info_rspecifier; + KALDI_ERR << "Invalid format of overlap info " + << additive_signals_info[i] + << "for key " << key << " in " + << additive_signals_info_rspecifier; } const std::string &utt_id = parts[0]; double start_time; @@ -110,17 +110,22 @@ int main(int argc, char *argv[]) { << "segmentation " << additive_signals_segmentation_rspecifier; if (duration < 0) { KALDI_ERR << "duration < 0 for utt_id " << utt_id << " in " - << "additive_signals_info " << additive_signals_info_rspecifier - << "; additive-signals-segmentation must be provided in such a case"; + << "additive_signals_info " + << additive_signals_info_rspecifier + << "; additive-signals-segmentation must be provided " + << "in such a case"; } num_err++; - unreliable_segmentation.EmplaceBack(start_frame, start_frame + duration - 1, 0); + int32 length = round(duration / frame_shift); + segmentation.EmplaceBack(start_frame, start_frame + length - 1, + junk_label); continue; // Treated as non-overlapping even though there // is overlap } - InsertFromSegmentation(additive_signals_segmentation_reader.Value(utt_id), - start_frame, false, &segmentation); + InsertFromSegmentation( + additive_signals_segmentation_reader.Value(utt_id), + start_frame, false, &segmentation); } Sort(&segmentation); @@ -134,19 +139,6 @@ int main(int argc, char *argv[]) { } writer.Write(key, segmentation); - if (!unreliable_segmentation_wspecifier.empty()) { - Sort(&unreliable_segmentation); - if (!lengths_rspecifier.empty()) { - if (!lengths_reader.HasKey(key)) { - KALDI_WARN << "Could not find length for the recording " << key - << "in " << lengths_rspecifier; - continue; - } - TruncateToLength(lengths_reader.Value(key), &unreliable_segmentation); - } - unreliable_writer.Write(key, unreliable_segmentation); - } - num_done++; } diff --git a/src/segmenterbin/segmentation-init-from-ali.cc b/src/segmenterbin/segmentation-init-from-ali.cc index a98a54368c9..452ff56c2d8 100644 --- a/src/segmenterbin/segmentation-init-from-ali.cc +++ b/src/segmenterbin/segmentation-init-from-ali.cc @@ -54,7 +54,7 @@ int main(int argc, char *argv[]) { int64 num_segments = 0; int64 num_err = 0; - std::vector frame_counts_per_class; + std::map frame_counts_per_class; SequentialInt32VectorReader alignment_reader(ali_rspecifier); @@ -80,7 +80,11 @@ int main(int argc, char *argv[]) { << "wrote " << num_segmentations << " segmentations " << "with a total of " << num_segments << " segments."; KALDI_LOG << "Number of frames for the different classes are : "; - WriteIntegerVector(KALDI_LOG, false, frame_counts_per_class); + + std::map::const_iterator it = frame_counts_per_class.begin(); + for (; it != frame_counts_per_class.end(); ++it) { + KALDI_LOG << it->first << " " << it->second << " ; "; + } return ((num_done > 0 && num_err < num_done) ? 0 : 1); } catch(const std::exception &e) { diff --git a/src/segmenterbin/segmentation-merge-recordings.cc b/src/segmenterbin/segmentation-merge-recordings.cc index 85b5108be29..dccd82b0595 100644 --- a/src/segmenterbin/segmentation-merge-recordings.cc +++ b/src/segmenterbin/segmentation-merge-recordings.cc @@ -92,7 +92,8 @@ int main(int argc, char *argv[]) { << "created overall " << num_segments << " segments; " << "failed to merge " << num_err << " old segmentations"; - return (num_new_segmentations > 0 && num_err < num_old_segmentations / 2); + return (num_new_segmentations > 0 && num_err < num_old_segmentations / 2 ? + 0 : 1); } catch(const std::exception &e) { std::cerr << e.what(); return -1; diff --git a/src/segmenterbin/segmentation-to-rttm.cc b/src/segmenterbin/segmentation-to-rttm.cc index 6ffd1a8b1e8..8f22d78f3bc 100644 --- a/src/segmenterbin/segmentation-to-rttm.cc +++ b/src/segmenterbin/segmentation-to-rttm.cc @@ -17,6 +17,7 @@ // See the Apache 2 License for the specific language governing permissions and // limitations under the License. +#include #include "base/kaldi-common.h" #include "util/common-utils.h" #include "segmenter/segmentation.h" @@ -40,42 +41,60 @@ namespace segmenter { * The function retunns the largest class_id that it encounters. **/ -int32 WriteRttm(const Segmentation &segmentation, - std::ostream &os, const std::string &file_id, - const std::string &channel, - BaseFloat frame_shift, BaseFloat start_time, - bool map_to_speech_and_sil) { +void WriteRttm(const Segmentation &segmentation, + const std::string &file_id, + const std::string &channel, + BaseFloat frame_shift, BaseFloat start_time, + bool map_to_speech_and_sil, + int32 no_score_label, std::ostream &os) { SegmentList::const_iterator it = segmentation.Begin(); - int32 largest_class = 0; + + unordered_map classes_map; + std::vector classes_vec; + for (; it != segmentation.End(); ++it) { + if (no_score_label > 0 && it->Label() == no_score_label) { + os << "NOSCORE " << file_id << " " << channel << " " + << it->start_frame * frame_shift + start_time << " " + << (it->Length()) * frame_shift << " \n"; + continue; + } os << "SPEAKER " << file_id << " " << channel << " " - << it->start_frame * frame_shift + start_time << " " + << it->start_frame * frame_shift + start_time << " " << (it->Length()) * frame_shift << " "; if (map_to_speech_and_sil) { switch (it->Label()) { - case 1: - os << "SPEECH "; + case 0: + os << "SILENCE "; break; default: - os << "SILENCE "; + os << "SPEECH "; break; } - largest_class = 1; } else { if (it->Label() >= 0) { os << it->Label() << " "; - if (it->Label() > largest_class) - largest_class = it->Label(); + if (classes_map.count(it->Label()) == 0) { + classes_map[it->Label()] = true; + classes_vec.push_back(it->Label()); + } } } os << "" << std::endl; - } - return largest_class; -} + } -} + if (!map_to_speech_and_sil) { + for (std::vector::const_iterator it = classes_vec.begin(); + it != classes_vec.end(); ++it) { + os << "SPKR-INFO " << file_id << " " << channel + << " unknown " << *it << " \n"; + } + } } +} // namespace segmenter +} // namespace kaldi + int main(int argc, char *argv[]) { try { using namespace kaldi; @@ -84,20 +103,27 @@ int main(int argc, char *argv[]) { const char *usage = "Convert segmentation into RTTM\n" "\n" - "Usage: segmentation-to-rttm [options] \n" + "Usage: segmentation-to-rttm [options] " + "\n" " e.g.: segmentation-to-rttm ark:1.seg -\n"; - + bool map_to_speech_and_sil = true; + int32 no_score_label = -1; BaseFloat frame_shift = 0.01; std::string segments_rxfilename; std::string reco2file_and_channel_rxfilename; ParseOptions po(usage); - + po.Register("frame-shift", &frame_shift, "Frame shift in seconds"); po.Register("segments", &segments_rxfilename, "Segments file"); - po.Register("reco2file-and-channel", &reco2file_and_channel_rxfilename, "reco2file_and_channel file"); - po.Register("map-to-speech-and-sil", &map_to_speech_and_sil, "Map all classes to SPEECH and SILENCE"); + po.Register("reco2file-and-channel", &reco2file_and_channel_rxfilename, + "reco2file_and_channel file"); + po.Register("map-to-speech-and-sil", &map_to_speech_and_sil, + "Map all classes other than 0 to SPEECH"); + po.Register("no-score-label", &no_score_label, + "If specified, then a NOSCORE region is added to RTTM " + "when this label occurs in the segmentation."); po.Read(argc, argv); @@ -105,20 +131,20 @@ int main(int argc, char *argv[]) { po.PrintUsage(); exit(1); } - + unordered_map utt2file; unordered_map utt2start_time; if (!segments_rxfilename.empty()) { - Input ki(segments_rxfilename); // no binary argment: never binary. + Input ki(segments_rxfilename); // no binary argment: never binary. int32 i = 0; std::string line; /* read each line from segments file */ while (std::getline(ki.Stream(), line)) { std::vector split_line; // Split the line by space or tab and check the number of fields in each - // line. There must be 4 fields--segment name , reacording wav file name, - // start time, end time; 5th field (channel info) is optional. + // line. There must be 4 fields--segment name , reacording wav file + // name, start time, end time; 5th field (channel info) is optional. SplitStringToVector(line, " \t\r", true, &split_line); if (split_line.size() != 4 && split_line.size() != 5) { KALDI_WARN << "Invalid line in segments file: " << line; @@ -128,7 +154,7 @@ int main(int argc, char *argv[]) { utterance = split_line[1], start_str = split_line[2], end_str = split_line[3]; - + // Convert the start time and endtime to real from string. Segment is // ignored if start or end time cannot be converted to real. double start, end; @@ -143,15 +169,18 @@ int main(int argc, char *argv[]) { // start time must not be negative; start time must not be greater than // end time, except if end time is -1 if (start < 0 || end <= 0 || start >= end) { - KALDI_WARN << "Invalid line in segments file [empty or invalid segment]: " - << line; + KALDI_WARN << "Invalid line in segments file " + << "[empty or invalid segment]: " + << line; continue; } int32 channel = -1; // means channel info is unspecified. - // if each line has 5 elements then 5th element must be channel identifier - if(split_line.size() == 5) { + // if each line has 5 elements then 5th element must be channel + // identifier + if (split_line.size() == 5) { if (!ConvertStringToInteger(split_line[4], &channel) || channel < 0) { - KALDI_WARN << "Invalid line in segments file [bad channel]: " << line; + KALDI_WARN << "Invalid line in segments file " + << "[bad channel]: " << line; continue; } } @@ -163,10 +192,12 @@ int main(int argc, char *argv[]) { KALDI_LOG << "Read " << i << " lines from " << segments_rxfilename; } - unordered_map , StringHasher> reco2file_and_channel; + unordered_map, + StringHasher> reco2file_and_channel; if (!reco2file_and_channel_rxfilename.empty()) { - Input ki(reco2file_and_channel_rxfilename); // no binary argment: never binary. + // no binary argment: never binary. + Input ki(reco2file_and_channel_rxfilename); int32 i = 0; std::string line; @@ -183,11 +214,13 @@ int main(int argc, char *argv[]) { const std::string &file_id = split_line[1]; const std::string &channel = split_line[2]; - reco2file_and_channel.insert(std::make_pair(reco_id, std::make_pair(file_id, channel))); + reco2file_and_channel.insert( + std::make_pair(reco_id, std::make_pair(file_id, channel))); i++; } - KALDI_LOG << "Read " << i << " lines from " << reco2file_and_channel_rxfilename; + KALDI_LOG << "Read " << i << " lines from " + << reco2file_and_channel_rxfilename; } unordered_set seen_files; @@ -196,18 +229,18 @@ int main(int argc, char *argv[]) { rttm_out_wxfilename = po.GetArg(2); int64 num_done = 0, num_err = 0; - + Output ko(rttm_out_wxfilename, false); SequentialSegmentationReader reader(segmentation_rspecifier); for (; !reader.Done(); reader.Next(), num_done++) { Segmentation segmentation(reader.Value()); const std::string &key = reader.Key(); - std::string reco_id = key; + std::string reco_id = key; BaseFloat start_time = 0.0; if (!segments_rxfilename.empty()) { if (utt2file.count(key) == 0 || utt2start_time.count(key) == 0) - KALDI_ERR << "Could not find key " << key << " in segments " + KALDI_ERR << "Could not find key " << key << " in segments " << segments_rxfilename; KALDI_ASSERT(utt2file.count(key) > 0 && utt2start_time.count(key) > 0); reco_id = utt2file[key]; @@ -216,8 +249,8 @@ int main(int argc, char *argv[]) { std::string file_id, channel; if (!reco2file_and_channel_rxfilename.empty()) { - if (reco2file_and_channel.count(reco_id) == 0) - KALDI_ERR << "Could not find recording " << reco_id + if (reco2file_and_channel.count(reco_id) == 0) + KALDI_ERR << "Could not find recording " << reco_id << " in " << reco2file_and_channel_rxfilename; file_id = reco2file_and_channel[reco_id].first; channel = reco2file_and_channel[reco_id].second; @@ -226,18 +259,18 @@ int main(int argc, char *argv[]) { channel = "1"; } - int32 largest_class = WriteRttm(segmentation, ko.Stream(), file_id, channel, frame_shift, start_time, map_to_speech_and_sil); + WriteRttm(segmentation, file_id, + channel, frame_shift, start_time, + map_to_speech_and_sil, no_score_label, ko.Stream()); if (map_to_speech_and_sil) { if (seen_files.count(reco_id) == 0) { - ko.Stream() << "SPKR-INFO " << file_id << " " << channel << " unknown SILENCE \n"; - ko.Stream() << "SPKR-INFO " << file_id << " " << channel << " unknown SPEECH \n"; + ko.Stream() << "SPKR-INFO " << file_id << " " << channel + << " unknown SILENCE \n"; + ko.Stream() << "SPKR-INFO " << file_id << " " << channel + << " unknown SPEECH \n"; seen_files.insert(reco_id); } - } else { - for (int32 i = 0; i < largest_class; i++) { - ko.Stream() << "SPKR-INFO " << file_id << " " << channel << " unknown " << i << " \n"; - } } } @@ -249,7 +282,3 @@ int main(int argc, char *argv[]) { return -1; } } - - - - From 5ac90c8df123f58eab8236484b9fdbdfe3bdfa38 Mon Sep 17 00:00:00 2001 From: Vimal Manohar Date: Mon, 2 Jan 2017 18:31:35 -0500 Subject: [PATCH 121/213] asr_diarization: Support objective type in basic_layers --- egs/wsj/s5/steps/libs/nnet3/xconfig/basic_layers.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/egs/wsj/s5/steps/libs/nnet3/xconfig/basic_layers.py b/egs/wsj/s5/steps/libs/nnet3/xconfig/basic_layers.py index f74137da48b..be2776c90b8 100644 --- a/egs/wsj/s5/steps/libs/nnet3/xconfig/basic_layers.py +++ b/egs/wsj/s5/steps/libs/nnet3/xconfig/basic_layers.py @@ -582,7 +582,9 @@ def get_full_config(self): if output_delay != 0: cur_node = 'Offset({0}, {1})'.format(cur_node, output_delay) - line = ('output-node name={0} input={1}'.format(self.name, cur_node)) + line = ('output-node name={0} input={1} ' + 'objective={2}'.format( + self.name, cur_node, objective_type)) ans.append((config_name, line)) return ans From 6e3889b83c8ad7a455561f37875d8b61d95abc2f Mon Sep 17 00:00:00 2001 From: Vimal Manohar Date: Mon, 2 Jan 2017 18:32:17 -0500 Subject: [PATCH 122/213] asr_diarization: Update multilingual egs creation --- .../nnet3/multilingual/allocate_multilingual_examples.py | 8 +++++++- egs/wsj/s5/steps/nnet3/multilingual/get_egs.sh | 3 +++ 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/egs/wsj/s5/steps/nnet3/multilingual/allocate_multilingual_examples.py b/egs/wsj/s5/steps/nnet3/multilingual/allocate_multilingual_examples.py index cba804b1a66..9bc6da53705 100644 --- a/egs/wsj/s5/steps/nnet3/multilingual/allocate_multilingual_examples.py +++ b/egs/wsj/s5/steps/nnet3/multilingual/allocate_multilingual_examples.py @@ -214,7 +214,13 @@ def Main(): start_egs = lang2len[lang_id] - lang_len[lang_id] this_egs.append((lang_id, start_egs, args.minibatch_size)) for scpline in range(args.minibatch_size): - print("{0} {1}".format(scp_files[lang_id].readline().splitlines()[0], lang_id), file = archfile) + lines = scp_files[lang_id].readline().splitlines() + try: + print("{0} {1}".format(lines[0], lang_id), file=archfile) + except Exception: + logger.error("Failure to read from file %s, got %s", + scp_files[lang_id].name, lines) + raise lang_len[lang_id] = lang_len[lang_id] - args.minibatch_size num_egs = num_egs + args.minibatch_size; diff --git a/egs/wsj/s5/steps/nnet3/multilingual/get_egs.sh b/egs/wsj/s5/steps/nnet3/multilingual/get_egs.sh index aa9a911ffb2..58ef965de3e 100755 --- a/egs/wsj/s5/steps/nnet3/multilingual/get_egs.sh +++ b/egs/wsj/s5/steps/nnet3/multilingual/get_egs.sh @@ -18,6 +18,7 @@ set -u # Begin configuration section. cmd=run.pl minibatch_size=512 # multiple of minibatch used during training. +minibatch_size= num_jobs=10 # This can be set to max number of jobs to run in parallel; # Helps for better randomness across languages # per archive. @@ -85,6 +86,8 @@ for lang in $(seq 0 $[$num_langs-1]);do done done +cp ${multi_egs_dir[$lang]}/cmvn_opts $megs_dir + if [ $stage -le 0 ]; then echo "$0: allocating multilingual examples for training." # Generate egs.*.scp for multilingual setup. From 3d10480b573d724cf990eb58dbd10cb1b191b442 Mon Sep 17 00:00:00 2001 From: Vimal Manohar Date: Mon, 2 Jan 2017 18:33:01 -0500 Subject: [PATCH 123/213] asr_diarization: Add per-dim accuracy to diagnostics --- egs/wsj/s5/steps/libs/nnet3/train/common.py | 16 +- .../nnet3/train/frame_level_objf/common.py | 166 ++++++++++-------- src/nnet3/nnet-diagnostics.cc | 74 ++++++-- src/nnet3/nnet-diagnostics.h | 15 +- 4 files changed, 176 insertions(+), 95 deletions(-) diff --git a/egs/wsj/s5/steps/libs/nnet3/train/common.py b/egs/wsj/s5/steps/libs/nnet3/train/common.py index 7ae44cdffae..503c3ba622d 100644 --- a/egs/wsj/s5/steps/libs/nnet3/train/common.py +++ b/egs/wsj/s5/steps/libs/nnet3/train/common.py @@ -197,9 +197,10 @@ def verify_egs_dir(egs_dir, feat_dim, ivector_dim, return [egs_left_context, egs_right_context, frames_per_eg, num_archives] - except (IOError, ValueError) as e: - raise Exception("The egs dir {0} has missing or " - "malformed files: {1}".format(egs_dir, e.strerr)) + except (IOError, ValueError): + logger.error("The egs dir {0} has missing or " + "malformed files.".format(egs_dir)) + raise def compute_presoftmax_prior_scale(dir, alidir, num_jobs, run_opts, @@ -561,7 +562,7 @@ def __init__(self): to steps/nnet3/get_egs.sh script""") self.parser.add_argument("--egs.use-multitask-egs", type=str, dest='use_multitask_egs', - default=True, choices=["true", "false"], + default=False, choices=["true", "false"], action=common_lib.StrToBoolAction, help="""Use mutlitask egs created using allocate_multilingual_egs.py.""") @@ -683,6 +684,13 @@ def __init__(self): lstm*=0,0.2,0'. More general should precede less general patterns, as they are applied sequentially.""") + self.parser.add_argument("--trainer.compute-per-dim-accuracy", + dest='compute_per_dim_accuracy', + type=str, choices=['true', 'false'], + default=False, + action=common_lib.StrToBoolAction, + help="Compute train and validation " + "accuracy per-dim") # General options self.parser.add_argument("--stage", type=int, default=-4, diff --git a/egs/wsj/s5/steps/libs/nnet3/train/frame_level_objf/common.py b/egs/wsj/s5/steps/libs/nnet3/train/frame_level_objf/common.py index 508445e331e..9c8b5d0ee95 100644 --- a/egs/wsj/s5/steps/libs/nnet3/train/frame_level_objf/common.py +++ b/egs/wsj/s5/steps/libs/nnet3/train/frame_level_objf/common.py @@ -115,13 +115,14 @@ def train_new_models(dir, iter, srand, num_jobs, "nnet3-merge-egs --minibatch-size={minibatch_size} " "--measure-output-frames=false " "--discard-partial-minibatches=true ark:- ark:- | " - "nnet3-shuffle-egs --buffer-size={shuffle_buffer_size} " + "nnet3-shuffle-egs " + "--buffer-size={shuffle_buffer_size} --srand={srand} " "ark:- ark:- |".format( frame_opts=("" if chunk_level_training else "--frame={0}".format(frame)), context_opts=context_opts, egs_dir=egs_dir, output_rename_opt=output_rename_opt, - archive_index=archive_index, + archive_index=archive_index, srand=iter + srand, shuffle_buffer_size=shuffle_buffer_size, extra_egs_copy_cmd=extra_egs_copy_cmd, minibatch_size=minibatch_size)) @@ -138,7 +139,7 @@ def train_new_models(dir, iter, srand, num_jobs, frame_opts=("" if chunk_level_training else "--frame={0}".format(frame)), context_opts=context_opts, egs_dir=egs_dir, - archive_index=archive_index, + archive_index=archive_index, srand=iter + srand, shuffle_buffer_size=shuffle_buffer_size, extra_egs_copy_cmd=extra_egs_copy_cmd, minibatch_size=minibatch_size)) @@ -154,9 +155,7 @@ def train_new_models(dir, iter, srand, num_jobs, {dir}/{next_iter}.{job}.raw""".format( command=run_opts.command, train_queue_opt=run_opts.train_queue_opt, - dir=dir, iter=iter, srand=iter + srand, - next_iter=iter + 1, - job=job, + dir=dir, iter=iter, next_iter=iter + 1, job=job, parallel_train_opts=run_opts.parallel_train_opts, cache_read_opt=cache_read_opt, cache_write_opt=cache_write_opt, @@ -194,7 +193,8 @@ def train_one_iteration(dir, iter, srand, egs_dir, get_raw_nnet_from_am=True, background_process_handler=None, extra_egs_copy_cmd="", use_multitask_egs=False, - rename_multitask_outputs=False): + rename_multitask_outputs=False, + compute_per_dim_accuracy=False): """ Called from steps/nnet3/train_*.py scripts for one iteration of neural network training @@ -223,9 +223,10 @@ def train_one_iteration(dir, iter, srand, egs_dir, if os.path.exists('{0}/srand'.format(dir)): try: saved_srand = int(open('{0}/srand'.format(dir)).readline().strip()) - except (IOError, ValueError) as e: - raise Exception("Exception while reading the random seed " - "for training: {0}".format(e.str())) + except (IOError, ValueError): + logger.error("Exception while reading the random seed " + "for training.") + raise if srand != saved_srand: logger.warning("The random seed provided to this iteration " "(srand={0}) is different from the one saved last " @@ -244,7 +245,8 @@ def train_one_iteration(dir, iter, srand, egs_dir, mb_size=cv_minibatch_size, get_raw_nnet_from_am=get_raw_nnet_from_am, wait=False, background_process_handler=background_process_handler, - extra_egs_copy_cmd=extra_egs_copy_cmd) + extra_egs_copy_cmd=extra_egs_copy_cmd, + compute_per_dim_accuracy=compute_per_dim_accuracy) if iter > 0: # Runs in the background @@ -395,25 +397,25 @@ def compute_preconditioning_matrix(dir, egs_dir, num_lda_jobs, run_opts, # Write stats with the same format as stats for LDA. common_lib.run_job( - """{command} JOB=1:{num_lda_jobs} {dir}/log/get_lda_stats.JOB.log \ - nnet3-acc-lda-stats --rand-prune={rand_prune} \ - {dir}/init.raw "ark:{egs_dir}/egs.JOB.ark" \ - {dir}/JOB.lda_stats""".format( - command=run_opts.command, - num_lda_jobs=num_lda_jobs, - dir=dir, - egs_dir=egs_dir, - rand_prune=rand_prune)) + """{command} JOB=1:{num_lda_jobs} {dir}/log/get_lda_stats.JOB.log """ + """ nnet3-acc-lda-stats --rand-prune={rand_prune}""" + """ {dir}/init.raw "ark:{egs_dir}/egs.JOB.ark" """ + """ {dir}/JOB.lda_stats""".format( + command=run_opts.command, + num_lda_jobs=num_lda_jobs, + dir=dir, + egs_dir=egs_dir, + rand_prune=rand_prune)) # the above command would have generated dir/{1..num_lda_jobs}.lda_stats lda_stat_files = map(lambda x: '{0}/{1}.lda_stats'.format(dir, x), range(1, num_lda_jobs + 1)) common_lib.run_job( - """{command} {dir}/log/sum_transform_stats.log \ - sum-lda-accs {dir}/lda_stats {lda_stat_files}""".format( - command=run_opts.command, - dir=dir, lda_stat_files=" ".join(lda_stat_files))) + "{command} {dir}/log/sum_transform_stats.log " + "sum-lda-accs {dir}/lda_stats {lda_stat_files}".format( + command=run_opts.command, + dir=dir, lda_stat_files=" ".join(lda_stat_files))) for file in lda_stat_files: try: @@ -426,11 +428,11 @@ def compute_preconditioning_matrix(dir, egs_dir, num_lda_jobs, run_opts, # variant of an LDA transform but without dimensionality reduction. common_lib.run_job( - """{command} {dir}/log/get_transform.log \ - nnet-get-feature-transform {lda_opts} {dir}/lda.mat \ - {dir}/lda_stats""".format( - command=run_opts.command, dir=dir, - lda_opts=lda_opts if lda_opts is not None else "")) + "{command} {dir}/log/get_transform.log" + " nnet-get-feature-transform {lda_opts} {dir}/lda.mat" + " {dir}/lda_stats".format( + command=run_opts.command, dir=dir, + lda_opts=lda_opts if lda_opts is not None else "")) common_lib.force_symlink("../lda.mat", "{0}/configs/lda.mat".format(dir)) @@ -439,7 +441,8 @@ def compute_train_cv_probabilities(dir, iter, egs_dir, left_context, right_context, run_opts, mb_size=256, wait=False, background_process_handler=None, get_raw_nnet_from_am=True, - extra_egs_copy_cmd=""): + extra_egs_copy_cmd="", + compute_per_dim_accuracy=False): if get_raw_nnet_from_am: model = "nnet3-am-copy --raw=true {dir}/{iter}.mdl - |".format( dir=dir, iter=iter) @@ -455,21 +458,26 @@ def compute_train_cv_probabilities(dir, iter, egs_dir, left_context, valid_diagnostic_egs = "scp:{0}/valid_diagnostic.egs.1".format( egs_dir) + opts = [] + if compute_per_dim_accuracy: + opts.append("--compute-per-dim-accuracy") + common_lib.run_job( - """ {command} {dir}/log/compute_prob_valid.{iter}.log \ - nnet3-compute-prob "{model}" \ - "ark,bg:nnet3-copy-egs {context_opts} \ - {egs_rspecifier} ark:- |{extra_egs_copy_cmd} \ - nnet3-merge-egs --minibatch-size={mb_size} ark:- \ - ark:- |" """.format(command=run_opts.command, - dir=dir, - iter=iter, - egs_rspecifier=valid_diagnostic_egs, - context_opts=context_opts, - mb_size=mb_size, - model=model, - egs_dir=egs_dir, - extra_egs_copy_cmd=extra_egs_copy_cmd), + """{command} {dir}/log/compute_prob_valid.{iter}.log""" + """ nnet3-compute-prob {opts} "{model}" """ + """ "ark,bg:nnet3-copy-egs {context_opts}""" + """ {egs_rspecifier} ark:- |{extra_egs_copy_cmd}""" + """ nnet3-merge-egs --minibatch-size={mb_size} ark:-""" + """ ark:- |" """.format(command=run_opts.command, + opts=' '.join(opts), + dir=dir, + iter=iter, + egs_rspecifier=valid_diagnostic_egs, + context_opts=context_opts, + mb_size=mb_size, + model=model, + egs_dir=egs_dir, + extra_egs_copy_cmd=extra_egs_copy_cmd), wait=wait, background_process_handler=background_process_handler) if os.path.isfile("{0}/train_diagnostic.egs".format(egs_dir)): @@ -479,20 +487,21 @@ def compute_train_cv_probabilities(dir, iter, egs_dir, left_context, egs_dir) common_lib.run_job( - """{command} {dir}/log/compute_prob_train.{iter}.log \ - nnet3-compute-prob "{model}" \ - "ark,bg:nnet3-copy-egs {context_opts} \ - {egs_rspecifier} ark:- | {extra_egs_copy_cmd} \ - nnet3-merge-egs --minibatch-size={mb_size} ark:- \ - ark:- |" """.format(command=run_opts.command, - dir=dir, - iter=iter, - egs_rspecifier=train_diagnostic_egs, - context_opts=context_opts, - mb_size=mb_size, - model=model, - egs_dir=egs_dir, - extra_egs_copy_cmd=extra_egs_copy_cmd), + """{command} {dir}/log/compute_prob_train.{iter}.log""" + """ nnet3-compute-prob {opts} "{model}" """ + """ "ark,bg:nnet3-copy-egs {context_opts}""" + """ {egs_rspecifier} ark:- | {extra_egs_copy_cmd}""" + """ nnet3-merge-egs --minibatch-size={mb_size} ark:-""" + """ ark:- |" """.format(command=run_opts.command, + opts=' '.join(opts), + dir=dir, + iter=iter, + egs_rspecifier=train_diagnostic_egs, + context_opts=context_opts, + mb_size=mb_size, + model=model, + egs_dir=egs_dir, + extra_egs_copy_cmd=extra_egs_copy_cmd), wait=wait, background_process_handler=background_process_handler) @@ -519,30 +528,29 @@ def compute_progress(dir, iter, egs_dir, left_context, right_context, egs_dir) common_lib.run_job( - """{command} {dir}/log/progress.{iter}.log \ - nnet3-info "{model}" '&&' \ - nnet3-show-progress --use-gpu=no "{prev_model}" "{model}" \ - "ark,bg:nnet3-copy-egs {context_opts} \ - {egs_rspecifier} ark:- |{extra_egs_copy_cmd} \ - nnet3-merge-egs --minibatch-size={mb_size} ark:- \ - ark:- |" """.format(command=run_opts.command, - dir=dir, - iter=iter, - egs_rspecifier=train_diagnostic_egs, - model=model, - context_opts=context_opts, - mb_size=mb_size, - prev_model=prev_model, - egs_dir=egs_dir, - extra_egs_copy_cmd=extra_egs_copy_cmd), - wait=wait, background_process_handler=background_process_handler) + """{command} {dir}/log/progress.{iter}.log nnet3-info "{model}" """ + """ '&&' nnet3-show-progress --use-gpu=no "{prev_model}" "{model}" """ + """ "ark,bg:nnet3-copy-egs {context_opts}""" + """ {egs_rspecifier} ark:- |{extra_egs_copy_cmd}""" + """ nnet3-merge-egs --minibatch-size={mb_size} ark:-""" + """ ark:- |" """.format(command=run_opts.command, + dir=dir, + iter=iter, + egs_rspecifier=train_diagnostic_egs, + model=model, + context_opts=context_opts, + mb_size=mb_size, + prev_model=prev_model, + egs_dir=egs_dir, + extra_egs_copy_cmd=extra_egs_copy_cmd), + wait=wait, background_process_handler=background_process_handler) def combine_models(dir, num_iters, models_to_combine, egs_dir, left_context, right_context, run_opts, background_process_handler=None, chunk_width=None, get_raw_nnet_from_am=True, - extra_egs_copy_cmd=""): + extra_egs_copy_cmd="", compute_per_dim_accuracy=False): """ Function to do model combination In the nnet3 setup, the logic @@ -617,7 +625,8 @@ def combine_models(dir, num_iters, models_to_combine, egs_dir, left_context=left_context, right_context=right_context, run_opts=run_opts, wait=False, background_process_handler=background_process_handler, - extra_egs_copy_cmd=extra_egs_copy_cmd) + extra_egs_copy_cmd=extra_egs_copy_cmd, + compute_per_dim_accuracy=compute_per_dim_accuracy) else: compute_train_cv_probabilities( dir=dir, iter='final', egs_dir=egs_dir, @@ -625,7 +634,8 @@ def combine_models(dir, num_iters, models_to_combine, egs_dir, run_opts=run_opts, wait=False, background_process_handler=background_process_handler, get_raw_nnet_from_am=False, - extra_egs_copy_cmd=extra_egs_copy_cmd) + extra_egs_copy_cmd=extra_egs_copy_cmd, + compute_per_dim_accuracy=compute_per_dim_accuracy) def get_realign_iters(realign_times, num_iters, diff --git a/src/nnet3/nnet-diagnostics.cc b/src/nnet3/nnet-diagnostics.cc index 64abe8a0578..d00dd31b245 100644 --- a/src/nnet3/nnet-diagnostics.cc +++ b/src/nnet3/nnet-diagnostics.cc @@ -108,22 +108,33 @@ void NnetComputeProb::ProcessOutputs(const NnetExample &eg, } if (config_.compute_accuracy) { BaseFloat tot_weight, tot_accuracy; + PerDimObjectiveInfo &totals = accuracy_info_[io.name]; + + if (config_.compute_per_dim_accuracy && + totals.tot_objective_vec.Dim() == 0) { + totals.tot_objective_vec.Resize(output.NumCols()); + totals.tot_weight_vec.Resize(output.NumCols()); + } + ComputeAccuracy(io.features, output, - &tot_weight, &tot_accuracy, deriv_weights); - SimpleObjectiveInfo &totals = accuracy_info_[io.name]; + &tot_weight, &tot_accuracy, deriv_weights, + config_.compute_per_dim_accuracy ? + &totals.tot_weight_vec : NULL, + config_.compute_per_dim_accuracy ? + &totals.tot_objective_vec : NULL); totals.tot_weight += tot_weight; totals.tot_objective += tot_accuracy; } - num_minibatches_processed_++; } } + num_minibatches_processed_++; } bool NnetComputeProb::PrintTotalStats() const { bool ans = false; - unordered_map::const_iterator - iter, end; { // First print regular objectives + unordered_map::const_iterator iter, end; iter = objf_info_.begin(); end = objf_info_.end(); for (; iter != end; ++iter) { @@ -141,15 +152,34 @@ bool NnetComputeProb::PrintTotalStats() const { ans = true; } } - { // now print accuracies. + { + unordered_map::const_iterator iter, end; + // now print accuracies. iter = accuracy_info_.begin(); end = accuracy_info_.end(); for (; iter != end; ++iter) { const std::string &name = iter->first; - const SimpleObjectiveInfo &info = iter->second; + const PerDimObjectiveInfo &info = iter->second; KALDI_LOG << "Overall accuracy for '" << name << "' is " << (info.tot_objective / info.tot_weight) << " per frame" << ", over " << info.tot_weight << " frames."; + + if (info.tot_weight_vec.Dim() > 0) { + Vector accuracy_vec(info.tot_weight_vec.Dim()); + for (size_t j = 0; j < info.tot_weight_vec.Dim(); j++) { + if (info.tot_weight_vec(j) != 0) { + accuracy_vec(j) = info.tot_objective_vec(j) + / info.tot_weight_vec(j); + } else { + accuracy_vec(j) = -1.0; + } + } + + KALDI_LOG << "Overall per-dim accuracy vector for '" << name + << "' is " << accuracy_vec << " per frame" + << ", over " << info.tot_weight << " frames."; + } // don't bother changing ans; the loop over the regular objective should // already have set it to true if we got any data. } @@ -161,12 +191,19 @@ void ComputeAccuracy(const GeneralMatrix &supervision, const CuMatrixBase &nnet_output, BaseFloat *tot_weight_out, BaseFloat *tot_accuracy_out, - const Vector *deriv_weights) { + const Vector *deriv_weights, + Vector *tot_weight_vec, + Vector *tot_accuracy_vec) { int32 num_rows = nnet_output.NumRows(), num_cols = nnet_output.NumCols(); KALDI_ASSERT(supervision.NumRows() == num_rows && supervision.NumCols() == num_cols); + if (tot_accuracy_vec || tot_weight_vec) + KALDI_ASSERT(tot_accuracy_vec && tot_weight_vec && + tot_accuracy_vec->Dim() == num_cols && + tot_weight_vec->Dim() == num_cols); + CuArray best_index(num_rows); nnet_output.FindRowMaxId(&best_index); std::vector best_index_cpu; @@ -192,8 +229,13 @@ void ComputeAccuracy(const GeneralMatrix &supervision, if (deriv_weights) row_sum *= (*deriv_weights)(r); tot_weight += row_sum; - if (best_index == best_index_cpu[r]) + if (tot_weight_vec) + (*tot_weight_vec)(best_index) += row_sum; + if (best_index == best_index_cpu[r]) { tot_accuracy += row_sum; + if (tot_accuracy_vec) + (*tot_accuracy_vec)(best_index) += row_sum; + } } break; } @@ -208,8 +250,13 @@ void ComputeAccuracy(const GeneralMatrix &supervision, if (deriv_weights) row_sum *= (*deriv_weights)(r); tot_weight += row_sum; - if (best_index == best_index_cpu[r]) + if (tot_weight_vec) + (*tot_weight_vec)(best_index) += row_sum; + if (best_index == best_index_cpu[r]) { tot_accuracy += row_sum; + if (tot_accuracy_vec) + (*tot_accuracy_vec)(best_index) += row_sum; + } } break; } @@ -224,8 +271,13 @@ void ComputeAccuracy(const GeneralMatrix &supervision, row_sum *= (*deriv_weights)(r); KALDI_ASSERT(best_index < num_cols); tot_weight += row_sum; - if (best_index == best_index_cpu[r]) + if (tot_weight_vec) + (*tot_weight_vec)(best_index) += row_sum; + if (best_index == best_index_cpu[r]) { tot_accuracy += row_sum; + if (tot_accuracy_vec) + (*tot_accuracy_vec)(best_index) += row_sum; + } } break; } diff --git a/src/nnet3/nnet-diagnostics.h b/src/nnet3/nnet-diagnostics.h index 59f0cd16f47..a333f0ac6fe 100644 --- a/src/nnet3/nnet-diagnostics.h +++ b/src/nnet3/nnet-diagnostics.h @@ -38,11 +38,17 @@ struct SimpleObjectiveInfo { tot_objective(0.0) { } }; +struct PerDimObjectiveInfo : SimpleObjectiveInfo { + Vector tot_weight_vec; + Vector tot_objective_vec; + PerDimObjectiveInfo(): SimpleObjectiveInfo() { } +}; struct NnetComputeProbOptions { bool debug_computation; bool compute_deriv; bool compute_accuracy; + bool compute_per_dim_accuracy; bool apply_deriv_weights; NnetOptimizeOptions optimize_config; @@ -51,6 +57,7 @@ struct NnetComputeProbOptions { debug_computation(false), compute_deriv(false), compute_accuracy(true), + compute_per_dim_accuracy(false), apply_deriv_weights(true) { } void Register(OptionsItf *opts) { // compute_deriv is not included in the command line options @@ -59,6 +66,8 @@ struct NnetComputeProbOptions { "debug for the actual computation (very verbose!)"); opts->Register("compute-accuracy", &compute_accuracy, "If true, compute " "accuracy values as well as objective functions"); + opts->Register("compute-per-dim-accuracy", &compute_per_dim_accuracy, + "If true, compute accuracy values per-dim"); opts->Register("apply-deriv-weights", &apply_deriv_weights, "Apply per-frame deriv weights"); @@ -128,7 +137,7 @@ class NnetComputeProb { unordered_map objf_info_; - unordered_map accuracy_info_; + unordered_map accuracy_info_; }; @@ -164,7 +173,9 @@ void ComputeAccuracy(const GeneralMatrix &supervision, const CuMatrixBase &nnet_output, BaseFloat *tot_weight, BaseFloat *tot_accuracy, - const Vector *deriv_weights = NULL); + const Vector *deriv_weights = NULL, + Vector *tot_weight_vec = NULL, + Vector *tot_accuracy_vec = NULL); } // namespace nnet3 From a5d78816f06aa279417feebef92ccacab33ba13d Mon Sep 17 00:00:00 2001 From: Vimal Manohar Date: Mon, 2 Jan 2017 18:34:16 -0500 Subject: [PATCH 124/213] sar_diarization: Minor bug fix in ../egs/wsj/s5/steps/nnet3/get_egs_multiple_targets.py --- egs/wsj/s5/steps/nnet3/get_egs_multiple_targets.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/egs/wsj/s5/steps/nnet3/get_egs_multiple_targets.py b/egs/wsj/s5/steps/nnet3/get_egs_multiple_targets.py index 72b0cb4edd3..8e6f1442c7a 100755 --- a/egs/wsj/s5/steps/nnet3/get_egs_multiple_targets.py +++ b/egs/wsj/s5/steps/nnet3/get_egs_multiple_targets.py @@ -741,7 +741,7 @@ def generate_training_examples_internal(dir, targets_parameters, feat_dir, out_egs_handle = open("{0}/egs.scp".format(dir), 'w') for i in range(1, num_archives_intermediate + 1): for line in open("{0}/egs.{1}.scp".format(dir, i)): - print (line, file=out_egs_handle) + print (line.strip(), file=out_egs_handle) out_egs_handle.close() else: # there are intermediate archives so we shuffle egs across jobs @@ -782,7 +782,7 @@ def generate_training_examples_internal(dir, targets_parameters, feat_dir, for i in range(1, num_archives_intermediate + 1): for j in range(1, archives_multiple + 1): for line in open("{0}/egs.{1}.{2}.scp".format(dir, i, j)): - print (line, file=out_egs_handle) + print (line.strip(), file=out_egs_handle) out_egs_handle.close() cleanup(dir, archives_multiple) From 894279b07c2e026913555dd8db5b58545aaba406 Mon Sep 17 00:00:00 2001 From: Vimal Manohar Date: Mon, 2 Jan 2017 18:35:03 -0500 Subject: [PATCH 125/213] asr_diarization: Some deep restructuring to decode and segmentation --- .../segmentation/decode_sad_to_segments.sh | 7 ++- .../segmentation/do_segmentation_data_dir.sh | 51 +++++++++---------- .../internal/convert_ali_to_vad.sh | 2 +- .../post_process_sad_to_subsegments.sh | 21 ++++++-- 4 files changed, 45 insertions(+), 36 deletions(-) mode change 100644 => 100755 egs/wsj/s5/steps/segmentation/post_process_sad_to_subsegments.sh diff --git a/egs/wsj/s5/steps/segmentation/decode_sad_to_segments.sh b/egs/wsj/s5/steps/segmentation/decode_sad_to_segments.sh index 8f4ed60dfda..de8ab0d90e8 100755 --- a/egs/wsj/s5/steps/segmentation/decode_sad_to_segments.sh +++ b/egs/wsj/s5/steps/segmentation/decode_sad_to_segments.sh @@ -88,10 +88,9 @@ if [ $stage -le 6 ]; then 1 0 2 1 EOF - steps/segmentation/post_process_sad_to_segments.sh \ - --phone2sad-map $lang/phone2sad_map \ - --ali-suffix "" --segmentation-config $segmentation_config \ + steps/segmentation/post_process_sad_to_subsegments.sh \ + --segmentation-config $segmentation_config \ --frame-subsampling-factor $frame_subsampling_factor \ - $data $lang $dir $dir $out_data + $data $lang/phone2sad_map $dir $dir $out_data fi diff --git a/egs/wsj/s5/steps/segmentation/do_segmentation_data_dir.sh b/egs/wsj/s5/steps/segmentation/do_segmentation_data_dir.sh index 9feb421ccd3..c1e690af366 100755 --- a/egs/wsj/s5/steps/segmentation/do_segmentation_data_dir.sh +++ b/egs/wsj/s5/steps/segmentation/do_segmentation_data_dir.sh @@ -36,7 +36,10 @@ do_downsampling=false # Segmentation configs min_silence_duration=30 min_speech_duration=30 +sil_prior=0.5 +speech_prior=0.5 segmentation_config=conf/segmentation_speech.conf +convert_data_dir_to_whole=true echo $* @@ -66,33 +69,27 @@ export PATH="$KALDI_ROOT/tools/sph2pipe_v2.5/:$PATH" whole_data_dir=${sad_dir}/${data_id}_whole -if [ $stage -le 0 ]; then - utils/data/convert_data_dir_to_whole.sh $src_data_dir ${whole_data_dir} - - if $do_downsampling; then - freq=`cat $mfcc_config | perl -pe 's/\s*#.*//g' | grep "sample-frequency=" | awk -F'=' '{if (NF == 0) print 16000; else print $2}'` - sox=`which sox` - - cat $src_data_dir/wav.scp | python -c "import sys -for line in sys.stdin.readlines(): - splits = line.strip().split() - if splits[-1] == '|': - out_line = line.strip() + ' $sox -t wav - -r $freq -c 1 -b 16 -t wav - downsample |' - else: - out_line = 'cat {0} {1} | $sox -t wav - -r $freq -c 1 -b 16 -t wav - downsample |'.format(splits[0], ' '.join(splits[1:])) - print (out_line)" > ${whole_data_dir}/wav.scp - fi - - utils/copy_data_dir.sh ${whole_data_dir} ${whole_data_dir}${feat_affix}_hires -fi +if $convert_data_dir_to_whole; then + if [ $stage -le 0 ]; then + utils/data/convert_data_dir_to_whole.sh $src_data_dir ${whole_data_dir} + + if $do_downsampling; then + freq=`cat $mfcc_config | perl -pe 's/\s*#.*//g' | grep "sample-frequency=" | awk -F'=' '{if (NF == 0) print 16000; else print $2}'` + utils/data/downsample_data_dir.sh $freq $whole_data_dir + fi -test_data_dir=${whole_data_dir}${feat_affix}_hires + utils/copy_data_dir.sh ${whole_data_dir} ${whole_data_dir}${feat_affix}_hires + fi -if [ $stage -le 1 ]; then - steps/make_mfcc.sh --mfcc-config $mfcc_config --nj $reco_nj --cmd "$train_cmd" \ - ${whole_data_dir}${feat_affix}_hires exp/make_hires/${data_id}_whole${feat_affix} $mfcc_dir - steps/compute_cmvn_stats.sh ${whole_data_dir}${feat_affix}_hires exp/make_hires/${data_id}_whole${feat_affix} $mfcc_dir - utils/fix_data_dir.sh ${whole_data_dir}${feat_affix}_hires + if [ $stage -le 1 ]; then + steps/make_mfcc.sh --mfcc-config $mfcc_config --nj $reco_nj --cmd "$train_cmd" \ + ${whole_data_dir}${feat_affix}_hires exp/make_hires/${data_id}_whole${feat_affix} $mfcc_dir + steps/compute_cmvn_stats.sh ${whole_data_dir}${feat_affix}_hires exp/make_hires/${data_id}_whole${feat_affix} $mfcc_dir + utils/fix_data_dir.sh ${whole_data_dir}${feat_affix}_hires + fi + test_data_dir=${whole_data_dir}${feat_affix}_hires +else + test_data_dir=$src_data_dir fi post_vec=$sad_nnet_dir/post_${output_name}.vec @@ -118,6 +115,8 @@ if [ $stage -le 3 ]; then --frame-subsampling-factor $frame_subsampling_factor \ --min-silence-duration $min_silence_duration \ --min-speech-duration $min_speech_duration \ + --sil-prior $sil_prior \ + --speech-prior $speech_prior \ --segmentation-config $segmentation_config --cmd "$train_cmd" \ ${test_data_dir} $sad_dir $seg_dir ${data_dir}_seg fi @@ -125,7 +124,7 @@ fi # Subsegment data directory if [ $stage -le 4 ]; then rm ${data_dir}_seg/feats.scp || true - utils/data/get_reco2num_frames.sh ${test_data_dir} + utils/data/get_reco2num_frames.sh --cmd "$train_cmd" --nj $reco_nj ${test_data_dir} awk '{print $1" "$2}' ${data_dir}_seg/segments | \ utils/apply_map.pl -f 2 ${test_data_dir}/reco2num_frames > \ ${data_dir}_seg/utt2max_frames diff --git a/egs/wsj/s5/steps/segmentation/internal/convert_ali_to_vad.sh b/egs/wsj/s5/steps/segmentation/internal/convert_ali_to_vad.sh index 353e6d4664e..234b5020797 100755 --- a/egs/wsj/s5/steps/segmentation/internal/convert_ali_to_vad.sh +++ b/egs/wsj/s5/steps/segmentation/internal/convert_ali_to_vad.sh @@ -13,7 +13,7 @@ frame_subsampling_factor=1 . parse_options.sh -if [ $# -ne 4 ]; then +if [ $# -ne 3 ]; then echo "This script converts the alignment in the alignment directory " echo "to speech activity segments based on the provided phone-map." echo "Usage: $0 exp/tri3_ali data/lang/phones/sad.map exp/tri3_ali_vad" diff --git a/egs/wsj/s5/steps/segmentation/post_process_sad_to_subsegments.sh b/egs/wsj/s5/steps/segmentation/post_process_sad_to_subsegments.sh old mode 100644 new mode 100755 index 8cfcaa40cda..0ca6b3dd126 --- a/egs/wsj/s5/steps/segmentation/post_process_sad_to_subsegments.sh +++ b/egs/wsj/s5/steps/segmentation/post_process_sad_to_subsegments.sh @@ -12,6 +12,7 @@ stage=-10 segmentation_config=conf/segmentation.conf nj=18 +frame_subsampling_factor=1 frame_shift=0.01 . utils/parse_options.sh @@ -39,26 +40,36 @@ if [ $stage -le 0 ]; then $cmd JOB=1:$nj $dir/log/segmentation.JOB.log \ segmentation-init-from-ali \ "ark:gunzip -c $vad_dir/ali.JOB.gz |" ark:- \| \ - segmentation-copy --label-map=$phone2sad_map ark:- \ + segmentation-copy --label-map=$phone2sad_map \ + --frame-subsampling-factor=$frame_subsampling_factor ark:- \ "ark:| gzip -c > $dir/orig_segmentation.JOB.gz" fi echo $nj > $dir/num_jobs +# Create a temporary directory into which we can create the new segments +# file. if [ $stage -le 1 ]; then rm -r $segmented_data_dir || true utils/data/convert_data_dir_to_whole.sh $data_dir $segmented_data_dir || exit 1 rm $segmented_data_dir/text || true fi -steps/segmentation/internal/post_process_segments.sh \ - --stage $stage --cmd "$cmd" \ - --config $segmentation_config --frame-shift $frame_shift \ - $data_dir $dir $segmented_data_dir +if [ $stage -le 2 ]; then + steps/segmentation/internal/post_process_segments.sh \ + --stage $stage --cmd "$cmd" \ + --config $segmentation_config --frame-shift $frame_shift \ + $data_dir $dir $segmented_data_dir +fi mv $segmented_data_dir/segments $segmented_data_dir/sub_segments utils/data/subsegment_data_dir.sh $data_dir $segmented_data_dir/sub_segments $segmented_data_dir +utils/data/get_reco2num_frames.sh ${data_dir} +mv $segmented_data_dir/feats.scp $segmented_data_dir/feats.scp.tmp +cat $segmented_data_dir/segments | utils/apply_map.pl -f 2 $data_dir/reco2num_frames > $segmetned_data_dir/utt2max_frames +cat $segmented_data_dir/feats.scp.tmp | utils/data/fix_subsegmented_feats.pl $dsegmented_data_dir/utt2max_frames > $segmented_data_dir/feats.scp + utils/utt2spk_to_spk2utt.pl $segmented_data_dir/utt2spk > $segmented_data_dir/spk2utt || exit 1 utils/fix_data_dir.sh $segmented_data_dir From 73eb9431d868336f1fff761b8fc33fcf2f310b9b Mon Sep 17 00:00:00 2001 From: Vimal Manohar Date: Mon, 2 Jan 2017 18:35:36 -0500 Subject: [PATCH 126/213] asr_diarization: Bug fix in get_reco2num_frames.sh --- egs/wsj/s5/utils/data/get_reco2num_frames.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/wsj/s5/utils/data/get_reco2num_frames.sh b/egs/wsj/s5/utils/data/get_reco2num_frames.sh index 03ab7b40616..8df5afdb156 100755 --- a/egs/wsj/s5/utils/data/get_reco2num_frames.sh +++ b/egs/wsj/s5/utils/data/get_reco2num_frames.sh @@ -20,7 +20,7 @@ if [ -f $data/reco2num_frames ]; then exit 0; fi -utils/data/get_reco2dur.sh $data +utils/data/get_reco2dur.sh --cmd "$cmd" --nj $nj $data awk -v fs=$frame_shift -v fovlp=$frame_overlap \ '{print $1" "int( ($2 - fovlp) / fs)}' $data/reco2dur > $data/reco2num_frames From 017774350ff1b4b5af05c6139a7474dedff0f6f3 Mon Sep 17 00:00:00 2001 From: Vimal Manohar Date: Mon, 2 Jan 2017 18:35:56 -0500 Subject: [PATCH 127/213] asr_diarization: Relax some errors in normalize_data_range --- egs/wsj/s5/utils/data/normalize_data_range.pl | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/egs/wsj/s5/utils/data/normalize_data_range.pl b/egs/wsj/s5/utils/data/normalize_data_range.pl index f7936d98a31..a7a144fd82e 100755 --- a/egs/wsj/s5/utils/data/normalize_data_range.pl +++ b/egs/wsj/s5/utils/data/normalize_data_range.pl @@ -45,14 +45,13 @@ sub combine_ranges { # though they are supported at the C++ level. if ($start1 eq "" || $start2 eq "" || $end1 eq "" || $end2 == "") { chop $line; - print("normalize_data_range.pl: could not make sense of line $line\n"); + print STDERR ("normalize_data_range.pl: could not make sense of line $line\n"); exit(1) } if ($start1 + $end2 > $end1) { chop $line; - print("normalize_data_range.pl: could not make sense of line $line " . + print STDERR ("normalize_data_range.pl: could not make sense of line $line " . "[second $row_or_column range too large vs first range, $start1 + $end2 > $end1]\n"); - exit(1); } return ($start2+$start1, $end2+$start1); } @@ -72,7 +71,7 @@ sub combine_ranges { # sometimes in scp files, we use the command concat-feats to splice together # two feature matrices. Handling this correctly is complicated and we don't # anticipate needing it, so we just refuse to process this type of data. - print "normalize_data_range.pl: this script cannot [yet] normalize the data ranges " . + print STDERR "normalize_data_range.pl: this script cannot [yet] normalize the data ranges " . "if concat-feats was in the input data\n"; exit(1); } From a638ccad32b7baaedddb5c984e61ac938db0d136 Mon Sep 17 00:00:00 2001 From: Vimal Manohar Date: Mon, 2 Jan 2017 18:37:34 -0500 Subject: [PATCH 128/213] asr_diarization: more tuning scripts for music detection --- .../tuning/train_stats_sad_music_1a.sh | 0 .../tuning/train_stats_sad_music_1b.sh | 0 .../tuning/train_stats_sad_music_1c.sh | 0 .../tuning/train_stats_sad_music_1d.sh | 0 .../tuning/train_stats_sad_music_1e.sh | 0 .../tuning/train_stats_sad_music_1f.sh | 227 +++++++++++++++++ .../tuning/train_stats_sad_music_1g.sh | 234 ++++++++++++++++++ 7 files changed, 461 insertions(+) mode change 100644 => 100755 egs/aspire/s5/local/segmentation/tuning/train_stats_sad_music_1a.sh mode change 100644 => 100755 egs/aspire/s5/local/segmentation/tuning/train_stats_sad_music_1b.sh mode change 100644 => 100755 egs/aspire/s5/local/segmentation/tuning/train_stats_sad_music_1c.sh mode change 100644 => 100755 egs/aspire/s5/local/segmentation/tuning/train_stats_sad_music_1d.sh mode change 100644 => 100755 egs/aspire/s5/local/segmentation/tuning/train_stats_sad_music_1e.sh create mode 100755 egs/aspire/s5/local/segmentation/tuning/train_stats_sad_music_1f.sh create mode 100755 egs/aspire/s5/local/segmentation/tuning/train_stats_sad_music_1g.sh diff --git a/egs/aspire/s5/local/segmentation/tuning/train_stats_sad_music_1a.sh b/egs/aspire/s5/local/segmentation/tuning/train_stats_sad_music_1a.sh old mode 100644 new mode 100755 diff --git a/egs/aspire/s5/local/segmentation/tuning/train_stats_sad_music_1b.sh b/egs/aspire/s5/local/segmentation/tuning/train_stats_sad_music_1b.sh old mode 100644 new mode 100755 diff --git a/egs/aspire/s5/local/segmentation/tuning/train_stats_sad_music_1c.sh b/egs/aspire/s5/local/segmentation/tuning/train_stats_sad_music_1c.sh old mode 100644 new mode 100755 diff --git a/egs/aspire/s5/local/segmentation/tuning/train_stats_sad_music_1d.sh b/egs/aspire/s5/local/segmentation/tuning/train_stats_sad_music_1d.sh old mode 100644 new mode 100755 diff --git a/egs/aspire/s5/local/segmentation/tuning/train_stats_sad_music_1e.sh b/egs/aspire/s5/local/segmentation/tuning/train_stats_sad_music_1e.sh old mode 100644 new mode 100755 diff --git a/egs/aspire/s5/local/segmentation/tuning/train_stats_sad_music_1f.sh b/egs/aspire/s5/local/segmentation/tuning/train_stats_sad_music_1f.sh new file mode 100755 index 00000000000..0afdd0072ac --- /dev/null +++ b/egs/aspire/s5/local/segmentation/tuning/train_stats_sad_music_1f.sh @@ -0,0 +1,227 @@ +#!/bin/bash + +# This is a script to train a time-delay neural network for speech activity detection (SAD) and +# music-id using statistic pooling component for long-context information. +# This script is same as 1e, but removes the stats component. + +set -o pipefail +set -e +set -u + +. cmd.sh + +# At this script level we don't support not running on GPU, as it would be painfully slow. +# If you want to run without GPU you'd have to call train_tdnn.sh with --gpu false, +# --num-threads 16 and --minibatch-size 128. + +stage=0 +train_stage=-10 +get_egs_stage=-10 +egs_opts= # Directly passed to get_egs_multiple_targets.py + +# TDNN options +relu_dim=256 +chunk_width=20 # We use chunk training for training TDNN +num_chunk_per_minibatch=64 + +extra_left_context=79 # Maximum left context in egs apart from TDNN's left context +extra_right_context=11 # Maximum right context in egs apart from TDNN's right context + +# We randomly select an extra {left,right} context for each job between +# min_extra_*_context and extra_*_context so that the network can get used +# to different contexts used to compute statistics. +min_extra_left_context=20 +min_extra_right_context=0 + +# training options +num_epochs=2 +initial_effective_lrate=0.0003 +final_effective_lrate=0.00003 +num_jobs_initial=3 +num_jobs_final=8 +remove_egs=false +max_param_change=0.2 # Small max-param change for small network +extra_egs_copy_cmd= # Used if you want to do some weird stuff to egs + # such as removing one of the targets + +num_utts_subset_valid=50 # "utts" is actually recording. So this is prettly small. +num_utts_subset_train=50 + +# target options +train_data_dir=data/train_aztec_small_unsad_whole_all_corrupted_sp_hires_bp + +speech_feat_scp=data/train_aztec_unsad_whole_corrupted_sp_hires_bp_2400/speech_feat.scp +deriv_weights_scp=data/train_aztec_unsad_whole_corrupted_sp_hires_bp_2400/deriv_weights.scp +music_labels_scp=data/train_aztec_small_unsad_whole_music_corrupted_sp_hires_bp/music_labels.scp + +extra_egs_copy_cmd="nnet3-copy-egs --keep-outputs=output-speech,output-music ark:- ark:- |" + +egs_dir= +nj=40 +feat_type=raw +config_dir= + +dir= +affix=a + +. cmd.sh +. ./path.sh +. ./utils/parse_options.sh + +if [ -z "$dir" ]; then + dir=exp/nnet3_stats_sad_music/nnet_tdnn +fi + +dir=$dir${affix:+_$affix} + +if ! cuda-compiled; then + cat < $dir/configs/network.xconfig + input dim=`feat-to-dim scp:$train_data_dir/feats.scp -` name=input + output name=output-temp input=Append(-3,-2,-1,0,1,2,3) + + relu-renorm-layer name=tdnn1 input=Append(input@-3, input@-2, input@-1, input, input@1, input@2, input@3) dim=256 add-log-stddev=true + relu-renorm-layer name=tdnn2 input=Append(tdnn1@-9, tdnn1@-3, tdnn1, tdnn1@3) dim=256 add-log-stddev=true + relu-renorm-layer name=tdnn3 input=Append(tdnn2@-27, tdnn2@-9, tdnn2, tdnn2@9) dim=256 add-log-stddev=true + + output-layer name=output-speech include-log-softmax=true dim=2 input=tdnn3 + output-layer name=output-music include-log-softmax=true dim=2 input=tdnn3 +EOF + steps/nnet3/xconfig_to_configs.py --xconfig-file $dir/configs/network.xconfig \ + --config-dir $dir/configs/ \ + --nnet-edits="rename-node old-name=output-speech new-name=output" + + cat <> $dir/configs/vars +add_lda=false +EOF +fi + +samples_per_iter=`perl -e "print int(400000 / $chunk_width)"` +speech_data_dir=$dir/`basename $train_data_dir`_speech +music_data_dir=$dir/`basename $train_data_dir`_music + +if [ -z "$egs_dir" ]; then + egs_dir=$dir/egs_multi + if [ $stage -le 2 ]; then + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $dir/egs_speech/storage ]; then + utils/create_split_dir.pl \ + /export/b{03,04,05,06}/$USER/kaldi-data/egs/aspire-$(date +'%m_%d_%H_%M')/s5/$dir/egs_speech/storage $dir/egs_speech/storage + fi + + + . $dir/configs/vars + + utils/subset_data_dir.sh --utt-list $speech_feat_scp ${train_data_dir} $dir/`basename ${train_data_dir}`_speech + + steps/nnet3/get_egs_multiple_targets.py --cmd="$decode_cmd" \ + $egs_opts \ + --feat.dir="$speech_data_dir" \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --frames-per-eg=$chunk_width \ + --left-context=$[model_left_context + extra_left_context] \ + --right-context=$[model_right_context + extra_right_context] \ + --num-utts-subset-train=$num_utts_subset_train \ + --num-utts-subset-valid=$num_utts_subset_valid \ + --samples-per-iter=$samples_per_iter \ + --stage=$get_egs_stage \ + --targets-parameters="--output-name=output-speech --target-type=sparse --dim=2 --targets-scp=$speech_feat_scp --deriv-weights-scp=$deriv_weights_scp --scp2ark-cmd=\"extract-column --column-index=0 scp:- ark,t:- | steps/segmentation/quantize_vector.pl | ali-to-post ark,t:- ark:- |\" --compress=true" \ + --generate-egs-scp=true \ + --dir=$dir/egs_speech + fi + + if [ $stage -le 3 ]; then + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $dir/egs_music/storage ]; then + utils/create_split_dir.pl \ + /export/b{03,04,05,06}/$USER/kaldi-data/egs/aspire-$(date +'%m_%d_%H_%M')/s5/$dir/egs_music/storage $dir/egs_music/storage + fi + + . $dir/configs/vars + utils/subset_data_dir.sh --utt-list $music_labels_scp ${train_data_dir} $dir/`basename ${train_data_dir}`_music + + steps/nnet3/get_egs_multiple_targets.py --cmd="$decode_cmd" \ + $egs_opts \ + --feat.dir="$music_data_dir" \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --frames-per-eg=$chunk_width \ + --left-context=$[model_left_context + extra_left_context] \ + --right-context=$[model_right_context + extra_right_context] \ + --num-utts-subset-train=$num_utts_subset_train \ + --num-utts-subset-valid=$num_utts_subset_valid \ + --samples-per-iter=$samples_per_iter \ + --stage=$get_egs_stage \ + --targets-parameters="--output-name=output-music --target-type=sparse --dim=2 --targets-scp=$music_labels_scp --scp2ark-cmd=\"ali-to-post scp:- ark:- |\" --compress=true" \ + --generate-egs-scp=true \ + --dir=$dir/egs_music + fi + + if [ $stage -le 4 ]; then + steps/nnet3/multilingual/get_egs.sh \ + --minibatch-size $[chunk_width * num_chunk_per_minibatch] \ + --samples-per-iter $samples_per_iter \ + 2 $dir/egs_speech $dir/egs_music $dir/egs_multi + fi +fi + +if [ $stage -le 5 ]; then + steps/nnet3/train_raw_rnn.py --stage=$train_stage \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --egs.chunk-width=$chunk_width \ + --egs.dir="$egs_dir" --egs.stage=$get_egs_stage \ + --egs.chunk-left-context=$extra_left_context \ + --egs.chunk-right-context=$extra_right_context \ + --egs.use-multitask-egs=true --egs.rename-multitask-outputs=false \ + ${extra_egs_copy_cmd:+--egs.extra-copy-cmd="$extra_egs_copy_cmd"} \ + --trainer.min-chunk-left-context=$min_extra_left_context \ + --trainer.min-chunk-right-context=$min_extra_right_context \ + --trainer.num-epochs=$num_epochs \ + --trainer.samples-per-iter=20000 \ + --trainer.optimization.num-jobs-initial=$num_jobs_initial \ + --trainer.optimization.num-jobs-final=$num_jobs_final \ + --trainer.optimization.initial-effective-lrate=$initial_effective_lrate \ + --trainer.optimization.final-effective-lrate=$final_effective_lrate \ + --trainer.optimization.shrink-value=1.0 \ + --trainer.rnn.num-chunk-per-minibatch=$num_chunk_per_minibatch \ + --trainer.deriv-truncate-margin=8 \ + --trainer.max-param-change=$max_param_change \ + --cmd="$decode_cmd" --nj 40 \ + --cleanup=true \ + --cleanup.remove-egs=$remove_egs \ + --cleanup.preserve-model-interval=10 \ + --use-gpu=true \ + --use-dense-targets=false \ + --feat-dir=$train_data_dir \ + --targets-scp="$speech_feat_scp" \ + --dir=$dir || exit 1 +fi + +if [ $stage -le 6 ]; then + $train_cmd JOB=1:100 $dir/log/compute_post_output-speech.JOB.log \ + extract-column "scp:utils/split_scp.pl -j 100 \$[JOB-1] $speech_feat_scp |" ark,t:- \| \ + steps/segmentation/quantize_vector.pl \| \ + ali-to-post ark,t:- ark:- \| \ + weight-post ark:- scp:$deriv_weights_scp ark:- \| \ + post-to-feats --post-dim=2 ark:- ark:- \| \ + matrix-sum-rows ark:- ark:- \| \ + vector-sum ark:- $dir/post_output-speech.vec.JOB + eval vector-sum $dir/post_output-speech.vec.{`seq -s, 100`} $dir/post_output-speech.vec + + $train_cmd JOB=1:100 $dir/log/compute_post_output-music.JOB.log \ + ali-to-post "scp:utils/split_scp.pl -j 100 \$[JOB-1] $music_labels_scp |" ark:- \| \ + post-to-feats --post-dim=2 ark:- ark:- \| \ + matrix-sum-rows ark:- ark:- \| \ + vector-sum ark:- $dir/post_output-music.vec.JOB + eval vector-sum $dir/post_output-music.vec.{`seq -s, 100`} $dir/post_output-music.vec +fi + diff --git a/egs/aspire/s5/local/segmentation/tuning/train_stats_sad_music_1g.sh b/egs/aspire/s5/local/segmentation/tuning/train_stats_sad_music_1g.sh new file mode 100755 index 00000000000..e411b94c893 --- /dev/null +++ b/egs/aspire/s5/local/segmentation/tuning/train_stats_sad_music_1g.sh @@ -0,0 +1,234 @@ +#!/bin/bash + +# This is a script to train a time-delay neural network for speech activity detection (SAD) and +# music-id using statistic pooling component for long-context information. +# This script is same as 1e, but removes the stats component in the 3rd layer. + +set -o pipefail +set -e +set -u + +. cmd.sh + +# At this script level we don't support not running on GPU, as it would be painfully slow. +# If you want to run without GPU you'd have to call train_tdnn.sh with --gpu false, +# --num-threads 16 and --minibatch-size 128. + +stage=0 +train_stage=-10 +get_egs_stage=-10 +egs_opts= # Directly passed to get_egs_multiple_targets.py + +# TDNN options +relu_dim=256 +chunk_width=20 # We use chunk training for training TDNN +num_chunk_per_minibatch=64 + +extra_left_context=79 # Maximum left context in egs apart from TDNN's left context +extra_right_context=11 # Maximum right context in egs apart from TDNN's right context + +# We randomly select an extra {left,right} context for each job between +# min_extra_*_context and extra_*_context so that the network can get used +# to different contexts used to compute statistics. +min_extra_left_context=20 +min_extra_right_context=0 + +# training options +num_epochs=2 +initial_effective_lrate=0.0003 +final_effective_lrate=0.00003 +num_jobs_initial=3 +num_jobs_final=8 +remove_egs=false +max_param_change=0.2 # Small max-param change for small network +extra_egs_copy_cmd= # Used if you want to do some weird stuff to egs + # such as removing one of the targets + +sad_data_dir=data/train_aztec_unsad_whole_corrupted_sp_hires_bp_2400 +music_data_dir=data/train_aztec_unsad_whole_music_corrupted_sp_hires_bp + +extra_egs_copy_cmd="nnet3-copy-egs --keep-outputs=output-speech,output-music ark:- ark:- |" + +egs_dir= +nj=40 +feat_type=raw +config_dir= + +dir= +affix=a + +. cmd.sh +. ./path.sh +. ./utils/parse_options.sh + +num_utts=`cat $sad_data_dir/utt2spk $music_data_dir/utt2spk | wc -l` +num_utts_subset_valid=`perl -e '$n=int($ARGV[0] * 0.005); print ($n > 4000 ? 4000 : $n)' $num_utts` +num_utts_subset_train=`perl -e '$n=int($ARGV[0] * 0.005); print ($n > 4000 ? 4000 : $n)' $num_utts` + +if [ -z "$dir" ]; then + dir=exp/nnet3_stats_sad_music/nnet_tdnn +fi + +dir=$dir${affix:+_$affix} + +if ! cuda-compiled; then + cat < $dir/configs/network.xconfig + input dim=`feat-to-dim scp:$sad_data_dir/feats.scp -` name=input + output name=output-temp input=Append(-3,-2,-1,0,1,2,3) + + relu-renorm-layer name=tdnn1 input=Append(input@-3, input@-2, input@-1, input, input@1, input@2, input@3) dim=256 add-log-stddev=true + stats-layer name=tdnn2_stats config=mean+count(-99:3:9:99) + relu-renorm-layer name=tdnn2 input=Append(tdnn1@-9, tdnn1@-3, tdnn1, tdnn1@3, tdnn2_stats) dim=256 add-log-stddev=true + relu-renorm-layer name=tdnn3 input=Append(tdnn2@-27, tdnn2@-9, tdnn2, tdnn2@9) dim=256 add-log-stddev=true + + output-layer name=output-snr include-log-softmax=false dim=$num_snr_bins objective-type=quadratic objective-scale=`perl -e "print (($num_frames_music / $num_frames_sad) ** 0.25) / $num_snr_bins"` input=tdnn3 + output-layer name=output-speech include-log-softmax=true dim=2 input=tdnn3 objective-scale=`perl -e "print (($num_frames_music / $num_frames_sad) ** 0.25)"` + output-layer name=output-music include-log-softmax=true dim=2 input=tdnn3 +EOF + steps/nnet3/xconfig_to_configs.py --xconfig-file $dir/configs/network.xconfig \ + --config-dir $dir/configs/ \ + --nnet-edits="rename-node old-name=output-speech new-name=output" + + cat <> $dir/configs/vars +add_lda=false +EOF +fi + +samples_per_iter=`perl -e "print int(400000 / $chunk_width)"` + +if [ -z "$egs_dir" ]; then + egs_dir=$dir/egs_multi + if [ $stage -le 2 ]; then + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $dir/egs_speech/storage ]; then + utils/create_split_dir.pl \ + /export/b{03,04,05,06}/$USER/kaldi-data/egs/aspire-$(date +'%m_%d_%H_%M')/s5/$dir/egs_speech/storage $dir/egs_speech/storage + fi + + . $dir/configs/vars + + steps/nnet3/get_egs_multiple_targets.py --cmd="$decode_cmd" \ + $egs_opts \ + --feat.dir="$sad_data_dir" \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --frames-per-eg=$chunk_width \ + --left-context=$[model_left_context + extra_left_context] \ + --right-context=$[model_right_context + extra_right_context] \ + --num-utts-subset-train=$num_utts_subset_train \ + --num-utts-subset-valid=$num_utts_subset_valid \ + --samples-per-iter=$samples_per_iter \ + --stage=$get_egs_stage \ + --targets-parameters="--output-name=output-snr --target-type=dense --targets-scp=$sad_data_dir/irm_targets.scp --deriv-weights-scp=$sad_data_dir/deriv_weights_manual_seg.scp" \ + --targets-parameters="--output-name=output-speech --target-type=sparse --dim=2 --targets-scp=$sad_data_dir/speech_feat.scp --deriv-weights-scp=$sad_data_dir/deriv_weights.scp --scp2ark-cmd=\"extract-column --column-index=0 scp:- ark,t:- | steps/segmentation/quantize_vector.pl | ali-to-post ark,t:- ark:- |\" --compress=true" \ + --generate-egs-scp=true \ + --dir=$dir/egs_speech + fi + + if [ $stage -le 3 ]; then + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $dir/egs_music/storage ]; then + utils/create_split_dir.pl \ + /export/b{03,04,05,06}/$USER/kaldi-data/egs/aspire-$(date +'%m_%d_%H_%M')/s5/$dir/egs_music/storage $dir/egs_music/storage + fi + + . $dir/configs/vars + + steps/nnet3/get_egs_multiple_targets.py --cmd="$decode_cmd" \ + $egs_opts \ + --feat.dir="$music_data_dir" \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --frames-per-eg=$chunk_width \ + --left-context=$[model_left_context + extra_left_context] \ + --right-context=$[model_right_context + extra_right_context] \ + --num-utts-subset-train=$num_utts_subset_train \ + --num-utts-subset-valid=$num_utts_subset_valid \ + --samples-per-iter=$samples_per_iter \ + --stage=$get_egs_stage \ + --targets-parameters="--output-name=output-music --target-type=sparse --dim=2 --targets-scp=$music_labels_scp --scp2ark-cmd=\"ali-to-post scp:- ark:- |\" --compress=true" \ + --generate-egs-scp=true \ + --dir=$dir/egs_music + fi + + if [ $stage -le 4 ]; then + # num_chunk_per_minibatch is multiplied by 4 to allow a buffer to use + # the same egs with a different num_chunk_per_minibatch + steps/nnet3/multilingual/get_egs.sh \ + --cmd "$train_cmd" \ + --minibatch-size $[num_chunk_per_minibatch * 4] \ + --samples-per-iter $samples_per_iter \ + 2 $dir/egs_speech $dir/egs_music $dir/egs_multi + fi +fi + +if [ $stage -le 5 ]; then + steps/nnet3/train_raw_rnn.py --stage=$train_stage \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --egs.chunk-width=$chunk_width \ + --egs.dir="$egs_dir" --egs.stage=$get_egs_stage \ + --egs.chunk-left-context=$extra_left_context \ + --egs.chunk-right-context=$extra_right_context \ + --egs.use-multitask-egs=true --egs.rename-multitask-outputs=false \ + ${extra_egs_copy_cmd:+--egs.extra-copy-cmd="$extra_egs_copy_cmd"} \ + --trainer.min-chunk-left-context=$min_extra_left_context \ + --trainer.min-chunk-right-context=$min_extra_right_context \ + --trainer.num-epochs=$num_epochs \ + --trainer.samples-per-iter=20000 \ + --trainer.optimization.num-jobs-initial=$num_jobs_initial \ + --trainer.optimization.num-jobs-final=$num_jobs_final \ + --trainer.optimization.initial-effective-lrate=$initial_effective_lrate \ + --trainer.optimization.final-effective-lrate=$final_effective_lrate \ + --trainer.optimization.shrink-value=1.0 \ + --trainer.rnn.num-chunk-per-minibatch=$num_chunk_per_minibatch \ + --trainer.deriv-truncate-margin=8 \ + --trainer.max-param-change=$max_param_change \ + --cmd="$decode_cmd" --nj 40 \ + --cleanup=true \ + --cleanup.remove-egs=$remove_egs \ + --cleanup.preserve-model-interval=10 \ + --use-gpu=true \ + --use-dense-targets=false \ + --feat-dir=$sad_data_dir \ + --targets-scp="$sad_data_dir/speech_feat.scp" \ + --dir=$dir || exit 1 +fi + From 47bf4fd5aa4ee972a7195f904cfca6eb9fed7141 Mon Sep 17 00:00:00 2001 From: Vimal Manohar Date: Mon, 2 Jan 2017 18:38:01 -0500 Subject: [PATCH 129/213] asr_diarization: Add more tuning scripts for sad overlap --- .../tuning/train_lstm_overlapping_sad_1b.sh | 262 +++++++++++++++++ .../tuning/train_lstm_stats_overlap_1a.sh | 202 +++++++++++++ .../tuning/train_lstm_stats_sad_overlap_1a.sh | 259 ++++++++++++++++ .../train_lstm_stats_sad_overlap_ami_1a.sh | 192 ++++++++++++ .../train_lstm_stats_sad_overlap_ami_1b.sh | 192 ++++++++++++ .../tuning/train_rnn_overlap_1a.sh | 184 ++++++++++++ .../tuning/train_rnn_overlap_1b.sh | 184 ++++++++++++ .../tuning/train_stats_overlap_1f.sh | 200 +++++++++++++ .../tuning/train_stats_overlap_1g.sh | 202 +++++++++++++ .../tuning/train_stats_overlap_1h.sh | 202 +++++++++++++ .../tuning/train_stats_overlap_1i.sh | 202 +++++++++++++ .../tuning/train_stats_sad_overlap_1a.sh | 0 .../tuning/train_stats_sad_overlap_1b.sh | 53 ++-- .../tuning/train_stats_sad_overlap_1c.sh | 239 +++++++++++++++ .../tuning/train_stats_sad_overlap_1d.sh | 262 +++++++++++++++++ .../tuning/train_stats_sad_overlap_1f.sh | 272 +++++++++++++++++ .../tuning/train_stats_sad_overlap_1g.sh | 275 +++++++++++++++++ .../tuning/train_stats_sad_overlap_1h.sh | 276 ++++++++++++++++++ 18 files changed, 3632 insertions(+), 26 deletions(-) create mode 100755 egs/aspire/s5/local/segmentation/tuning/train_lstm_overlapping_sad_1b.sh create mode 100755 egs/aspire/s5/local/segmentation/tuning/train_lstm_stats_overlap_1a.sh create mode 100755 egs/aspire/s5/local/segmentation/tuning/train_lstm_stats_sad_overlap_1a.sh create mode 100644 egs/aspire/s5/local/segmentation/tuning/train_lstm_stats_sad_overlap_ami_1a.sh create mode 100644 egs/aspire/s5/local/segmentation/tuning/train_lstm_stats_sad_overlap_ami_1b.sh create mode 100755 egs/aspire/s5/local/segmentation/tuning/train_rnn_overlap_1a.sh create mode 100755 egs/aspire/s5/local/segmentation/tuning/train_rnn_overlap_1b.sh create mode 100755 egs/aspire/s5/local/segmentation/tuning/train_stats_overlap_1f.sh create mode 100755 egs/aspire/s5/local/segmentation/tuning/train_stats_overlap_1g.sh create mode 100755 egs/aspire/s5/local/segmentation/tuning/train_stats_overlap_1h.sh create mode 100755 egs/aspire/s5/local/segmentation/tuning/train_stats_overlap_1i.sh mode change 100644 => 100755 egs/aspire/s5/local/segmentation/tuning/train_stats_sad_overlap_1a.sh mode change 100644 => 100755 egs/aspire/s5/local/segmentation/tuning/train_stats_sad_overlap_1b.sh create mode 100755 egs/aspire/s5/local/segmentation/tuning/train_stats_sad_overlap_1c.sh create mode 100755 egs/aspire/s5/local/segmentation/tuning/train_stats_sad_overlap_1d.sh create mode 100755 egs/aspire/s5/local/segmentation/tuning/train_stats_sad_overlap_1f.sh create mode 100755 egs/aspire/s5/local/segmentation/tuning/train_stats_sad_overlap_1g.sh create mode 100755 egs/aspire/s5/local/segmentation/tuning/train_stats_sad_overlap_1h.sh diff --git a/egs/aspire/s5/local/segmentation/tuning/train_lstm_overlapping_sad_1b.sh b/egs/aspire/s5/local/segmentation/tuning/train_lstm_overlapping_sad_1b.sh new file mode 100755 index 00000000000..a634060b317 --- /dev/null +++ b/egs/aspire/s5/local/segmentation/tuning/train_lstm_overlapping_sad_1b.sh @@ -0,0 +1,262 @@ +#!/bin/bash + +# This is a script to train a LSTM for overlapped speech activity detection +# and SAD. This uses a larger LSTM-TDNN architecture and trains on +# ternary overlapping SAD labels. + +set -o pipefail +set -e +set -u + +. cmd.sh + +# At this script level we don't support not running on GPU, as it would be painfully slow. +# If you want to run without GPU you'd have to call train_tdnn.sh with --gpu false, +# --num-threads 16 and --minibatch-size 128. + +stage=0 +train_stage=-10 +get_egs_stage=-10 +egs_opts= # Directly passed to get_egs_multiple_targets.py + +chunk_width=20 +num_chunk_per_minibatch=64 + +extra_left_context=80 # Maximum left context in egs apart from TDNN's left context +extra_right_context=40 # Maximum right context in egs apart from TDNN's right context + +# training options +num_epochs=2 +initial_effective_lrate=0.0003 +final_effective_lrate=0.00003 +num_jobs_initial=3 +num_jobs_final=8 +remove_egs=false +max_param_change=0.2 # Small max-param change for small network +extra_egs_copy_cmd= # Used if you want to do some weird stuff to egs + # such as removing one of the targets + +sad_data_dir=data/train_aztec_unsad_whole_corrupted_sp_hires_bp_2400 +ovlp_sad_data_dir=data/train_aztec_unsad_seg_ovlp_corrupted_hires_bp + +egs_dir= +nj=40 +feat_type=raw +config_dir= + +dir= +affix=a + +. cmd.sh +. ./path.sh +. ./utils/parse_options.sh + +num_utts=`cat $sad_data_dir/utt2spk $ovlp_sad_data_dir/utt2spk | wc -l` +num_utts_subset_valid=`perl -e '$n=int($ARGV[0] * 0.005); print ($n > 4000 ? 4000 : $n)' $num_utts` +num_utts_subset_train=`perl -e '$n=int($ARGV[0] * 0.005); print ($n > 4000 ? 4000 : $n)' $num_utts` + +if [ -z "$dir" ]; then + dir=exp/nnet3_stats_sad_ovlp_snr/nnet_lstm +fi + +dir=$dir${affix:+_$affix} + +if ! cuda-compiled; then + cat < $dir/configs/network.xconfig + input dim=`feat-to-dim scp:$sad_data_dir/feats.scp -` name=input + output name=output-temp input=Append(-2,-1,0,1,2) + + relu-renorm-layer name=tdnn1 input=Append(input@-2, input@-1, input, input@1, input@2) dim=256 + relu-renorm-layer name=tdnn2 input=Append(-1,0,1,2) dim=256 + relu-renorm-layer name=tdnn3 input=Append(-3,0,3,6) dim=256 + fast-lstmp-layer name=lstm1 cell-dim=256 recurrent-projection-dim=128 non-recurrent-projection-dim=128 delay=-3 + relu-renorm-layer name=tdnn4 input=Append(-6,0,6,12) dim=256 + fast-lstmp-layer name=lstm2 cell-dim=256 recurrent-projection-dim=128 non-recurrent-projection-dim=128 delay=-6 + + output-layer name=output-speech include-log-softmax=true dim=2 objective-scale=$speech_scale input=lstm2 presoftmax-scale-file=$dir/presoftmax_prior_scale_output-speech.txt learning-rate-factor=0.05 + + output-layer name=output-snr include-log-softmax=false dim=$num_snr_bins objective-type=quadratic objective-scale=`perl -e "print $speech_scale / $num_snr_bins"` input=lstm2 max-change=0.75 learning-rate-factor=0.5 + + output-layer name=output-overlapping_sad include-log-softmax=true dim=3 objective-scale=$ovlp_scale input=lstm2 presoftmax-scale-file=$dir/presoftmax_prior_scale_output-overlapping_sad.txt max-change=0.75 learning-rate-factor=0.02 +EOF + steps/nnet3/xconfig_to_configs.py --xconfig-file $dir/configs/network.xconfig \ + --config-dir $dir/configs/ \ + --nnet-edits="rename-node old-name=output-speech new-name=output" + + cat <> $dir/configs/vars +add_lda=false +EOF +fi + +samples_per_iter=`perl -e "print int(400000 / $chunk_width)"` + +if [ -z "$egs_dir" ]; then + egs_dir=$dir/egs_multi + if [ $stage -le 2 ]; then + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $dir/egs_speech/storage ]; then + utils/create_split_dir.pl \ + /export/b{01,02,05,06}/$USER/kaldi-data/egs/aspire-$(date +'%m_%d_%H_%M')/s5/$dir/egs_speech/storage $dir/egs_speech/storage + fi + + . $dir/configs/vars + + steps/nnet3/get_egs_multiple_targets.py --cmd="$decode_cmd" \ + $egs_opts \ + --feat.dir="$sad_data_dir" \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --frames-per-eg=$chunk_width \ + --left-context=$[model_left_context + extra_left_context] \ + --right-context=$[model_right_context + extra_right_context] \ + --num-utts-subset-train=$num_utts_subset_train \ + --num-utts-subset-valid=$num_utts_subset_valid \ + --samples-per-iter=$samples_per_iter \ + --stage=$get_egs_stage \ + --targets-parameters="--output-name=output-snr --target-type=dense --targets-scp=$sad_data_dir/irm_targets.scp --deriv-weights-scp=$sad_data_dir/deriv_weights_manual_seg.scp" \ + --targets-parameters="--output-name=output-speech --target-type=sparse --dim=2 --targets-scp=$sad_data_dir/speech_feat.scp --deriv-weights-scp=$sad_data_dir/deriv_weights.scp --scp2ark-cmd=\"extract-column --column-index=0 scp:- ark,t:- | steps/segmentation/quantize_vector.pl | ali-to-post ark,t:- ark:- |\" --compress=true" \ + --generate-egs-scp=true \ + --dir=$dir/egs_speech + fi + + if [ $stage -le 3 ]; then + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $dir/egs_ovlp/storage ]; then + utils/create_split_dir.pl \ + /export/b{01,02,05,06}/$USER/kaldi-data/egs/aspire-$(date +'%m_%d_%H_%M')/s5/$dir/egs_ovlp/storage $dir/egs_ovlp/storage + fi + + . $dir/configs/vars + + steps/nnet3/get_egs_multiple_targets.py --cmd="$decode_cmd" \ + $egs_opts \ + --feat.dir="$ovlp_sad_data_dir" \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --frames-per-eg=$chunk_width \ + --left-context=$[model_left_context + extra_left_context] \ + --right-context=$[model_right_context + extra_right_context] \ + --num-utts-subset-train=$num_utts_subset_train \ + --num-utts-subset-valid=$num_utts_subset_valid \ + --samples-per-iter=$samples_per_iter \ + --stage=$get_egs_stage \ + --targets-parameters="--output-name=output-speech --target-type=sparse --dim=2 --targets-scp=$ovlp_sad_data_dir/speech_feat.scp --deriv-weights-scp=$ovlp_sad_data_dir/deriv_weights.scp --scp2ark-cmd=\"extract-column --column-index=0 scp:- ark,t:- | steps/segmentation/quantize_vector.pl | ali-to-post ark,t:- ark:- |\"" \ + --targets-parameters="--output-name=output-overlapping_sad --target-type=sparse --dim=3 --targets-scp=$ovlp_sad_data_dir/overlapping_sad_labels_fixed.scp --deriv-weights-scp=$ovlp_sad_data_dir/deriv_weights.scp --scp2ark-cmd=\"ali-to-post scp:- ark:- |\"" \ + --generate-egs-scp=true \ + --dir=$dir/egs_ovlp + fi + + if [ $stage -le 4 ]; then + # num_chunk_per_minibatch is multiplied by 4 to allow a buffer to use + # the same egs with a different num_chunk_per_minibatch + steps/nnet3/multilingual/get_egs.sh \ + --cmd "$train_cmd" \ + --minibatch-size $[num_chunk_per_minibatch * 4] \ + --samples-per-iter $samples_per_iter \ + 2 $dir/egs_speech $dir/egs_ovlp $dir/egs_multi + fi +fi + +if [ $stage -le 5 ]; then + steps/nnet3/train_raw_rnn.py --stage=$train_stage \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --egs.chunk-width=$chunk_width \ + --egs.dir="$egs_dir" --egs.stage=$get_egs_stage \ + --egs.chunk-left-context=$extra_left_context \ + --egs.chunk-right-context=$extra_right_context \ + --egs.use-multitask-egs=true --egs.rename-multitask-outputs=false \ + ${extra_egs_copy_cmd:+--egs.extra-copy-cmd="$extra_egs_copy_cmd"} \ + --trainer.num-epochs=$num_epochs \ + --trainer.samples-per-iter=20000 \ + --trainer.optimization.num-jobs-initial=$num_jobs_initial \ + --trainer.optimization.num-jobs-final=$num_jobs_final \ + --trainer.optimization.initial-effective-lrate=$initial_effective_lrate \ + --trainer.optimization.final-effective-lrate=$final_effective_lrate \ + --trainer.optimization.shrink-value=1.0 \ + --trainer.rnn.num-chunk-per-minibatch=$num_chunk_per_minibatch \ + --trainer.deriv-truncate-margin=8 \ + --trainer.max-param-change=$max_param_change \ + --trainer.compute-per-dim-accuracy=true \ + --cmd="$decode_cmd" --nj 40 \ + --cleanup=true \ + --cleanup.remove-egs=$remove_egs \ + --cleanup.preserve-model-interval=10 \ + --use-gpu=true \ + --use-dense-targets=false \ + --feat-dir=$sad_data_dir \ + --targets-scp="$sad_data_dir/speech_feat.scp" \ + --dir=$dir || exit 1 +fi diff --git a/egs/aspire/s5/local/segmentation/tuning/train_lstm_stats_overlap_1a.sh b/egs/aspire/s5/local/segmentation/tuning/train_lstm_stats_overlap_1a.sh new file mode 100755 index 00000000000..adc4fc81c08 --- /dev/null +++ b/egs/aspire/s5/local/segmentation/tuning/train_lstm_stats_overlap_1a.sh @@ -0,0 +1,202 @@ +#!/bin/bash + +# This is a script to train a time-delay neural network for overlapped speech activity detection +# using statistic pooling component for long-context information. + +# This scripts is similar to 1f but adds max-change=0.75 and learning-rate-factor=0.02 to the final affine. +# And changed relu-dim to 512. + +set -o pipefail +set -e +set -u + +. cmd.sh + +# At this script level we don't support not running on GPU, as it would be painfully slow. +# If you want to run without GPU you'd have to call train_tdnn.sh with --gpu false, +# --num-threads 16 and --minibatch-size 128. + +stage=0 +train_stage=-10 +get_egs_stage=-10 +egs_opts= # Directly passed to get_egs_multiple_targets.py + +# TDNN options +chunk_width=40 # We use chunk training for training TDNN +num_chunk_per_minibatch=64 + +extra_left_context=100 # Maximum left context in egs apart from TDNN's left context +extra_right_context=20 # Maximum right context in egs apart from TDNN's right context + +# We randomly select an extra {left,right} context for each job between +# min_extra_*_context and extra_*_context so that the network can get used +# to different contexts used to compute statistics. +min_extra_left_context=20 +min_extra_right_context=0 + +# training options +num_epochs=2 +initial_effective_lrate=0.0003 +final_effective_lrate=0.00003 +num_jobs_initial=3 +num_jobs_final=8 +remove_egs=false +max_param_change=0.2 # Small max-param change for small network +extra_egs_copy_cmd="nnet3-copy-egs --keep-outputs=output-overlapped_speech ark:- ark:- |" # Used if you want to do some weird stuff to egs + # such as removing one of the targets + +ovlp_data_dir=data/train_aztec_unsad_seg_ovlp_corrupted_hires_bp + +#extra_left_context=79 +#extra_right_context=11 + +egs_dir= +nj=40 +feat_type=raw +config_dir= + +dir= +affix=f + +. cmd.sh +. ./path.sh +. ./utils/parse_options.sh + +num_utts=`cat $ovlp_data_dir/utt2spk | wc -l` +num_utts_subset_valid=`perl -e '$n=int($ARGV[0] * 0.005); print ($n > 4000 ? 4000 : $n)' $num_utts` +num_utts_subset_train=`perl -e '$n=int($ARGV[0] * 0.005); print ($n > 4000 ? 4000 : $n)' $num_utts` + +if [ -z "$dir" ]; then + dir=exp/nnet3_stats_ovlp/nnet_tdnn +fi + +dir=$dir${affix:+_$affix} + +if ! cuda-compiled; then + cat < $dir/configs/network.xconfig + input dim=`feat-to-dim scp:$ovlp_data_dir/feats.scp -` name=input + output name=output-temp input=Append(-3,-2,-1,0,1,2,3) + + relu-renorm-layer name=tdnn1 input=Append(input@-3, input@-2, input@-1, input, input@1, input@2, input@3) dim=512 + stats-layer name=tdnn2_stats config=mean+count(-99:3:9:99) + relu-renorm-layer name=tdnn2 input=Append(tdnn1@-6, tdnn1, tdnn2_stats) dim=512 + relu-renorm-layer name=tdnn3 input=Append(-9,0,3) dim=512 + relu-renorm-layer name=tdnn4 dim=512 + + output-layer name=output-overlapped_speech include-log-softmax=true dim=2 input=tdnn4 presoftmax-scale-file=$dir/presoftmax_prior_scale_output-overlapped_speech.txt max-change=0.75 learning-rate-factor=0.02 +EOF + steps/nnet3/xconfig_to_configs.py --xconfig-file $dir/configs/network.xconfig \ + --config-dir $dir/configs/ \ + --nnet-edits="rename-node old-name=output-overlapped_speech new-name=output" + + cat <> $dir/configs/vars +add_lda=false +EOF +fi + +samples_per_iter=`perl -e "print int(400000 / $chunk_width)"` + +if [ -z "$egs_dir" ]; then + egs_dir=$dir/egs_ovlp + if [ $stage -le 3 ]; then + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $dir/egs_ovlp/storage ]; then + utils/create_split_dir.pl \ + /export/b{03,04,05,06}/$USER/kaldi-data/egs/aspire-$(date +'%m_%d_%H_%M')/s5/$dir/egs_ovlp/storage $dir/egs_ovlp/storage + fi + + . $dir/configs/vars + + steps/nnet3/get_egs_multiple_targets.py --cmd="$decode_cmd" \ + $egs_opts \ + --feat.dir="$ovlp_data_dir" \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --frames-per-eg=$chunk_width \ + --left-context=$[model_left_context + extra_left_context] \ + --right-context=$[model_right_context + extra_right_context] \ + --num-utts-subset-train=$num_utts_subset_train \ + --num-utts-subset-valid=$num_utts_subset_valid \ + --samples-per-iter=$samples_per_iter \ + --stage=$get_egs_stage \ + --targets-parameters="--output-name=output-speech --target-type=sparse --dim=2 --targets-scp=$ovlp_data_dir/speech_feat.scp --deriv-weights-scp=$ovlp_data_dir/deriv_weights.scp --scp2ark-cmd=\"extract-column --column-index=0 scp:- ark,t:- | steps/segmentation/quantize_vector.pl | ali-to-post ark,t:- ark:- |\"" \ + --targets-parameters="--output-name=output-overlapped_speech --target-type=sparse --dim=2 --targets-scp=$ovlp_data_dir/overlapped_speech_labels_fixed.scp --deriv-weights-scp=$ovlp_data_dir/deriv_weights_for_overlapped_speech.scp --scp2ark-cmd=\"ali-to-post scp:- ark:- |\"" \ + --generate-egs-scp=true \ + --dir=$dir/egs_ovlp + fi +fi + +if [ $stage -le 5 ]; then + steps/nnet3/train_raw_rnn.py --stage=$train_stage \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --egs.chunk-width=$chunk_width \ + --egs.dir="$egs_dir" --egs.stage=$get_egs_stage \ + --egs.chunk-left-context=$extra_left_context \ + --egs.chunk-right-context=$extra_right_context \ + ${extra_egs_copy_cmd:+--egs.extra-copy-cmd="$extra_egs_copy_cmd"} \ + --trainer.min-chunk-left-context=$min_extra_left_context \ + --trainer.min-chunk-right-context=$min_extra_right_context \ + --trainer.num-epochs=$num_epochs \ + --trainer.samples-per-iter=20000 \ + --trainer.optimization.num-jobs-initial=$num_jobs_initial \ + --trainer.optimization.num-jobs-final=$num_jobs_final \ + --trainer.optimization.initial-effective-lrate=$initial_effective_lrate \ + --trainer.optimization.final-effective-lrate=$final_effective_lrate \ + --trainer.optimization.shrink-value=1.0 \ + --trainer.rnn.num-chunk-per-minibatch=$num_chunk_per_minibatch \ + --trainer.deriv-truncate-margin=8 \ + --trainer.max-param-change=$max_param_change \ + --cmd="$decode_cmd" --nj 40 \ + --cleanup=true \ + --cleanup.remove-egs=$remove_egs \ + --cleanup.preserve-model-interval=10 \ + --use-gpu=true \ + --use-dense-targets=false \ + --feat-dir=$ovlp_data_dir \ + --targets-scp="$ovlp_data_dir/overlapped_spech_labels.scp" \ + --dir=$dir || exit 1 +fi diff --git a/egs/aspire/s5/local/segmentation/tuning/train_lstm_stats_sad_overlap_1a.sh b/egs/aspire/s5/local/segmentation/tuning/train_lstm_stats_sad_overlap_1a.sh new file mode 100755 index 00000000000..52a15686d28 --- /dev/null +++ b/egs/aspire/s5/local/segmentation/tuning/train_lstm_stats_sad_overlap_1a.sh @@ -0,0 +1,259 @@ +#!/bin/bash + +# This is a script to train a LSTM for overlapped speech activity detection +# and SAD. + +set -o pipefail +set -e +set -u + +. cmd.sh + +# At this script level we don't support not running on GPU, as it would be painfully slow. +# If you want to run without GPU you'd have to call train_tdnn.sh with --gpu false, +# --num-threads 16 and --minibatch-size 128. + +stage=0 +train_stage=-10 +get_egs_stage=-10 +egs_opts= # Directly passed to get_egs_multiple_targets.py + +chunk_width=40 +num_chunk_per_minibatch=64 + +extra_left_context=40 # Maximum left context in egs apart from TDNN's left context +extra_right_context=0 # Maximum right context in egs apart from TDNN's right context + +# training options +num_epochs=2 +initial_effective_lrate=0.0003 +final_effective_lrate=0.00003 +num_jobs_initial=3 +num_jobs_final=8 +remove_egs=false +max_param_change=0.2 # Small max-param change for small network +extra_egs_copy_cmd= # Used if you want to do some weird stuff to egs + # such as removing one of the targets + +sad_data_dir=data/train_aztec_unsad_whole_corrupted_sp_hires_bp_2400 +ovlp_data_dir=data/train_aztec_unsad_seg_ovlp_corrupted_hires_bp + +egs_dir= +nj=40 +feat_type=raw +config_dir= + +dir= +affix=a + +. cmd.sh +. ./path.sh +. ./utils/parse_options.sh + +num_utts=`cat $sad_data_dir/utt2spk $ovlp_data_dir/utt2spk | wc -l` +num_utts_subset_valid=`perl -e '$n=int($ARGV[0] * 0.005); print ($n > 4000 ? 4000 : $n)' $num_utts` +num_utts_subset_train=`perl -e '$n=int($ARGV[0] * 0.005); print ($n > 4000 ? 4000 : $n)' $num_utts` + +if [ -z "$dir" ]; then + dir=exp/nnet3_stats_sad_ovlp_snr/nnet_lstm +fi + +dir=$dir${affix:+_$affix} + +if ! cuda-compiled; then + cat < $dir/configs/network.xconfig + input dim=`feat-to-dim scp:$sad_data_dir/feats.scp -` name=input + output name=output-temp input=Append(-2,-1,0,1,2) + + relu-renorm-layer name=tdnn1 input=Append(input@-2, input@-1, input, input@1, input@2) dim=512 + lstmp-layer name=lstm1 cell-dim=512 recurrent-projection-dim=256 non-recurrent-projection-dim=256 delay=-3 + relu-renorm-layer name=tdnn2 input=Append(-6,0,6) dim=512 + lstmp-layer name=lstm2 cell-dim=512 recurrent-projection-dim=256 non-recurrent-projection-dim=256 delay=-6 + + output-layer name=output-speech include-log-softmax=true dim=2 objective-scale=$speech_scale input=lstm2 presoftmax-scale-file=$dir/presoftmax_prior_scale_output-speech.txt learning-rate-factor=0.05 + + output-layer name=output-snr include-log-softmax=false dim=$num_snr_bins objective-type=quadratic objective-scale=`perl -e "print $speech_scale / $num_snr_bins"` input=lstm2 max-change=0.75 learning-rate-factor=0.5 + + output-layer name=output-overlapped_speech include-log-softmax=true dim=2 objective-scale=$ovlp_scale input=lstm2 presoftmax-scale-file=$dir/presoftmax_prior_scale_output-overlapped_speech.txt max-change=0.75 learning-rate-factor=0.02 +EOF + steps/nnet3/xconfig_to_configs.py --xconfig-file $dir/configs/network.xconfig \ + --config-dir $dir/configs/ \ + --nnet-edits="rename-node old-name=output-speech new-name=output" + + cat <> $dir/configs/vars +add_lda=false +EOF +fi + +samples_per_iter=`perl -e "print int(400000 / $chunk_width)"` + +if [ -z "$egs_dir" ]; then + egs_dir=$dir/egs_multi + if [ $stage -le 2 ]; then + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $dir/egs_speech/storage ]; then + utils/create_split_dir.pl \ + /export/b{01,02,05,06}/$USER/kaldi-data/egs/aspire-$(date +'%m_%d_%H_%M')/s5/$dir/egs_speech/storage $dir/egs_speech/storage + fi + + . $dir/configs/vars + + steps/nnet3/get_egs_multiple_targets.py --cmd="$decode_cmd" \ + $egs_opts \ + --feat.dir="$sad_data_dir" \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --frames-per-eg=$chunk_width \ + --left-context=$[model_left_context + extra_left_context] \ + --right-context=$[model_right_context + extra_right_context] \ + --num-utts-subset-train=$num_utts_subset_train \ + --num-utts-subset-valid=$num_utts_subset_valid \ + --samples-per-iter=$samples_per_iter \ + --stage=$get_egs_stage \ + --targets-parameters="--output-name=output-snr --target-type=dense --targets-scp=$sad_data_dir/irm_targets.scp --deriv-weights-scp=$sad_data_dir/deriv_weights_manual_seg.scp" \ + --targets-parameters="--output-name=output-speech --target-type=sparse --dim=2 --targets-scp=$sad_data_dir/speech_feat.scp --deriv-weights-scp=$sad_data_dir/deriv_weights.scp --scp2ark-cmd=\"extract-column --column-index=0 scp:- ark,t:- | steps/segmentation/quantize_vector.pl | ali-to-post ark,t:- ark:- |\" --compress=true" \ + --generate-egs-scp=true \ + --dir=$dir/egs_speech + fi + + if [ $stage -le 3 ]; then + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $dir/egs_ovlp/storage ]; then + utils/create_split_dir.pl \ + /export/b{01,02,05,06}/$USER/kaldi-data/egs/aspire-$(date +'%m_%d_%H_%M')/s5/$dir/egs_ovlp/storage $dir/egs_ovlp/storage + fi + + . $dir/configs/vars + + steps/nnet3/get_egs_multiple_targets.py --cmd="$decode_cmd" \ + $egs_opts \ + --feat.dir="$ovlp_data_dir" \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --frames-per-eg=$chunk_width \ + --left-context=$[model_left_context + extra_left_context] \ + --right-context=$[model_right_context + extra_right_context] \ + --num-utts-subset-train=$num_utts_subset_train \ + --num-utts-subset-valid=$num_utts_subset_valid \ + --samples-per-iter=$samples_per_iter \ + --stage=$get_egs_stage \ + --targets-parameters="--output-name=output-speech --target-type=sparse --dim=2 --targets-scp=$ovlp_data_dir/speech_feat.scp --deriv-weights-scp=$ovlp_data_dir/deriv_weights.scp --scp2ark-cmd=\"extract-column --column-index=0 scp:- ark,t:- | steps/segmentation/quantize_vector.pl | ali-to-post ark,t:- ark:- |\"" \ + --targets-parameters="--output-name=output-overlapped_speech --target-type=sparse --dim=2 --targets-scp=$ovlp_data_dir/overlapped_speech_labels_fixed.scp --deriv-weights-scp=$ovlp_data_dir/deriv_weights_for_overlapped_speech.scp --scp2ark-cmd=\"ali-to-post scp:- ark:- |\"" \ + --generate-egs-scp=true \ + --dir=$dir/egs_ovlp + fi + + if [ $stage -le 4 ]; then + # num_chunk_per_minibatch is multiplied by 4 to allow a buffer to use + # the same egs with a different num_chunk_per_minibatch + steps/nnet3/multilingual/get_egs.sh \ + --cmd "$train_cmd" \ + --minibatch-size $[num_chunk_per_minibatch * 4] \ + --samples-per-iter $samples_per_iter \ + 2 $dir/egs_speech $dir/egs_ovlp $dir/egs_multi + fi +fi + +if [ $stage -le 5 ]; then + steps/nnet3/train_raw_rnn.py --stage=$train_stage \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --egs.chunk-width=$chunk_width \ + --egs.dir="$egs_dir" --egs.stage=$get_egs_stage \ + --egs.chunk-left-context=$extra_left_context \ + --egs.chunk-right-context=$extra_right_context \ + --egs.use-multitask-egs=true --egs.rename-multitask-outputs=false \ + ${extra_egs_copy_cmd:+--egs.extra-copy-cmd="$extra_egs_copy_cmd"} \ + --trainer.num-epochs=$num_epochs \ + --trainer.samples-per-iter=20000 \ + --trainer.optimization.num-jobs-initial=$num_jobs_initial \ + --trainer.optimization.num-jobs-final=$num_jobs_final \ + --trainer.optimization.initial-effective-lrate=$initial_effective_lrate \ + --trainer.optimization.final-effective-lrate=$final_effective_lrate \ + --trainer.optimization.shrink-value=1.0 \ + --trainer.rnn.num-chunk-per-minibatch=$num_chunk_per_minibatch \ + --trainer.deriv-truncate-margin=8 \ + --trainer.max-param-change=$max_param_change \ + --trainer.compute-per-dim-accuracy=true \ + --cmd="$decode_cmd" --nj 40 \ + --cleanup=true \ + --cleanup.remove-egs=$remove_egs \ + --cleanup.preserve-model-interval=10 \ + --use-gpu=true \ + --use-dense-targets=false \ + --feat-dir=$sad_data_dir \ + --targets-scp="$sad_data_dir/speech_feat.scp" \ + --dir=$dir || exit 1 +fi diff --git a/egs/aspire/s5/local/segmentation/tuning/train_lstm_stats_sad_overlap_ami_1a.sh b/egs/aspire/s5/local/segmentation/tuning/train_lstm_stats_sad_overlap_ami_1a.sh new file mode 100644 index 00000000000..d003f746c4b --- /dev/null +++ b/egs/aspire/s5/local/segmentation/tuning/train_lstm_stats_sad_overlap_ami_1a.sh @@ -0,0 +1,192 @@ +#!/bin/bash + +# This is a script to train a LSTM for overlapped speech activity detection +# and SAD. + +set -o pipefail +set -e +set -u + +. cmd.sh + +# At this script level we don't support not running on GPU, as it would be painfully slow. +# If you want to run without GPU you'd have to call train_tdnn.sh with --gpu false, +# --num-threads 16 and --minibatch-size 128. + +stage=0 +train_stage=-10 +get_egs_stage=-10 +egs_opts= # Directly passed to get_egs_multiple_targets.py + +chunk_width=20 +num_chunk_per_minibatch=128 + +extra_left_context=40 # Maximum left context in egs apart from TDNN's left context +extra_right_context=0 # Maximum right context in egs apart from TDNN's right context + +# training options +num_epochs=8 +initial_effective_lrate=0.0003 +final_effective_lrate=0.00003 +num_jobs_initial=3 +num_jobs_final=8 +remove_egs=false +max_param_change=0.2 # Small max-param change for small network +extra_egs_copy_cmd= # Used if you want to do some weird stuff to egs + # such as removing one of the targets + +data_dir=data/ami_sdm1_train_whole_hires_bp +labels_scp=exp/sad_ami_sdm1_train/ref/overlapping_sad_labels.scp +deriv_weights_scp=exp/sad_ami_sdm1_train/ref/deriv_weights_for_overlapping_sad.scp + +egs_dir= +nj=40 +feat_type=raw +config_dir= + +dir= +affix=a + +. cmd.sh +. ./path.sh +. ./utils/parse_options.sh + +num_utts=`cat $data_dir/utt2spk | wc -l` +num_utts_subset_valid=`perl -e '$n=int($ARGV[0] * 0.005); $n = ($n > 4000 ? 4000 : $n); print ($n < 6 ? 6 : $n)' $num_utts` +num_utts_subset_train=`perl -e '$n=int($ARGV[0] * 0.005); $n = ($n > 4000 ? 4000 : $n); print ($n < 6 ? 6 : $n)' $num_utts` + +if [ -z "$dir" ]; then + dir=exp/nnet3_ovlp_sad_ami/nnet_lstm +fi + +dir=$dir${affix:+_$affix} + +if ! cuda-compiled; then + cat < $dir/configs/network.xconfig + input dim=`feat-to-dim scp:$data_dir/feats.scp -` name=input + output name=output-temp input=Append(-2,-1,0,1,2) + + relu-renorm-layer name=tdnn1 input=Append(input@-2, input@-1, input, input@1, input@2) dim=256 + relu-renorm-layer name=tdnn2 input=Append(-1,0,1,2) dim=256 + relu-renorm-layer name=tdnn3 input=Append(-3,0,3,6) dim=256 + lstmp-layer name=lstm1 cell-dim=256 recurrent-projection-dim=128 non-recurrent-projection-dim=128 delay=-3 + relu-renorm-layer name=tdnn4 input=Append(-6,0,6) dim=256 + lstmp-layer name=lstm2 cell-dim=256 recurrent-projection-dim=128 non-recurrent-projection-dim=128 delay=-6 + + output-layer name=output-overlapping_sad include-log-softmax=true dim=3 input=lstm2 presoftmax-scale-file=$dir/presoftmax_prior_scale_output-overlapping_sad.txt learning-rate-factor=0.05 +EOF + steps/nnet3/xconfig_to_configs.py --xconfig-file $dir/configs/network.xconfig \ + --config-dir $dir/configs/ \ + --nnet-edits="rename-node old-name=output-overlapping_sad new-name=output" + + cat <> $dir/configs/vars +add_lda=false +EOF +fi + +samples_per_iter=`perl -e "print int(400000 / $chunk_width)"` + +if [ -z "$egs_dir" ]; then + egs_dir=$dir/egs_overlapping_sad + if [ $stage -le 2 ]; then + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $dir/egs_overlapping_sad/storage ]; then + utils/create_split_dir.pl \ + /export/b{01,02,05,06}/$USER/kaldi-data/egs/aspire-$(date +'%m_%d_%H_%M')/s5/$dir/egs_overlapping_sad/storage $dir/egs_overlapping_sad/storage + fi + + . $dir/configs/vars + + steps/nnet3/get_egs_multiple_targets.py --cmd="$decode_cmd" \ + $egs_opts \ + --feat.dir="$data_dir" \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --frames-per-eg=$chunk_width \ + --left-context=$[model_left_context + extra_left_context] \ + --right-context=$[model_right_context + extra_right_context] \ + --num-utts-subset-train=$num_utts_subset_train \ + --num-utts-subset-valid=$num_utts_subset_valid \ + --samples-per-iter=$samples_per_iter \ + --stage=$get_egs_stage \ + --targets-parameters="--output-name=output-overlapping_sad --target-type=sparse --dim=3 --targets-scp=$labels_scp --deriv-weights-scp=$deriv_weights_scp --scp2ark-cmd=\"ali-to-post scp:- ark: |\" --compress=true" \ + --generate-egs-scp=true \ + --dir=$dir/egs_overlapping_sad + fi +fi + +if [ $stage -le 5 ]; then + steps/nnet3/train_raw_rnn.py --stage=$train_stage \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --egs.chunk-width=$chunk_width \ + --egs.dir="$egs_dir" --egs.stage=$get_egs_stage \ + --egs.chunk-left-context=$extra_left_context \ + --egs.chunk-right-context=$extra_right_context \ + --egs.use-multitask-egs=false --egs.rename-multitask-outputs=false \ + ${extra_egs_copy_cmd:+--egs.extra-copy-cmd="$extra_egs_copy_cmd"} \ + --trainer.num-epochs=$num_epochs \ + --trainer.samples-per-iter=20000 \ + --trainer.optimization.num-jobs-initial=$num_jobs_initial \ + --trainer.optimization.num-jobs-final=$num_jobs_final \ + --trainer.optimization.initial-effective-lrate=$initial_effective_lrate \ + --trainer.optimization.final-effective-lrate=$final_effective_lrate \ + --trainer.optimization.shrink-value=1.0 \ + --trainer.rnn.num-chunk-per-minibatch=$num_chunk_per_minibatch \ + --trainer.deriv-truncate-margin=8 \ + --trainer.max-param-change=$max_param_change \ + --trainer.compute-per-dim-accuracy=true \ + --cmd="$decode_cmd" --nj 40 \ + --cleanup=true \ + --cleanup.remove-egs=$remove_egs \ + --cleanup.preserve-model-interval=10 \ + --use-gpu=true \ + --use-dense-targets=false \ + --feat-dir=$data_dir \ + --targets-scp="$labels_scp" \ + --dir=$dir || exit 1 +fi + diff --git a/egs/aspire/s5/local/segmentation/tuning/train_lstm_stats_sad_overlap_ami_1b.sh b/egs/aspire/s5/local/segmentation/tuning/train_lstm_stats_sad_overlap_ami_1b.sh new file mode 100644 index 00000000000..3aa4f28f99a --- /dev/null +++ b/egs/aspire/s5/local/segmentation/tuning/train_lstm_stats_sad_overlap_ami_1b.sh @@ -0,0 +1,192 @@ +#!/bin/bash + +# This is a script to train a LSTM for overlapped speech activity detection +# and SAD. + +set -o pipefail +set -e +set -u + +. cmd.sh + +# At this script level we don't support not running on GPU, as it would be painfully slow. +# If you want to run without GPU you'd have to call train_tdnn.sh with --gpu false, +# --num-threads 16 and --minibatch-size 128. + +stage=0 +train_stage=-10 +get_egs_stage=-10 +egs_opts= # Directly passed to get_egs_multiple_targets.py + +chunk_width=20 +num_chunk_per_minibatch=128 + +extra_left_context=40 # Maximum left context in egs apart from TDNN's left context +extra_right_context=0 # Maximum right context in egs apart from TDNN's right context + +# training options +num_epochs=8 +initial_effective_lrate=0.0003 +final_effective_lrate=0.00003 +num_jobs_initial=3 +num_jobs_final=8 +remove_egs=false +max_param_change=0.2 # Small max-param change for small network +extra_egs_copy_cmd= # Used if you want to do some weird stuff to egs + # such as removing one of the targets + +data_dir=data/ami_sdm1_train_whole_hires_bp +labels_scp=exp/sad_ami_sdm1_train/ref/overlapping_sad_labels.scp +deriv_weights_scp=exp/sad_ami_sdm1_train/ref/deriv_weights_for_overlapping_sad.scp + +egs_dir= +nj=40 +feat_type=raw +config_dir= + +dir= +affix=a + +. cmd.sh +. ./path.sh +. ./utils/parse_options.sh + +num_utts=`cat $data_dir/utt2spk | wc -l` +num_utts_subset_valid=`perl -e '$n=int($ARGV[0] * 0.005); $n = ($n > 4000 ? 4000 : $n); print ($n < 6 ? 6 : $n)' $num_utts` +num_utts_subset_train=`perl -e '$n=int($ARGV[0] * 0.005); $n = ($n > 4000 ? 4000 : $n); print ($n < 6 ? 6 : $n)' $num_utts` + +if [ -z "$dir" ]; then + dir=exp/nnet3_ovlp_sad_ami/nnet_lstm +fi + +dir=$dir${affix:+_$affix} + +if ! cuda-compiled; then + cat < $dir/configs/network.xconfig + input dim=`feat-to-dim scp:$data_dir/feats.scp -` name=input + output name=output-temp input=Append(-2,-1,0,1,2) + + relu-renorm-layer name=tdnn1 input=Append(input@-2, input@-1, input, input@1, input@2) dim=256 + relu-renorm-layer name=tdnn2 input=Append(-1,0,1,2) dim=256 + relu-renorm-layer name=tdnn3 input=Append(-3,0,3,6) dim=256 + lstmp-layer name=lstm1 cell-dim=256 recurrent-projection-dim=128 non-recurrent-projection-dim=128 delay=-3 + relu-renorm-layer name=tdnn4 input=Append(-6,0,6) dim=256 + lstmp-layer name=lstm2 cell-dim=256 recurrent-projection-dim=128 non-recurrent-projection-dim=128 delay=-6 + + output-layer name=output-overlapping_sad include-log-softmax=true dim=3 input=lstm2 presoftmax-scale-file=$dir/presoftmax_prior_scale_output-overlapping_sad.txt learning-rate-factor=0.05 +EOF + steps/nnet3/xconfig_to_configs.py --xconfig-file $dir/configs/network.xconfig \ + --config-dir $dir/configs/ \ + --nnet-edits="rename-node old-name=output-overlapping_sad new-name=output" + + cat <> $dir/configs/vars +add_lda=false +EOF +fi + +samples_per_iter=`perl -e "print int(400000 / $chunk_width)"` + +if [ -z "$egs_dir" ]; then + egs_dir=$dir/egs_overlapping_sad + if [ $stage -le 2 ]; then + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $dir/egs_overlapping_sad/storage ]; then + utils/create_split_dir.pl \ + /export/b{01,02,05,06}/$USER/kaldi-data/egs/aspire-$(date +'%m_%d_%H_%M')/s5/$dir/egs_overlapping_sad/storage $dir/egs_overlapping_sad/storage + fi + + . $dir/configs/vars + + steps/nnet3/get_egs_multiple_targets.py --cmd="$decode_cmd" \ + $egs_opts \ + --feat.dir="$data_dir" \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --frames-per-eg=$chunk_width \ + --left-context=$[model_left_context + extra_left_context] \ + --right-context=$[model_right_context + extra_right_context] \ + --num-utts-subset-train=$num_utts_subset_train \ + --num-utts-subset-valid=$num_utts_subset_valid \ + --samples-per-iter=$samples_per_iter \ + --stage=$get_egs_stage \ + --targets-parameters="--output-name=output-overlapping_sad --target-type=sparse --dim=3 --targets-scp=$labels_scp --deriv-weights-scp=$deriv_weights_scp --scp2ark-cmd=\"ali-to-post scp:- ark: |\" --compress=true" \ + --generate-egs-scp=true \ + --dir=$dir/egs_overlapping_sad + fi +fi + +if [ $stage -le 5 ]; then + steps/nnet3/train_raw_rnn.py --stage=$train_stage \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --egs.chunk-width=$chunk_width \ + --egs.dir="$egs_dir" --egs.stage=$get_egs_stage \ + --egs.chunk-left-context=$extra_left_context \ + --egs.chunk-right-context=$extra_right_context \ + --egs.use-multitask-egs=false --egs.rename-multitask-outputs=false \ + ${extra_egs_copy_cmd:+--egs.extra-copy-cmd="$extra_egs_copy_cmd"} \ + --trainer.num-epochs=$num_epochs \ + --trainer.samples-per-iter=20000 \ + --trainer.optimization.num-jobs-initial=$num_jobs_initial \ + --trainer.optimization.num-jobs-final=$num_jobs_final \ + --trainer.optimization.initial-effective-lrate=$initial_effective_lrate \ + --trainer.optimization.final-effective-lrate=$final_effective_lrate \ + --trainer.optimization.shrink-value=1.0 \ + --trainer.rnn.num-chunk-per-minibatch=$num_chunk_per_minibatch \ + --trainer.deriv-truncate-margin=8 \ + --trainer.max-param-change=$max_param_change \ + --trainer.compute-per-dim-accuracy=true \ + --cmd="$decode_cmd" --nj 40 \ + --cleanup=true \ + --cleanup.remove-egs=$remove_egs \ + --cleanup.preserve-model-interval=10 \ + --use-gpu=true \ + --use-dense-targets=false \ + --feat-dir=$data_dir \ + --targets-scp="$labels_scp" \ + --dir=$dir || exit 1 +fi + diff --git a/egs/aspire/s5/local/segmentation/tuning/train_rnn_overlap_1a.sh b/egs/aspire/s5/local/segmentation/tuning/train_rnn_overlap_1a.sh new file mode 100755 index 00000000000..e63c5d8a063 --- /dev/null +++ b/egs/aspire/s5/local/segmentation/tuning/train_rnn_overlap_1a.sh @@ -0,0 +1,184 @@ +#!/bin/bash + +# This is a script to train a lstm for overlapped speech activity detection. + +set -o pipefail +set -e +set -u + +. cmd.sh + +# At this script level we don't support not running on GPU, as it would be painfully slow. + +stage=0 +train_stage=-10 +get_egs_stage=-10 +egs_opts= # Directly passed to get_egs_multiple_targets.py + +chunk_width=40 +num_chunk_per_minibatch=64 + +extra_left_context=40 # Maximum left context in egs apart from TDNN's left context +extra_right_context=0 # Maximum right context in egs apart from TDNN's right context + +# training options +num_epochs=2 +initial_effective_lrate=0.0003 +final_effective_lrate=0.00003 +num_jobs_initial=3 +num_jobs_final=8 +remove_egs=false +max_param_change=0.2 # Small max-param change for small network +extra_egs_copy_cmd="nnet3-copy-egs --keep-outputs=output-overlapped_speech ark:- ark:- |" # Used if you want to do some weird stuff to egs + # such as removing one of the targets + +ovlp_data_dir=data/train_aztec_unsad_seg_ovlp_corrupted_hires_bp + +egs_dir= +nj=40 +feat_type=raw +config_dir= + +dir= +affix=f + +. cmd.sh +. ./path.sh +. ./utils/parse_options.sh + +num_utts=`cat $ovlp_data_dir/utt2spk | wc -l` +num_utts_subset_valid=`perl -e '$n=int($ARGV[0] * 0.005); print ($n > 4000 ? 4000 : $n)' $num_utts` +num_utts_subset_train=`perl -e '$n=int($ARGV[0] * 0.005); print ($n > 4000 ? 4000 : $n)' $num_utts` + +if [ -z "$dir" ]; then + dir=exp/nnet3_stats_ovlp/nnet_lstm +fi + +dir=$dir${affix:+_$affix} + +if ! cuda-compiled; then + cat < $dir/configs/network.xconfig + input dim=`feat-to-dim scp:$ovlp_data_dir/feats.scp -` name=input + output name=output-temp input=Append(-2,-1,0,1,2) + + relu-renorm-layer name=tdnn1 dim=256 input=Append(input@-2, input@-1, input, input@1, input@2) + lstmp-layer name=lstm1 cell-dim=256 recurrent-projection-dim=128 non-recurrent-projection-dim=128 delay=-3 + relu-renorm-layer name=tdnn2 input=Append(-6,0,6) dim=256 + lstmp-layer name=lstm2 cell-dim=256 recurrent-projection-dim=128 non-recurrent-projection-dim=128 delay=-6 + + output-layer name=output-overlapped_speech include-log-softmax=true dim=2 input=lstm2 presoftmax-scale-file=$dir/presoftmax_prior_scale_output-overlapped_speech.txt max-change=0.75 learning-rate-factor=0.02 +EOF + steps/nnet3/xconfig_to_configs.py --xconfig-file $dir/configs/network.xconfig \ + --config-dir $dir/configs/ \ + --nnet-edits="rename-node old-name=output-overlapped_speech new-name=output" + + cat <> $dir/configs/vars +add_lda=false +EOF +fi + +samples_per_iter=`perl -e "print int(400000 / $chunk_width)"` + +if [ -z "$egs_dir" ]; then + egs_dir=$dir/egs_ovlp + if [ $stage -le 3 ]; then + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $dir/egs_ovlp/storage ]; then + utils/create_split_dir.pl \ + /export/b{03,04,05,06}/$USER/kaldi-data/egs/aspire-$(date +'%m_%d_%H_%M')/s5/$dir/egs_ovlp/storage $dir/egs_ovlp/storage + fi + + . $dir/configs/vars + + steps/nnet3/get_egs_multiple_targets.py --cmd="$decode_cmd" \ + $egs_opts \ + --feat.dir="$ovlp_data_dir" \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --frames-per-eg=$chunk_width \ + --left-context=$[model_left_context + extra_left_context] \ + --right-context=$[model_right_context + extra_right_context] \ + --num-utts-subset-train=$num_utts_subset_train \ + --num-utts-subset-valid=$num_utts_subset_valid \ + --samples-per-iter=$samples_per_iter \ + --stage=$get_egs_stage \ + --targets-parameters="--output-name=output-speech --target-type=sparse --dim=2 --targets-scp=$ovlp_data_dir/speech_feat.scp --deriv-weights-scp=$ovlp_data_dir/deriv_weights.scp --scp2ark-cmd=\"extract-column --column-index=0 scp:- ark,t:- | steps/segmentation/quantize_vector.pl | ali-to-post ark,t:- ark:- |\"" \ + --targets-parameters="--output-name=output-overlapped_speech --target-type=sparse --dim=2 --targets-scp=$ovlp_data_dir/overlapped_speech_labels_fixed.scp --deriv-weights-scp=$ovlp_data_dir/deriv_weights_for_overlapped_speech.scp --scp2ark-cmd=\"ali-to-post scp:- ark:- |\"" \ + --generate-egs-scp=true \ + --dir=$dir/egs_ovlp + fi +fi + +if [ $stage -le 5 ]; then + steps/nnet3/train_raw_rnn.py --stage=$train_stage \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --egs.chunk-width=$chunk_width \ + --egs.dir="$egs_dir" --egs.stage=$get_egs_stage \ + --egs.chunk-left-context=$extra_left_context \ + --egs.chunk-right-context=$extra_right_context \ + ${extra_egs_copy_cmd:+--egs.extra-copy-cmd="$extra_egs_copy_cmd"} \ + --trainer.num-epochs=$num_epochs \ + --trainer.samples-per-iter=20000 \ + --trainer.optimization.num-jobs-initial=$num_jobs_initial \ + --trainer.optimization.num-jobs-final=$num_jobs_final \ + --trainer.optimization.initial-effective-lrate=$initial_effective_lrate \ + --trainer.optimization.final-effective-lrate=$final_effective_lrate \ + --trainer.optimization.shrink-value=1.0 \ + --trainer.rnn.num-chunk-per-minibatch=$num_chunk_per_minibatch \ + --trainer.deriv-truncate-margin=8 \ + --trainer.max-param-change=$max_param_change \ + --cmd="$decode_cmd" --nj 40 \ + --cleanup=true \ + --cleanup.remove-egs=$remove_egs \ + --cleanup.preserve-model-interval=10 \ + --use-gpu=true \ + --use-dense-targets=false \ + --feat-dir=$ovlp_data_dir \ + --targets-scp="$ovlp_data_dir/overlapped_spech_labels.scp" \ + --dir=$dir || exit 1 +fi + diff --git a/egs/aspire/s5/local/segmentation/tuning/train_rnn_overlap_1b.sh b/egs/aspire/s5/local/segmentation/tuning/train_rnn_overlap_1b.sh new file mode 100755 index 00000000000..15235882f90 --- /dev/null +++ b/egs/aspire/s5/local/segmentation/tuning/train_rnn_overlap_1b.sh @@ -0,0 +1,184 @@ +#!/bin/bash + +# This is a script to train a LSTM for overlapped speech activity detection. + +set -o pipefail +set -e +set -u + +. cmd.sh + +# At this script level we don't support not running on GPU, as it would be painfully slow. + +stage=0 +train_stage=-10 +get_egs_stage=-10 +egs_opts= # Directly passed to get_egs_multiple_targets.py + +chunk_width=40 +num_chunk_per_minibatch=64 + +extra_left_context=40 # Maximum left context in egs apart from TDNN's left context +extra_right_context=0 # Maximum right context in egs apart from TDNN's right context + +# training options +num_epochs=2 +initial_effective_lrate=0.0003 +final_effective_lrate=0.00003 +num_jobs_initial=3 +num_jobs_final=8 +remove_egs=false +max_param_change=0.2 # Small max-param change for small network +extra_egs_copy_cmd="nnet3-copy-egs --keep-outputs=output-overlapped_speech ark:- ark:- |" # Used if you want to do some weird stuff to egs + # such as removing one of the targets + +ovlp_data_dir=data/train_aztec_unsad_seg_ovlp_corrupted_hires_bp + +egs_dir= +nj=40 +feat_type=raw +config_dir= + +dir= +affix=b + +. cmd.sh +. ./path.sh +. ./utils/parse_options.sh + +num_utts=`cat $ovlp_data_dir/utt2spk | wc -l` +num_utts_subset_valid=`perl -e '$n=int($ARGV[0] * 0.005); print ($n > 4000 ? 4000 : $n)' $num_utts` +num_utts_subset_train=`perl -e '$n=int($ARGV[0] * 0.005); print ($n > 4000 ? 4000 : $n)' $num_utts` + +if [ -z "$dir" ]; then + dir=exp/nnet3_stats_ovlp/nnet_lstm +fi + +dir=$dir${affix:+_$affix} + +if ! cuda-compiled; then + cat < $dir/configs/network.xconfig + input dim=`feat-to-dim scp:$ovlp_data_dir/feats.scp -` name=input + output name=output-temp input=Append(-2,-1,0,1,2) + + relu-renorm-layer name=tdnn1 dim=512 input=Append(input@-2, input@-1, input, input@1, input@2) + lstmp-layer name=lstm1 cell-dim=512 recurrent-projection-dim=256 non-recurrent-projection-dim=256 delay=-3 + relu-renorm-layer name=tdnn2 input=Append(-6,0,6) dim=512 + lstmp-layer name=lstm2 cell-dim=512 recurrent-projection-dim=256 non-recurrent-projection-dim=256 delay=-6 + + output-layer name=output-overlapped_speech include-log-softmax=true dim=2 input=lstm2 presoftmax-scale-file=$dir/presoftmax_prior_scale_output-overlapped_speech.txt max-change=0.75 learning-rate-factor=0.02 +EOF + steps/nnet3/xconfig_to_configs.py --xconfig-file $dir/configs/network.xconfig \ + --config-dir $dir/configs/ \ + --nnet-edits="rename-node old-name=output-overlapped_speech new-name=output" + + cat <> $dir/configs/vars +add_lda=false +EOF +fi + +samples_per_iter=`perl -e "print int(400000 / $chunk_width)"` + +if [ -z "$egs_dir" ]; then + egs_dir=$dir/egs_ovlp + if [ $stage -le 3 ]; then + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $dir/egs_ovlp/storage ]; then + utils/create_split_dir.pl \ + /export/b{03,04,05,06}/$USER/kaldi-data/egs/aspire-$(date +'%m_%d_%H_%M')/s5/$dir/egs_ovlp/storage $dir/egs_ovlp/storage + fi + + . $dir/configs/vars + + steps/nnet3/get_egs_multiple_targets.py --cmd="$decode_cmd" \ + $egs_opts \ + --feat.dir="$ovlp_data_dir" \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --frames-per-eg=$chunk_width \ + --left-context=$[model_left_context + extra_left_context] \ + --right-context=$[model_right_context + extra_right_context] \ + --num-utts-subset-train=$num_utts_subset_train \ + --num-utts-subset-valid=$num_utts_subset_valid \ + --samples-per-iter=$samples_per_iter \ + --stage=$get_egs_stage \ + --targets-parameters="--output-name=output-speech --target-type=sparse --dim=2 --targets-scp=$ovlp_data_dir/speech_feat.scp --deriv-weights-scp=$ovlp_data_dir/deriv_weights.scp --scp2ark-cmd=\"extract-column --column-index=0 scp:- ark,t:- | steps/segmentation/quantize_vector.pl | ali-to-post ark,t:- ark:- |\"" \ + --targets-parameters="--output-name=output-overlapped_speech --target-type=sparse --dim=2 --targets-scp=$ovlp_data_dir/overlapped_speech_labels_fixed.scp --deriv-weights-scp=$ovlp_data_dir/deriv_weights_for_overlapped_speech.scp --scp2ark-cmd=\"ali-to-post scp:- ark:- |\"" \ + --generate-egs-scp=true \ + --dir=$dir/egs_ovlp + fi +fi + +if [ $stage -le 5 ]; then + steps/nnet3/train_raw_rnn.py --stage=$train_stage \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --egs.chunk-width=$chunk_width \ + --egs.dir="$egs_dir" --egs.stage=$get_egs_stage \ + --egs.chunk-left-context=$extra_left_context \ + --egs.chunk-right-context=$extra_right_context \ + ${extra_egs_copy_cmd:+--egs.extra-copy-cmd="$extra_egs_copy_cmd"} \ + --trainer.num-epochs=$num_epochs \ + --trainer.samples-per-iter=20000 \ + --trainer.optimization.num-jobs-initial=$num_jobs_initial \ + --trainer.optimization.num-jobs-final=$num_jobs_final \ + --trainer.optimization.initial-effective-lrate=$initial_effective_lrate \ + --trainer.optimization.final-effective-lrate=$final_effective_lrate \ + --trainer.optimization.shrink-value=1.0 \ + --trainer.rnn.num-chunk-per-minibatch=$num_chunk_per_minibatch \ + --trainer.deriv-truncate-margin=8 \ + --trainer.max-param-change=$max_param_change \ + --cmd="$decode_cmd" --nj 40 \ + --cleanup=true \ + --cleanup.remove-egs=$remove_egs \ + --cleanup.preserve-model-interval=10 \ + --use-gpu=true \ + --use-dense-targets=false \ + --feat-dir=$ovlp_data_dir \ + --targets-scp="$ovlp_data_dir/overlapped_spech_labels.scp" \ + --dir=$dir || exit 1 +fi + diff --git a/egs/aspire/s5/local/segmentation/tuning/train_stats_overlap_1f.sh b/egs/aspire/s5/local/segmentation/tuning/train_stats_overlap_1f.sh new file mode 100755 index 00000000000..2201f9fd8d1 --- /dev/null +++ b/egs/aspire/s5/local/segmentation/tuning/train_stats_overlap_1f.sh @@ -0,0 +1,200 @@ +#!/bin/bash + +# This is a script to train a time-delay neural network for overlapped speech activity detection +# using statistic pooling component for long-context information. + +set -o pipefail +set -e +set -u + +. cmd.sh + +# At this script level we don't support not running on GPU, as it would be painfully slow. +# If you want to run without GPU you'd have to call train_tdnn.sh with --gpu false, +# --num-threads 16 and --minibatch-size 128. + +stage=0 +train_stage=-10 +get_egs_stage=-10 +egs_opts= # Directly passed to get_egs_multiple_targets.py + +# TDNN options +relu_dim=256 +chunk_width=40 # We use chunk training for training TDNN +num_chunk_per_minibatch=64 + +extra_left_context=100 # Maximum left context in egs apart from TDNN's left context +extra_right_context=20 # Maximum right context in egs apart from TDNN's right context + +# We randomly select an extra {left,right} context for each job between +# min_extra_*_context and extra_*_context so that the network can get used +# to different contexts used to compute statistics. +min_extra_left_context=20 +min_extra_right_context=0 + +# training options +num_epochs=2 +initial_effective_lrate=0.0003 +final_effective_lrate=0.00003 +num_jobs_initial=3 +num_jobs_final=8 +remove_egs=false +max_param_change=0.2 # Small max-param change for small network +extra_egs_copy_cmd="nnet3-copy-egs --keep-outputs=output-overlapped_speech ark:- ark:- |" # Used if you want to do some weird stuff to egs + # such as removing one of the targets + +ovlp_data_dir=data/train_aztec_unsad_seg_ovlp_corrupted_hires_bp + +#extra_left_context=79 +#extra_right_context=11 + +egs_dir= +nj=40 +feat_type=raw +config_dir= + +dir= +affix=f + +. cmd.sh +. ./path.sh +. ./utils/parse_options.sh + +num_utts=`cat $ovlp_data_dir/utt2spk | wc -l` +num_utts_subset_valid=`perl -e '$n=int($ARGV[0] * 0.005); print ($n > 4000 ? 4000 : $n)' $num_utts` +num_utts_subset_train=`perl -e '$n=int($ARGV[0] * 0.005); print ($n > 4000 ? 4000 : $n)' $num_utts` + +if [ -z "$dir" ]; then + dir=exp/nnet3_stats_ovlp/nnet_tdnn +fi + +dir=$dir${affix:+_$affix} + +if ! cuda-compiled; then + cat < $dir/configs/network.xconfig + input dim=`feat-to-dim scp:$ovlp_data_dir/feats.scp -` name=input + output name=output-temp input=Append(-3,-2,-1,0,1,2,3) + + relu-renorm-layer name=tdnn1 input=Append(input@-3, input@-2, input@-1, input, input@1, input@2, input@3) dim=256 + stats-layer name=tdnn2_stats config=mean+count(-99:3:9:99) + relu-renorm-layer name=tdnn2 input=Append(tdnn1@-6, tdnn1, tdnn2_stats) dim=256 + relu-renorm-layer name=tdnn3 input=Append(-9,0,3) dim=256 + relu-renorm-layer name=tdnn4 dim=256 + + output-layer name=output-overlapped_speech include-log-softmax=true dim=2 input=tdnn4 presoftmax-scale-file=$dir/presoftmax_prior_scale_output-overlapped_speech.txt +EOF + steps/nnet3/xconfig_to_configs.py --xconfig-file $dir/configs/network.xconfig \ + --config-dir $dir/configs/ \ + --nnet-edits="rename-node old-name=output-overlapped_speech new-name=output" + + cat <> $dir/configs/vars +add_lda=false +EOF +fi + +samples_per_iter=`perl -e "print int(400000 / $chunk_width)"` + +if [ -z "$egs_dir" ]; then + egs_dir=$dir/egs_ovlp + if [ $stage -le 3 ]; then + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $dir/egs_ovlp/storage ]; then + utils/create_split_dir.pl \ + /export/b{03,04,05,06}/$USER/kaldi-data/egs/aspire-$(date +'%m_%d_%H_%M')/s5/$dir/egs_ovlp/storage $dir/egs_ovlp/storage + fi + + . $dir/configs/vars + + steps/nnet3/get_egs_multiple_targets.py --cmd="$decode_cmd" \ + $egs_opts \ + --feat.dir="$ovlp_data_dir" \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --frames-per-eg=$chunk_width \ + --left-context=$[model_left_context + extra_left_context] \ + --right-context=$[model_right_context + extra_right_context] \ + --num-utts-subset-train=$num_utts_subset_train \ + --num-utts-subset-valid=$num_utts_subset_valid \ + --samples-per-iter=$samples_per_iter \ + --stage=$get_egs_stage \ + --targets-parameters="--output-name=output-speech --target-type=sparse --dim=2 --targets-scp=$ovlp_data_dir/speech_feat.scp --deriv-weights-scp=$ovlp_data_dir/deriv_weights.scp --scp2ark-cmd=\"extract-column --column-index=0 scp:- ark,t:- | steps/segmentation/quantize_vector.pl | ali-to-post ark,t:- ark:- |\"" \ + --targets-parameters="--output-name=output-overlapped_speech --target-type=sparse --dim=2 --targets-scp=$ovlp_data_dir/overlapped_speech_labels_fixed.scp --deriv-weights-scp=$ovlp_data_dir/deriv_weights_for_overlapped_speech.scp --scp2ark-cmd=\"ali-to-post scp:- ark:- |\"" \ + --generate-egs-scp=true \ + --dir=$dir/egs_ovlp + fi +fi + +if [ $stage -le 5 ]; then + steps/nnet3/train_raw_rnn.py --stage=$train_stage \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --egs.chunk-width=$chunk_width \ + --egs.dir="$egs_dir" --egs.stage=$get_egs_stage \ + --egs.chunk-left-context=$extra_left_context \ + --egs.chunk-right-context=$extra_right_context \ + ${extra_egs_copy_cmd:+--egs.extra-copy-cmd="$extra_egs_copy_cmd"} \ + --trainer.min-chunk-left-context=$min_extra_left_context \ + --trainer.min-chunk-right-context=$min_extra_right_context \ + --trainer.num-epochs=$num_epochs \ + --trainer.samples-per-iter=20000 \ + --trainer.optimization.num-jobs-initial=$num_jobs_initial \ + --trainer.optimization.num-jobs-final=$num_jobs_final \ + --trainer.optimization.initial-effective-lrate=$initial_effective_lrate \ + --trainer.optimization.final-effective-lrate=$final_effective_lrate \ + --trainer.optimization.shrink-value=1.0 \ + --trainer.rnn.num-chunk-per-minibatch=$num_chunk_per_minibatch \ + --trainer.deriv-truncate-margin=8 \ + --trainer.max-param-change=$max_param_change \ + --cmd="$decode_cmd" --nj 40 \ + --cleanup=true \ + --cleanup.remove-egs=$remove_egs \ + --cleanup.preserve-model-interval=10 \ + --use-gpu=true \ + --use-dense-targets=false \ + --feat-dir=$ovlp_data_dir \ + --targets-scp="$ovlp_data_dir/overlapped_spech_labels.scp" \ + --dir=$dir || exit 1 +fi diff --git a/egs/aspire/s5/local/segmentation/tuning/train_stats_overlap_1g.sh b/egs/aspire/s5/local/segmentation/tuning/train_stats_overlap_1g.sh new file mode 100755 index 00000000000..81febb5fa09 --- /dev/null +++ b/egs/aspire/s5/local/segmentation/tuning/train_stats_overlap_1g.sh @@ -0,0 +1,202 @@ +#!/bin/bash + +# This is a script to train a time-delay neural network for overlapped speech activity detection +# using statistic pooling component for long-context information. + +# This scripts is similar to 1f but adds max-change=0.75 and learning-rate-factor=0.1 to the final affine. + +set -o pipefail +set -e +set -u + +. cmd.sh + +# At this script level we don't support not running on GPU, as it would be painfully slow. +# If you want to run without GPU you'd have to call train_tdnn.sh with --gpu false, +# --num-threads 16 and --minibatch-size 128. + +stage=0 +train_stage=-10 +get_egs_stage=-10 +egs_opts= # Directly passed to get_egs_multiple_targets.py + +# TDNN options +relu_dim=256 +chunk_width=40 # We use chunk training for training TDNN +num_chunk_per_minibatch=64 + +extra_left_context=100 # Maximum left context in egs apart from TDNN's left context +extra_right_context=20 # Maximum right context in egs apart from TDNN's right context + +# We randomly select an extra {left,right} context for each job between +# min_extra_*_context and extra_*_context so that the network can get used +# to different contexts used to compute statistics. +min_extra_left_context=20 +min_extra_right_context=0 + +# training options +num_epochs=2 +initial_effective_lrate=0.0003 +final_effective_lrate=0.00003 +num_jobs_initial=3 +num_jobs_final=8 +remove_egs=false +max_param_change=0.2 # Small max-param change for small network +extra_egs_copy_cmd="nnet3-copy-egs --keep-outputs=output-overlapped_speech ark:- ark:- |" # Used if you want to do some weird stuff to egs + # such as removing one of the targets + +ovlp_data_dir=data/train_aztec_unsad_seg_ovlp_corrupted_hires_bp + +#extra_left_context=79 +#extra_right_context=11 + +egs_dir= +nj=40 +feat_type=raw +config_dir= + +dir= +affix=f + +. cmd.sh +. ./path.sh +. ./utils/parse_options.sh + +num_utts=`cat $ovlp_data_dir/utt2spk | wc -l` +num_utts_subset_valid=`perl -e '$n=int($ARGV[0] * 0.005); print ($n > 4000 ? 4000 : $n)' $num_utts` +num_utts_subset_train=`perl -e '$n=int($ARGV[0] * 0.005); print ($n > 4000 ? 4000 : $n)' $num_utts` + +if [ -z "$dir" ]; then + dir=exp/nnet3_stats_ovlp/nnet_tdnn +fi + +dir=$dir${affix:+_$affix} + +if ! cuda-compiled; then + cat < $dir/configs/network.xconfig + input dim=`feat-to-dim scp:$ovlp_data_dir/feats.scp -` name=input + output name=output-temp input=Append(-3,-2,-1,0,1,2,3) + + relu-renorm-layer name=tdnn1 input=Append(input@-3, input@-2, input@-1, input, input@1, input@2, input@3) dim=256 + stats-layer name=tdnn2_stats config=mean+count(-99:3:9:99) + relu-renorm-layer name=tdnn2 input=Append(tdnn1@-6, tdnn1, tdnn2_stats) dim=256 + relu-renorm-layer name=tdnn3 input=Append(-9,0,3) dim=256 + relu-renorm-layer name=tdnn4 dim=256 + + output-layer name=output-overlapped_speech include-log-softmax=true dim=2 input=tdnn4 presoftmax-scale-file=$dir/presoftmax_prior_scale_output-overlapped_speech.txt max-change=0.75 learning-rate-factor=0.1 +EOF + steps/nnet3/xconfig_to_configs.py --xconfig-file $dir/configs/network.xconfig \ + --config-dir $dir/configs/ \ + --nnet-edits="rename-node old-name=output-overlapped_speech new-name=output" + + cat <> $dir/configs/vars +add_lda=false +EOF +fi + +samples_per_iter=`perl -e "print int(400000 / $chunk_width)"` + +if [ -z "$egs_dir" ]; then + egs_dir=$dir/egs_ovlp + if [ $stage -le 3 ]; then + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $dir/egs_ovlp/storage ]; then + utils/create_split_dir.pl \ + /export/b{03,04,05,06}/$USER/kaldi-data/egs/aspire-$(date +'%m_%d_%H_%M')/s5/$dir/egs_ovlp/storage $dir/egs_ovlp/storage + fi + + . $dir/configs/vars + + steps/nnet3/get_egs_multiple_targets.py --cmd="$decode_cmd" \ + $egs_opts \ + --feat.dir="$ovlp_data_dir" \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --frames-per-eg=$chunk_width \ + --left-context=$[model_left_context + extra_left_context] \ + --right-context=$[model_right_context + extra_right_context] \ + --num-utts-subset-train=$num_utts_subset_train \ + --num-utts-subset-valid=$num_utts_subset_valid \ + --samples-per-iter=$samples_per_iter \ + --stage=$get_egs_stage \ + --targets-parameters="--output-name=output-speech --target-type=sparse --dim=2 --targets-scp=$ovlp_data_dir/speech_feat.scp --deriv-weights-scp=$ovlp_data_dir/deriv_weights.scp --scp2ark-cmd=\"extract-column --column-index=0 scp:- ark,t:- | steps/segmentation/quantize_vector.pl | ali-to-post ark,t:- ark:- |\"" \ + --targets-parameters="--output-name=output-overlapped_speech --target-type=sparse --dim=2 --targets-scp=$ovlp_data_dir/overlapped_speech_labels_fixed.scp --deriv-weights-scp=$ovlp_data_dir/deriv_weights_for_overlapped_speech.scp --scp2ark-cmd=\"ali-to-post scp:- ark:- |\"" \ + --generate-egs-scp=true \ + --dir=$dir/egs_ovlp + fi +fi + +if [ $stage -le 5 ]; then + steps/nnet3/train_raw_rnn.py --stage=$train_stage \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --egs.chunk-width=$chunk_width \ + --egs.dir="$egs_dir" --egs.stage=$get_egs_stage \ + --egs.chunk-left-context=$extra_left_context \ + --egs.chunk-right-context=$extra_right_context \ + ${extra_egs_copy_cmd:+--egs.extra-copy-cmd="$extra_egs_copy_cmd"} \ + --trainer.min-chunk-left-context=$min_extra_left_context \ + --trainer.min-chunk-right-context=$min_extra_right_context \ + --trainer.num-epochs=$num_epochs \ + --trainer.samples-per-iter=20000 \ + --trainer.optimization.num-jobs-initial=$num_jobs_initial \ + --trainer.optimization.num-jobs-final=$num_jobs_final \ + --trainer.optimization.initial-effective-lrate=$initial_effective_lrate \ + --trainer.optimization.final-effective-lrate=$final_effective_lrate \ + --trainer.optimization.shrink-value=1.0 \ + --trainer.rnn.num-chunk-per-minibatch=$num_chunk_per_minibatch \ + --trainer.deriv-truncate-margin=8 \ + --trainer.max-param-change=$max_param_change \ + --cmd="$decode_cmd" --nj 40 \ + --cleanup=true \ + --cleanup.remove-egs=$remove_egs \ + --cleanup.preserve-model-interval=10 \ + --use-gpu=true \ + --use-dense-targets=false \ + --feat-dir=$ovlp_data_dir \ + --targets-scp="$ovlp_data_dir/overlapped_spech_labels.scp" \ + --dir=$dir || exit 1 +fi diff --git a/egs/aspire/s5/local/segmentation/tuning/train_stats_overlap_1h.sh b/egs/aspire/s5/local/segmentation/tuning/train_stats_overlap_1h.sh new file mode 100755 index 00000000000..adc4fc81c08 --- /dev/null +++ b/egs/aspire/s5/local/segmentation/tuning/train_stats_overlap_1h.sh @@ -0,0 +1,202 @@ +#!/bin/bash + +# This is a script to train a time-delay neural network for overlapped speech activity detection +# using statistic pooling component for long-context information. + +# This scripts is similar to 1f but adds max-change=0.75 and learning-rate-factor=0.02 to the final affine. +# And changed relu-dim to 512. + +set -o pipefail +set -e +set -u + +. cmd.sh + +# At this script level we don't support not running on GPU, as it would be painfully slow. +# If you want to run without GPU you'd have to call train_tdnn.sh with --gpu false, +# --num-threads 16 and --minibatch-size 128. + +stage=0 +train_stage=-10 +get_egs_stage=-10 +egs_opts= # Directly passed to get_egs_multiple_targets.py + +# TDNN options +chunk_width=40 # We use chunk training for training TDNN +num_chunk_per_minibatch=64 + +extra_left_context=100 # Maximum left context in egs apart from TDNN's left context +extra_right_context=20 # Maximum right context in egs apart from TDNN's right context + +# We randomly select an extra {left,right} context for each job between +# min_extra_*_context and extra_*_context so that the network can get used +# to different contexts used to compute statistics. +min_extra_left_context=20 +min_extra_right_context=0 + +# training options +num_epochs=2 +initial_effective_lrate=0.0003 +final_effective_lrate=0.00003 +num_jobs_initial=3 +num_jobs_final=8 +remove_egs=false +max_param_change=0.2 # Small max-param change for small network +extra_egs_copy_cmd="nnet3-copy-egs --keep-outputs=output-overlapped_speech ark:- ark:- |" # Used if you want to do some weird stuff to egs + # such as removing one of the targets + +ovlp_data_dir=data/train_aztec_unsad_seg_ovlp_corrupted_hires_bp + +#extra_left_context=79 +#extra_right_context=11 + +egs_dir= +nj=40 +feat_type=raw +config_dir= + +dir= +affix=f + +. cmd.sh +. ./path.sh +. ./utils/parse_options.sh + +num_utts=`cat $ovlp_data_dir/utt2spk | wc -l` +num_utts_subset_valid=`perl -e '$n=int($ARGV[0] * 0.005); print ($n > 4000 ? 4000 : $n)' $num_utts` +num_utts_subset_train=`perl -e '$n=int($ARGV[0] * 0.005); print ($n > 4000 ? 4000 : $n)' $num_utts` + +if [ -z "$dir" ]; then + dir=exp/nnet3_stats_ovlp/nnet_tdnn +fi + +dir=$dir${affix:+_$affix} + +if ! cuda-compiled; then + cat < $dir/configs/network.xconfig + input dim=`feat-to-dim scp:$ovlp_data_dir/feats.scp -` name=input + output name=output-temp input=Append(-3,-2,-1,0,1,2,3) + + relu-renorm-layer name=tdnn1 input=Append(input@-3, input@-2, input@-1, input, input@1, input@2, input@3) dim=512 + stats-layer name=tdnn2_stats config=mean+count(-99:3:9:99) + relu-renorm-layer name=tdnn2 input=Append(tdnn1@-6, tdnn1, tdnn2_stats) dim=512 + relu-renorm-layer name=tdnn3 input=Append(-9,0,3) dim=512 + relu-renorm-layer name=tdnn4 dim=512 + + output-layer name=output-overlapped_speech include-log-softmax=true dim=2 input=tdnn4 presoftmax-scale-file=$dir/presoftmax_prior_scale_output-overlapped_speech.txt max-change=0.75 learning-rate-factor=0.02 +EOF + steps/nnet3/xconfig_to_configs.py --xconfig-file $dir/configs/network.xconfig \ + --config-dir $dir/configs/ \ + --nnet-edits="rename-node old-name=output-overlapped_speech new-name=output" + + cat <> $dir/configs/vars +add_lda=false +EOF +fi + +samples_per_iter=`perl -e "print int(400000 / $chunk_width)"` + +if [ -z "$egs_dir" ]; then + egs_dir=$dir/egs_ovlp + if [ $stage -le 3 ]; then + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $dir/egs_ovlp/storage ]; then + utils/create_split_dir.pl \ + /export/b{03,04,05,06}/$USER/kaldi-data/egs/aspire-$(date +'%m_%d_%H_%M')/s5/$dir/egs_ovlp/storage $dir/egs_ovlp/storage + fi + + . $dir/configs/vars + + steps/nnet3/get_egs_multiple_targets.py --cmd="$decode_cmd" \ + $egs_opts \ + --feat.dir="$ovlp_data_dir" \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --frames-per-eg=$chunk_width \ + --left-context=$[model_left_context + extra_left_context] \ + --right-context=$[model_right_context + extra_right_context] \ + --num-utts-subset-train=$num_utts_subset_train \ + --num-utts-subset-valid=$num_utts_subset_valid \ + --samples-per-iter=$samples_per_iter \ + --stage=$get_egs_stage \ + --targets-parameters="--output-name=output-speech --target-type=sparse --dim=2 --targets-scp=$ovlp_data_dir/speech_feat.scp --deriv-weights-scp=$ovlp_data_dir/deriv_weights.scp --scp2ark-cmd=\"extract-column --column-index=0 scp:- ark,t:- | steps/segmentation/quantize_vector.pl | ali-to-post ark,t:- ark:- |\"" \ + --targets-parameters="--output-name=output-overlapped_speech --target-type=sparse --dim=2 --targets-scp=$ovlp_data_dir/overlapped_speech_labels_fixed.scp --deriv-weights-scp=$ovlp_data_dir/deriv_weights_for_overlapped_speech.scp --scp2ark-cmd=\"ali-to-post scp:- ark:- |\"" \ + --generate-egs-scp=true \ + --dir=$dir/egs_ovlp + fi +fi + +if [ $stage -le 5 ]; then + steps/nnet3/train_raw_rnn.py --stage=$train_stage \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --egs.chunk-width=$chunk_width \ + --egs.dir="$egs_dir" --egs.stage=$get_egs_stage \ + --egs.chunk-left-context=$extra_left_context \ + --egs.chunk-right-context=$extra_right_context \ + ${extra_egs_copy_cmd:+--egs.extra-copy-cmd="$extra_egs_copy_cmd"} \ + --trainer.min-chunk-left-context=$min_extra_left_context \ + --trainer.min-chunk-right-context=$min_extra_right_context \ + --trainer.num-epochs=$num_epochs \ + --trainer.samples-per-iter=20000 \ + --trainer.optimization.num-jobs-initial=$num_jobs_initial \ + --trainer.optimization.num-jobs-final=$num_jobs_final \ + --trainer.optimization.initial-effective-lrate=$initial_effective_lrate \ + --trainer.optimization.final-effective-lrate=$final_effective_lrate \ + --trainer.optimization.shrink-value=1.0 \ + --trainer.rnn.num-chunk-per-minibatch=$num_chunk_per_minibatch \ + --trainer.deriv-truncate-margin=8 \ + --trainer.max-param-change=$max_param_change \ + --cmd="$decode_cmd" --nj 40 \ + --cleanup=true \ + --cleanup.remove-egs=$remove_egs \ + --cleanup.preserve-model-interval=10 \ + --use-gpu=true \ + --use-dense-targets=false \ + --feat-dir=$ovlp_data_dir \ + --targets-scp="$ovlp_data_dir/overlapped_spech_labels.scp" \ + --dir=$dir || exit 1 +fi diff --git a/egs/aspire/s5/local/segmentation/tuning/train_stats_overlap_1i.sh b/egs/aspire/s5/local/segmentation/tuning/train_stats_overlap_1i.sh new file mode 100755 index 00000000000..dcd11ad2aa6 --- /dev/null +++ b/egs/aspire/s5/local/segmentation/tuning/train_stats_overlap_1i.sh @@ -0,0 +1,202 @@ +#!/bin/bash + +# This is a script to train a time-delay neural network for overlapped speech activity detection +# using statistic pooling component for long-context information. + +# This scripts is similar to 1f but adds max-change=0.75 and learning-rate-factor=0.02 to the final affine. +# Similar to 1g but moved stats pooling to higher layer. Changed splicing to -12 from -9. + +set -o pipefail +set -e +set -u + +. cmd.sh + +# At this script level we don't support not running on GPU, as it would be painfully slow. +# If you want to run without GPU you'd have to call train_tdnn.sh with --gpu false, +# --num-threads 16 and --minibatch-size 128. + +stage=0 +train_stage=-10 +get_egs_stage=-10 +egs_opts= # Directly passed to get_egs_multiple_targets.py + +# TDNN options +chunk_width=40 # We use chunk training for training TDNN +num_chunk_per_minibatch=64 + +extra_left_context=90 # Maximum left context in egs apart from TDNN's left context +extra_right_context=15 # Maximum right context in egs apart from TDNN's right context + +# We randomly select an extra {left,right} context for each job between +# min_extra_*_context and extra_*_context so that the network can get used +# to different contexts used to compute statistics. +min_extra_left_context=20 +min_extra_right_context=0 + +# training options +num_epochs=2 +initial_effective_lrate=0.0003 +final_effective_lrate=0.00003 +num_jobs_initial=3 +num_jobs_final=8 +remove_egs=false +max_param_change=0.2 # Small max-param change for small network +extra_egs_copy_cmd="nnet3-copy-egs --keep-outputs=output-overlapped_speech ark:- ark:- |" # Used if you want to do some weird stuff to egs + # such as removing one of the targets + +ovlp_data_dir=data/train_aztec_unsad_seg_ovlp_corrupted_hires_bp + +#extra_left_context=79 +#extra_right_context=11 + +egs_dir= +nj=40 +feat_type=raw +config_dir= + +dir= +affix=f + +. cmd.sh +. ./path.sh +. ./utils/parse_options.sh + +num_utts=`cat $ovlp_data_dir/utt2spk | wc -l` +num_utts_subset_valid=`perl -e '$n=int($ARGV[0] * 0.005); print ($n > 4000 ? 4000 : $n)' $num_utts` +num_utts_subset_train=`perl -e '$n=int($ARGV[0] * 0.005); print ($n > 4000 ? 4000 : $n)' $num_utts` + +if [ -z "$dir" ]; then + dir=exp/nnet3_stats_ovlp/nnet_tdnn +fi + +dir=$dir${affix:+_$affix} + +if ! cuda-compiled; then + cat < $dir/configs/network.xconfig + input dim=`feat-to-dim scp:$ovlp_data_dir/feats.scp -` name=input + output name=output-temp input=Append(-3,-2,-1,0,1,2,3) + + relu-renorm-layer name=tdnn1 input=Append(input@-3, input@-2, input@-1, input, input@1, input@2, input@3) dim=512 + relu-renorm-layer name=tdnn2 input=Append(tdnn1@-6, tdnn1) dim=512 + stats-layer name=tdnn3_stats config=mean+count(-96:6:12:96) + relu-renorm-layer name=tdnn3 input=Append(tdnn2@-12,tdnn2,tdnn2@6, tdnn3_stats) dim=512 + relu-renorm-layer name=tdnn4 dim=512 + + output-layer name=output-overlapped_speech include-log-softmax=true dim=2 input=tdnn4 presoftmax-scale-file=$dir/presoftmax_prior_scale_output-overlapped_speech.txt max-change=0.75 learning-rate-factor=0.02 +EOF + steps/nnet3/xconfig_to_configs.py --xconfig-file $dir/configs/network.xconfig \ + --config-dir $dir/configs/ \ + --nnet-edits="rename-node old-name=output-overlapped_speech new-name=output" + + cat <> $dir/configs/vars +add_lda=false +EOF +fi + +samples_per_iter=`perl -e "print int(400000 / $chunk_width)"` + +if [ -z "$egs_dir" ]; then + egs_dir=$dir/egs_ovlp + if [ $stage -le 3 ]; then + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $dir/egs_ovlp/storage ]; then + utils/create_split_dir.pl \ + /export/b{03,04,05,06}/$USER/kaldi-data/egs/aspire-$(date +'%m_%d_%H_%M')/s5/$dir/egs_ovlp/storage $dir/egs_ovlp/storage + fi + + . $dir/configs/vars + + steps/nnet3/get_egs_multiple_targets.py --cmd="$decode_cmd" \ + $egs_opts \ + --feat.dir="$ovlp_data_dir" \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --frames-per-eg=$chunk_width \ + --left-context=$[model_left_context + extra_left_context] \ + --right-context=$[model_right_context + extra_right_context] \ + --num-utts-subset-train=$num_utts_subset_train \ + --num-utts-subset-valid=$num_utts_subset_valid \ + --samples-per-iter=$samples_per_iter \ + --stage=$get_egs_stage \ + --targets-parameters="--output-name=output-speech --target-type=sparse --dim=2 --targets-scp=$ovlp_data_dir/speech_feat.scp --deriv-weights-scp=$ovlp_data_dir/deriv_weights.scp --scp2ark-cmd=\"extract-column --column-index=0 scp:- ark,t:- | steps/segmentation/quantize_vector.pl | ali-to-post ark,t:- ark:- |\"" \ + --targets-parameters="--output-name=output-overlapped_speech --target-type=sparse --dim=2 --targets-scp=$ovlp_data_dir/overlapped_speech_labels_fixed.scp --deriv-weights-scp=$ovlp_data_dir/deriv_weights_for_overlapped_speech.scp --scp2ark-cmd=\"ali-to-post scp:- ark:- |\"" \ + --generate-egs-scp=true \ + --dir=$dir/egs_ovlp + fi +fi + +if [ $stage -le 5 ]; then + steps/nnet3/train_raw_rnn.py --stage=$train_stage \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --egs.chunk-width=$chunk_width \ + --egs.dir="$egs_dir" --egs.stage=$get_egs_stage \ + --egs.chunk-left-context=$extra_left_context \ + --egs.chunk-right-context=$extra_right_context \ + ${extra_egs_copy_cmd:+--egs.extra-copy-cmd="$extra_egs_copy_cmd"} \ + --trainer.min-chunk-left-context=$min_extra_left_context \ + --trainer.min-chunk-right-context=$min_extra_right_context \ + --trainer.num-epochs=$num_epochs \ + --trainer.samples-per-iter=20000 \ + --trainer.optimization.num-jobs-initial=$num_jobs_initial \ + --trainer.optimization.num-jobs-final=$num_jobs_final \ + --trainer.optimization.initial-effective-lrate=$initial_effective_lrate \ + --trainer.optimization.final-effective-lrate=$final_effective_lrate \ + --trainer.optimization.shrink-value=1.0 \ + --trainer.rnn.num-chunk-per-minibatch=$num_chunk_per_minibatch \ + --trainer.deriv-truncate-margin=8 \ + --trainer.max-param-change=$max_param_change \ + --cmd="$decode_cmd" --nj 40 \ + --cleanup=true \ + --cleanup.remove-egs=$remove_egs \ + --cleanup.preserve-model-interval=10 \ + --use-gpu=true \ + --use-dense-targets=false \ + --feat-dir=$ovlp_data_dir \ + --targets-scp="$ovlp_data_dir/overlapped_spech_labels.scp" \ + --dir=$dir || exit 1 +fi diff --git a/egs/aspire/s5/local/segmentation/tuning/train_stats_sad_overlap_1a.sh b/egs/aspire/s5/local/segmentation/tuning/train_stats_sad_overlap_1a.sh old mode 100644 new mode 100755 diff --git a/egs/aspire/s5/local/segmentation/tuning/train_stats_sad_overlap_1b.sh b/egs/aspire/s5/local/segmentation/tuning/train_stats_sad_overlap_1b.sh old mode 100644 new mode 100755 index 888c25295d6..b562a83f6c3 --- a/egs/aspire/s5/local/segmentation/tuning/train_stats_sad_overlap_1b.sh +++ b/egs/aspire/s5/local/segmentation/tuning/train_stats_sad_overlap_1b.sh @@ -1,10 +1,10 @@ #!/bin/bash -# This is a script to train a time-delay neural network for overlapped speech activity detection +# This is a script to train a time-delay neural network for overlapped speech activity detection # using statistic pooling component for long-context information. set -o pipefail -set -e +set -e set -u . cmd.sh @@ -21,17 +21,19 @@ egs_opts= # Directly passed to get_egs_multiple_targets.py # TDNN options relu_dim=256 chunk_width=40 # We use chunk training for training TDNN -extra_left_context=100 # Maximum left context in egs apart from TDNN's left context -extra_right_context=20 # Maximum right context in egs apart from TDNN's right context +num_chunk_per_minibatch=64 + +extra_left_context=100 # Maximum left context in egs apart from TDNN's left context +extra_right_context=20 # Maximum right context in egs apart from TDNN's right context # We randomly select an extra {left,right} context for each job between # min_extra_*_context and extra_*_context so that the network can get used # to different contexts used to compute statistics. -min_extra_left_context=20 +min_extra_left_context=20 min_extra_right_context=0 # training options -num_epochs=1 +num_epochs=2 initial_effective_lrate=0.0003 final_effective_lrate=0.00003 num_jobs_initial=3 @@ -44,7 +46,7 @@ extra_egs_copy_cmd= # Used if you want to do some weird stuff to egs sad_data_dir=data/train_aztec_unsad_whole_corrupted_sp_hires_bp_2400 ovlp_data_dir=data/train_aztec_unsad_seg_ovlp_corrupted_hires_bp -#extra_left_context=79 +#extra_left_context=79 #extra_right_context=11 egs_dir= @@ -70,13 +72,13 @@ fi dir=$dir${affix:+_$affix} if ! cuda-compiled; then - cat < $dir/configs/network.xconfig input dim=`feat-to-dim scp:$sad_data_dir/feats.scp -` name=input output name=output-temp input=Append(-3,-2,-1,0,1,2,3) - + relu-renorm-layer name=tdnn1 input=Append(input@-3, input@-2, input@-1, input, input@1, input@2, input@3) dim=256 stats-layer name=tdnn2_stats config=mean+count(-99:3:9:99) relu-renorm-layer name=tdnn2 input=Append(tdnn1@-6, tdnn1, tdnn2_stats) dim=256 relu-renorm-layer name=tdnn3 input=Append(-9,0,3) dim=256 + relu-renorm-layer name=tdnn4 dim=256 - relu-renorm-layer name=pre-final-speech dim=256 input=tdnn3 - output-layer name=output-speech include-log-softmax=true dim=2 objective-scale=`perl -e "print ($num_frames_ovlp / $num_frames_sad) ** 0.25"` + output-layer name=output-speech include-log-softmax=true dim=2 objective-scale=`perl -e "print ($num_frames_ovlp / $num_frames_sad) ** 0.25"` input=tdnn4 - relu-renorm-layer name=pre-final-snr dim=256 input=tdnn3 - output-layer name=output-snr include-log-softmax=false dim=$num_snr_bins objective-type=quadratic objective-scale=`perl -e "print (($num_frames_ovlp / $num_frames_sad) ** 0.25) / $num_snr_bins"` + output-layer name=output-snr include-log-softmax=false dim=$num_snr_bins objective-type=quadratic objective-scale=`perl -e "print (($num_frames_ovlp / $num_frames_sad) ** 0.25) / $num_snr_bins"` input=tdnn4 - relu-renorm-layer name=pre-final-overlapped_speech dim=256 input=tdnn3 - output-layer name=output-overlapped_speech include-log-softmax=true dim=2 + output-layer name=output-overlapped_speech include-log-softmax=true dim=2 input=tdnn4 EOF steps/nnet3/xconfig_to_configs.py --xconfig-file $dir/configs/network.xconfig \ --config-dir $dir/configs/ \ --nnet-edits="rename-node old-name=output-speech new-name=output" - + cat <> $dir/configs/vars add_lda=false EOF @@ -148,11 +148,11 @@ if [ -z "$egs_dir" ]; then if [ $stage -le 2 ]; then if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $dir/egs_speech/storage ]; then utils/create_split_dir.pl \ - /export/b{03,04,05,06}/$USER/kaldi-data/egs_speech/aspire-$(date +'%m_%d_%H_%M')/s5/$dir/egs_speech/storage $dir/egs_speech/storage + /export/b{03,04,05,06}/$USER/kaldi-data/egs/aspire-$(date +'%m_%d_%H_%M')/s5/$dir/egs_speech/storage $dir/egs_speech/storage fi - + . $dir/configs/vars - + steps/nnet3/get_egs_multiple_targets.py --cmd="$decode_cmd" \ $egs_opts \ --feat.dir="$sad_data_dir" \ @@ -177,7 +177,7 @@ if [ -z "$egs_dir" ]; then fi . $dir/configs/vars - + steps/nnet3/get_egs_multiple_targets.py --cmd="$decode_cmd" \ $egs_opts \ --feat.dir="$ovlp_data_dir" \ @@ -196,10 +196,11 @@ if [ -z "$egs_dir" ]; then fi if [ $stage -le 4 ]; then - # num_chunk_per_minibatch is multiplied by 4 to allow a buffer to use + # num_chunk_per_minibatch is multiplied by 4 to allow a buffer to use # the same egs with a different num_chunk_per_minibatch steps/nnet3/multilingual/get_egs.sh \ - --minibatch-size $[chunk_width * num_chunk_per_minibatch * 4] \ + --cmd "$train_cmd" \ + --minibatch-size $[num_chunk_per_minibatch * 4] \ --samples-per-iter $samples_per_iter \ 2 $dir/egs_speech $dir/egs_ovlp $dir/egs_multi fi @@ -233,7 +234,7 @@ if [ $stage -le 5 ]; then --use-gpu=true \ --use-dense-targets=false \ --feat-dir=$sad_data_dir \ - --targets-scp="$speech_feat_scp" \ + --targets-scp="$sad_data_dir/speech_feat.scp" \ --dir=$dir || exit 1 fi diff --git a/egs/aspire/s5/local/segmentation/tuning/train_stats_sad_overlap_1c.sh b/egs/aspire/s5/local/segmentation/tuning/train_stats_sad_overlap_1c.sh new file mode 100755 index 00000000000..7041b0b3e9b --- /dev/null +++ b/egs/aspire/s5/local/segmentation/tuning/train_stats_sad_overlap_1c.sh @@ -0,0 +1,239 @@ +#!/bin/bash + +# This is a script to train a time-delay neural network for overlapped speech activity detection +# using statistic pooling component for long-context information. + +set -o pipefail +set -e +set -u + +. cmd.sh + +# At this script level we don't support not running on GPU, as it would be painfully slow. +# If you want to run without GPU you'd have to call train_tdnn.sh with --gpu false, +# --num-threads 16 and --minibatch-size 128. + +stage=0 +train_stage=-10 +get_egs_stage=-10 +egs_opts= # Directly passed to get_egs_multiple_targets.py + +# TDNN options +relu_dim=256 +chunk_width=40 # We use chunk training for training TDNN +num_chunk_per_minibatch=64 + +extra_left_context=100 # Maximum left context in egs apart from TDNN's left context +extra_right_context=20 # Maximum right context in egs apart from TDNN's right context + +# We randomly select an extra {left,right} context for each job between +# min_extra_*_context and extra_*_context so that the network can get used +# to different contexts used to compute statistics. +min_extra_left_context=20 +min_extra_right_context=0 + +# training options +num_epochs=2 +initial_effective_lrate=0.0003 +final_effective_lrate=0.00003 +num_jobs_initial=3 +num_jobs_final=8 +remove_egs=false +max_param_change=0.2 # Small max-param change for small network +extra_egs_copy_cmd= # Used if you want to do some weird stuff to egs + # such as removing one of the targets + +sad_data_dir=data/train_aztec_unsad_whole_corrupted_sp_hires_bp_2400 +ovlp_data_dir=data/train_aztec_unsad_seg_ovlp_corrupted_hires_bp + +#extra_left_context=79 +#extra_right_context=11 + +egs_dir= +nj=40 +feat_type=raw +config_dir= + +dir= +affix=b + +. cmd.sh +. ./path.sh +. ./utils/parse_options.sh + +num_utts=`cat $sad_data_dir/utt2spk $ovlp_data_dir/utt2spk | wc -l` +num_utts_subset_valid=`perl -e '$n=int($ARGV[0] * 0.005); print ($n > 4000 ? 4000 : $n)' $num_utts` +num_utts_subset_train=`perl -e '$n=int($ARGV[0] * 0.005); print ($n > 4000 ? 4000 : $n)' $num_utts` + +if [ -z "$dir" ]; then + dir=exp/nnet3_stats_sad_ovlp_snr/nnet_tdnn +fi + +dir=$dir${affix:+_$affix} + +if ! cuda-compiled; then + cat < $dir/configs/network.xconfig + input dim=`feat-to-dim scp:$sad_data_dir/feats.scp -` name=input + output name=output-temp input=Append(-3,-2,-1,0,1,2,3) + + relu-renorm-layer name=tdnn1 input=Append(input@-3, input@-2, input@-1, input, input@1, input@2, input@3) dim=256 + stats-layer name=tdnn2_stats config=mean+count(-99:3:9:99) + relu-renorm-layer name=tdnn2 input=Append(tdnn1@-6, tdnn1, tdnn2_stats) dim=256 + relu-renorm-layer name=tdnn3 input=Append(-9,0,3) dim=256 + relu-renorm-layer name=tdnn4 dim=256 + + output-layer name=output-speech include-log-softmax=true dim=2 objective-scale=`perl -e "print ($num_frames_ovlp / $num_frames_sad) ** 0.25"` input=tdnn4 + + output-layer name=output-snr include-log-softmax=false dim=$num_snr_bins objective-type=quadratic objective-scale=`perl -e "print (($num_frames_ovlp / $num_frames_sad) ** 0.25) / $num_snr_bins"` input=tdnn4 + + output-layer name=output-overlapped_speech include-log-softmax=true dim=2 input=tdnn4 +EOF + steps/nnet3/xconfig_to_configs.py --xconfig-file $dir/configs/network.xconfig \ + --config-dir $dir/configs/ \ + --nnet-edits="rename-node old-name=output-speech new-name=output" + + cat <> $dir/configs/vars +add_lda=false +EOF +fi + +samples_per_iter=`perl -e "print int(400000 / $chunk_width)"` + +if [ -z "$egs_dir" ]; then + egs_dir=$dir/egs_multi + if [ $stage -le 2 ]; then + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $dir/egs_speech/storage ]; then + utils/create_split_dir.pl \ + /export/b{03,04,05,06}/$USER/kaldi-data/egs/aspire-$(date +'%m_%d_%H_%M')/s5/$dir/egs_speech/storage $dir/egs_speech/storage + fi + + . $dir/configs/vars + + steps/nnet3/get_egs_multiple_targets.py --cmd="$decode_cmd" \ + $egs_opts \ + --feat.dir="$sad_data_dir" \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --frames-per-eg=$chunk_width \ + --left-context=$[model_left_context + extra_left_context] \ + --right-context=$[model_right_context + extra_right_context] \ + --num-utts-subset-train=$num_utts_subset_train \ + --num-utts-subset-valid=$num_utts_subset_valid \ + --samples-per-iter=$samples_per_iter \ + --stage=$get_egs_stage \ + --targets-parameters="--output-name=output-snr --target-type=dense --targets-scp=$sad_data_dir/irm_targets.scp --deriv-weights-scp=$sad_data_dir/deriv_weights_manual_seg.scp" \ + --targets-parameters="--output-name=output-speech --target-type=sparse --dim=2 --targets-scp=$sad_data_dir/speech_feat.scp --deriv-weights-scp=$sad_data_dir/deriv_weights.scp --scp2ark-cmd=\"extract-column --column-index=0 scp:- ark,t:- | steps/segmentation/quantize_vector.pl | ali-to-post ark,t:- ark:- |\" --compress=true" \ + --generate-egs-scp=true \ + --dir=$dir/egs_speech + fi + + if [ $stage -le 3 ]; then + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $dir/egs_ovlp/storage ]; then + utils/create_split_dir.pl \ + /export/b{03,04,05,06}/$USER/kaldi-data/egs/aspire-$(date +'%m_%d_%H_%M')/s5/$dir/egs_ovlp/storage $dir/egs_ovlp/storage + fi + + . $dir/configs/vars + + steps/nnet3/get_egs_multiple_targets.py --cmd="$decode_cmd" \ + $egs_opts \ + --feat.dir="$ovlp_data_dir" \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --frames-per-eg=$chunk_width \ + --left-context=$[model_left_context + extra_left_context] \ + --right-context=$[model_right_context + extra_right_context] \ + --num-utts-subset-train=$num_utts_subset_train \ + --num-utts-subset-valid=$num_utts_subset_valid \ + --samples-per-iter=$samples_per_iter \ + --stage=$get_egs_stage \ + --targets-parameters="--output-name=output-speech --target-type=sparse --dim=2 --targets-scp=$ovlp_data_dir/speech_feat.scp --deriv-weights-scp=$ovlp_data_dir/deriv_weights.scp --scp2ark-cmd=\"extract-column --column-index=0 scp:- ark,t:- | steps/segmentation/quantize_vector.pl | ali-to-post ark,t:- ark:- |\"" \ + --targets-parameters="--output-name=output-overlapped_speech --target-type=sparse --dim=2 --targets-scp=$ovlp_data_dir/overlapped_speech_labels_fixed.scp --deriv-weights-scp=$ovlp_data_dir/deriv_weights_for_overlapped_speech.scp --scp2ark-cmd=\"ali-to-post scp:- ark:- |\"" \ + --generate-egs-scp=true \ + --dir=$dir/egs_ovlp + fi + + if [ $stage -le 4 ]; then + # num_chunk_per_minibatch is multiplied by 4 to allow a buffer to use + # the same egs with a different num_chunk_per_minibatch + steps/nnet3/multilingual/get_egs.sh \ + --cmd "$train_cmd" \ + --minibatch-size $[num_chunk_per_minibatch * 4] \ + --samples-per-iter $samples_per_iter \ + 2 $dir/egs_speech $dir/egs_ovlp $dir/egs_multi + fi +fi + +if [ $stage -le 5 ]; then + steps/nnet3/train_raw_rnn.py --stage=$train_stage \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --egs.chunk-width=$chunk_width \ + --egs.dir="$egs_dir" --egs.stage=$get_egs_stage \ + --egs.chunk-left-context=$extra_left_context \ + --egs.chunk-right-context=$extra_right_context \ + --egs.use-multitask-egs=true --egs.rename-multitask-outputs=false \ + ${extra_egs_copy_cmd:+--egs.extra-copy-cmd="$extra_egs_copy_cmd"} \ + --trainer.min-chunk-left-context=$min_extra_left_context \ + --trainer.min-chunk-right-context=$min_extra_right_context \ + --trainer.num-epochs=$num_epochs \ + --trainer.samples-per-iter=20000 \ + --trainer.optimization.num-jobs-initial=$num_jobs_initial \ + --trainer.optimization.num-jobs-final=$num_jobs_final \ + --trainer.optimization.initial-effective-lrate=$initial_effective_lrate \ + --trainer.optimization.final-effective-lrate=$final_effective_lrate \ + --trainer.optimization.shrink-value=1.0 \ + --trainer.rnn.num-chunk-per-minibatch=$num_chunk_per_minibatch \ + --trainer.deriv-truncate-margin=8 \ + --trainer.max-param-change=$max_param_change \ + --cmd="$decode_cmd" --nj 40 \ + --cleanup=true \ + --cleanup.remove-egs=$remove_egs \ + --cleanup.preserve-model-interval=10 \ + --use-gpu=true \ + --use-dense-targets=false \ + --feat-dir=$sad_data_dir \ + --targets-scp="$sad_data_dir/speech_feat.scp" \ + --dir=$dir || exit 1 +fi diff --git a/egs/aspire/s5/local/segmentation/tuning/train_stats_sad_overlap_1d.sh b/egs/aspire/s5/local/segmentation/tuning/train_stats_sad_overlap_1d.sh new file mode 100755 index 00000000000..a361435baa1 --- /dev/null +++ b/egs/aspire/s5/local/segmentation/tuning/train_stats_sad_overlap_1d.sh @@ -0,0 +1,262 @@ +#!/bin/bash + +# This is a script to train a time-delay neural network for overlapped speech activity detection +# using statistic pooling component for long-context information. + +set -o pipefail +set -e +set -u + +. cmd.sh + +# At this script level we don't support not running on GPU, as it would be painfully slow. +# If you want to run without GPU you'd have to call train_tdnn.sh with --gpu false, +# --num-threads 16 and --minibatch-size 128. + +stage=0 +train_stage=-10 +get_egs_stage=-10 +egs_opts= # Directly passed to get_egs_multiple_targets.py + +# TDNN options +relu_dim=256 +chunk_width=40 # We use chunk training for training TDNN +num_chunk_per_minibatch=64 + +extra_left_context=100 # Maximum left context in egs apart from TDNN's left context +extra_right_context=20 # Maximum right context in egs apart from TDNN's right context + +# We randomly select an extra {left,right} context for each job between +# min_extra_*_context and extra_*_context so that the network can get used +# to different contexts used to compute statistics. +min_extra_left_context=20 +min_extra_right_context=0 + +# training options +num_epochs=2 +initial_effective_lrate=0.0003 +final_effective_lrate=0.00003 +num_jobs_initial=3 +num_jobs_final=8 +remove_egs=false +max_param_change=0.2 # Small max-param change for small network +extra_egs_copy_cmd= # Used if you want to do some weird stuff to egs + # such as removing one of the targets + +sad_data_dir=data/train_aztec_unsad_whole_corrupted_sp_hires_bp_2400 +ovlp_data_dir=data/train_aztec_unsad_seg_ovlp_corrupted_hires_bp + +#extra_left_context=79 +#extra_right_context=11 + +egs_dir= +nj=40 +feat_type=raw +config_dir= + +dir= +affix=d + +. cmd.sh +. ./path.sh +. ./utils/parse_options.sh + +num_utts=`cat $sad_data_dir/utt2spk $ovlp_data_dir/utt2spk | wc -l` +num_utts_subset_valid=`perl -e '$n=int($ARGV[0] * 0.005); print ($n > 4000 ? 4000 : $n)' $num_utts` +num_utts_subset_train=`perl -e '$n=int($ARGV[0] * 0.005); print ($n > 4000 ? 4000 : $n)' $num_utts` + +if [ -z "$dir" ]; then + dir=exp/nnet3_stats_sad_ovlp_snr/nnet_tdnn +fi + +dir=$dir${affix:+_$affix} + +if ! cuda-compiled; then + cat < $dir/configs/network.xconfig + input dim=`feat-to-dim scp:$sad_data_dir/feats.scp -` name=input + output name=output-temp input=Append(-3,-2,-1,0,1,2,3) + + relu-renorm-layer name=tdnn1 input=Append(input@-3, input@-2, input@-1, input, input@1, input@2, input@3) dim=256 + stats-layer name=tdnn2_stats config=mean+count(-99:3:9:99) + relu-renorm-layer name=tdnn2 input=Append(tdnn1@-6, tdnn1, tdnn2_stats) dim=256 + relu-renorm-layer name=tdnn3 input=Append(-9,0,3) dim=256 + relu-renorm-layer name=tdnn4 dim=256 + + output-layer name=output-speech include-log-softmax=true dim=2 objective-scale=`perl -e "print ($num_frames_ovlp / $num_frames_sad) ** 0.25"` input=tdnn4 presoftmax-scale-file=$dir/presoftmax_prior_scale_output-speech.txt + + output-layer name=output-snr include-log-softmax=false dim=$num_snr_bins objective-type=quadratic objective-scale=`perl -e "print (($num_frames_ovlp / $num_frames_sad) ** 0.25) / $num_snr_bins"` input=tdnn4 + + output-layer name=output-overlapped_speech include-log-softmax=true dim=2 input=tdnn4 presoftmax-scale-file=$dir/presoftmax_prior_scale_output-overlapped_speech.txt + +EOF + steps/nnet3/xconfig_to_configs.py --xconfig-file $dir/configs/network.xconfig \ + --config-dir $dir/configs/ \ + --nnet-edits="rename-node old-name=output-speech new-name=output" + + cat <> $dir/configs/vars +add_lda=false +EOF +fi + +samples_per_iter=`perl -e "print int(400000 / $chunk_width)"` + +if [ -z "$egs_dir" ]; then + egs_dir=$dir/egs_multi + if [ $stage -le 2 ]; then + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $dir/egs_speech/storage ]; then + utils/create_split_dir.pl \ + /export/b{03,04,05,06}/$USER/kaldi-data/egs/aspire-$(date +'%m_%d_%H_%M')/s5/$dir/egs_speech/storage $dir/egs_speech/storage + fi + + . $dir/configs/vars + + steps/nnet3/get_egs_multiple_targets.py --cmd="$decode_cmd" \ + $egs_opts \ + --feat.dir="$sad_data_dir" \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --frames-per-eg=$chunk_width \ + --left-context=$[model_left_context + extra_left_context] \ + --right-context=$[model_right_context + extra_right_context] \ + --num-utts-subset-train=$num_utts_subset_train \ + --num-utts-subset-valid=$num_utts_subset_valid \ + --samples-per-iter=$samples_per_iter \ + --stage=$get_egs_stage \ + --targets-parameters="--output-name=output-snr --target-type=dense --targets-scp=$sad_data_dir/irm_targets.scp --deriv-weights-scp=$sad_data_dir/deriv_weights_manual_seg.scp" \ + --targets-parameters="--output-name=output-speech --target-type=sparse --dim=2 --targets-scp=$sad_data_dir/speech_feat.scp --deriv-weights-scp=$sad_data_dir/deriv_weights.scp --scp2ark-cmd=\"extract-column --column-index=0 scp:- ark,t:- | steps/segmentation/quantize_vector.pl | ali-to-post ark,t:- ark:- |\" --compress=true" \ + --generate-egs-scp=true \ + --dir=$dir/egs_speech + fi + + if [ $stage -le 3 ]; then + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $dir/egs_ovlp/storage ]; then + utils/create_split_dir.pl \ + /export/b{03,04,05,06}/$USER/kaldi-data/egs/aspire-$(date +'%m_%d_%H_%M')/s5/$dir/egs_ovlp/storage $dir/egs_ovlp/storage + fi + + . $dir/configs/vars + + steps/nnet3/get_egs_multiple_targets.py --cmd="$decode_cmd" \ + $egs_opts \ + --feat.dir="$ovlp_data_dir" \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --frames-per-eg=$chunk_width \ + --left-context=$[model_left_context + extra_left_context] \ + --right-context=$[model_right_context + extra_right_context] \ + --num-utts-subset-train=$num_utts_subset_train \ + --num-utts-subset-valid=$num_utts_subset_valid \ + --samples-per-iter=$samples_per_iter \ + --stage=$get_egs_stage \ + --targets-parameters="--output-name=output-speech --target-type=sparse --dim=2 --targets-scp=$ovlp_data_dir/speech_feat.scp --deriv-weights-scp=$ovlp_data_dir/deriv_weights.scp --scp2ark-cmd=\"extract-column --column-index=0 scp:- ark,t:- | steps/segmentation/quantize_vector.pl | ali-to-post ark,t:- ark:- |\"" \ + --targets-parameters="--output-name=output-overlapped_speech --target-type=sparse --dim=2 --targets-scp=$ovlp_data_dir/overlapped_speech_labels_fixed.scp --deriv-weights-scp=$ovlp_data_dir/deriv_weights_for_overlapped_speech.scp --scp2ark-cmd=\"ali-to-post scp:- ark:- |\"" \ + --generate-egs-scp=true \ + --dir=$dir/egs_ovlp + fi + + if [ $stage -le 4 ]; then + # num_chunk_per_minibatch is multiplied by 4 to allow a buffer to use + # the same egs with a different num_chunk_per_minibatch + steps/nnet3/multilingual/get_egs.sh \ + --cmd "$train_cmd" \ + --minibatch-size $[num_chunk_per_minibatch * 4] \ + --samples-per-iter $samples_per_iter \ + 2 $dir/egs_speech $dir/egs_ovlp $dir/egs_multi + fi +fi + +if [ $stage -le 5 ]; then + steps/nnet3/train_raw_rnn.py --stage=$train_stage \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --egs.chunk-width=$chunk_width \ + --egs.dir="$egs_dir" --egs.stage=$get_egs_stage \ + --egs.chunk-left-context=$extra_left_context \ + --egs.chunk-right-context=$extra_right_context \ + --egs.use-multitask-egs=true --egs.rename-multitask-outputs=false \ + ${extra_egs_copy_cmd:+--egs.extra-copy-cmd="$extra_egs_copy_cmd"} \ + --trainer.min-chunk-left-context=$min_extra_left_context \ + --trainer.min-chunk-right-context=$min_extra_right_context \ + --trainer.num-epochs=$num_epochs \ + --trainer.samples-per-iter=20000 \ + --trainer.optimization.num-jobs-initial=$num_jobs_initial \ + --trainer.optimization.num-jobs-final=$num_jobs_final \ + --trainer.optimization.initial-effective-lrate=$initial_effective_lrate \ + --trainer.optimization.final-effective-lrate=$final_effective_lrate \ + --trainer.optimization.shrink-value=1.0 \ + --trainer.rnn.num-chunk-per-minibatch=$num_chunk_per_minibatch \ + --trainer.deriv-truncate-margin=8 \ + --trainer.max-param-change=$max_param_change \ + --cmd="$decode_cmd" --nj 40 \ + --cleanup=true \ + --cleanup.remove-egs=$remove_egs \ + --cleanup.preserve-model-interval=10 \ + --use-gpu=true \ + --use-dense-targets=false \ + --feat-dir=$sad_data_dir \ + --targets-scp="$sad_data_dir/speech_feat.scp" \ + --dir=$dir || exit 1 +fi diff --git a/egs/aspire/s5/local/segmentation/tuning/train_stats_sad_overlap_1f.sh b/egs/aspire/s5/local/segmentation/tuning/train_stats_sad_overlap_1f.sh new file mode 100755 index 00000000000..7048c40f62b --- /dev/null +++ b/egs/aspire/s5/local/segmentation/tuning/train_stats_sad_overlap_1f.sh @@ -0,0 +1,272 @@ +#!/bin/bash + +# This is a script to train a time-delay neural network for overlapped speech activity detection +# using statistic pooling component for long-context information. + +set -o pipefail +set -e +set -u + +. cmd.sh + +# At this script level we don't support not running on GPU, as it would be painfully slow. +# If you want to run without GPU you'd have to call train_tdnn.sh with --gpu false, +# --num-threads 16 and --minibatch-size 128. + +stage=0 +train_stage=-10 +get_egs_stage=-10 +egs_opts= # Directly passed to get_egs_multiple_targets.py + +# TDNN options +relu_dim=256 +chunk_width=40 # We use chunk training for training TDNN +num_chunk_per_minibatch=64 + +extra_left_context=100 # Maximum left context in egs apart from TDNN's left context +extra_right_context=20 # Maximum right context in egs apart from TDNN's right context + +# We randomly select an extra {left,right} context for each job between +# min_extra_*_context and extra_*_context so that the network can get used +# to different contexts used to compute statistics. +min_extra_left_context=20 +min_extra_right_context=0 + +# training options +num_epochs=2 +initial_effective_lrate=0.0003 +final_effective_lrate=0.00003 +num_jobs_initial=3 +num_jobs_final=8 +remove_egs=false +max_param_change=0.2 # Small max-param change for small network +extra_egs_copy_cmd= # Used if you want to do some weird stuff to egs + # such as removing one of the targets + +sad_data_dir=data/train_aztec_unsad_whole_corrupted_sp_hires_bp_2400 +ovlp_data_dir=data/train_aztec_unsad_seg_ovlp_corrupted_hires_bp + +#extra_left_context=79 +#extra_right_context=11 + +egs_dir= +nj=40 +feat_type=raw +config_dir= + +dir= +affix=d + +. cmd.sh +. ./path.sh +. ./utils/parse_options.sh + +num_utts=`cat $sad_data_dir/utt2spk $ovlp_data_dir/utt2spk | wc -l` +num_utts_subset_valid=`perl -e '$n=int($ARGV[0] * 0.005); print ($n > 4000 ? 4000 : $n)' $num_utts` +num_utts_subset_train=`perl -e '$n=int($ARGV[0] * 0.005); print ($n > 4000 ? 4000 : $n)' $num_utts` + +if [ -z "$dir" ]; then + dir=exp/nnet3_stats_sad_ovlp_snr/nnet_tdnn +fi + +dir=$dir${affix:+_$affix} + +if ! cuda-compiled; then + cat < $dir/configs/network.xconfig + input dim=`feat-to-dim scp:$sad_data_dir/feats.scp -` name=input + output name=output-temp input=Append(-3,-2,-1,0,1,2,3) + + relu-renorm-layer name=tdnn1 input=Append(input@-3, input@-2, input@-1, input, input@1, input@2, input@3) dim=256 + stats-layer name=tdnn2_stats config=mean+count(-99:3:9:99) + relu-renorm-layer name=tdnn2 input=Append(tdnn1@-6, tdnn1, tdnn2_stats) dim=256 + relu-renorm-layer name=tdnn3 input=Append(-9,0,3) dim=256 + relu-renorm-layer name=tdnn4 dim=256 + + output-layer name=output-speech include-log-softmax=true dim=2 objective-scale=$speech_scale input=tdnn4 presoftmax-scale-file=$dir/presoftmax_prior_scale_output-speech.txt + + output-layer name=output-snr include-log-softmax=false dim=$num_snr_bins objective-type=quadratic objective-scale=`perl -e "print $speech_scale / $num_snr_bins"` input=tdnn4 max-change=0.75 + + output-layer name=output-overlapped_speech include-log-softmax=true dim=2 objective-scale=$ovlp_scale input=tdnn4 presoftmax-scale-file=$dir/presoftmax_prior_scale_output-overlapped_speech.txt max-change=0.75 +EOF + steps/nnet3/xconfig_to_configs.py --xconfig-file $dir/configs/network.xconfig \ + --config-dir $dir/configs/ \ + --nnet-edits="rename-node old-name=output-speech new-name=output" + + cat <> $dir/configs/vars +add_lda=false +EOF +fi + +samples_per_iter=`perl -e "print int(400000 / $chunk_width)"` + +if [ -z "$egs_dir" ]; then + egs_dir=$dir/egs_multi + if [ $stage -le 2 ]; then + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $dir/egs_speech/storage ]; then + utils/create_split_dir.pl \ + /export/b{03,04,05,06}/$USER/kaldi-data/egs/aspire-$(date +'%m_%d_%H_%M')/s5/$dir/egs_speech/storage $dir/egs_speech/storage + fi + + . $dir/configs/vars + + steps/nnet3/get_egs_multiple_targets.py --cmd="$decode_cmd" \ + $egs_opts \ + --feat.dir="$sad_data_dir" \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --frames-per-eg=$chunk_width \ + --left-context=$[model_left_context + extra_left_context] \ + --right-context=$[model_right_context + extra_right_context] \ + --num-utts-subset-train=$num_utts_subset_train \ + --num-utts-subset-valid=$num_utts_subset_valid \ + --samples-per-iter=$samples_per_iter \ + --stage=$get_egs_stage \ + --targets-parameters="--output-name=output-snr --target-type=dense --targets-scp=$sad_data_dir/irm_targets.scp --deriv-weights-scp=$sad_data_dir/deriv_weights_manual_seg.scp" \ + --targets-parameters="--output-name=output-speech --target-type=sparse --dim=2 --targets-scp=$sad_data_dir/speech_feat.scp --deriv-weights-scp=$sad_data_dir/deriv_weights.scp --scp2ark-cmd=\"extract-column --column-index=0 scp:- ark,t:- | steps/segmentation/quantize_vector.pl | ali-to-post ark,t:- ark:- |\" --compress=true" \ + --generate-egs-scp=true \ + --dir=$dir/egs_speech + fi + + if [ $stage -le 3 ]; then + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $dir/egs_ovlp/storage ]; then + utils/create_split_dir.pl \ + /export/b{03,04,05,06}/$USER/kaldi-data/egs/aspire-$(date +'%m_%d_%H_%M')/s5/$dir/egs_ovlp/storage $dir/egs_ovlp/storage + fi + + . $dir/configs/vars + + steps/nnet3/get_egs_multiple_targets.py --cmd="$decode_cmd" \ + $egs_opts \ + --feat.dir="$ovlp_data_dir" \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --frames-per-eg=$chunk_width \ + --left-context=$[model_left_context + extra_left_context] \ + --right-context=$[model_right_context + extra_right_context] \ + --num-utts-subset-train=$num_utts_subset_train \ + --num-utts-subset-valid=$num_utts_subset_valid \ + --samples-per-iter=$samples_per_iter \ + --stage=$get_egs_stage \ + --targets-parameters="--output-name=output-speech --target-type=sparse --dim=2 --targets-scp=$ovlp_data_dir/speech_feat.scp --deriv-weights-scp=$ovlp_data_dir/deriv_weights.scp --scp2ark-cmd=\"extract-column --column-index=0 scp:- ark,t:- | steps/segmentation/quantize_vector.pl | ali-to-post ark,t:- ark:- |\"" \ + --targets-parameters="--output-name=output-overlapped_speech --target-type=sparse --dim=2 --targets-scp=$ovlp_data_dir/overlapped_speech_labels_fixed.scp --deriv-weights-scp=$ovlp_data_dir/deriv_weights_for_overlapped_speech.scp --scp2ark-cmd=\"ali-to-post scp:- ark:- |\"" \ + --generate-egs-scp=true \ + --dir=$dir/egs_ovlp + fi + + if [ $stage -le 4 ]; then + # num_chunk_per_minibatch is multiplied by 4 to allow a buffer to use + # the same egs with a different num_chunk_per_minibatch + steps/nnet3/multilingual/get_egs.sh \ + --cmd "$train_cmd" \ + --minibatch-size $[num_chunk_per_minibatch * 4] \ + --samples-per-iter $samples_per_iter \ + 2 $dir/egs_speech $dir/egs_ovlp $dir/egs_multi + fi +fi + +if [ $stage -le 5 ]; then + steps/nnet3/train_raw_rnn.py --stage=$train_stage \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --egs.chunk-width=$chunk_width \ + --egs.dir="$egs_dir" --egs.stage=$get_egs_stage \ + --egs.chunk-left-context=$extra_left_context \ + --egs.chunk-right-context=$extra_right_context \ + --egs.use-multitask-egs=true --egs.rename-multitask-outputs=false \ + ${extra_egs_copy_cmd:+--egs.extra-copy-cmd="$extra_egs_copy_cmd"} \ + --trainer.min-chunk-left-context=$min_extra_left_context \ + --trainer.min-chunk-right-context=$min_extra_right_context \ + --trainer.num-epochs=$num_epochs \ + --trainer.samples-per-iter=20000 \ + --trainer.optimization.num-jobs-initial=$num_jobs_initial \ + --trainer.optimization.num-jobs-final=$num_jobs_final \ + --trainer.optimization.initial-effective-lrate=$initial_effective_lrate \ + --trainer.optimization.final-effective-lrate=$final_effective_lrate \ + --trainer.optimization.shrink-value=1.0 \ + --trainer.rnn.num-chunk-per-minibatch=$num_chunk_per_minibatch \ + --trainer.deriv-truncate-margin=8 \ + --trainer.max-param-change=$max_param_change \ + --cmd="$decode_cmd" --nj 40 \ + --cleanup=true \ + --cleanup.remove-egs=$remove_egs \ + --cleanup.preserve-model-interval=10 \ + --use-gpu=true \ + --use-dense-targets=false \ + --feat-dir=$sad_data_dir \ + --targets-scp="$sad_data_dir/speech_feat.scp" \ + --dir=$dir || exit 1 +fi diff --git a/egs/aspire/s5/local/segmentation/tuning/train_stats_sad_overlap_1g.sh b/egs/aspire/s5/local/segmentation/tuning/train_stats_sad_overlap_1g.sh new file mode 100755 index 00000000000..72e26b5347b --- /dev/null +++ b/egs/aspire/s5/local/segmentation/tuning/train_stats_sad_overlap_1g.sh @@ -0,0 +1,275 @@ +#!/bin/bash + +# This is a script to train a time-delay neural network for overlapped speech activity detection +# using statistic pooling component for long-context information. + +# This script is same as 1e but adds max-change=0.75 for snr and overlapped_speech outputs +# and learning rate factor 0.1 for the final affine components. + +set -o pipefail +set -e +set -u + +. cmd.sh + +# At this script level we don't support not running on GPU, as it would be painfully slow. +# If you want to run without GPU you'd have to call train_tdnn.sh with --gpu false, +# --num-threads 16 and --minibatch-size 128. + +stage=0 +train_stage=-10 +get_egs_stage=-10 +egs_opts= # Directly passed to get_egs_multiple_targets.py + +# TDNN options +relu_dim=256 +chunk_width=40 # We use chunk training for training TDNN +num_chunk_per_minibatch=64 + +extra_left_context=100 # Maximum left context in egs apart from TDNN's left context +extra_right_context=20 # Maximum right context in egs apart from TDNN's right context + +# We randomly select an extra {left,right} context for each job between +# min_extra_*_context and extra_*_context so that the network can get used +# to different contexts used to compute statistics. +min_extra_left_context=20 +min_extra_right_context=0 + +# training options +num_epochs=2 +initial_effective_lrate=0.0003 +final_effective_lrate=0.00003 +num_jobs_initial=3 +num_jobs_final=8 +remove_egs=false +max_param_change=0.2 # Small max-param change for small network +extra_egs_copy_cmd= # Used if you want to do some weird stuff to egs + # such as removing one of the targets + +sad_data_dir=data/train_aztec_unsad_whole_corrupted_sp_hires_bp_2400 +ovlp_data_dir=data/train_aztec_unsad_seg_ovlp_corrupted_hires_bp + +#extra_left_context=79 +#extra_right_context=11 + +egs_dir= +nj=40 +feat_type=raw +config_dir= + +dir= +affix=g + +. cmd.sh +. ./path.sh +. ./utils/parse_options.sh + +num_utts=`cat $sad_data_dir/utt2spk $ovlp_data_dir/utt2spk | wc -l` +num_utts_subset_valid=`perl -e '$n=int($ARGV[0] * 0.005); print ($n > 4000 ? 4000 : $n)' $num_utts` +num_utts_subset_train=`perl -e '$n=int($ARGV[0] * 0.005); print ($n > 4000 ? 4000 : $n)' $num_utts` + +if [ -z "$dir" ]; then + dir=exp/nnet3_stats_sad_ovlp_snr/nnet_tdnn +fi + +dir=$dir${affix:+_$affix} + +if ! cuda-compiled; then + cat < $dir/configs/network.xconfig + input dim=`feat-to-dim scp:$sad_data_dir/feats.scp -` name=input + output name=output-temp input=Append(-3,-2,-1,0,1,2,3) + + relu-renorm-layer name=tdnn1 input=Append(input@-3, input@-2, input@-1, input, input@1, input@2, input@3) dim=256 + stats-layer name=tdnn2_stats config=mean+count(-99:3:9:99) + relu-renorm-layer name=tdnn2 input=Append(tdnn1@-6, tdnn1, tdnn2_stats) dim=256 + relu-renorm-layer name=tdnn3 input=Append(-9,0,3) dim=256 + relu-renorm-layer name=tdnn4 dim=256 + + output-layer name=output-speech include-log-softmax=true dim=2 objective-scale=$speech_scale input=tdnn4 presoftmax-scale-file=$dir/presoftmax_prior_scale_output-speech.txt learning-rate-factor=0.1 + + output-layer name=output-snr include-log-softmax=false dim=$num_snr_bins objective-type=quadratic objective-scale=`perl -e "print $speech_scale / $num_snr_bins"` input=tdnn4 max-change=0.75 + + output-layer name=output-overlapped_speech include-log-softmax=true dim=2 objective-scale=$ovlp_scale input=tdnn4 presoftmax-scale-file=$dir/presoftmax_prior_scale_output-overlapped_speech.txt max-change=0.75 learning-rate-factor=0.1 +EOF + steps/nnet3/xconfig_to_configs.py --xconfig-file $dir/configs/network.xconfig \ + --config-dir $dir/configs/ \ + --nnet-edits="rename-node old-name=output-speech new-name=output" + + cat <> $dir/configs/vars +add_lda=false +EOF +fi + +samples_per_iter=`perl -e "print int(400000 / $chunk_width)"` + +if [ -z "$egs_dir" ]; then + egs_dir=$dir/egs_multi + if [ $stage -le 2 ]; then + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $dir/egs_speech/storage ]; then + utils/create_split_dir.pl \ + /export/b{03,04,05,06}/$USER/kaldi-data/egs/aspire-$(date +'%m_%d_%H_%M')/s5/$dir/egs_speech/storage $dir/egs_speech/storage + fi + + . $dir/configs/vars + + steps/nnet3/get_egs_multiple_targets.py --cmd="$decode_cmd" \ + $egs_opts \ + --feat.dir="$sad_data_dir" \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --frames-per-eg=$chunk_width \ + --left-context=$[model_left_context + extra_left_context] \ + --right-context=$[model_right_context + extra_right_context] \ + --num-utts-subset-train=$num_utts_subset_train \ + --num-utts-subset-valid=$num_utts_subset_valid \ + --samples-per-iter=$samples_per_iter \ + --stage=$get_egs_stage \ + --targets-parameters="--output-name=output-snr --target-type=dense --targets-scp=$sad_data_dir/irm_targets.scp --deriv-weights-scp=$sad_data_dir/deriv_weights_manual_seg.scp" \ + --targets-parameters="--output-name=output-speech --target-type=sparse --dim=2 --targets-scp=$sad_data_dir/speech_feat.scp --deriv-weights-scp=$sad_data_dir/deriv_weights.scp --scp2ark-cmd=\"extract-column --column-index=0 scp:- ark,t:- | steps/segmentation/quantize_vector.pl | ali-to-post ark,t:- ark:- |\" --compress=true" \ + --generate-egs-scp=true \ + --dir=$dir/egs_speech + fi + + if [ $stage -le 3 ]; then + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $dir/egs_ovlp/storage ]; then + utils/create_split_dir.pl \ + /export/b{03,04,05,06}/$USER/kaldi-data/egs/aspire-$(date +'%m_%d_%H_%M')/s5/$dir/egs_ovlp/storage $dir/egs_ovlp/storage + fi + + . $dir/configs/vars + + steps/nnet3/get_egs_multiple_targets.py --cmd="$decode_cmd" \ + $egs_opts \ + --feat.dir="$ovlp_data_dir" \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --frames-per-eg=$chunk_width \ + --left-context=$[model_left_context + extra_left_context] \ + --right-context=$[model_right_context + extra_right_context] \ + --num-utts-subset-train=$num_utts_subset_train \ + --num-utts-subset-valid=$num_utts_subset_valid \ + --samples-per-iter=$samples_per_iter \ + --stage=$get_egs_stage \ + --targets-parameters="--output-name=output-speech --target-type=sparse --dim=2 --targets-scp=$ovlp_data_dir/speech_feat.scp --deriv-weights-scp=$ovlp_data_dir/deriv_weights.scp --scp2ark-cmd=\"extract-column --column-index=0 scp:- ark,t:- | steps/segmentation/quantize_vector.pl | ali-to-post ark,t:- ark:- |\"" \ + --targets-parameters="--output-name=output-overlapped_speech --target-type=sparse --dim=2 --targets-scp=$ovlp_data_dir/overlapped_speech_labels_fixed.scp --deriv-weights-scp=$ovlp_data_dir/deriv_weights_for_overlapped_speech.scp --scp2ark-cmd=\"ali-to-post scp:- ark:- |\"" \ + --generate-egs-scp=true \ + --dir=$dir/egs_ovlp + fi + + if [ $stage -le 4 ]; then + # num_chunk_per_minibatch is multiplied by 4 to allow a buffer to use + # the same egs with a different num_chunk_per_minibatch + steps/nnet3/multilingual/get_egs.sh \ + --cmd "$train_cmd" \ + --minibatch-size $[num_chunk_per_minibatch * 4] \ + --samples-per-iter $samples_per_iter \ + 2 $dir/egs_speech $dir/egs_ovlp $dir/egs_multi + fi +fi + +if [ $stage -le 5 ]; then + steps/nnet3/train_raw_rnn.py --stage=$train_stage \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --egs.chunk-width=$chunk_width \ + --egs.dir="$egs_dir" --egs.stage=$get_egs_stage \ + --egs.chunk-left-context=$extra_left_context \ + --egs.chunk-right-context=$extra_right_context \ + --egs.use-multitask-egs=true --egs.rename-multitask-outputs=false \ + ${extra_egs_copy_cmd:+--egs.extra-copy-cmd="$extra_egs_copy_cmd"} \ + --trainer.min-chunk-left-context=$min_extra_left_context \ + --trainer.min-chunk-right-context=$min_extra_right_context \ + --trainer.num-epochs=$num_epochs \ + --trainer.samples-per-iter=20000 \ + --trainer.optimization.num-jobs-initial=$num_jobs_initial \ + --trainer.optimization.num-jobs-final=$num_jobs_final \ + --trainer.optimization.initial-effective-lrate=$initial_effective_lrate \ + --trainer.optimization.final-effective-lrate=$final_effective_lrate \ + --trainer.optimization.shrink-value=1.0 \ + --trainer.rnn.num-chunk-per-minibatch=$num_chunk_per_minibatch \ + --trainer.deriv-truncate-margin=8 \ + --trainer.max-param-change=$max_param_change \ + --cmd="$decode_cmd" --nj 40 \ + --cleanup=true \ + --cleanup.remove-egs=$remove_egs \ + --cleanup.preserve-model-interval=10 \ + --use-gpu=true \ + --use-dense-targets=false \ + --feat-dir=$sad_data_dir \ + --targets-scp="$sad_data_dir/speech_feat.scp" \ + --dir=$dir || exit 1 +fi diff --git a/egs/aspire/s5/local/segmentation/tuning/train_stats_sad_overlap_1h.sh b/egs/aspire/s5/local/segmentation/tuning/train_stats_sad_overlap_1h.sh new file mode 100755 index 00000000000..fb1616b9ac7 --- /dev/null +++ b/egs/aspire/s5/local/segmentation/tuning/train_stats_sad_overlap_1h.sh @@ -0,0 +1,276 @@ +#!/bin/bash + +# This is a script to train a time-delay neural network for overlapped speech activity detection +# using statistic pooling component for long-context information. + +# This script is same as 1e but adds max-change=0.75 for snr and overlapped_speech outputs +# and learning rate factor 0.01 for the final affine components. +# Decreased learning rate factor of overlapped speech to 0.025 and 0.05 for speech. +# Changed relu-dim to 512 + +set -o pipefail +set -e +set -u + +. cmd.sh + +# At this script level we don't support not running on GPU, as it would be painfully slow. +# If you want to run without GPU you'd have to call train_tdnn.sh with --gpu false, +# --num-threads 16 and --minibatch-size 128. + +stage=0 +train_stage=-10 +get_egs_stage=-10 +egs_opts= # Directly passed to get_egs_multiple_targets.py + +# TDNN options +chunk_width=40 # We use chunk training for training TDNN +num_chunk_per_minibatch=64 + +extra_left_context=100 # Maximum left context in egs apart from TDNN's left context +extra_right_context=20 # Maximum right context in egs apart from TDNN's right context + +# We randomly select an extra {left,right} context for each job between +# min_extra_*_context and extra_*_context so that the network can get used +# to different contexts used to compute statistics. +min_extra_left_context=20 +min_extra_right_context=0 + +# training options +num_epochs=2 +initial_effective_lrate=0.0003 +final_effective_lrate=0.00003 +num_jobs_initial=3 +num_jobs_final=8 +remove_egs=false +max_param_change=0.2 # Small max-param change for small network +extra_egs_copy_cmd= # Used if you want to do some weird stuff to egs + # such as removing one of the targets + +sad_data_dir=data/train_aztec_unsad_whole_corrupted_sp_hires_bp_2400 +ovlp_data_dir=data/train_aztec_unsad_seg_ovlp_corrupted_hires_bp + +#extra_left_context=79 +#extra_right_context=11 + +egs_dir= +nj=40 +feat_type=raw +config_dir= + +dir= +affix=g + +. cmd.sh +. ./path.sh +. ./utils/parse_options.sh + +num_utts=`cat $sad_data_dir/utt2spk $ovlp_data_dir/utt2spk | wc -l` +num_utts_subset_valid=`perl -e '$n=int($ARGV[0] * 0.005); print ($n > 4000 ? 4000 : $n)' $num_utts` +num_utts_subset_train=`perl -e '$n=int($ARGV[0] * 0.005); print ($n > 4000 ? 4000 : $n)' $num_utts` + +if [ -z "$dir" ]; then + dir=exp/nnet3_stats_sad_ovlp_snr/nnet_tdnn +fi + +dir=$dir${affix:+_$affix} + +if ! cuda-compiled; then + cat < $dir/configs/network.xconfig + input dim=`feat-to-dim scp:$sad_data_dir/feats.scp -` name=input + output name=output-temp input=Append(-3,-2,-1,0,1,2,3) + + relu-renorm-layer name=tdnn1 input=Append(input@-3, input@-2, input@-1, input, input@1, input@2, input@3) dim=512 + stats-layer name=tdnn2_stats config=mean+count(-99:3:9:99) + relu-renorm-layer name=tdnn2 input=Append(tdnn1@-6, tdnn1, tdnn2_stats) dim=512 + relu-renorm-layer name=tdnn3 input=Append(-9,0,3) dim=512 + relu-renorm-layer name=tdnn4 dim=512 + + output-layer name=output-speech include-log-softmax=true dim=2 objective-scale=$speech_scale input=tdnn4 presoftmax-scale-file=$dir/presoftmax_prior_scale_output-speech.txt learning-rate-factor=0.05 + + output-layer name=output-snr include-log-softmax=false dim=$num_snr_bins objective-type=quadratic objective-scale=`perl -e "print $speech_scale / $num_snr_bins"` input=tdnn4 max-change=0.75 learning-rate-factor=0.5 + + output-layer name=output-overlapped_speech include-log-softmax=true dim=2 objective-scale=$ovlp_scale input=tdnn4 presoftmax-scale-file=$dir/presoftmax_prior_scale_output-overlapped_speech.txt max-change=0.75 learning-rate-factor=0.025 +EOF + steps/nnet3/xconfig_to_configs.py --xconfig-file $dir/configs/network.xconfig \ + --config-dir $dir/configs/ \ + --nnet-edits="rename-node old-name=output-speech new-name=output" + + cat <> $dir/configs/vars +add_lda=false +EOF +fi + +samples_per_iter=`perl -e "print int(400000 / $chunk_width)"` + +if [ -z "$egs_dir" ]; then + egs_dir=$dir/egs_multi + if [ $stage -le 2 ]; then + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $dir/egs_speech/storage ]; then + utils/create_split_dir.pl \ + /export/b{03,04,05,06}/$USER/kaldi-data/egs/aspire-$(date +'%m_%d_%H_%M')/s5/$dir/egs_speech/storage $dir/egs_speech/storage + fi + + . $dir/configs/vars + + steps/nnet3/get_egs_multiple_targets.py --cmd="$decode_cmd" \ + $egs_opts \ + --feat.dir="$sad_data_dir" \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --frames-per-eg=$chunk_width \ + --left-context=$[model_left_context + extra_left_context] \ + --right-context=$[model_right_context + extra_right_context] \ + --num-utts-subset-train=$num_utts_subset_train \ + --num-utts-subset-valid=$num_utts_subset_valid \ + --samples-per-iter=$samples_per_iter \ + --stage=$get_egs_stage \ + --targets-parameters="--output-name=output-snr --target-type=dense --targets-scp=$sad_data_dir/irm_targets.scp --deriv-weights-scp=$sad_data_dir/deriv_weights_manual_seg.scp" \ + --targets-parameters="--output-name=output-speech --target-type=sparse --dim=2 --targets-scp=$sad_data_dir/speech_feat.scp --deriv-weights-scp=$sad_data_dir/deriv_weights.scp --scp2ark-cmd=\"extract-column --column-index=0 scp:- ark,t:- | steps/segmentation/quantize_vector.pl | ali-to-post ark,t:- ark:- |\" --compress=true" \ + --generate-egs-scp=true \ + --dir=$dir/egs_speech + fi + + if [ $stage -le 3 ]; then + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $dir/egs_ovlp/storage ]; then + utils/create_split_dir.pl \ + /export/b{03,04,05,06}/$USER/kaldi-data/egs/aspire-$(date +'%m_%d_%H_%M')/s5/$dir/egs_ovlp/storage $dir/egs_ovlp/storage + fi + + . $dir/configs/vars + + steps/nnet3/get_egs_multiple_targets.py --cmd="$decode_cmd" \ + $egs_opts \ + --feat.dir="$ovlp_data_dir" \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --frames-per-eg=$chunk_width \ + --left-context=$[model_left_context + extra_left_context] \ + --right-context=$[model_right_context + extra_right_context] \ + --num-utts-subset-train=$num_utts_subset_train \ + --num-utts-subset-valid=$num_utts_subset_valid \ + --samples-per-iter=$samples_per_iter \ + --stage=$get_egs_stage \ + --targets-parameters="--output-name=output-speech --target-type=sparse --dim=2 --targets-scp=$ovlp_data_dir/speech_feat.scp --deriv-weights-scp=$ovlp_data_dir/deriv_weights.scp --scp2ark-cmd=\"extract-column --column-index=0 scp:- ark,t:- | steps/segmentation/quantize_vector.pl | ali-to-post ark,t:- ark:- |\"" \ + --targets-parameters="--output-name=output-overlapped_speech --target-type=sparse --dim=2 --targets-scp=$ovlp_data_dir/overlapped_speech_labels_fixed.scp --deriv-weights-scp=$ovlp_data_dir/deriv_weights_for_overlapped_speech.scp --scp2ark-cmd=\"ali-to-post scp:- ark:- |\"" \ + --generate-egs-scp=true \ + --dir=$dir/egs_ovlp + fi + + if [ $stage -le 4 ]; then + # num_chunk_per_minibatch is multiplied by 4 to allow a buffer to use + # the same egs with a different num_chunk_per_minibatch + steps/nnet3/multilingual/get_egs.sh \ + --cmd "$train_cmd" \ + --minibatch-size $[num_chunk_per_minibatch * 4] \ + --samples-per-iter $samples_per_iter \ + 2 $dir/egs_speech $dir/egs_ovlp $dir/egs_multi + fi +fi + +if [ $stage -le 5 ]; then + steps/nnet3/train_raw_rnn.py --stage=$train_stage \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --egs.chunk-width=$chunk_width \ + --egs.dir="$egs_dir" --egs.stage=$get_egs_stage \ + --egs.chunk-left-context=$extra_left_context \ + --egs.chunk-right-context=$extra_right_context \ + --egs.use-multitask-egs=true --egs.rename-multitask-outputs=false \ + ${extra_egs_copy_cmd:+--egs.extra-copy-cmd="$extra_egs_copy_cmd"} \ + --trainer.min-chunk-left-context=$min_extra_left_context \ + --trainer.min-chunk-right-context=$min_extra_right_context \ + --trainer.num-epochs=$num_epochs \ + --trainer.samples-per-iter=20000 \ + --trainer.optimization.num-jobs-initial=$num_jobs_initial \ + --trainer.optimization.num-jobs-final=$num_jobs_final \ + --trainer.optimization.initial-effective-lrate=$initial_effective_lrate \ + --trainer.optimization.final-effective-lrate=$final_effective_lrate \ + --trainer.optimization.shrink-value=1.0 \ + --trainer.rnn.num-chunk-per-minibatch=$num_chunk_per_minibatch \ + --trainer.deriv-truncate-margin=8 \ + --trainer.max-param-change=$max_param_change \ + --cmd="$decode_cmd" --nj 40 \ + --cleanup=true \ + --cleanup.remove-egs=$remove_egs \ + --cleanup.preserve-model-interval=10 \ + --use-gpu=true \ + --use-dense-targets=false \ + --feat-dir=$sad_data_dir \ + --targets-scp="$sad_data_dir/speech_feat.scp" \ + --dir=$dir || exit 1 +fi From e738191624bce6d8a79de7e410121a59ae89af1b Mon Sep 17 00:00:00 2001 From: Vimal Manohar Date: Mon, 2 Jan 2017 18:38:28 -0500 Subject: [PATCH 130/213] asr_diarization: Modify overlapping sad recipe for AMI --- ...o_corruption_data_dir_overlapped_speech.sh | 38 ++- .../prepare_unsad_overlapped_speech_data.sh | 67 ++++- ...are_unsad_overlapped_speech_data_simple.sh | 156 ++++++++++ .../segmentation/run_segmentation_ami.sh | 270 ++++++++++++++++-- 4 files changed, 494 insertions(+), 37 deletions(-) create mode 100755 egs/aspire/s5/local/segmentation/prepare_unsad_overlapped_speech_data_simple.sh diff --git a/egs/aspire/s5/local/segmentation/do_corruption_data_dir_overlapped_speech.sh b/egs/aspire/s5/local/segmentation/do_corruption_data_dir_overlapped_speech.sh index aa1d9adc3e9..991bec96308 100755 --- a/egs/aspire/s5/local/segmentation/do_corruption_data_dir_overlapped_speech.sh +++ b/egs/aspire/s5/local/segmentation/do_corruption_data_dir_overlapped_speech.sh @@ -126,8 +126,44 @@ else corrupted_data_dir=${corrupted_data_dir}_$feat_suffix fi -exit 0 +if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d log_energy/storage ]; then + utils/create_split_dir.pl \ + /export/b0{3,4,5,6}/$USER/kaldi-data/egs/aspire-$(date +'%m_%d_%H_%M')/s5/log_energy/storage log_energy/storage +fi + +if [ $stage -le 5 ]; then + utils/copy_data_dir.sh $clean_data_dir ${clean_data_dir}_log_energy + steps/make_mfcc.sh --mfcc-config conf/log_energy.conf \ + --cmd "$cmd" --nj $nj ${clean_data_dir}_log_energy \ + exp/make_log_energy/${clean_data_id} log_energy +fi + +if [ $stage -le 6 ]; then + utils/copy_data_dir.sh $noise_data_dir ${noise_data_dir}_log_energy + steps/make_mfcc.sh --mfcc-config conf/log_energy.conf \ + --cmd "$cmd" --nj $nj ${noise_data_dir}_log_energy \ + exp/make_log_energy/${noise_data_id} log_energy +fi + +targets_dir=log_snr +if [ $stage -le 7 ]; then + mkdir -p exp/make_log_snr/${corrupted_data_id} + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $targets_dir/storage ]; then + utils/create_split_dir.pl \ + /export/b0{3,4,5,6}/$USER/kaldi-data/egs/aspire-$(date +'%m_%d_%H_%M')/s5/$targets_dir/storage $targets_dir/storage + fi + + # Get log-SNR targets + steps/segmentation/make_snr_targets.sh \ + --nj $nj --cmd "$cmd" \ + --target-type Snr --compress false \ + ${clean_data_dir}_log_energy ${noise_data_dir}_log_energy ${corrupted_data_dir} \ + exp/make_log_snr/${corrupted_data_id} $targets_dir +fi + +exit 0 + if [ $stage -le 5 ]; then # clean here is the reverberated first-speaker signal utils/copy_data_dir.sh $clean_data_dir ${clean_data_dir}_$feat_suffix diff --git a/egs/aspire/s5/local/segmentation/prepare_unsad_overlapped_speech_data.sh b/egs/aspire/s5/local/segmentation/prepare_unsad_overlapped_speech_data.sh index 36eb4de2afe..6d21859d7fe 100755 --- a/egs/aspire/s5/local/segmentation/prepare_unsad_overlapped_speech_data.sh +++ b/egs/aspire/s5/local/segmentation/prepare_unsad_overlapped_speech_data.sh @@ -12,13 +12,14 @@ set -o pipefail num_data_reps=5 nj=40 cmd=queue.pl +snr_db_threshold=10 stage=-1 . utils/parse_options.sh if [ $# -ne 5 ]; then echo "Usage: $0 " - echo " e.g.: $0 data/fisher_train_100k_sp_75k_hires_bp data/fisher_train_100k_sp_75k/overlapped_segments_info.txt exp/unsad/make_unsad_fisher_train_100k_sp/tri4_ali_fisher_train_100k_sp_vad_fisher_train_100k_sp exp/unsad/make_overlap_labels/fisher_train_100k_sp_75k overlap_labels" + echo " e.g.: $0 data/fisher_train_100k_sp_75k_seg_ovlp_corrupted_hires_bp data/fisher_train_100k_sp_75k_seg_ovlp_corrupted exp/unsad/make_unsad_fisher_train_100k/tri4a_ali_fisher_train_100k_sp_vad_fisher_train_100k_sp exp/unsad overlap_labels" exit 1 fi @@ -56,6 +57,15 @@ if [ $stage -le 1 ]; then utils/data/get_utt2num_frames.sh $corrupted_data_dir utils/split_data.sh ${corrupted_data_dir} $nj + # 1) segmentation-init-from-additive-signals-info converts the informtation + # written out but by steps/data/make_corrupted_data_dir.py in overlapped_segments_info.txt + # and converts it to segments. It then adds those segments to the + # segments already present ($corrupted_data_dir/sad_seg.scp) + # 2) Retain only the speech segments (label 1) from these. + # 3) Convert this to overlap stats using segmentation-get-stats, which + # writes for each frame the number of overlapping segments. + # 4) Convert this per-frame "alignment" information to segmentation + # ($overlap_dir/overlap_seg.*.gz). $cmd JOB=1:$nj $overlap_dir/log/get_overlap_seg.JOB.log \ segmentation-init-from-additive-signals-info --lengths-rspecifier=ark,t:$corrupted_data_dir/utt2num_frames \ --additive-signals-segmentation-rspecifier=scp:$utt_vad_dir/sad_seg.scp \ @@ -69,6 +79,8 @@ if [ $stage -le 1 ]; then fi if [ $stage -le 2 ]; then + # Retain labels >2, i.e. regions where more than 1 speaker overlap. + # Write this out in alignment format as "overlapped_speech_labels" $cmd JOB=1:$nj $overlap_dir/log/get_overlapped_speech_labels.JOB.log \ gunzip -c $overlap_dir/overlap_seg.JOB.gz \| \ segmentation-post-process --remove-labels=0:1 ark:- ark:- \| \ @@ -82,13 +94,15 @@ if [ $stage -le 2 ]; then fi if [ $stage -le 3 ]; then - # First convert the unreliable segments into a segmentation. - # Initialize a segmentation from utt2num_frames and set to 0, the regions - # of unreliable segments. At this stage deriv weights is 1 for all but the - # unreliable segment regions. - # Initialize a segmentation from the overlap labels and retain regions where - # there is speech from at least one speaker. - # Intersect this with the deriv weights segmentation from above. + # 1) Initialize a segmentation where all the frames have label 1 using + # segmentation-init-from-length. + # 2) Use the program segmentation-create-subsegments to set to 0 + # the regions of unreliable segments read from unreliable_seg.*.gz. + # This is the initial deriv weights. At this stage deriv weights is 1 for all + # but the unreliable segment regions. + # 3) Initialize a segmentation from the overlap labels (overlap_seg.*.gz) + # and retain regions where there is speech from at least one speaker. + # 4) Intersect this with the deriv weights segmentation from above. # At this stage deriv weights is 1 for only the regions where there is # at least one speaker and the the overlapping segment is not unreliable. # Convert this to deriv weights. @@ -110,8 +124,8 @@ if [ $stage -le 3 ]; then fi if [ $stage -le 4 ]; then - # Get only first speaker labels as speech_feat as we are not sure of the energy levels of the other speaker. - $cmd JOB=1:$nj $overlap_dir/log/get_first_speaker_labels.JOB.log \ + # Find regions where there is at least one speaker speaking. + $cmd JOB=1:$nj $overlap_dir/log/get_speech_labels.JOB.log \ gunzip -c $overlap_dir/overlap_seg.JOB.gz \| \ segmentation-post-process --remove-labels=0 --merge-labels=1:2:3:4:5:6:7:8:9:10 --merge-dst-label=1 ark:- ark:- \| \ segmentation-to-ali --lengths-rspecifier=ark,t:${corrupted_data_dir}/utt2num_frames ark:- ark,t:- \| \ @@ -125,6 +139,8 @@ if [ $stage -le 4 ]; then fi if [ $stage -le 5 ]; then + # Deriv weights speech / non-speech labels is 1 everywhere but the + # unreliable regions. $cmd JOB=1:$nj $unreliable_dir/log/get_deriv_weights.JOB.log \ utils/filter_scp.pl $corrupted_data_dir/split$nj/JOB/utt2spk $corrupted_data_dir/utt2num_frames \| \ segmentation-init-from-lengths ark,t:- ark:- \| \ @@ -139,6 +155,37 @@ if [ $stage -le 5 ]; then cat $overlap_labels_dir/deriv_weights_$corrupted_data_id.${n}.scp done > $corrupted_data_dir/deriv_weights.scp fi + +snr_threshold=`perl -e "print $snr_db_threshold / 10.0 * log(10.0)"` + +cat < $overlap_dir/invert_labels.map +0 1 +1 0 +EOF + +if [ $stage -le 6 ]; then + if [ ! -f $corrupted_data_dir/log_snr.scp ]; then + echo "$0: Could not find $corrupted_data_dir/log_snr.scp. Run local/segmentation/do_corruption_data_dir_overlapped_speech.sh." + exit 1 + fi + + $cmd JOB=1:$nj $overlap_dir/log/fix_overlapped_speech_labels.JOB.log \ + copy-matrix --apply-power=1 \ + "scp:utils/filter_scp.pl $corrupted_data_dir/split$nj/JOB/utt2spk $corrupted_data_dir/log_snr.scp |" \ + ark:- \| extract-column ark:- ark,t:- \| \ + steps/segmentation/quantize_vector.pl $snr_threshold \| \ + segmentation-init-from-ali ark,t:- ark:- \| \ + segmentation-copy --label-map=$overlap_dir/invert_labels.map ark:- ark:- \| \ + segmentation-intersect-segments --mismatch-label=1000 \ + "ark:utils/filter_scp.pl $corrupted_data_dir/split$nj/JOB/utt2spk $corrupted_data_dir/overlapped_speech_labels.scp | segmentation-init-from-ali scp:- ark:- | segmentation-copy --keep-label=1 ark:- ark:- |" ark:- ark:- \| \ + segmentation-copy --keep-label=1 ark:- ark:- \| \ + segmentation-to-ali --lengths-rspecifier=ark,t:$corrupted_data_dir/utt2num_frames \ + ark:- ark,scp:$overlap_labels_dir/overlapped_speech_labels_fixed_${corrupted_data_id}.JOB.ark,$overlap_labels_dir/overlapped_speech_labels_fixed_${corrupted_data_id}.JOB.scp + + for n in `seq $nj`; do + cat $overlap_labels_dir/overlapped_speech_labels_fixed_${corrupted_data_id}.$n.scp + done > $corrupted_data_dir/overlapped_speech_labels_fixed.scp +fi exit 0 diff --git a/egs/aspire/s5/local/segmentation/prepare_unsad_overlapped_speech_data_simple.sh b/egs/aspire/s5/local/segmentation/prepare_unsad_overlapped_speech_data_simple.sh new file mode 100755 index 00000000000..73f2abca566 --- /dev/null +++ b/egs/aspire/s5/local/segmentation/prepare_unsad_overlapped_speech_data_simple.sh @@ -0,0 +1,156 @@ +#! /bin/bash + +# Copyright 2016 Vimal Manohar +# Apache 2.0 + +set -e +set -u +set -o pipefail + +. path.sh + +num_data_reps=5 +nj=40 +cmd=queue.pl +snr_db_threshold=10 +stage=-1 + +. utils/parse_options.sh + +if [ $# -ne 5 ]; then + echo "Usage: $0 " + echo " e.g.: $0 data/fisher_train_100k_sp_75k_seg_ovlp_corrupted_hires_bp data/fisher_train_100k_sp_75k_seg_ovlp_corrupted exp/unsad/make_unsad_fisher_train_100k/tri4a_ali_fisher_train_100k_sp_vad_fisher_train_100k_sp exp/unsad overlapping_sad_labels" + exit 1 +fi + +corrupted_data_dir=$1 +orig_corrupted_data_dir=$2 +utt_vad_dir=$3 +tmpdir=$4 +overlap_labels_dir=$5 + +overlapped_segments_info=$orig_corrupted_data_dir/overlapped_segments_info.txt +corrupted_data_id=`basename $orig_corrupted_data_dir` + +for f in $corrupted_data_dir/feats.scp $overlapped_segments_info $utt_vad_dir/sad_seg.scp; do + [ ! -f $f ] && echo "Could not find file $f" && exit 1 +done + +overlap_dir=$tmpdir/make_overlapping_sad_labels_${corrupted_data_id} + +# make $overlap_labels_dir an absolute pathname. +overlap_labels_dir=`perl -e '($dir,$pwd)= @ARGV; if($dir!~m:^/:) { $dir = "$pwd/$dir"; } print $dir; ' $overlap_labels_dir ${PWD}` +mkdir -p $overlap_labels_dir + +# Combine the VAD from the base recording and the VAD from the overlapping segments +# to create per-frame labels of the number of overlapping speech segments +# Unreliable segments are regions where no VAD labels were available for the +# overlapping segments. These can be later removed by setting deriv weights to 0. + +if [ $stage -le 1 ]; then + for n in `seq $num_data_reps`; do + cat $utt_vad_dir/sad_seg.scp | \ + awk -v n=$n '{print "ovlp"n"_"$0}' + done | sort -k1,1 > ${corrupted_data_dir}/sad_seg.scp + utils/data/get_utt2num_frames.sh $corrupted_data_dir + utils/split_data.sh ${corrupted_data_dir} $nj + + # 1) segmentation-init-from-additive-signals-info converts the informtation + # written out but by steps/data/make_corrupted_data_dir.py in overlapped_segments_info.txt + # and converts it to segments. It then adds those segments to the + # segments already present ($corrupted_data_dir/sad_seg.scp) + # 2) Retain only the speech segments (label 1) from these. + # 3) Convert this to overlap stats using segmentation-get-stats, which + # writes for each frame the number of overlapping segments. + # 4) Convert this per-frame "alignment" information to segmentation + # ($overlap_dir/overlap_seg.*.gz). + $cmd JOB=1:$nj $overlap_dir/log/get_overlapping_sad_seg.JOB.log \ + segmentation-init-from-additive-signals-info --lengths-rspecifier=ark,t:$corrupted_data_dir/utt2num_frames \ + --junk-label=10000 \ + --additive-signals-segmentation-rspecifier=scp:$utt_vad_dir/sad_seg.scp \ + "scp:utils/filter_scp.pl ${corrupted_data_dir}/split${nj}/JOB/utt2spk $corrupted_data_dir/sad_seg.scp |" \ + ark,t:$orig_corrupted_data_dir/overlapped_segments_info.txt ark:- \| \ + segmentation-get-stats --lengths-rspecifier=ark,t:$corrupted_data_dir/utt2num_frames \ + ark:- ark:/dev/null ark:/dev/null ark:- \| \ + classes-per-frame-to-labels --junk-label=10000 ark:- ark:- \| \ + segmentation-init-from-ali ark:- \ + "ark:| gzip -c > $overlap_dir/overlap_sad_seg.JOB.gz" +fi + +if [ $stage -le 2 ]; then + # Call labels >2, i.e. regions where more than 1 speaker overlap as overlapping speech. labels = 1 is single speaker and labels = 0 is silence. + # Write this out in alignment format as "overlapping_sad_labels" + $cmd JOB=1:$nj $overlap_dir/log/get_overlapping_sad_labels.JOB.log \ + gunzip -c $overlap_dir/overlap_sad_seg.JOB.gz \| \ + segmentation-post-process --remove-labels=10000 ark:- ark:- \| \ + segmentation-to-ali --lengths-rspecifier=ark,t:${corrupted_data_dir}/utt2num_frames ark:- \ + ark,scp:$overlap_labels_dir/overlapping_sad_labels_${corrupted_data_id}.JOB.ark,$overlap_labels_dir/overlapping_sad_labels_${corrupted_data_id}.JOB.scp + + for n in `seq $nj`; do + cat $overlap_labels_dir/overlapping_sad_labels_${corrupted_data_id}.$n.scp + done > ${corrupted_data_dir}/overlapping_sad_labels.scp +fi + +if [ $stage -le 3 ]; then + # Find regions where there is at least one speaker speaking. + $cmd JOB=1:$nj $overlap_dir/log/get_speech_feat.JOB.log \ + gunzip -c $overlap_dir/overlap_sad_seg.JOB.gz \| \ + segmentation-post-process --remove-labels=10000 ark:- ark:- \| \ + segmentation-post-process --merge-labels=1:2 --merge-dst-label=1 ark:- ark:- \| \ + segmentation-to-ali --lengths-rspecifier=ark,t:${corrupted_data_dir}/utt2num_frames ark:- ark,t:- \| \ + steps/segmentation/convert_ali_to_vec.pl \| \ + vector-to-feat ark:- \ + ark,scp:$overlap_labels_dir/speech_feat_${corrupted_data_id}.JOB.ark,$overlap_labels_dir/speech_feat_${corrupted_data_id}.JOB.scp + + for n in `seq $nj`; do + cat $overlap_labels_dir/speech_feat_${corrupted_data_id}.$n.scp + done > ${corrupted_data_dir}/speech_feat.scp +fi + +if [ $stage -le 4 ]; then + # Deriv weights is 1 everywhere but the + # unreliable regions. + $cmd JOB=1:$nj $overlap_dir/log/get_deriv_weights.JOB.log \ + gunzip -c $overlap_dir/overlap_sad_seg.JOB.gz \| \ + segmentation-post-process --merge-labels=0:1:2 --merge-dst-label=1 ark:- ark:- \| \ + segmentation-post-process --merge-labels=10000 --merge-dst-label=0 ark:- ark:- \| \ + segmentation-to-ali --lengths-rspecifier=ark,t:${corrupted_data_dir}/utt2num_frames ark:- ark,t:- \| \ + steps/segmentation/convert_ali_to_vec.pl \| copy-vector ark,t:- \ + ark,scp:$overlap_labels_dir/deriv_weights_${corrupted_data_id}.JOB.ark,$overlap_labels_dir/deriv_weights_${corrupted_data_id}.JOB.scp + + for n in `seq $nj`; do + cat $overlap_labels_dir/deriv_weights_$corrupted_data_id.${n}.scp + done > $corrupted_data_dir/deriv_weights.scp +fi + +snr_threshold=`perl -e "print $snr_db_threshold / 10.0 * log(10.0)"` + +cat < $overlap_dir/invert_labels.map +0 2 +1 1 +EOF + +if [ $stage -le 5 ]; then + if [ ! -f $corrupted_data_dir/log_snr.scp ]; then + echo "$0: Could not find $corrupted_data_dir/log_snr.scp. Run local/segmentation/do_corruption_data_dir_overlapped_speech.sh." + exit 1 + fi + + $cmd JOB=1:$nj $overlap_dir/log/fix_overlapping_sad_labels.JOB.log \ + copy-matrix --apply-power=1 \ + "scp:utils/filter_scp.pl $corrupted_data_dir/split$nj/JOB/utt2spk $corrupted_data_dir/log_snr.scp |" \ + ark:- \| extract-column ark:- ark,t:- \| \ + steps/segmentation/quantize_vector.pl $snr_threshold \| \ + segmentation-init-from-ali ark,t:- ark:- \| \ + segmentation-copy --label-map=$overlap_dir/invert_labels.map ark:- ark:- \| \ + segmentation-create-subsegments --filter-label=1 --subsegment-label=1 \ + "ark:utils/filter_scp.pl $corrupted_data_dir/split$nj/JOB/utt2spk $corrupted_data_dir/overlapping_sad_labels.scp | segmentation-init-from-ali scp:- ark:- |" ark:- ark:- \| \ + segmentation-to-ali --lengths-rspecifier=ark,t:$corrupted_data_dir/utt2num_frames \ + ark:- ark,scp:$overlap_labels_dir/overlapping_sad_labels_fixed_${corrupted_data_id}.JOB.ark,$overlap_labels_dir/overlapping_sad_labels_fixed_${corrupted_data_id}.JOB.scp + + for n in `seq $nj`; do + cat $overlap_labels_dir/overlapping_sad_labels_fixed_${corrupted_data_id}.$n.scp + done > $corrupted_data_dir/overlapping_sad_labels_fixed.scp +fi + +exit 0 diff --git a/egs/aspire/s5/local/segmentation/run_segmentation_ami.sh b/egs/aspire/s5/local/segmentation/run_segmentation_ami.sh index 4b98eec9f43..733c6aa53fe 100755 --- a/egs/aspire/s5/local/segmentation/run_segmentation_ami.sh +++ b/egs/aspire/s5/local/segmentation/run_segmentation_ami.sh @@ -14,8 +14,16 @@ stage=-1 nnet_dir=exp/nnet3_sad_snr/nnet_tdnn_k_n4 extra_left_context=100 extra_right_context=20 +task=SAD iter=final +segmentation_stage=-1 +sil_prior=0.7 +speech_prior=0.3 +min_silence_duration=30 +min_speech_duration=10 +frame_subsampling_factor=3 + . utils/parse_options.sh export PATH=$KALDI_ROOT/tools/sctk/bin:$PATH @@ -77,58 +85,268 @@ echo "A 1" > $dir/channel_map cat $src_dir/data/sdm1/dev/reco2file_and_channel | \ utils/apply_map.pl -f 3 $dir/channel_map > $dir/reco2file_and_channel +cat $src_dir/data/sdm1/dev_ihmdata/reco2utt | \ + awk 'BEGIN{i=1} {print $1" "i; i++;}' > \ + $src_dir/data/sdm1/dev_ihmdata/reco.txt + if [ $stage -le 5 ]; then + utils/data/get_reco2num_frames.sh --frame-shift 0.01 --frame-overlap 0.015 \ + --cmd queue.pl --nj 18 \ + $src_dir/data/sdm1/dev + + # Get a filter that selects only regions within the manual segments. + $train_cmd $dir/log/get_manual_segments_regions.log \ + segmentation-init-from-segments --shift-to-zero=false $src_dir/data/sdm1/dev/segments ark:- \| \ + segmentation-combine-segments-to-recordings ark:- ark,t:$src_dir/data/sdm1/dev/reco2utt ark:- \| \ + segmentation-create-subsegments --filter-label=1 --subsegment-label=1 \ + "ark:segmentation-init-from-lengths --label=0 ark,t:$src_dir/data/sdm1/dev/reco2num_frames ark:- |" ark:- ark,t:- \| \ + perl -ane '$F[3] = 10000; $F[$#F-1] = 10000; print join(" ", @F) . "\n";' \| \ + segmentation-post-process --merge-labels=0:1 --merge-dst-label=1 ark:- ark:- \| \ + segmentation-post-process --merge-labels=10000 --merge-dst-label=0 --merge-adjacent-segments \ + --max-intersegment-length=10000 ark,t:- \ + "ark:| gzip -c > $dir/manual_segments_regions.seg.gz" +fi + +if [ $stage -le 6 ]; then # Reference RTTM where SPEECH frames are obtainted by combining IHM VAD alignments - $train_cmd $dir/log/get_ref_rttm.log \ + $train_cmd $dir/log/get_ref_spk_seg.log \ segmentation-combine-segments scp:$dir/sad_seg.scp \ "ark:segmentation-init-from-segments --shift-to-zero=false $src_dir/data/sdm1/dev_ihmdata/segments ark:- |" \ ark,t:$src_dir/data/sdm1/dev_ihmdata/reco2utt ark:- \| \ + segmentation-copy --keep-label=1 ark:- ark:- \| \ + segmentation-copy --utt2label-rspecifier=ark,t:$src_dir/data/sdm1/dev_ihmdata/reco.txt \ + ark:- ark:- \| \ segmentation-merge-recordings \ "ark,t:utils/utt2spk_to_spk2utt.pl $src_dir/data/sdm1/dev_ihmdata/ihm2sdm_reco |" \ - ark:- ark:- \| \ - segmentation-to-rttm --reco2file-and-channel=$dir/reco2file_and_channel \ - ark:- $dir/ref.rttm + ark:- "ark:| gzip -c > $dir/ref_spk_seg.gz" fi - -if [ $stage -le 6 ]; then - # Get an UEM which evaluates only on the manual segments. - $train_cmd $dir/log/get_uem.log \ - segmentation-init-from-segments --shift-to-zero=false $src_dir/data/sdm1/dev/segments ark:- \| \ - segmentation-combine-segments-to-recordings ark:- ark,t:$src_dir/data/sdm1/dev/reco2utt ark:- \| \ - segmentation-post-process --remove-labels=0 --merge-adjacent-segments \ - --max-intersegment-length=10000 ark:- ark:- \| \ + +if [ $stage -le 7 ]; then + # To get the actual RTTM, we need to add no-score + $train_cmd $dir/log/get_ref_rttm.log \ + segmentation-get-stats --lengths-rspecifier=ark,t:$src_dir/data/sdm1/dev/reco2num_frames \ + "ark:gunzip -c $dir/ref_spk_seg.gz | segmentation-post-process --remove-labels=0 ark:- ark:- |" \ + ark:/dev/null ark:- \| \ + segmentation-init-from-ali ark:- ark:- \| \ + segmentation-post-process --merge-labels=1:2:3:4:5:6:7:8:9:10 --merge-dst-label=1 \ + --merge-adjacent-segments --max-intersegment-length=10000 ark:- ark:- \| \ + segmentation-create-subsegments --filter-label=0 --subsegment-label=10000 \ + ark:- "ark:gunzip -c $dir/manual_segments_regions.seg.gz |" ark:- \| \ + segmentation-post-process --merge-adjacent-segments --max-intersegment-length=10000 ark:- ark:- \| \ segmentation-to-rttm --reco2file-and-channel=$dir/reco2file_and_channel \ - ark:- - \| grep SPEECH \| grep SPEAKER \| \ - rttmSmooth.pl -s 0 \| awk '{ print $2" "$3" "$4" "$5+$4 }' '>' $dir/uem + --no-score-label=10000 ark:- $dir/ref.rttm + + # Get RTTM for overlapped speech detection with 3 classes + # 0 -> SILENCE, 1 -> SINGLE_SPEAKER, 2 -> OVERLAP + $train_cmd $dir/log/get_overlapping_rttm.log \ + segmentation-get-stats --lengths-rspecifier=ark,t:$src_dir/data/sdm1/dev/reco2num_frames \ + "ark:gunzip -c $dir/ref_spk_seg.gz | segmentation-post-process --remove-labels=0 ark:- ark:- |" \ + ark:/dev/null ark:- \| \ + segmentation-init-from-ali ark:- ark:- \| \ + segmentation-post-process --merge-labels=2:3:4:5:6:7:8:9:10 --merge-dst-label=2 \ + --merge-adjacent-segments --max-intersegment-length=10000 ark:- ark:- \| \ + segmentation-create-subsegments --filter-label=0 --subsegment-label=10000 \ + ark:- "ark:gunzip -c $dir/manual_segments_regions.seg.gz |" ark:- \| \ + segmentation-post-process --merge-adjacent-segments --max-intersegment-length=10000 ark:- ark:- \| \ + segmentation-to-rttm --map-to-speech-and-sil=false --reco2file-and-channel=$dir/reco2file_and_channel \ + --no-score-label=10000 ark:- $dir/overlapping_speech_ref.rttm fi +if [ $stage -le 8 ]; then + # Get a filter that selects only regions of speech + $train_cmd $dir/log/get_speech_filter.log \ + segmentation-get-stats --lengths-rspecifier=ark,t:$src_dir/data/sdm1/dev/reco2num_frames \ + "ark:gunzip -c $dir/ref_spk_seg.gz | segmentation-post-process --remove-labels=0 ark:- ark:- |" \ + ark:/dev/null ark:- \| \ + segmentation-init-from-ali ark:- ark:- \| \ + segmentation-post-process --merge-labels=1:2:3:4:5:6:7:8:9:10 --merge-dst-label=1 ark:- ark:- \| \ + segmentation-create-subsegments --filter-label=0 --subsegment-label=0 \ + ark:- "ark:gunzip -c $dir/manual_segments_regions.seg.gz |" ark:- \| \ + segmentation-post-process --merge-adjacent-segments --max-intersegment-length=10000 \ + ark:- "ark:| gzip -c > $dir/manual_segments_speech_regions.seg.gz" +fi + hyp_dir=${nnet_dir}/segmentation_ami_sdm1_dev_whole_bp/ami_sdm1_dev -if [ $stage -le 7 ]; then +if [ $stage -le 9 ]; then steps/segmentation/do_segmentation_data_dir.sh --reco-nj 18 \ --mfcc-config conf/mfcc_hires_bp.conf --feat-affix bp --do-downsampling true \ --extra-left-context $extra_left_context --extra-right-context $extra_right_context \ - --output-name output-speech --frame-subsampling-factor 6 --iter $iter \ + --output-name output-speech --frame-subsampling-factor $frame_subsampling_factor --iter $iter \ + --stage $segmentation_stage \ $src_dir/data/sdm1/dev $nnet_dir mfcc_hires_bp $hyp_dir fi +sad_dir=${nnet_dir}/sad_ami_sdm1_dev_whole_bp/ hyp_dir=${hyp_dir}_seg -if [ $stage -le 8 ]; then +if [ $stage -le 10 ]; then utils/data/get_reco2utt.sh $src_dir/data/sdm1/dev_ihmdata - - steps/segmentation/convert_utt2spk_and_segments_to_rttm.py \ - $hyp_dir/utt2spk \ - $hyp_dir/segments \ - $dir/reco2file_and_channel \ - /dev/stdout | spkr2sad.pl > $hyp_dir/sys.rttm + utils/data/get_reco2utt.sh $hyp_dir + + segmentation-init-from-segments --shift-to-zero=false $hyp_dir/segments ark:- | \ + segmentation-combine-segments-to-recordings ark:- ark,t:$hyp_dir/reco2utt ark:- | \ + segmentation-to-ali --length-tolerance=48 --lengths-rspecifier=ark,t:$src_dir/data/sdm1/dev/reco2num_frames \ + ark:- ark:- | \ + segmentation-init-from-ali ark:- ark:- | \ + segmentation-to-rttm --reco2file-and-channel=$dir/reco2file_and_channel ark:- $hyp_dir/sys.rttm + + #steps/segmentation/convert_utt2spk_and_segments_to_rttm.py \ + # $hyp_dir/utt2spk \ + # $hyp_dir/segments \ + # $dir/reco2file_and_channel \ + # /dev/stdout | spkr2sad.pl > $hyp_dir/sys.rttm fi -if [ $stage -le 9 ]; then - md-eval.pl -s <(cat $hyp_dir/sys.rttm | grep speech | rttmSmooth.pl -s 0) \ - -r <(cat $dir/ref.rttm | grep SPEECH | rttmSmooth.pl -s 0 ) \ +if [ $stage -le 11 ]; then + cat < $likes_dir/log_likes.JOB.gz" + cp $sad_dir/num_jobs $likes_dir + fi + else + if [ $stage -le 12 ]; then + steps/segmentation/do_segmentation_data_dir_generic.sh --reco-nj 18 \ + --mfcc-config conf/mfcc_hires_bp.conf --feat-affix bp --do-downsampling true \ + --extra-left-context $extra_left_context --extra-right-context $extra_right_context \ + --segmentation-config conf/segmentation_ovlp.conf \ + --output-name output-overlapping_sad \ + --min-durations 30:10:10 --priors 0.5:0.35:0.15 \ + --sad-name ovlp_sad --segmentation-name segmentation_ovlp_sad \ + --frame-subsampling-factor $frame_subsampling_factor --iter $iter \ + --stage $segmentation_stage \ + $src_dir/data/sdm1/dev $nnet_dir mfcc_hires_bp $hyp_dir + fi + + likes_dir=${nnet_dir}/ovlp_sad_ami_sdm1_dev_whole_bp/ + fi + + hyp_dir=${hyp_dir}_seg + mkdir -p $hyp_dir + + seg_dir=${nnet_dir}/segmentation_ovlp_sad_ami_sdm1_dev_whole_bp/ + lang=${seg_dir}/lang + + if [ $stage -le 14 ]; then + mkdir -p $lang + steps/segmentation/internal/prepare_sad_lang.py \ + --phone-transition-parameters="--phone-list=1 --min-duration=10 --end-transition-probability=0.1" \ + --phone-transition-parameters="--phone-list=2 --min-duration=3 --end-transition-probability=0.1" \ + --phone-transition-parameters="--phone-list=3 --min-duration=3 --end-transition-probability=0.1" $lang + cp $lang/phones.txt $lang/words.txt + + feat_dim=2 # dummy. We don't need this. + $train_cmd $seg_dir/log/create_transition_model.log gmm-init-mono \ + $lang/topo $feat_dim - $seg_dir/tree \| \ + copy-transition-model --binary=false - $seg_dir/trans.mdl || exit 1 +fi + + if [ $stage -le 15 ]; then + + cat > $lang/word2prior < $lang/G.fst +fi + + if [ $stage -le 16 ]; then + $train_cmd $seg_dir/log/make_vad_graph.log \ + steps/segmentation/internal/make_sad_graph.sh --iter trans \ + $lang $seg_dir $seg_dir/graph_test || exit 1 + fi + + if [ $stage -le 17 ]; then + steps/segmentation/decode_sad.sh \ + --acwt 1 --beam 10 --max-active 7000 \ + $seg_dir/graph_test $likes_dir $seg_dir + fi + + if [ $stage -le 18 ]; then + cat < $hyp_dir/labels_map +1 0 +2 1 +3 2 +EOF + gunzip -c $seg_dir/ali.*.gz | \ + segmentation-init-from-ali ark:- ark:- | \ + segmentation-copy --frame-subsampling-factor=$frame_subsampling_factor \ + --label-map=$hyp_dir/labels_map ark:- ark:- | \ + segmentation-to-rttm --map-to-speech-and-sil=false \ + --reco2file-and-channel=$dir/reco2file_and_channel ark:- $hyp_dir/sys.rttm + fi + # Get RTTM for overlapped speech detection with 3 classes + # 0 -> SILENCE, 1 -> SINGLE_SPEAKER, 2 -> OVERLAP + $train_cmd $dir/log/get_overlapping_rttm.log \ + segmentation-get-stats --lengths-rspecifier=ark,t:$src_dir/data/sdm1/dev/reco2num_frames \ + "ark:gunzip -c $dir/ref_spk_seg.gz | segmentation-post-process --remove-labels=0 ark:- ark:- |" \ + ark:/dev/null ark:- \| \ + segmentation-init-from-ali ark:- ark:- \| \ + segmentation-post-process --merge-labels=2:3:4:5:6:7:8:9:10 --merge-dst-label=2 ark:- ark:- \| \ + segmentation-create-subsegments --filter-label=0 --subsegment-label=10000 \ + ark:- "ark:gunzip -c $dir/manual_segments_regions.seg.gz |" ark:- \| \ + segmentation-post-process --merge-adjacent-segments --max-intersegment-length=10000 ark:- ark:- \| \ + segmentation-to-rttm --map-to-speech-and-sil=false --reco2file-and-channel=$dir/reco2file_and_channel \ + --no-score-label=10000 ark:- $dir/overlapping_speech_ref.rttm + + if [ $stage -le 19 ]; then + cat < Date: Mon, 2 Jan 2017 18:40:24 -0500 Subject: [PATCH 131/213] asr_diarization: Fisher+ Babel SAD recipe --- .../prepare_babel_data_overlapped_speech.sh | 112 +++++++++++++++++ .../prepare_fisher_data_overlapped_speech.sh | 113 ++++++++++++++++++ .../s5/local/segmentation/run_fisher_babel.sh | 2 + 3 files changed, 227 insertions(+) create mode 100644 egs/aspire/s5/local/segmentation/prepare_babel_data_overlapped_speech.sh create mode 100644 egs/aspire/s5/local/segmentation/prepare_fisher_data_overlapped_speech.sh create mode 100644 egs/aspire/s5/local/segmentation/run_fisher_babel.sh diff --git a/egs/aspire/s5/local/segmentation/prepare_babel_data_overlapped_speech.sh b/egs/aspire/s5/local/segmentation/prepare_babel_data_overlapped_speech.sh new file mode 100644 index 00000000000..2136f42f322 --- /dev/null +++ b/egs/aspire/s5/local/segmentation/prepare_babel_data_overlapped_speech.sh @@ -0,0 +1,112 @@ +#! /bin/bash + +# Copyright 2016 Vimal Manohar +# Apache 2.0. + +# This script prepares Babel data for training speech activity detection, +# music detection, and overlapped speech detection systems. + +. path.sh +. cmd.sh + +set -e +set -o pipefail +set -u + +lang_id=assamese +subset=25 # Number of recordings to keep before speed perturbation and corruption +utt_subset=30000 # Number of utterances to keep after speed perturbation for adding overlapped-speech + +# All the paths below can be modified to any absolute path. +ROOT_DIR=/home/vimal/workspace_waveform/egs/babel/s5c_assamese/ + +. utils/parse_options.sh + +if [ $# -ne 0 ]; then + echo "Usage: $0" + echo "This script is to serve as an example recipe." + echo "Edit the script to change variables if needed." + exit 1 +fi + +dir=exp/unsad/make_unsad_babel_${lang_id}_train # Work dir + +# The original data directory which will be converted to a whole (recording-level) directory. +train_data_dir=$ROOT_DIR/data/train + +model_dir=$ROOT_DIR/exp/tri4 # Model directory used for decoding +sat_model_dir=$ROOT_DIR/exp/tri5 # Model directory used for getting alignments +lang=$ROOT_DIR/data/lang # Language directory +lang_test=$ROOT_DIR/data/lang # Language directory used to build graph + +# Hard code the mapping from phones to SAD labels +# 0 for silence, 1 for speech, 2 for noise, 3 for unk +cat < $dir/babel_sad.map + 3 +_B 3 +_E 3 +_I 3 +_S 3 + 2 +_B 2 +_E 2 +_I 2 +_S 2 + 2 +_B 2 +_E 2 +_I 2 +_S 2 +SIL 0 +SIL_B 0 +SIL_E 0 +SIL_I 0 +SIL_S 0 +EOF + +# Expecting the user to have done run.sh to have $model_dir, +# $sat_model_dir, $lang, $lang_test, $train_data_dir +local/segmentation/prepare_unsad_data.sh \ + --sad-map $dir/babel_sad.map \ + --config-dir $ROOT_DIR/conf \ + --reco-nj 40 --nj 100 --cmd "$train_cmd" \ + --sat-model $sat_model_dir \ + --lang-test $lang_test \ + $train_data_dir $lang $model_dir $dir + +orig_data_dir=${train_data_dir}_sp + +data_dir=${train_data_dir}_whole + +if [ ! -z $subset ]; then + # Work on a subset + utils/subset_data_dir.sh ${data_dir} $subset \ + ${data_dir}_$subset + data_dir=${data_dir}_$subset +fi + +reco_vad_dir=$dir/`basename $model_dir`_reco_vad_`basename $train_data_dir`_sp + +# Add noise from MUSAN corpus to data directory and create a new data directory +local/segmentation/do_corruption_data_dir.sh + --data-dir $data_dir \ + --reco-vad-dir $reco_vad_dir + --feat-suffix hires_bp --mfcc-config conf/mfcc_hires_bp.conf + +# Add music from MUSAN corpus to data directory and create a new data directory +local/segmentation/do_corruption_data_dir_music.sh + --data-dir $data_dir \ + --reco-vad-dir $reco_vad_dir + --feat-suffix hires_bp --mfcc-config conf/mfcc_hires_bp.conf + +if [ ! -z $utt_subset ]; then + utils/subset_data_dir.sh ${orig_data_dir} $utt_subset \ + ${orig_data_dir}_`echo $utt_subset | perl -e 's/000$/k/'` + orig_data_dir=${orig_data_dir}_`echo $utt_subset | perl -e 's/000$/k/'` +fi + +# Add overlapping speech from $orig_data_dir/segments and create a new data directory +utt_vad_dir=$dir/`baseline $sat_model_dir`_ali_`basename $train_data_dir`_sp_vad_`basename $train_data_dir`_sp +local/segmentation/do_corruption_data_dir_overlapped_speech.sh \ + --data-dir ${orig_data_dir} \ + --utt-vad-dir $utt_vad_dir diff --git a/egs/aspire/s5/local/segmentation/prepare_fisher_data_overlapped_speech.sh b/egs/aspire/s5/local/segmentation/prepare_fisher_data_overlapped_speech.sh new file mode 100644 index 00000000000..79a03fa9e9d --- /dev/null +++ b/egs/aspire/s5/local/segmentation/prepare_fisher_data_overlapped_speech.sh @@ -0,0 +1,113 @@ +#! /bin/bash + +# Copyright 2016 Vimal Manohar +# Apache 2.0. + +# This script prepares Fisher data for training speech activity detection, +# music detection, and overlapped speech detection systems. + +. path.sh +. cmd.sh + +if [ $# -ne 0 ]; then + echo "Usage: $0" + echo "This script is to serve as an example recipe." + echo "Edit the script to change variables if needed." + exit 1 +fi + +dir=exp/unsad/make_unsad_fisher_train_100k # Work dir +subset=60 # Number of recordings to keep before speed perturbation and corruption +utt_subset=75000 # Number of utterances to keep after speed perturbation for adding overlapped-speech + +# All the paths below can be modified to any absolute path. + +# The original data directory which will be converted to a whole (recording-level) directory. +train_data_dir=data/fisher_train_100k + +model_dir=exp/tri3a # Model directory used for decoding +sat_model_dir=exp/tri4a # Model directory used for getting alignments +lang=data/lang # Language directory +lang_test=data/lang_test # Language directory used to build graph + +# Hard code the mapping from phones to SAD labels +# 0 for silence, 1 for speech, 2 for noise, 3 for unk +cat < $dir/fisher_sad.map +sil 0 +sil_B 0 +sil_E 0 +sil_I 0 +sil_S 0 +laughter 2 +laughter_B 2 +laughter_E 2 +laughter_I 2 +laughter_S 2 +noise 2 +noise_B 2 +noise_E 2 +noise_I 2 +noise_S 2 +oov 3 +oov_B 3 +oov_E 3 +oov_I 3 +oov_S 3 +EOF + +# Expecting the user to have done run.sh to have $model_dir, +# $sat_model_dir, $lang, $lang_test, $train_data_dir +local/segmentation/prepare_unsad_data.sh \ + --sad-map $dir/fisher_sad.map \ + --config-dir conf \ + --reco-nj 40 --nj 100 --cmd "$train_cmd" \ + --sat-model $sat_model_dir \ + --lang-test $lang_test \ + $train_data_dir $lang $model_dir $dir + +orig_data_dir=${train_data_dir}_sp + +data_dir=${train_data_dir}_whole + +if [ ! -z $subset ]; then + # Work on a subset + utils/subset_data_dir.sh ${data_dir} $subset \ + ${data_dir}_$subset + data_dir=${data_dir}_$subset +fi + +reco_vad_dir=$dir/`basename $model_dir`_reco_vad_`basename $train_data_dir`_sp + +# Add noise from MUSAN corpus to data directory and create a new data directory +local/segmentation/do_corruption_data_dir.sh \ + --num-data-reps 5 \ + --data-dir $data_dir \ + --reco-vad-dir $reco_vad_dir + --feat-suffix hires_bp --mfcc-config conf/mfcc_hires_bp.conf + +# Add music from MUSAN corpus to data directory and create a new data directory +local/segmentation/do_corruption_data_dir_music.sh \ + --num-data-reps 5 \ + --data-dir $data_dir \ + --reco-vad-dir $reco_vad_dir + --feat-suffix hires_bp --mfcc-config conf/mfcc_hires_bp.conf + +if [ ! -z $utt_subset ]; then + utils/subset_data_dir.sh ${orig_data_dir} $utt_subset \ + ${orig_data_dir}_`echo $utt_subset | perl -e 's/000$/k/'` + orig_data_dir=${orig_data_dir}_`echo $utt_subset | perl -e 's/000$/k/'` +fi + +# Add overlapping speech from $orig_data_dir/segments and create a new data directory +utt_vad_dir=$dir/`baseline $sat_model_dir`_ali_`basename $train_data_dir`_sp_vad_`basename $train_data_dir`_sp +local/segmentation/do_corruption_data_dir_overlapped_speech.sh \ + --nj 40 --cmd queue.pl \ + --num-data-reps 1 \ + --data-dir ${orig_data_dir} \ + --utt-vad-dir $utt_vad_dir + +local/segmentation/prepare_unsad_overlapped_speech_labels.sh \ + --num-data-reps 1 --nj 40 --cmd queue.pl \ + ${orig_data_dir}_ovlp_corrupted_hires_bp \ + ${orig_data_dir}_ovlp_corrupted/overlapped_segments_info.txt \ + $utt_vad_dir exp/make_overlap_labels overlap_labels diff --git a/egs/aspire/s5/local/segmentation/run_fisher_babel.sh b/egs/aspire/s5/local/segmentation/run_fisher_babel.sh new file mode 100644 index 00000000000..bdf6d3585f7 --- /dev/null +++ b/egs/aspire/s5/local/segmentation/run_fisher_babel.sh @@ -0,0 +1,2 @@ + +utils/combine_data.sh From d16de41d4b0ccf28f114947bd05b80c5004e88fb Mon Sep 17 00:00:00 2001 From: Vimal Manohar Date: Mon, 2 Jan 2017 18:40:54 -0500 Subject: [PATCH 132/213] asr_diarization: Prepare labels for AMI --- .../s5/local/segmentation/prepare_ami.sh | 213 ++++++++++++++++++ 1 file changed, 213 insertions(+) create mode 100755 egs/aspire/s5/local/segmentation/prepare_ami.sh diff --git a/egs/aspire/s5/local/segmentation/prepare_ami.sh b/egs/aspire/s5/local/segmentation/prepare_ami.sh new file mode 100755 index 00000000000..38ed9559c89 --- /dev/null +++ b/egs/aspire/s5/local/segmentation/prepare_ami.sh @@ -0,0 +1,213 @@ +#! /bin/bash + +# Copyright 2016 Vimal Manohar +# Apache 2.0. + +. cmd.sh +. path.sh + +set -e +set -o pipefail +set -u + +stage=-1 + +dataset=dev +nj=18 + +. utils/parse_options.sh + +export PATH=$KALDI_ROOT/tools/sctk/bin:$PATH + +src_dir=/export/a09/vmanoha1/workspace_asr_diarization/egs/ami/s5b # AMI src_dir +dir=exp/sad_ami_sdm1_${dataset}/ref + +mkdir -p $dir + +# Expecting user to have done run.sh to run the AMI recipe in $src_dir for +# both sdm and ihm microphone conditions + +if [ $stage -le 1 ]; then + ( + cd $src_dir + local/prepare_parallel_train_data.sh --train-set ${dataset} sdm1 + + awk '{print $1" "$2}' $src_dir/data/ihm/${dataset}/segments > \ + $src_dir/data/ihm/${dataset}/utt2reco + awk '{print $1" "$2}' $src_dir/data/sdm1/${dataset}/segments > \ + $src_dir/data/sdm1/${dataset}/utt2reco + + cat $src_dir/data/sdm1/${dataset}_ihmdata/ihmutt2utt | \ + utils/filter_scp.pl -f 1 $src_dir/data/ihm/${dataset}/utt2reco | \ + utils/apply_map.pl -f 1 $src_dir/data/ihm/${dataset}/utt2reco | \ + utils/filter_scp.pl -f 2 $src_dir/data/sdm1/${dataset}/utt2reco | \ + utils/apply_map.pl -f 2 $src_dir/data/sdm1/${dataset}/utt2reco | \ + sort -u > $src_dir/data/sdm1/${dataset}_ihmdata/ihm2sdm_reco + ) +fi + +[ ! -s $src_dir/data/sdm1/${dataset}_ihmdata/ihm2sdm_reco ] && echo "Empty $src_dir/data/sdm1/${dataset}_ihmdata/ihm2sdm_reco!" && exit 1 + +phone_map=$dir/phone_map +if [ $stage -le 2 ]; then + ( + cd $src_dir + utils/data/get_reco2utt.sh $src_dir/data/sdm1/${dataset} + + steps/make_mfcc.sh --nj $nj --cmd "$train_cmd" \ + data/sdm1/${dataset}_ihmdata exp/sdm1/make_mfcc mfcc_sdm1 + steps/compute_cmvn_stats.sh \ + data/sdm1/${dataset}_ihmdata exp/sdm1/make_mfcc mfcc_sdm1 + utils/fix_data_dir.sh data/sdm1/${dataset}_ihmdata + ) + + steps/segmentation/get_sad_map.py \ + $src_dir/data/lang | utils/sym2int.pl -f 1 $src_dir/data/lang/phones.txt > \ + $phone_map +fi + +if [ $stage -le 3 ]; then + # Expecting user to have run local/run_cleanup_segmentation.sh in $src_dir + ( + cd $src_dir + steps/align_fmllr.sh --nj $nj --cmd "$train_cmd" \ + data/sdm1/${dataset}_ihmdata data/lang \ + exp/ihm/tri3_cleaned \ + exp/sdm1/tri3_cleaned_${dataset}_ihmdata + ) +fi + +if [ $stage -le 4 ]; then + steps/segmentation/internal/convert_ali_to_vad.sh --cmd "$train_cmd" \ + $src_dir/exp/sdm1/tri3_cleaned_${dataset}_ihmdata $phone_map $dir +fi + +echo "A 1" > $dir/channel_map +cat $src_dir/data/sdm1/${dataset}/reco2file_and_channel | \ + utils/apply_map.pl -f 3 $dir/channel_map > $dir/reco2file_and_channel + +utils/data/get_reco2utt.sh $src_dir/data/sdm1/${dataset}_ihmdata +cat $src_dir/data/sdm1/${dataset}_ihmdata/reco2utt | \ + awk 'BEGIN{i=1} {print $1" "i; i++;}' > \ + $src_dir/data/sdm1/${dataset}_ihmdata/reco.txt + +if [ $stage -le 5 ]; then + # Reference RTTM where SPEECH frames are obtainted by combining IHM VAD alignments + cat $src_dir/data/sdm1/${dataset}_ihmdata/reco.txt | \ + awk '{print $1" 1:"$2" 10000:10000 0:0"}' > $dir/ref_spk2label_map + + $train_cmd $dir/log/get_ref_spk_seg.log \ + segmentation-combine-segments --include-missing-utt-level-segmentations scp:$dir/sad_seg.scp \ + "ark:segmentation-init-from-segments --segment-label=10000 --shift-to-zero=false $src_dir/data/sdm1/${dataset}_ihmdata/segments ark:- |" \ + ark,t:$src_dir/data/sdm1/${dataset}_ihmdata/reco2utt ark:- \| \ + segmentation-copy --utt2label-map-rspecifier=ark,t:$dir/ref_spk2label_map \ + ark:- ark:- \| \ + segmentation-merge-recordings \ + "ark,t:utils/utt2spk_to_spk2utt.pl $src_dir/data/sdm1/${dataset}_ihmdata/ihm2sdm_reco |" \ + ark:- "ark:| gzip -c > $dir/ref_spk_seg.gz" +fi + +if [ $stage -le 6 ]; then + utils/data/get_reco2num_frames.sh --frame-shift 0.01 --frame-overlap 0.015 \ + --cmd queue.pl --nj $nj \ + $src_dir/data/sdm1/${dataset} + + # Get a filter that selects only regions within the manual segments. + $train_cmd $dir/log/get_manual_segments_regions.log \ + segmentation-init-from-segments --shift-to-zero=false $src_dir/data/sdm1/${dataset}/segments ark:- \| \ + segmentation-combine-segments-to-recordings ark:- ark,t:$src_dir/data/sdm1/${dataset}/reco2utt ark:- \| \ + segmentation-create-subsegments --filter-label=1 --subsegment-label=1 \ + "ark:segmentation-init-from-lengths --label=0 ark,t:$src_dir/data/sdm1/${dataset}/reco2num_frames ark:- |" ark:- ark,t:- \| \ + perl -ane '$F[3] = 10000; $F[$#F-1] = 10000; print join(" ", @F) . "\n";' \| \ + segmentation-create-subsegments --filter-label=10000 --subsegment-label=10000 \ + ark,t:- "ark:gunzip -c $dir/ref_spk_seg.gz |" ark:- \| \ + segmentation-post-process --merge-labels=0:1 --merge-dst-label=1 ark:- ark:- \| \ + segmentation-post-process --merge-labels=10000 --merge-dst-label=0 --merge-adjacent-segments \ + --max-intersegment-length=10000 ark,t:- \ + "ark:| gzip -c > $dir/manual_segments_regions.seg.gz" +fi + +if [ $stage -le 7 ]; then + # To get the actual RTTM, we need to add no-score + $train_cmd $dir/log/get_ref_rttm.log \ + segmentation-get-stats --lengths-rspecifier=ark,t:$src_dir/data/sdm1/${dataset}/reco2num_frames \ + "ark:gunzip -c $dir/ref_spk_seg.gz | segmentation-post-process --remove-labels=0:10000 ark:- ark:- |" \ + ark:/dev/null ark:- \| \ + segmentation-init-from-ali ark:- ark:- \| \ + segmentation-post-process --merge-labels=1:2:3:4:5:6:7:8:9:10 --merge-dst-label=1 \ + --merge-adjacent-segments --max-intersegment-length=10000 ark:- ark:- \| \ + segmentation-create-subsegments --filter-label=0 --subsegment-label=10000 \ + ark:- "ark:gunzip -c $dir/manual_segments_regions.seg.gz |" ark:- \| \ + segmentation-post-process --merge-adjacent-segments --max-intersegment-length=10000 ark:- ark:- \| \ + segmentation-to-rttm --reco2file-and-channel=$dir/reco2file_and_channel \ + --no-score-label=10000 ark:- $dir/ref.rttm +fi + + +if [ $stage -le 8 ]; then + # Get RTTM for overlapped speech detection with 3 classes + # 0 -> SILENCE, 1 -> SINGLE_SPEAKER, 2 -> OVERLAP + $train_cmd $dir/log/get_overlapping_rttm.log \ + segmentation-get-stats --lengths-rspecifier=ark,t:$src_dir/data/sdm1/${dataset}/reco2num_frames \ + "ark:gunzip -c $dir/ref_spk_seg.gz | segmentation-post-process --remove-labels=0:10000 ark:- ark:- |" \ + ark:/dev/null ark:- \| \ + segmentation-init-from-ali ark:- ark:- \| \ + segmentation-post-process --merge-labels=2:3:4:5:6:7:8:9:10 --merge-dst-label=2 \ + --merge-adjacent-segments --max-intersegment-length=10000 ark:- ark:- \| \ + segmentation-create-subsegments --filter-label=0 --subsegment-label=10000 \ + ark:- "ark:gunzip -c $dir/manual_segments_regions.seg.gz |" ark:- \| \ + segmentation-post-process --merge-adjacent-segments --max-intersegment-length=10000 ark:- ark:- \| \ + segmentation-to-rttm --map-to-speech-and-sil=false --reco2file-and-channel=$dir/reco2file_and_channel \ + --no-score-label=10000 ark:- $dir/overlapping_speech_ref.rttm +fi + +if [ $stage -le 9 ]; then + # Get a filter that selects only regions of speech + $train_cmd $dir/log/get_speech_filter.log \ + segmentation-get-stats --lengths-rspecifier=ark,t:$src_dir/data/sdm1/${dataset}/reco2num_frames \ + "ark:gunzip -c $dir/ref_spk_seg.gz | segmentation-post-process --remove-labels=0:10000 ark:- ark:- |" \ + ark:/dev/null ark:- \| \ + segmentation-init-from-ali ark:- ark:- \| \ + segmentation-post-process --merge-labels=1:2:3:4:5:6:7:8:9:10 --merge-dst-label=1 ark:- ark:- \| \ + segmentation-create-subsegments --filter-label=0 --subsegment-label=0 \ + ark:- "ark:gunzip -c $dir/manual_segments_regions.seg.gz |" ark:- \| \ + segmentation-post-process --merge-adjacent-segments --max-intersegment-length=10000 \ + ark:- "ark:| gzip -c > $dir/manual_segments_speech_regions.seg.gz" +fi + +# make $dir an absolute pathname. +dir=`perl -e '($dir,$pwd)= @ARGV; if($dir!~m:^/:) { $dir = "$pwd/$dir"; } print $dir; ' $dir ${PWD}` + +if [ $stage -le 10 ]; then + $train_cmd $dir/log/get_overlapping_sad.log \ + segmentation-get-stats --lengths-rspecifier=ark,t:$src_dir/data/sdm1/${dataset}/reco2num_frames \ + "ark:gunzip -c $dir/ref_spk_seg.gz | segmentation-post-process --remove-labels=0:10000 ark:- ark:- |" \ + ark:/dev/null ark:- \| \ + segmentation-init-from-ali ark:- ark:- \| \ + segmentation-post-process --merge-labels=2:3:4:5:6:7:8:9:10 --merge-dst-label=2 \ + --merge-adjacent-segments --max-intersegment-length=10000 ark:- ark:- \| \ + segmentation-post-process --merge-adjacent-segments --max-intersegment-length=10000 ark:- ark:- \| \ + segmentation-to-ali ark:- ark,scp:$dir/overlapping_sad_labels.ark,$dir/overlapping_sad_labels.scp + + $train_cmd $dir/log/get_deriv_weights_for_overlapping_sad.log \ + segmentation-to-ali "ark:gunzip -c $dir/manual_segments_regions.seg.gz |" ark,t:- \| \ + steps/segmentation/convert_ali_to_vec.pl \| \ + copy-vector ark,t: ark,scp:$dir/deriv_weights_for_overlapping_sad.ark,$dir/deriv_weights_for_overlapping_sad.scp +fi + +if false && [ $stage -le 11 ]; then + utils/data/convert_data_dir_to_whole.sh \ + $src_dir/data/sdm1/${dataset} data/ami_sdm1_${dataset}_whole + utils/fix_data_dir.sh \ + data/ami_sdm1_${dataset}_whole + utils/copy_data_dir.sh \ + data/ami_sdm1_${dataset}_whole data/ami_sdm1_${dataset}_whole_hires_bp + utils/data/downsample_data_dir.sh 8000 data/ami_sdm1_${dataset}_whole_hires_bp + + steps/make_mfcc.sh --mfcc-config conf/mfcc_hires_bp.conf --nj $nj \ + data/ami_sdm1_${dataset}_whole_hires_bp exp/make_hires_bp mfcc_hires_bp + steps/compute_cmvn_stats.sh --fake \ + data/ami_sdm1_${dataset}_whole_hires_bp exp/make_hires_bp mfcc_hires_bp + utils/fix_data_dir.sh \ + data/ami_sdm1_${dataset}_whole_hires_bp +fi From 5ac841e33cb47a46f88b66f813e2f1961bea7134 Mon Sep 17 00:00:00 2001 From: Vimal Manohar Date: Mon, 2 Jan 2017 18:41:29 -0500 Subject: [PATCH 133/213] asr_diarization: segmentation configs --- egs/aspire/s5/conf/segmentation_music.conf | 14 ++++++++++++++ egs/aspire/s5/conf/segmentation_ovlp.conf | 14 ++++++++++++++ egs/aspire/s5/conf/segmentation_speech.conf | 14 ++++++++++++++ 3 files changed, 42 insertions(+) create mode 100644 egs/aspire/s5/conf/segmentation_music.conf create mode 100644 egs/aspire/s5/conf/segmentation_ovlp.conf create mode 100644 egs/aspire/s5/conf/segmentation_speech.conf diff --git a/egs/aspire/s5/conf/segmentation_music.conf b/egs/aspire/s5/conf/segmentation_music.conf new file mode 100644 index 00000000000..28b5feaf5d5 --- /dev/null +++ b/egs/aspire/s5/conf/segmentation_music.conf @@ -0,0 +1,14 @@ +# General segmentation options +pad_length=-1 # Pad speech segments by this many frames on either side +max_blend_length=-1 # Maximum duration of speech that will be removed as part + # of smoothing process. This is only if there are no other + # speech segments nearby. +max_intersegment_length=0 # Merge nearby speech segments if the silence + # between them is less than this many frames. +post_pad_length=-1 # Pad speech segments by this many frames on either side + # after the merging process using max_intersegment_length +max_segment_length=1000 # Segments that are longer than this are split into + # overlapping frames. +overlap_length=250 # Overlapping frames when segments are split. + # See the above option. +min_silence_length=100000 # Min silence length at which to split very long segments diff --git a/egs/aspire/s5/conf/segmentation_ovlp.conf b/egs/aspire/s5/conf/segmentation_ovlp.conf new file mode 100644 index 00000000000..28b5feaf5d5 --- /dev/null +++ b/egs/aspire/s5/conf/segmentation_ovlp.conf @@ -0,0 +1,14 @@ +# General segmentation options +pad_length=-1 # Pad speech segments by this many frames on either side +max_blend_length=-1 # Maximum duration of speech that will be removed as part + # of smoothing process. This is only if there are no other + # speech segments nearby. +max_intersegment_length=0 # Merge nearby speech segments if the silence + # between them is less than this many frames. +post_pad_length=-1 # Pad speech segments by this many frames on either side + # after the merging process using max_intersegment_length +max_segment_length=1000 # Segments that are longer than this are split into + # overlapping frames. +overlap_length=250 # Overlapping frames when segments are split. + # See the above option. +min_silence_length=100000 # Min silence length at which to split very long segments diff --git a/egs/aspire/s5/conf/segmentation_speech.conf b/egs/aspire/s5/conf/segmentation_speech.conf new file mode 100644 index 00000000000..c4c75b212fc --- /dev/null +++ b/egs/aspire/s5/conf/segmentation_speech.conf @@ -0,0 +1,14 @@ +# General segmentation options +pad_length=20 # Pad speech segments by this many frames on either side +max_relabel_length=10 # Maximum duration of speech that will be removed as part + # of smoothing process. This is only if there are no other + # speech segments nearby. +max_intersegment_length=30 # Merge nearby speech segments if the silence + # between them is less than this many frames. +post_pad_length=10 # Pad speech segments by this many frames on either side + # after the merging process using max_intersegment_length +max_segment_length=1000 # Segments that are longer than this are split into + # overlapping frames. +overlap_length=250 # Overlapping frames when segments are split. + # See the above option. +min_silence_length=20 # Min silence length at which to split very long segments From baa5bf48b9947e1c62d59892cfb8c4c650c7b74d Mon Sep 17 00:00:00 2001 From: Vimal Manohar Date: Mon, 16 Jan 2017 13:39:38 -0500 Subject: [PATCH 134/213] asr_diarization: Support per-utt gmm global --- src/gmmbin/gmm-global-copy.cc | 44 ++- src/gmmbin/gmm-global-get-post.cc | 77 +++-- .../gmm-global-init-models-from-feats.cc | 291 ++++++++++++++++++ 3 files changed, 385 insertions(+), 27 deletions(-) create mode 100644 src/gmmbin/gmm-global-init-models-from-feats.cc diff --git a/src/gmmbin/gmm-global-copy.cc b/src/gmmbin/gmm-global-copy.cc index af31b03aa9a..b850cdced51 100644 --- a/src/gmmbin/gmm-global-copy.cc +++ b/src/gmmbin/gmm-global-copy.cc @@ -29,11 +29,13 @@ int main(int argc, char *argv[]) { const char *usage = "Copy a diagonal-covariance GMM\n" "Usage: gmm-global-copy [options] \n" + " or gmm-global-copy [options] \n" "e.g.: gmm-global-copy --binary=false 1.model - | less"; bool binary_write = true; ParseOptions po(usage); - po.Register("binary", &binary_write, "Write output in binary mode"); + po.Register("binary", &binary_write, + "Write in binary mode (only relevant if output is a wxfilename)"); po.Read(argc, argv); @@ -45,15 +47,39 @@ int main(int argc, char *argv[]) { std::string model_in_filename = po.GetArg(1), model_out_filename = po.GetArg(2); - DiagGmm gmm; - { - bool binary_read; - Input ki(model_in_filename, &binary_read); - gmm.Read(ki.Stream(), binary_read); - } - WriteKaldiObject(gmm, model_out_filename, binary_write); + // all these "fn"'s are either rspecifiers or filenames. + + bool in_is_rspecifier = + (ClassifyRspecifier(model_in_filename, NULL, NULL) + != kNoRspecifier), + out_is_wspecifier = + (ClassifyWspecifier(model_out_filename, NULL, NULL, NULL) + != kNoWspecifier); + + if (in_is_rspecifier != out_is_wspecifier) + KALDI_ERR << "Cannot mix archives with regular files (copying gmm models)"; - KALDI_LOG << "Written model to " << model_out_filename; + if (!in_is_rspecifier) { + DiagGmm gmm; + { + bool binary_read; + Input ki(model_in_filename, &binary_read); + gmm.Read(ki.Stream(), binary_read); + } + WriteKaldiObject(gmm, model_out_filename, binary_write); + + KALDI_LOG << "Written model to " << model_out_filename; + } else { + SequentialDiagGmmReader gmm_reader(model_in_filename); + DiagGmmWriter gmm_writer(model_out_filename); + + int32 num_done = 0; + for (; !gmm_reader.Done(); gmm_reader.Next(), num_done++) { + gmm_writer.Write(gmm_reader.Key(), gmm_reader.Value()); + } + + KALDI_LOG << "Wrote " << num_done << " GMM models to " << model_out_filename; + } } catch(const std::exception &e) { std::cerr << e.what() << '\n'; return -1; diff --git a/src/gmmbin/gmm-global-get-post.cc b/src/gmmbin/gmm-global-get-post.cc index b364c33cab4..2092d1348f0 100644 --- a/src/gmmbin/gmm-global-get-post.cc +++ b/src/gmmbin/gmm-global-get-post.cc @@ -36,35 +36,51 @@ int main(int argc, char *argv[]) { " (e.g. in training UBMs, SGMMs, tied-mixture systems)\n" " For each frame, gives a list of the n best Gaussian indices,\n" " sorted from best to worst.\n" - "Usage: gmm-global-get-post [options] \n" - "e.g.: gmm-global-get-post --n=20 1.gmm \"ark:feature-command |\" \"ark,t:|gzip -c >post.1.gz\"\n"; + "Usage: gmm-global-get-post [options] []\n" + "e.g.: gmm-global-get-post --n=20 1.gmm \"ark:feature-command |\" \"ark,t:|gzip -c >post.1.gz\"\n" + " or : gmm-global-get-post --n=20 ark:1.gmm \"ark:feature-command |\" \"ark,t:|gzip -c >post.1.gz\"\n"; ParseOptions po(usage); int32 num_post = 50; BaseFloat min_post = 0.0; + std::string utt2spk_rspecifier; + po.Register("n", &num_post, "Number of Gaussians to keep per frame\n"); po.Register("min-post", &min_post, "Minimum posterior we will output " "before pruning and renormalizing (e.g. 0.01)"); + po.Register("utt2spk", &utt2spk_rspecifier, + "rspecifier for utterance to speaker map"); po.Read(argc, argv); - if (po.NumArgs() != 3) { + if (po.NumArgs() < 3 || po.NumArgs() > 4) { po.PrintUsage(); exit(1); } - std::string model_filename = po.GetArg(1), + std::string model_in_filename = po.GetArg(1), feature_rspecifier = po.GetArg(2), - post_wspecifier = po.GetArg(3); + post_wspecifier = po.GetArg(3), + frame_loglikes_wspecifier = po.GetOptArg(4); - DiagGmm gmm; - ReadKaldiObject(model_filename, &gmm); + RandomAccessDiagGmmReaderMapped *gmm_reader = NULL; + DiagGmm *gmm = NULL; + KALDI_ASSERT(num_post > 0); KALDI_ASSERT(min_post < 1.0); - int32 num_gauss = gmm.NumGauss(); - if (num_post > num_gauss) { - KALDI_WARN << "You asked for " << num_post << " Gaussians but GMM " - << "only has " << num_gauss << ", returning this many. "; - num_post = num_gauss; + + if (ClassifyRspecifier(model_in_filename, NULL, NULL) + != kNoRspecifier) { // reading models from a Table. + gmm_reader = new RandomAccessDiagGmmReaderMapped(model_in_filename, + utt2spk_rspecifier); + } else { + gmm = new DiagGmm(); + ReadKaldiObject(model_in_filename, gmm); + int32 num_gauss = gmm->NumGauss(); + if (num_post > num_gauss) { + KALDI_WARN << "You asked for " << num_post << " Gaussians but GMM " + << "only has " << num_gauss << ", returning this many. "; + num_post = num_gauss; + } } double tot_like = 0.0; @@ -72,10 +88,11 @@ int main(int argc, char *argv[]) { SequentialBaseFloatMatrixReader feature_reader(feature_rspecifier); PosteriorWriter post_writer(post_wspecifier); + BaseFloatVectorWriter likes_writer(frame_loglikes_wspecifier); int32 num_done = 0, num_err = 0; for (; !feature_reader.Done(); feature_reader.Next()) { - std::string utt = feature_reader.Key(); + const std::string &utt = feature_reader.Key(); const Matrix &feats = feature_reader.Value(); int32 T = feats.NumRows(); if (T == 0) { @@ -83,9 +100,20 @@ int main(int argc, char *argv[]) { num_err++; continue; } - if (feats.NumCols() != gmm.Dim()) { + + if (gmm_reader) { + if (!gmm_reader.HasKey(utt)) { + KALDI_WARN << "Could not find GMM for utterance " << utt; + num_err++; + continue; + } + gmm = gmm_reader.Value(utt); + } + + if (feats.NumCols() != gmm->Dim()) { KALDI_WARN << "Dimension mismatch for utterance " << utt - << ": got " << feats.NumCols() << ", expected " << gmm.Dim(); + << ": got " << feats.NumCols() << ", expected " + << gmm->Dim(); num_err++; continue; } @@ -93,15 +121,22 @@ int main(int argc, char *argv[]) { Matrix loglikes; - gmm.LogLikelihoods(feats, &loglikes); + gmm->LogLikelihoods(feats, &loglikes); + + Vector frame_loglikes; + if (!frame_loglikes_wspecifier.empty()) frame_loglikes.Resize(T); Posterior post(T); double log_like_this_file = 0.0; for (int32 t = 0; t < T; t++) { - log_like_this_file += - VectorToPosteriorEntry(loglikes.Row(t), num_post, + double log_like_this_frame = + VectorToPosteriorEntry(loglikes.Row(t), + num_post > num_gauss ? num_gauss : num_post, min_post, &(post[t])); + if (!frame_loglikes_wspecifier.empty()) + frame_loglikes(t) = log_like_this_frame; + log_like_this_file += log_like_this_frame; } KALDI_VLOG(1) << "Processed utterance " << utt << ", average likelihood " << (log_like_this_file / T) << " over " << T << " frames"; @@ -109,8 +144,14 @@ int main(int argc, char *argv[]) { tot_t += T; post_writer.Write(utt, post); + if (!frame_loglikes_wspecifier.empty()) + frame_loglikes.Write(utt, frame_loglikes); + num_done++; } + + delete gmm_reader; + delete gmm; KALDI_LOG << "Done " << num_done << " files, " << num_err << " with errors, average UBM log-likelihood is " diff --git a/src/gmmbin/gmm-global-init-models-from-feats.cc b/src/gmmbin/gmm-global-init-models-from-feats.cc new file mode 100644 index 00000000000..486ba5af27b --- /dev/null +++ b/src/gmmbin/gmm-global-init-models-from-feats.cc @@ -0,0 +1,291 @@ +// gmmbin/gmm-global-init-models-from-feats.cc + +// Copyright 2013 Johns Hopkins University (author: Daniel Povey) + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + + +#include "base/kaldi-common.h" +#include "util/common-utils.h" +#include "gmm/model-common.h" +#include "gmm/full-gmm.h" +#include "gmm/diag-gmm.h" +#include "gmm/mle-full-gmm.h" + +namespace kaldi { + +// We initialize the GMM parameters by setting the variance to the global +// variance of the features, and the means to distinct randomly chosen frames. +void InitGmmFromRandomFrames(const MatrixBase &feats, DiagGmm *gmm) { + int32 num_gauss = gmm->NumGauss(), num_frames = feats.NumRows(), + dim = feats.NumCols(); + KALDI_ASSERT(num_frames >= 10 * num_gauss && "Too few frames to train on"); + Vector mean(dim), var(dim); + for (int32 i = 0; i < num_frames; i++) { + mean.AddVec(1.0 / num_frames, feats.Row(i)); + var.AddVec2(1.0 / num_frames, feats.Row(i)); + } + var.AddVec2(-1.0, mean); + if (var.Max() <= 0.0) + KALDI_ERR << "Features do not have positive variance " << var; + + DiagGmmNormal gmm_normal(*gmm); + + std::set used_frames; + for (int32 g = 0; g < num_gauss; g++) { + int32 random_frame = RandInt(0, num_frames - 1); + while (used_frames.count(random_frame) != 0) + random_frame = RandInt(0, num_frames - 1); + used_frames.insert(random_frame); + gmm_normal.weights_(g) = 1.0 / num_gauss; + gmm_normal.means_.Row(g).CopyFromVec(feats.Row(random_frame)); + gmm_normal.vars_.Row(g).CopyFromVec(var); + } + gmm->CopyFromNormal(gmm_normal); + gmm->ComputeGconsts(); +} + +void TrainOneIter(const MatrixBase &feats, + const MleDiagGmmOptions &gmm_opts, + int32 iter, + int32 num_threads, + DiagGmm *gmm) { + AccumDiagGmm gmm_acc(*gmm, kGmmAll); + + Vector frame_weights(feats.NumRows(), kUndefined); + frame_weights.Set(1.0); + + double tot_like; + tot_like = gmm_acc.AccumulateFromDiagMultiThreaded(*gmm, feats, frame_weights, + num_threads); + + KALDI_LOG << "Likelihood per frame on iteration " << iter + << " was " << (tot_like / feats.NumRows()) << " over " + << feats.NumRows() << " frames."; + + BaseFloat objf_change, count; + MleDiagGmmUpdate(gmm_opts, gmm_acc, kGmmAll, gmm, &objf_change, &count); + + KALDI_LOG << "Objective-function change on iteration " << iter << " was " + << (objf_change / count) << " over " << count << " frames."; +} + +void TrainGmm(const MatrixBase &feats, + const MleDiagGmmOptions &gmm_opts, + int32 num_gauss, int32 num_gauss_init, int32 num_iters, + int32 num_threads, DiagGmm *gmm) { + KALDI_LOG << "Initializing GMM means from random frames to " + << num_gauss_init << " Gaussians."; + InitGmmFromRandomFrames(feats, gmm); + + // we'll increase the #Gaussians by splitting, + // till halfway through training. + int32 cur_num_gauss = num_gauss_init, + gauss_inc = (num_gauss - num_gauss_init) / (num_iters / 2); + + for (int32 iter = 0; iter < num_iters; iter++) { + TrainOneIter(feats, gmm_opts, iter, num_threads, gmm); + + int32 next_num_gauss = std::min(num_gauss, cur_num_gauss + gauss_inc); + if (next_num_gauss > gmm->NumGauss()) { + KALDI_LOG << "Splitting to " << next_num_gauss << " Gaussians."; + gmm->Split(next_num_gauss, 0.1); + cur_num_gauss = next_num_gauss; + } + } +} + +} // namespace kaldi + +int main(int argc, char *argv[]) { + try { + using namespace kaldi; + + const char *usage = + "This program initializes a single diagonal GMM and does multiple iterations of\n" + "training from features stored in memory.\n" + "Usage: gmm-global-init-from-feats [options] \n" + "e.g.: gmm-global-init-from-feats scp:train.scp ark:1.ark\n"; + + ParseOptions po(usage); + MleDiagGmmOptions gmm_opts; + + bool binary = true; + int32 num_gauss = 100; + int32 num_gauss_init = 0; + int32 num_iters = 50; + int32 num_frames = 200000; + int32 srand_seed = 0; + int32 num_threads = 4; + std::string spk2utt_rspecifier; + + po.Register("binary", &binary, "Write output in binary mode"); + po.Register("num-gauss", &num_gauss, "Number of Gaussians in the model"); + po.Register("num-gauss-init", &num_gauss_init, "Number of Gaussians in " + "the model initially (if nonzero and less than num_gauss, " + "we'll do mixture splitting)"); + po.Register("num-iters", &num_iters, "Number of iterations of training"); + po.Register("num-frames", &num_frames, "Number of feature vectors to store in " + "memory and train on (randomly chosen from the input features)"); + po.Register("srand", &srand_seed, "Seed for random number generator "); + po.Register("num-threads", &num_threads, "Number of threads used for " + "statistics accumulation"); + po.Register("spk2utt-rspecifier", &spk2utt_rspecifier, + "If specified, estimates models per-speaker"); + + gmm_opts.Register(&po); + + po.Read(argc, argv); + + srand(srand_seed); + + if (po.NumArgs() != 2) { + po.PrintUsage(); + exit(1); + } + + if (num_gauss_init <= 0 || num_gauss_init > num_gauss) + num_gauss_init = num_gauss; + + std::string feature_rspecifier = po.GetArg(1), + model_wspecifier = po.GetArg(2); + + DiagGmmWriter gmm_writer(model_wspecifier); + + KALDI_ASSERT(num_frames > 0); + + KALDI_LOG << "Reading features (will keep " << num_frames << " frames " + << "per utterance.)"; + + int32 dim = 0; + + if (!spk2utt_rspecifier.empty()) { + SequentialBaseFloatMatrixReader feature_reader(feature_rspecifier); + for (; !feature_reader.Done(); feature_reader.Next()) { + const Matrix &this_feats = feature_reader.Value(); + if (dim == 0) { + dim = this_feats.NumCols(); + } else if (this_feats.NumCols() != dim) { + KALDI_ERR << "Features have inconsistent dims " + << this_feats.NumCols() << " vs. " << dim + << " (current utt is) " << feature_reader.Key(); + } + + Matrix feats(num_frames, dim); + int64 num_read = 0; + + for (int32 t = 0; t < this_feats.NumRows(); t++) { + num_read++; + if (num_read <= num_frames) { + feats.Row(num_read - 1).CopyFromVec(this_feats.Row(t)); + } else { + BaseFloat keep_prob = num_frames / static_cast(num_read); + if (WithProb(keep_prob)) { // With probability "keep_prob" + feats.Row(RandInt(0, num_frames - 1)).CopyFromVec(this_feats.Row(t)); + } + } + } + + if (num_read < num_frames) { + KALDI_WARN << "For utterance " << feature_reader.Key() << ", " + << "number of frames read " << num_read << " was less than " + << "target number " << num_frames << ", using all we read."; + feats.Resize(num_read, dim, kCopyData); + } else { + BaseFloat percent = num_frames * 100.0 / num_read; + KALDI_LOG << "For utterance " << feature_reader.Key() << ", " + << "kept " << num_frames << " out of " << num_read + << " input frames = " << percent << "%."; + } + + DiagGmm gmm(num_gauss_init, dim); + TrainGmm(feats, gmm_opts, num_gauss, num_gauss_init, num_iters, + num_threads, &gmm); + + gmm_writer.Write(feature_reader.Key(), gmm); + } + KALDI_LOG << "Done initializing GMMs."; + } else { + SequentialTokenVectorReader spk2utt_reader(spk2utt_rspecifier); + RandomAccessBaseFloatMatrixReader feature_reader(feature_rspecifier); + + int32 num_err = 0; + for (; !spk2utt_reader.Done(); spk2utt_reader.Next()) { + Matrix feats; + int64 num_read = 0; + + const std::vector &uttlist = spk2utt_reader.Value(); + + for (std::vector::const_iterator it = uttlist.begin(); + it != uttlist.end(); ++it) { + if (!feature_reader.HasKey(*it)) { + KALDI_WARN << "Could not find features for utterance " << *it; + num_err++; + } + + const Matrix &this_feats = feature_reader.Value(*it); + if (dim == 0) { + dim = this_feats.NumCols(); + feats.Resize(num_frames, dim); + } else if (this_feats.NumCols() != dim) { + KALDI_ERR << "Features have inconsistent dims " + << this_feats.NumCols() << " vs. " << dim + << " (current utt is) " << *it; + } + + for (int32 t = 0; t < this_feats.NumRows(); t++) { + num_read++; + if (num_read <= num_frames) { + feats.Row(num_read - 1).CopyFromVec(this_feats.Row(t)); + } else { + BaseFloat keep_prob = num_frames / static_cast(num_read); + if (WithProb(keep_prob)) { // With probability "keep_prob" + feats.Row(RandInt(0, num_frames - 1)).CopyFromVec(this_feats.Row(t)); + } + } + } + } + + if (num_read < num_frames) { + KALDI_WARN << "For speaker " << spk2utt_reader.Key() << ", " + << "number of frames read " << num_read << " was less than " + << "target number " << num_frames << ", using all we read."; + feats.Resize(num_read, dim, kCopyData); + } else { + BaseFloat percent = num_frames * 100.0 / num_read; + KALDI_LOG << "For spekear " << spk2utt_reader.Key() << ", " + << "kept " << num_frames << " out of " << num_read + << " input frames = " << percent << "%."; + } + + DiagGmm gmm(num_gauss_init, dim); + TrainGmm(feats, gmm_opts, num_gauss, num_gauss_init, num_iters, + num_threads, &gmm); + + gmm_writer.Write(spk2utt_reader.Key(), gmm); + } + + KALDI_LOG << "Done initializing GMMs. Failed getting features for " + << num_err << "utterances"; + } + + return 0; + } catch(const std::exception &e) { + std::cerr << e.what(); + return -1; + } +} + From dafec02a4c04f3ac6c1e23d0d7ce2b9cc53cadc6 Mon Sep 17 00:00:00 2001 From: Vimal Manohar Date: Thu, 19 Jan 2017 18:07:12 -0500 Subject: [PATCH 135/213] asr_diarization: Fix some bugs in segmenter code and make it simpler --- src/segmenter/Makefile | 8 +-- src/segmenter/segment.h | 19 +++++++ src/segmenter/segmentation-post-processor.cc | 3 +- src/segmenter/segmentation-post-processor.h | 6 ++- src/segmenter/segmentation-utils.cc | 21 ++++++-- src/segmenter/segmentation-utils.h | 11 +++- src/segmenterbin/Makefile | 4 +- src/segmenterbin/segmentation-copy.cc | 3 +- ...ntation-init-from-additive-signals-info.cc | 53 +++++++++---------- .../segmentation-init-from-segments.cc | 43 +++++++-------- .../segmentation-remove-segments.cc | 8 ++- 11 files changed, 115 insertions(+), 64 deletions(-) diff --git a/src/segmenter/Makefile b/src/segmenter/Makefile index 03df6132050..8a9b37cad75 100644 --- a/src/segmenter/Makefile +++ b/src/segmenter/Makefile @@ -2,14 +2,16 @@ all: include ../kaldi.mk -TESTFILES = segmentation-io-test +TESTFILES = segmentation-io-test information-bottleneck-clusterable-test OBJFILES = segment.o segmentation.o segmentation-utils.o \ - segmentation-post-processor.o + segmentation-post-processor.o \ + information-bottleneck-clusterable.o \ + information-bottleneck-cluster-utils.o LIBNAME = kaldi-segmenter -ADDLIBS = ../gmm/kaldi-gmm.a \ +ADDLIBS = ../tree/kaldi-tree.a ../gmm/kaldi-gmm.a \ ../util/kaldi-util.a ../matrix/kaldi-matrix.a ../base/kaldi-base.a ../thread/kaldi-thread.a include ../makefiles/default_rules.mk diff --git a/src/segmenter/segment.h b/src/segmenter/segment.h index b54b5367c73..f7ada5b92ee 100644 --- a/src/segmenter/segment.h +++ b/src/segmenter/segment.h @@ -1,3 +1,22 @@ +// segmenter/segment.h" + +// Copyright 2016 Vimal Manohar (Johns Hopkins University) + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + #ifndef KALDI_SEGMENTER_SEGMENT_H_ #define KALDI_SEGMENTER_SEGMENT_H_ diff --git a/src/segmenter/segmentation-post-processor.cc b/src/segmenter/segmentation-post-processor.cc index 2c97e31db56..e8c7747c8c4 100644 --- a/src/segmenter/segmentation-post-processor.cc +++ b/src/segmenter/segmentation-post-processor.cc @@ -177,7 +177,8 @@ void SegmentationPostProcessor::DoBlendingShortSegments( void SegmentationPostProcessor::DoRemovingSegments(Segmentation *seg) const { if (!IsRemovingSegmentsToBeDone(opts_)) return; - RemoveSegments(remove_labels_, seg); + RemoveSegments(remove_labels_, opts_.max_remove_length, + seg); } void SegmentationPostProcessor::DoMergingAdjacentSegments( diff --git a/src/segmenter/segmentation-post-processor.h b/src/segmenter/segmentation-post-processor.h index 01a23b93b1b..0de54d026e1 100644 --- a/src/segmenter/segmentation-post-processor.h +++ b/src/segmenter/segmentation-post-processor.h @@ -47,6 +47,7 @@ struct SegmentationPostProcessingOptions { int32 max_blend_length; std::string remove_labels_csl; + int32 max_remove_length; bool merge_adjacent_segments; int32 max_intersegment_length; @@ -63,7 +64,7 @@ struct SegmentationPostProcessingOptions { blend_short_segments_class(-1), max_blend_length(-1), merge_adjacent_segments(false), max_intersegment_length(0), max_segment_length(-1), overlap_length(0), - post_process_label(-1) { } + max_remove_length(-1), post_process_label(-1) { } void Register(OptionsItf *opts) { opts->Register("merge-labels", &merge_labels_csl, "Merge labels into a " @@ -109,6 +110,9 @@ struct SegmentationPostProcessingOptions { "Remove any segment whose label is contained in " "remove_labels_csl. " "Refer to the RemoveLabels() code for details."); + opts->Register("max-remove-length", &max_remove_length, + "If provided, specifies the maximum length of segments " + "that will be removed by --remove-labels option"); opts->Register("merge-adjacent-segments", &merge_adjacent_segments, "Merge adjacent segments of the same label if they are " "within max-intersegment-length distance. " diff --git a/src/segmenter/segmentation-utils.cc b/src/segmenter/segmentation-utils.cc index c69d7ff3397..3cece810d45 100644 --- a/src/segmenter/segmentation-utils.cc +++ b/src/segmenter/segmentation-utils.cc @@ -54,18 +54,27 @@ void RelabelSegmentsUsingMap(const unordered_map &label_map, } for (SegmentList::iterator it = segmentation->Begin(); - it != segmentation->End(); ++it) { + it != segmentation->End(); ) { unordered_map::const_iterator map_it = label_map.find( it->Label()); + int32 dest_label = -100; if (map_it == label_map.end()) { if (default_label == -1) KALDI_ERR << "Could not find label " << it->Label() << " in label map."; else - it->SetLabel(default_label); + dest_label = default_label; } else { - it->SetLabel(map_it->second); + dest_label = map_it->second; } + + if (dest_label == -1) { + // Remove segments that will be mapped to label -1. + it = segmentation->Erase(it); + continue; + } + it->SetLabel(dest_label); + ++it; } } @@ -98,6 +107,7 @@ void RemoveSegments(int32 label, Segmentation *segmentation) { } void RemoveSegments(const std::vector &labels, + int32 max_remove_length, Segmentation *segmentation) { // Check if sorted and unique KALDI_ASSERT(std::adjacent_find(labels.begin(), @@ -105,7 +115,10 @@ void RemoveSegments(const std::vector &labels, for (SegmentList::iterator it = segmentation->Begin(); it != segmentation->End(); ) { - if (std::binary_search(labels.begin(), labels.end(), it->Label())) { + if ((max_remove_length == -1 || + it->Length() < max_remove_length) && + std::binary_search(labels.begin(), labels.end(), + it->Label())) { it = segmentation->Erase(it); } else { ++it; diff --git a/src/segmenter/segmentation-utils.h b/src/segmenter/segmentation-utils.h index 30136ab0a5a..4fa3271e874 100644 --- a/src/segmenter/segmentation-utils.h +++ b/src/segmenter/segmentation-utils.h @@ -56,12 +56,19 @@ void ScaleFrameShift(BaseFloat factor, Segmentation *segmentation); void RemoveSegments(int32 label, Segmentation *segmentation); /** - * This is very straight forward. It removes any segment whose label is - * contained in the vector "labels" + * This removes any segment whose label is + * contained in the vector "labels" and has a length smaller than + * max_remove_length. max_remove_length can be provided -1 to + * specify a value of +infinity i.e. to remove segments + * based on only the labels and irrespective of their lengths. **/ void RemoveSegments(const std::vector &labels, + int32 max_remove_length, Segmentation *segmentation); +void RemoveShortSegments(int32 label, int32 min_length, + Segmentation *segmentation); + // Keep only segments of label "label" void KeepSegments(int32 label, Segmentation *segmentation); diff --git a/src/segmenterbin/Makefile b/src/segmenterbin/Makefile index 22a74e70551..6e2fd226019 100644 --- a/src/segmenterbin/Makefile +++ b/src/segmenterbin/Makefile @@ -17,7 +17,9 @@ BINFILES = segmentation-copy segmentation-get-stats \ segmentation-create-overlapped-subsegments \ segmentation-intersect-segments \ segmentation-init-from-additive-signals-info \ - class-counts-per-frame-to-labels#\ + class-counts-per-frame-to-labels \ + agglomerative-cluster-ib \ + intersect-int-vectors #\ gmm-acc-pdf-stats-segmentation \ gmm-est-segmentation gmm-update-segmentation \ segmentation-init-from-diarization \ diff --git a/src/segmenterbin/segmentation-copy.cc b/src/segmenterbin/segmentation-copy.cc index e3384170805..b7e215b55f8 100644 --- a/src/segmenterbin/segmentation-copy.cc +++ b/src/segmenterbin/segmentation-copy.cc @@ -54,7 +54,8 @@ int main(int argc, char *argv[]) { "Write in binary mode " "(only relevant if output is a wxfilename)"); po.Register("label-map", &label_map_rxfilename, - "File with mapping from old to new labels"); + "File with mapping from old to new labels. " + "If new label is -1, then that segment is removed."); po.Register("frame-subsampling-factor", &frame_subsampling_factor, "Change frame rate by this factor"); po.Register("utt2label-map-rspecifier", &utt2label_map_rspecifier, diff --git a/src/segmenterbin/segmentation-init-from-additive-signals-info.cc b/src/segmenterbin/segmentation-init-from-additive-signals-info.cc index ccddb4c2b60..abf5aed219b 100644 --- a/src/segmenterbin/segmentation-init-from-additive-signals-info.cc +++ b/src/segmenterbin/segmentation-init-from-additive-signals-info.cc @@ -29,13 +29,13 @@ int main(int argc, char *argv[]) { const char *usage = "Convert overlapping segments information into segmentation\n" "\n" - "Usage: segmentation-init-from-additive-signals-info [options] " + "Usage: segmentation-init-from-additive-signals-info [options] " " \n" " e.g.: segmentation-init-from-additive-signals-info --additive-signals-segmentation-rspecifier=ark:utt_segmentation.ark " - "ark:reco_segmentation.ark ark,t:overlapped_segments_info.txt ark:-\n"; + "ark,t:overlapped_segments_info.txt ark:-\n"; BaseFloat frame_shift = 0.01; - int32 junk_label = -1; + int32 junk_label = -2; std::string lengths_rspecifier; std::string additive_signals_segmentation_rspecifier; @@ -50,42 +50,35 @@ int main(int argc, char *argv[]) { "Archive of segmentation of the additive signal which will used " "instead of an all 1 segmentation"); po.Register("junk-label", &junk_label, - "If specified, then unreliable regions are labeled with this " - "label"); + "The unreliable regions are labeled with this label"); po.Read(argc, argv); - if (po.NumArgs() != 3) { + if (po.NumArgs() != 2) { po.PrintUsage(); exit(1); } - std::string reco_segmentation_rspecifier = po.GetArg(1), - additive_signals_info_rspecifier = po.GetArg(2), - segmentation_wspecifier = po.GetArg(3); + std::string additive_signals_info_rspecifier = po.GetArg(1), + segmentation_wspecifier = po.GetArg(2); - SequentialSegmentationReader reco_segmentation_reader(reco_segmentation_rspecifier); - RandomAccessTokenVectorReader additive_signals_info_reader(additive_signals_info_rspecifier); + SequentialTokenVectorReader additive_signals_info_reader( + additive_signals_info_rspecifier); SegmentationWriter writer(segmentation_wspecifier); - RandomAccessSegmentationReader additive_signals_segmentation_reader(additive_signals_segmentation_rspecifier); - + RandomAccessSegmentationReader additive_signals_segmentation_reader( + additive_signals_segmentation_rspecifier); RandomAccessInt32Reader lengths_reader(lengths_rspecifier); - int32 num_done = 0, num_err = 0, num_missing = 0; + int32 num_done = 0, num_err = 0; - for (; !reco_segmentation_reader.Done(); reco_segmentation_reader.Next()) { - const std::string &key = reco_segmentation_reader.Key(); - - if (!additive_signals_info_reader.HasKey(key)) { - KALDI_WARN << "Could not find additive_signals_info for key " << key; - num_missing++; - continue; - } + for (; !additive_signals_info_reader.Done(); + additive_signals_info_reader.Next()) { + const std::string &key = additive_signals_info_reader.Key(); const std::vector &additive_signals_info = - additive_signals_info_reader.Value(key); + additive_signals_info_reader.Value(); - Segmentation segmentation(reco_segmentation_reader.Value()); + Segmentation segmentation; for (size_t i = 0; i < additive_signals_info.size(); i++) { std::vector parts; @@ -107,7 +100,9 @@ int main(int argc, char *argv[]) { if (!additive_signals_segmentation_reader.HasKey(utt_id)) { KALDI_WARN << "Could not find utterance " << utt_id << " in " - << "segmentation " << additive_signals_segmentation_rspecifier; + << "segmentation " + << additive_signals_segmentation_rspecifier + << ". Assiginng the segment --junk-label."; if (duration < 0) { KALDI_ERR << "duration < 0 for utt_id " << utt_id << " in " << "additive_signals_info " @@ -143,14 +138,14 @@ int main(int argc, char *argv[]) { } KALDI_LOG << "Successfully processed " << num_done << " recordings " - << " in additive signals info; failed for " << num_missing - << "; could not get segmentation for " << num_err; + << " in additive signals info" + << "; could not get segmentation for " << num_err + << "additive signals."; - return (num_done > (num_missing/ 2) ? 0 : 1); + return (num_done > num_err / 2 ? 0 : 1); } catch(const std::exception &e) { std::cerr << e.what(); return -1; } } - diff --git a/src/segmenterbin/segmentation-init-from-segments.cc b/src/segmenterbin/segmentation-init-from-segments.cc index c39996b5ef4..469b4ef2965 100644 --- a/src/segmenterbin/segmentation-init-from-segments.cc +++ b/src/segmenterbin/segmentation-init-from-segments.cc @@ -27,15 +27,15 @@ // Beta-001 Beta 0.50 2.66 // Beta-002 Beta 3.50 5.20 // the output segmentation will contain -// Alpha-001 [ 0 16 1 ] -// Alpha-002 [ 0 360 1 ] -// Beta-001 [ 0 216 1 ] -// Beta-002 [ 0 170 1 ] +// Alpha-001 [ 0 15 1 ] +// Alpha-002 [ 0 359 1 ] +// Beta-001 [ 0 215 1 ] +// Beta-002 [ 0 169 1 ] // If --shift-to-zero=false is provided, then the output will contain -// Alpha-001 [ 0 16 1 ] -// Alpha-002 [ 150 410 1 ] -// Beta-001 [ 50 266 1 ] -// Beta-002 [ 350 520 1 ] +// Alpha-001 [ 0 15 1 ] +// Alpha-002 [ 150 409 1 ] +// Beta-001 [ 50 265 1 ] +// Beta-002 [ 350 519 1 ] // // If the following utt2label-rspecifier was provided: // Alpha-001 2 @@ -43,10 +43,10 @@ // Beta-001 4 // Beta-002 4 // then the output segmentation will contain -// Alpha-001 [ 0 16 2 ] -// Alpha-002 [ 0 360 2 ] -// Beta-001 [ 0 216 4 ] -// Beta-002 [ 0 170 4 ] +// Alpha-001 [ 0 15 2 ] +// Alpha-002 [ 0 359 2 ] +// Beta-001 [ 0 215 4 ] +// Beta-002 [ 0 169 4 ] int main(int argc, char *argv[]) { try { @@ -153,15 +153,16 @@ int main(int argc, char *argv[]) { segment_label = utt2label_reader.Value(utt); } - int32 length = round((end - frame_overlap)/ frame_shift) - - round(start / frame_shift); - - if (shift_to_zero) - segmentation.EmplaceBack(0, length, segment_label); - else - segmentation.EmplaceBack(round(start / frame_shift), - round((end-frame_overlap) / frame_shift) - 1, - segment_label); + if (shift_to_zero) { + int32 last_frame = (end-frame_overlap) / frame_shift + - start / frame_shift - 1; + segmentation.EmplaceBack(0, last_frame, segment_label); + } else { + segmentation.EmplaceBack( + static_cast(start / frame_shift + 0.5), + static_cast((end-frame_overlap) / frame_shift - 0.5), + segment_label); + } writer.Write(utt, segmentation); num_done++; diff --git a/src/segmenterbin/segmentation-remove-segments.cc b/src/segmenterbin/segmentation-remove-segments.cc index ce3ef2de6fd..27af1420e54 100644 --- a/src/segmenterbin/segmentation-remove-segments.cc +++ b/src/segmenterbin/segmentation-remove-segments.cc @@ -45,6 +45,7 @@ int main(int argc, char *argv[]) { bool binary = true; int32 remove_label = -1; + int32 max_remove_length = -1; std::string remove_labels_rspecifier = ""; ParseOptions po(usage); @@ -55,6 +56,11 @@ int main(int argc, char *argv[]) { po.Register("remove-label", &remove_label, "Remove segments of this label"); po.Register("remove-labels-rspecifier", &remove_labels_rspecifier, "Specify colon separated list of labels for each key"); + po.Register("max-remove-length", &max_remove_length, + "If supplied, this specifies the maximum length of segments " + "will be removed. A value of -1 specifies a length of " + "+infinity i.e. segments will be removed based " + "on only their labels and irrespective of their lengths."); po.Read(argc, argv); @@ -135,7 +141,7 @@ int main(int argc, char *argv[]) { remove_label = remove_labels[0]; - RemoveSegments(remove_labels, &segmentation); + RemoveSegments(remove_labels, max_remove_length, &segmentation); } else { RemoveSegments(remove_label, &segmentation); } From 31a3e79e7c539aafdfd862d906ae76d861966294 Mon Sep 17 00:00:00 2001 From: Vimal Manohar Date: Thu, 19 Jan 2017 18:07:58 -0500 Subject: [PATCH 136/213] asr_diarzation: Rename get_subsegmented_feats.sh --- egs/wsj/s5/utils/data/get_subsegment_feats.sh | 47 +------------------ .../s5/utils/data/get_subsegmented_feats.sh | 46 ++++++++++++++++++ 2 files changed, 47 insertions(+), 46 deletions(-) mode change 100755 => 120000 egs/wsj/s5/utils/data/get_subsegment_feats.sh create mode 100755 egs/wsj/s5/utils/data/get_subsegmented_feats.sh diff --git a/egs/wsj/s5/utils/data/get_subsegment_feats.sh b/egs/wsj/s5/utils/data/get_subsegment_feats.sh deleted file mode 100755 index 6baba68eedd..00000000000 --- a/egs/wsj/s5/utils/data/get_subsegment_feats.sh +++ /dev/null @@ -1,46 +0,0 @@ -#! /bin/bash - -# Copyright 2016 Johns Hopkins University (Author: Dan Povey) -# 2016 Vimal Manohar -# Apache 2.0. - -if [ $# -ne 4 ]; then - echo "This scripts gets subsegmented_feats (by adding ranges to data/feats.scp) " - echo "for the subsegments file. This is does one part of the " - echo "functionality in subsegment_data_dir.sh, which additionally " - echo "creates a new subsegmented data directory." - echo "Usage: $0 " - echo " e.g.: $0 data/train/feats.scp 0.01 0.015 subsegments" - exit 1 -fi - -feats=$1 -frame_shift=$2 -frame_overlap=$3 -subsegments=$4 - -# The subsegments format is . -# e.g. 'utt_foo-1 utt_foo 7.21 8.93' -# The first awk command replaces this with the format: -# -# e.g. 'utt_foo-1 utt_foo 721 893' -# and the apply_map.pl command replaces 'utt_foo' (the 2nd field) with its corresponding entry -# from the original wav.scp, so we get a line like: -# e.g. 'utt_foo-1 foo-bar.ark:514231 721 892' -# Note: the reason we subtract one from the last time is that it's going to -# represent the 'last' frame, not the 'end' frame [i.e. not one past the last], -# in the matlab-like, but zero-indexed [first:last] notion. For instance, a segment with 1 frame -# would have start-time 0.00 and end-time 0.01, which would become the frame range -# [0:0] -# The second awk command turns this into something like -# utt_foo-1 foo-bar.ark:514231[721:892] -# It has to be a bit careful because the format actually allows for more general things -# like pipes that might contain spaces, so it has to be able to produce output like the -# following: -# utt_foo-1 some command|[721:892] -# Lastly, utils/data/normalize_data_range.pl will only do something nontrivial if -# the original data-dir already had data-ranges in square brackets. -awk -v s=$frame_shift -v fovlp=$frame_overlap '{print $1, $2, int(($3/s)+0.5), int(($4-fovlp)/s+0.5);}' <$subsegments| \ - utils/apply_map.pl -f 2 $feats | \ - awk '{p=NF-1; for (n=1;n " + echo " e.g.: $0 data/train/feats.scp 0.01 0.015 subsegments" + exit 1 +fi + +feats=$1 +frame_shift=$2 +frame_overlap=$3 +subsegments=$4 + +# The subsegments format is . +# e.g. 'utt_foo-1 utt_foo 7.21 8.93' +# The first awk command replaces this with the format: +# +# e.g. 'utt_foo-1 utt_foo 721 893' +# and the apply_map.pl command replaces 'utt_foo' (the 2nd field) with its corresponding entry +# from the original wav.scp, so we get a line like: +# e.g. 'utt_foo-1 foo-bar.ark:514231 721 892' +# Note: the reason we subtract one from the last time is that it's going to +# represent the 'last' frame, not the 'end' frame [i.e. not one past the last], +# in the matlab-like, but zero-indexed [first:last] notion. For instance, a segment with 1 frame +# would have start-time 0.00 and end-time 0.01, which would become the frame range +# [0:0] +# The second awk command turns this into something like +# utt_foo-1 foo-bar.ark:514231[721:892] +# It has to be a bit careful because the format actually allows for more general things +# like pipes that might contain spaces, so it has to be able to produce output like the +# following: +# utt_foo-1 some command|[721:892] +# Lastly, utils/data/normalize_data_range.pl will only do something nontrivial if +# the original data-dir already had data-ranges in square brackets. +awk -v s=$frame_shift -v fovlp=$frame_overlap '{print $1, $2, int(($3/s)+0.5), int(($4-fovlp)/s+0.5);}' <$subsegments| \ + utils/apply_map.pl -f 2 $feats | \ + awk '{p=NF-1; for (n=1;n Date: Thu, 19 Jan 2017 18:09:10 -0500 Subject: [PATCH 137/213] asr_diarization: Modify utt2num_frames etc. --- egs/wsj/s5/utils/data/get_reco2num_frames.sh | 2 +- egs/wsj/s5/utils/data/get_utt2num_frames.sh | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/egs/wsj/s5/utils/data/get_reco2num_frames.sh b/egs/wsj/s5/utils/data/get_reco2num_frames.sh index 8df5afdb156..edb16609703 100755 --- a/egs/wsj/s5/utils/data/get_reco2num_frames.sh +++ b/egs/wsj/s5/utils/data/get_reco2num_frames.sh @@ -15,7 +15,7 @@ fi data=$1 -if [ -f $data/reco2num_frames ]; then +if [ -s $data/reco2num_frames ]; then echo "$0: $data/reco2num_frames already present!" exit 0; fi diff --git a/egs/wsj/s5/utils/data/get_utt2num_frames.sh b/egs/wsj/s5/utils/data/get_utt2num_frames.sh index e2921601ec9..3f6d15c45a5 100755 --- a/egs/wsj/s5/utils/data/get_utt2num_frames.sh +++ b/egs/wsj/s5/utils/data/get_utt2num_frames.sh @@ -31,12 +31,12 @@ if [ ! -f $data/feats.scp ]; then exit 0 fi -utils/split_data.sh $data $nj || exit 1 +utils/split_data.sh --per-utt $data $nj || exit 1 $cmd JOB=1:$nj $data/log/get_utt2num_frames.JOB.log \ - feat-to-len scp:$data/split${nj}/JOB/feats.scp ark,t:$data/split$nj/JOB/utt2num_frames || exit 1 + feat-to-len scp:$data/split${nj}utt/JOB/feats.scp ark,t:$data/split${nj}utt/JOB/utt2num_frames || exit 1 for n in `seq $nj`; do - cat $data/split$nj/$n/utt2num_frames + cat $data/split${nj}utt/$n/utt2num_frames done > $data/utt2num_frames echo "$0: Computed and wrote $data/utt2num_frames" From bf44dda710cbd6cfa2daba76224128ceded74cc6 Mon Sep 17 00:00:00 2001 From: Vimal Manohar Date: Wed, 25 Jan 2017 14:43:38 -0500 Subject: [PATCH 138/213] asr_diarization: gmm-global-get-post to support archives of models --- src/gmmbin/gmm-global-get-post.cc | 28 ++- .../gmm-global-init-models-from-feats.cc | 229 ++++++++++++++++-- 2 files changed, 224 insertions(+), 33 deletions(-) rename src/{gmmbin => segmenterbin}/gmm-global-init-models-from-feats.cc (55%) diff --git a/src/gmmbin/gmm-global-get-post.cc b/src/gmmbin/gmm-global-get-post.cc index 2092d1348f0..35438a7e849 100644 --- a/src/gmmbin/gmm-global-get-post.cc +++ b/src/gmmbin/gmm-global-get-post.cc @@ -49,7 +49,8 @@ int main(int argc, char *argv[]) { po.Register("min-post", &min_post, "Minimum posterior we will output " "before pruning and renormalizing (e.g. 0.01)"); po.Register("utt2spk", &utt2spk_rspecifier, - "rspecifier for utterance to speaker map"); + "rspecifier for utterance to speaker map for reading " + "per-speaker GMM models"); po.Read(argc, argv); if (po.NumArgs() < 3 || po.NumArgs() > 4) { @@ -63,7 +64,7 @@ int main(int argc, char *argv[]) { frame_loglikes_wspecifier = po.GetOptArg(4); RandomAccessDiagGmmReaderMapped *gmm_reader = NULL; - DiagGmm *gmm = NULL; + DiagGmm diag_gmm; KALDI_ASSERT(num_post > 0); KALDI_ASSERT(min_post < 1.0); @@ -73,9 +74,8 @@ int main(int argc, char *argv[]) { gmm_reader = new RandomAccessDiagGmmReaderMapped(model_in_filename, utt2spk_rspecifier); } else { - gmm = new DiagGmm(); - ReadKaldiObject(model_in_filename, gmm); - int32 num_gauss = gmm->NumGauss(); + ReadKaldiObject(model_in_filename, &diag_gmm); + int32 num_gauss = diag_gmm.NumGauss(); if (num_post > num_gauss) { KALDI_WARN << "You asked for " << num_post << " Gaussians but GMM " << "only has " << num_gauss << ", returning this many. "; @@ -88,7 +88,7 @@ int main(int argc, char *argv[]) { SequentialBaseFloatMatrixReader feature_reader(feature_rspecifier); PosteriorWriter post_writer(post_wspecifier); - BaseFloatVectorWriter likes_writer(frame_loglikes_wspecifier); + BaseFloatVectorWriter frame_loglikes_writer(frame_loglikes_wspecifier); int32 num_done = 0, num_err = 0; for (; !feature_reader.Done(); feature_reader.Next()) { @@ -101,14 +101,19 @@ int main(int argc, char *argv[]) { continue; } + const DiagGmm *gmm; if (gmm_reader) { - if (!gmm_reader.HasKey(utt)) { + if (!gmm_reader->HasKey(utt)) { KALDI_WARN << "Could not find GMM for utterance " << utt; num_err++; continue; } - gmm = gmm_reader.Value(utt); + gmm = &(gmm_reader->Value(utt)); + } else { + gmm = &diag_gmm; } + int32 num_gauss_to_compute = + num_post > gmm->NumGauss() ? gmm->NumGauss() : num_post; if (feats.NumCols() != gmm->Dim()) { KALDI_WARN << "Dimension mismatch for utterance " << utt @@ -117,8 +122,6 @@ int main(int argc, char *argv[]) { num_err++; continue; } - vector > gselect(T); - Matrix loglikes; gmm->LogLikelihoods(feats, &loglikes); @@ -132,7 +135,7 @@ int main(int argc, char *argv[]) { for (int32 t = 0; t < T; t++) { double log_like_this_frame = VectorToPosteriorEntry(loglikes.Row(t), - num_post > num_gauss ? num_gauss : num_post, + num_gauss_to_compute, min_post, &(post[t])); if (!frame_loglikes_wspecifier.empty()) frame_loglikes(t) = log_like_this_frame; @@ -145,13 +148,12 @@ int main(int argc, char *argv[]) { post_writer.Write(utt, post); if (!frame_loglikes_wspecifier.empty()) - frame_loglikes.Write(utt, frame_loglikes); + frame_loglikes_writer.Write(utt, frame_loglikes); num_done++; } delete gmm_reader; - delete gmm; KALDI_LOG << "Done " << num_done << " files, " << num_err << " with errors, average UBM log-likelihood is " diff --git a/src/gmmbin/gmm-global-init-models-from-feats.cc b/src/segmenterbin/gmm-global-init-models-from-feats.cc similarity index 55% rename from src/gmmbin/gmm-global-init-models-from-feats.cc rename to src/segmenterbin/gmm-global-init-models-from-feats.cc index 486ba5af27b..a472b48624c 100644 --- a/src/gmmbin/gmm-global-init-models-from-feats.cc +++ b/src/segmenterbin/gmm-global-init-models-from-feats.cc @@ -1,6 +1,7 @@ // gmmbin/gmm-global-init-models-from-feats.cc // Copyright 2013 Johns Hopkins University (author: Daniel Povey) +// 2016 Vimal Manohar // See ../../COPYING for clarification regarding multiple authors // @@ -58,10 +59,145 @@ void InitGmmFromRandomFrames(const MatrixBase &feats, DiagGmm *gmm) { gmm->ComputeGconsts(); } +void MleDiagGmmSharedVarsUpdate(const MleDiagGmmOptions &config, + const AccumDiagGmm &diag_gmm_acc, + GmmFlagsType flags, + DiagGmm *gmm, + BaseFloat *obj_change_out, + BaseFloat *count_out, + int32 *floored_elements_out, + int32 *floored_gaussians_out, + int32 *removed_gaussians_out) { + KALDI_ASSERT(gmm != NULL); + + if (flags & ~diag_gmm_acc.Flags()) + KALDI_ERR << "Flags in argument do not match the active accumulators"; + + KALDI_ASSERT(diag_gmm_acc.NumGauss() == gmm->NumGauss() && + diag_gmm_acc.Dim() == gmm->Dim()); + + int32 num_gauss = gmm->NumGauss(); + double occ_sum = diag_gmm_acc.occupancy().Sum(); + + int32 elements_floored = 0, gauss_floored = 0; + + // remember old objective value + gmm->ComputeGconsts(); + BaseFloat obj_old = MlObjective(*gmm, diag_gmm_acc); + + // First get the gmm in "normal" representation (not the exponential-model + // form). + DiagGmmNormal ngmm(*gmm); + + Vector shared_var(gmm->Dim()); + + std::vector to_remove; + for (int32 i = 0; i < num_gauss; i++) { + double occ = diag_gmm_acc.occupancy()(i); + double prob; + if (occ_sum > 0.0) + prob = occ / occ_sum; + else + prob = 1.0 / num_gauss; + + if (occ > static_cast(config.min_gaussian_occupancy) + && prob > static_cast(config.min_gaussian_weight)) { + + ngmm.weights_(i) = prob; + + // copy old mean for later normalizations + Vector old_mean(ngmm.means_.Row(i)); + + // update mean, then variance, as far as there are accumulators + if (diag_gmm_acc.Flags() & (kGmmMeans|kGmmVariances)) { + Vector mean(diag_gmm_acc.mean_accumulator().Row(i)); + mean.Scale(1.0 / occ); + // transfer to estimate + ngmm.means_.CopyRowFromVec(mean, i); + } + + if (diag_gmm_acc.Flags() & kGmmVariances) { + KALDI_ASSERT(diag_gmm_acc.Flags() & kGmmMeans); + Vector var(diag_gmm_acc.variance_accumulator().Row(i)); + var.Scale(1.0 / occ); + var.AddVec2(-1.0, ngmm.means_.Row(i)); // subtract squared means. + + // if we intend to only update the variances, we need to compensate by + // adding the difference between the new and old mean + if (!(flags & kGmmMeans)) { + old_mean.AddVec(-1.0, ngmm.means_.Row(i)); + var.AddVec2(1.0, old_mean); + } + shared_var.AddVec(occ, var); + } + } else { // Insufficient occupancy. + if (config.remove_low_count_gaussians && + static_cast(to_remove.size()) < num_gauss-1) { + // remove the component, unless it is the last one. + KALDI_WARN << "Too little data - removing Gaussian (weight " + << std::fixed << prob + << ", occupation count " << std::fixed << diag_gmm_acc.occupancy()(i) + << ", vector size " << gmm->Dim() << ")"; + to_remove.push_back(i); + } else { + KALDI_WARN << "Gaussian has too little data but not removing it because" + << (config.remove_low_count_gaussians ? + " it is the last Gaussian: i = " + : " remove-low-count-gaussians == false: g = ") << i + << ", occ = " << diag_gmm_acc.occupancy()(i) << ", weight = " << prob; + ngmm.weights_(i) = + std::max(prob, static_cast(config.min_gaussian_weight)); + } + } + } + + if (diag_gmm_acc.Flags() & kGmmVariances) { + int32 floored; + if (config.variance_floor_vector.Dim() != 0) { + floored = shared_var.ApplyFloor(config.variance_floor_vector); + } else { + floored = shared_var.ApplyFloor(config.min_variance); + } + if (floored != 0) { + elements_floored += floored; + gauss_floored++; + } + + shared_var.Scale(1.0 / occ_sum); + for (int32 i = 0; i < num_gauss; i++) { + ngmm.vars_.CopyRowFromVec(shared_var, i); + } + } + + // copy to natural representation according to flags + ngmm.CopyToDiagGmm(gmm, flags); + + gmm->ComputeGconsts(); // or MlObjective will fail. + BaseFloat obj_new = MlObjective(*gmm, diag_gmm_acc); + + if (obj_change_out) + *obj_change_out = (obj_new - obj_old); + if (count_out) *count_out = occ_sum; + if (floored_elements_out) *floored_elements_out = elements_floored; + if (floored_gaussians_out) *floored_gaussians_out = gauss_floored; + + if (to_remove.size() > 0) { + gmm->RemoveComponents(to_remove, true /*renormalize weights*/); + gmm->ComputeGconsts(); + } + if (removed_gaussians_out != NULL) *removed_gaussians_out = to_remove.size(); + + if (gauss_floored > 0) + KALDI_VLOG(2) << gauss_floored << " variances floored in " << gauss_floored + << " Gaussians."; +} + + void TrainOneIter(const MatrixBase &feats, const MleDiagGmmOptions &gmm_opts, int32 iter, int32 num_threads, + bool share_covars, DiagGmm *gmm) { AccumDiagGmm gmm_acc(*gmm, kGmmAll); @@ -86,7 +222,7 @@ void TrainOneIter(const MatrixBase &feats, void TrainGmm(const MatrixBase &feats, const MleDiagGmmOptions &gmm_opts, int32 num_gauss, int32 num_gauss_init, int32 num_iters, - int32 num_threads, DiagGmm *gmm) { + int32 num_threads, bool share_covars, DiagGmm *gmm) { KALDI_LOG << "Initializing GMM means from random frames to " << num_gauss_init << " Gaussians."; InitGmmFromRandomFrames(feats, gmm); @@ -97,7 +233,7 @@ void TrainGmm(const MatrixBase &feats, gauss_inc = (num_gauss - num_gauss_init) / (num_iters / 2); for (int32 iter = 0; iter < num_iters; iter++) { - TrainOneIter(feats, gmm_opts, iter, num_threads, gmm); + TrainOneIter(feats, gmm_opts, iter, num_threads, share_covars, gmm); int32 next_num_gauss = std::min(num_gauss, cur_num_gauss + gauss_inc); if (next_num_gauss > gmm->NumGauss()) { @@ -126,10 +262,14 @@ int main(int argc, char *argv[]) { bool binary = true; int32 num_gauss = 100; int32 num_gauss_init = 0; + int32 max_gauss = 0; + int32 min_gauss = 0; int32 num_iters = 50; int32 num_frames = 200000; int32 srand_seed = 0; int32 num_threads = 4; + BaseFloat num_gauss_fraction = -1; + bool share_covars = false; std::string spk2utt_rspecifier; po.Register("binary", &binary, "Write output in binary mode"); @@ -145,6 +285,16 @@ int main(int argc, char *argv[]) { "statistics accumulation"); po.Register("spk2utt-rspecifier", &spk2utt_rspecifier, "If specified, estimates models per-speaker"); + po.Register("num-gauss-fraction", &num_gauss_fraction, + "If specified, chooses the number of gaussians to be " + "num-gauss-fraction * min(num-frames-available, num-frames). " + "This number is expected to be in the range(0, 0.1)."); + po.Register("max-gauss", &max_gauss, "Maximum number of Gaussians allowed " + "in the model. Applicable when num_gauss_fraction is specified."); + po.Register("min-gauss", &min_gauss, "Minimum number of Gaussians allowed " + "in the model. Applicable when num_gauss_fraction is specified."); + po.Register("share-covars", &share_covars, "If true, then the variances " + "of the Gaussian components are tied."); gmm_opts.Register(&po); @@ -157,25 +307,33 @@ int main(int argc, char *argv[]) { exit(1); } - if (num_gauss_init <= 0 || num_gauss_init > num_gauss) - num_gauss_init = num_gauss; - + if (num_gauss_fraction != -1) { + KALDI_ASSERT(num_gauss_fraction > 0 && num_gauss_fraction < 0.1); + } + + KALDI_ASSERT(max_gauss >= 0 && min_gauss >= 0 && max_gauss >= min_gauss); + std::string feature_rspecifier = po.GetArg(1), model_wspecifier = po.GetArg(2); DiagGmmWriter gmm_writer(model_wspecifier); KALDI_ASSERT(num_frames > 0); - - KALDI_LOG << "Reading features (will keep " << num_frames << " frames " - << "per utterance.)"; + + if (spk2utt_rspecifier.empty()) { + KALDI_LOG << "Reading features (will keep " << num_frames << " frames " + << "per utterance.)"; + } else { + KALDI_LOG << "Reading features (will keep " << num_frames << " frames " + << "per speaker.)"; + } int32 dim = 0; - if (!spk2utt_rspecifier.empty()) { + if (spk2utt_rspecifier.empty()) { SequentialBaseFloatMatrixReader feature_reader(feature_rspecifier); for (; !feature_reader.Done(); feature_reader.Next()) { - const Matrix &this_feats = feature_reader.Value(); + const Matrix &this_feats = feature_reader.Value(); if (dim == 0) { dim = this_feats.NumCols(); } else if (this_feats.NumCols() != dim) { @@ -198,6 +356,8 @@ int main(int argc, char *argv[]) { } } } + + KALDI_ASSERT(num_read > 0); if (num_read < num_frames) { KALDI_WARN << "For utterance " << feature_reader.Key() << ", " @@ -211,9 +371,23 @@ int main(int argc, char *argv[]) { << " input frames = " << percent << "%."; } - DiagGmm gmm(num_gauss_init, dim); - TrainGmm(feats, gmm_opts, num_gauss, num_gauss_init, num_iters, - num_threads, &gmm); + int32 this_num_gauss_init = num_gauss_init; + int32 this_num_gauss = num_gauss; + + if (num_gauss_fraction != -1) { + this_num_gauss = feats.NumRows() * num_gauss_fraction; + if (this_num_gauss > max_gauss) + this_num_gauss = max_gauss; + if (this_num_gauss < min_gauss) + this_num_gauss = min_gauss; + } + + if (this_num_gauss_init <= 0 || this_num_gauss_init > this_num_gauss) + this_num_gauss_init = this_num_gauss; + + DiagGmm gmm(this_num_gauss_init, dim); + TrainGmm(feats, gmm_opts, this_num_gauss, this_num_gauss_init, + num_iters, num_threads, share_covars, &gmm); gmm_writer.Write(feature_reader.Key(), gmm); } @@ -224,11 +398,11 @@ int main(int argc, char *argv[]) { int32 num_err = 0; for (; !spk2utt_reader.Done(); spk2utt_reader.Next()) { + const std::vector &uttlist = spk2utt_reader.Value(); + Matrix feats; int64 num_read = 0; - const std::vector &uttlist = spk2utt_reader.Value(); - for (std::vector::const_iterator it = uttlist.begin(); it != uttlist.end(); ++it) { if (!feature_reader.HasKey(*it)) { @@ -237,7 +411,7 @@ int main(int argc, char *argv[]) { } const Matrix &this_feats = feature_reader.Value(*it); - if (dim == 0) { + if (feats.NumCols() == 0) { dim = this_feats.NumCols(); feats.Resize(num_frames, dim); } else if (this_feats.NumCols() != dim) { @@ -258,6 +432,8 @@ int main(int argc, char *argv[]) { } } } + + KALDI_ASSERT(num_read > 0); if (num_read < num_frames) { KALDI_WARN << "For speaker " << spk2utt_reader.Key() << ", " @@ -271,9 +447,23 @@ int main(int argc, char *argv[]) { << " input frames = " << percent << "%."; } - DiagGmm gmm(num_gauss_init, dim); - TrainGmm(feats, gmm_opts, num_gauss, num_gauss_init, num_iters, - num_threads, &gmm); + int32 this_num_gauss_init = num_gauss_init; + int32 this_num_gauss = num_gauss; + + if (num_gauss_fraction != -1) { + this_num_gauss = feats.NumRows() * num_gauss_fraction; + if (this_num_gauss > max_gauss) + this_num_gauss = max_gauss; + if (this_num_gauss < min_gauss) + this_num_gauss = min_gauss; + } + + if (this_num_gauss_init <= 0 || this_num_gauss_init > this_num_gauss) + this_num_gauss_init = this_num_gauss; + + DiagGmm gmm(this_num_gauss_init, dim); + TrainGmm(feats, gmm_opts, this_num_gauss, this_num_gauss_init, + num_iters, num_threads, share_covars, &gmm); gmm_writer.Write(spk2utt_reader.Key(), gmm); } @@ -288,4 +478,3 @@ int main(int argc, char *argv[]) { return -1; } } - From 9bd17f271009808f27ce6e0c463e2fb06e4ae7d5 Mon Sep 17 00:00:00 2001 From: Vimal Manohar Date: Wed, 25 Jan 2017 14:44:55 -0500 Subject: [PATCH 139/213] asr_diarization: Add some debugging stuff to segmenter --- src/segmenter/segment.cc | 7 +++++++ src/segmenter/segment.h | 2 ++ src/segmenter/segmentation-post-processor.h | 6 ++++-- 3 files changed, 13 insertions(+), 2 deletions(-) diff --git a/src/segmenter/segment.cc b/src/segmenter/segment.cc index b4f485c26bc..65a91a39264 100644 --- a/src/segmenter/segment.cc +++ b/src/segmenter/segment.cc @@ -31,5 +31,12 @@ void Segment::Read(std::istream &is, bool binary) { KALDI_ASSERT(end_frame >= start_frame && start_frame >= 0); } +std::ostream& operator<<(std::ostream& os, const Segment &seg) { + os << "[ "; + seg.Write(os, false); + os << "]"; + return os; +} + } // end namespace segmenter } // end namespace kaldi diff --git a/src/segmenter/segment.h b/src/segmenter/segment.h index f7ada5b92ee..b172fa854a8 100644 --- a/src/segmenter/segment.h +++ b/src/segmenter/segment.h @@ -96,6 +96,8 @@ class SegmentLengthComparator { return lhs.Length() < rhs.Length(); } }; + +std::ostream& operator<<(std::ostream& os, const Segment &seg); } // end namespace segmenter } // end namespace kaldi diff --git a/src/segmenter/segmentation-post-processor.h b/src/segmenter/segmentation-post-processor.h index 0de54d026e1..040d6c44383 100644 --- a/src/segmenter/segmentation-post-processor.h +++ b/src/segmenter/segmentation-post-processor.h @@ -62,9 +62,11 @@ struct SegmentationPostProcessingOptions { pad_label(-1), pad_length(-1), shrink_label(-1), shrink_length(-1), blend_short_segments_class(-1), max_blend_length(-1), - merge_adjacent_segments(false), max_intersegment_length(0), + max_remove_length(-1), + merge_adjacent_segments(false), + max_intersegment_length(0), max_segment_length(-1), overlap_length(0), - max_remove_length(-1), post_process_label(-1) { } + post_process_label(-1) { } void Register(OptionsItf *opts) { opts->Register("merge-labels", &merge_labels_csl, "Merge labels into a " From 1e6b3c9b79316dcd1e087e847c4d1aff6c204e45 Mon Sep 17 00:00:00 2001 From: Vimal Manohar Date: Wed, 25 Jan 2017 14:49:48 -0500 Subject: [PATCH 140/213] asr_diarization: Preprare for SimpleHmm --- src/hmm/hmm-utils.cc | 2 -- src/hmm/transition-model.cc | 10 ++++++++++ src/hmm/transition-model.h | 5 ++++- 3 files changed, 14 insertions(+), 3 deletions(-) diff --git a/src/hmm/hmm-utils.cc b/src/hmm/hmm-utils.cc index ab0b133f708..f9e1533daac 100644 --- a/src/hmm/hmm-utils.cc +++ b/src/hmm/hmm-utils.cc @@ -231,8 +231,6 @@ GetHmmAsFstSimple(std::vector phone_window, - - // The H transducer has a separate outgoing arc for each of the symbols in ilabel_info. fst::VectorFst *GetHTransducer (const std::vector > &ilabel_info, diff --git a/src/hmm/transition-model.cc b/src/hmm/transition-model.cc index 83edbaf5805..7973be69dcd 100644 --- a/src/hmm/transition-model.cc +++ b/src/hmm/transition-model.cc @@ -240,6 +240,16 @@ TransitionModel::TransitionModel(const ContextDependencyInterface &ctx_dep, Check(); } +void TransitionModel::Init(const ContextDependencyInterface &ctx_dep, + const HmmTopology &hmm_topo) { + topo_ = hmm_topo; + // First thing is to get all possible tuples. + ComputeTuples(ctx_dep); + ComputeDerived(); + InitializeProbs(); + Check(); +} + int32 TransitionModel::TupleToTransitionState(int32 phone, int32 hmm_state, int32 pdf, int32 self_loop_pdf) const { Tuple tuple(phone, hmm_state, pdf, self_loop_pdf); // Note: if this ever gets too expensive, which is unlikely, we can refactor diff --git a/src/hmm/transition-model.h b/src/hmm/transition-model.h index 33a0d55443e..c059e319dd5 100644 --- a/src/hmm/transition-model.h +++ b/src/hmm/transition-model.h @@ -128,10 +128,13 @@ class TransitionModel { TransitionModel(const ContextDependencyInterface &ctx_dep, const HmmTopology &hmm_topo); - /// Constructor that takes no arguments: typically used prior to calling Read. TransitionModel() { } + /// Does the same things as the constructor. + void Init(const ContextDependencyInterface &ctx_dep, + const HmmTopology &hmm_topo); + void Read(std::istream &is, bool binary); // note, no symbol table: topo object always read/written w/o symbols. void Write(std::ostream &os, bool binary) const; From eaa56b44bf7fecc8242ee671dedb5d0147461141 Mon Sep 17 00:00:00 2001 From: Vimal Manohar Date: Wed, 25 Jan 2017 14:50:16 -0500 Subject: [PATCH 141/213] asr_diarization: Old version of SimpleHmm --- src/simplehmm/Makefile | 16 + src/simplehmm/decodable-simple-hmm.h | 88 ++++ src/simplehmm/simple-hmm-acc-stats-fsts.cc | 173 ++++++++ src/simplehmm/simple-hmm-computation.cc | 5 + src/simplehmm/simple-hmm-test.cc | 76 ++++ src/simplehmm/simple-hmm.cc | 456 +++++++++++++++++++++ src/simplehmm/simple-hmm.h | 274 +++++++++++++ 7 files changed, 1088 insertions(+) create mode 100644 src/simplehmm/Makefile create mode 100644 src/simplehmm/decodable-simple-hmm.h create mode 100644 src/simplehmm/simple-hmm-acc-stats-fsts.cc create mode 100644 src/simplehmm/simple-hmm-computation.cc create mode 100644 src/simplehmm/simple-hmm-test.cc create mode 100644 src/simplehmm/simple-hmm.cc create mode 100644 src/simplehmm/simple-hmm.h diff --git a/src/simplehmm/Makefile b/src/simplehmm/Makefile new file mode 100644 index 00000000000..89c9f70a8c3 --- /dev/null +++ b/src/simplehmm/Makefile @@ -0,0 +1,16 @@ +all: + + +include ../kaldi.mk + +TESTFILES = simple-hmm-test + +OBJFILES = simple-hmm.o simple-hmm-utils.o simple-hmm-graph-compiler.o + +LIBNAME = kaldi-simplehmm +ADDLIBS = ../hmm/kaldi-hmm.a ../decoder/kaldi-decoder.a \ + ../util/kaldi-util.a ../thread/kaldi-thread.a \ + ../matrix/kaldi-matrix.a ../base/kaldi-base.a + +include ../makefiles/default_rules.mk + diff --git a/src/simplehmm/decodable-simple-hmm.h b/src/simplehmm/decodable-simple-hmm.h new file mode 100644 index 00000000000..6f224ee6176 --- /dev/null +++ b/src/simplehmm/decodable-simple-hmm.h @@ -0,0 +1,88 @@ +// simplehmm/decodable-simple-hmm.h + +// Copyright 2016 Vimal Manohar + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#ifndef KALDI_SIMPLEHMM_DECODABLE_SIMPLE_HMM_H_ +#define KALDI_SIMPLEHMM_DECODABLE_SIMPLE_HMM_H_ + +#include + +#include "base/kaldi-common.h" +#include "simplehmm/simple-hmm.h" +#include "itf/decodable-itf.h" + +namespace kaldi { +namespace simple_hmm { + +class DecodableMatrixSimpleHmm: public DecodableInterface { + public: + // This constructor creates an object that will not delete "likes" + // when done. + DecodableMatrixSimpleHmm(const SimpleHmm &model, + const Matrix &likes, + BaseFloat scale): + model_(model), likes_(&likes), scale_(scale), delete_likes_(false) + { + if (likes.NumCols() != model.NumPdfs()) + KALDI_ERR << "DecodableMatrixScaledMapped: mismatch, matrix has " + << likes.NumCols() << " rows but transition-model has " + << model.NumPdfs() << " pdf-ids."; + } + + // This constructor creates an object that will delete "likes" + // when done. + DecodableMatrixSimpleHmm(const SimpleHmm &model, + BaseFloat scale, + const Matrix *likes): + model_(model), likes_(likes), scale_(scale), delete_likes_(true) { + if (likes->NumCols() != model.NumPdfs()) + KALDI_ERR << "DecodableMatrixScaledMapped: mismatch, matrix has " + << likes->NumCols() << " rows but transition-model has " + << model.NumPdfs() << " pdf-ids."; + } + + virtual int32 NumFramesReady() const { return likes_->NumRows(); } + + virtual bool IsLastFrame(int32 frame) const { + KALDI_ASSERT(frame < NumFramesReady()); + return (frame == NumFramesReady() - 1); + } + + // Note, frames are numbered from zero. + virtual BaseFloat LogLikelihood(int32 frame, int32 tid) { + return scale_ * (*likes_)(frame, model_.TransitionIdToPdfClass(tid)); + } + + // Indices are one-based! This is for compatibility with OpenFst. + virtual int32 NumIndices() const { return model_.NumTransitionIds(); } + + virtual ~DecodableMatrixSimpleHmm() { + if (delete_likes_) delete likes_; + } + private: + const SimpleHmm &model_; // for tid to pdf mapping + const Matrix *likes_; + BaseFloat scale_; + bool delete_likes_; + KALDI_DISALLOW_COPY_AND_ASSIGN(DecodableMatrixSimpleHmm); +}; + +} // namespace simple_hmm +} // namespace kaldi + +#endif // KALDI_SIMPLEHMM_DECODABLE_SIMPLE_HMM_H_ diff --git a/src/simplehmm/simple-hmm-acc-stats-fsts.cc b/src/simplehmm/simple-hmm-acc-stats-fsts.cc new file mode 100644 index 00000000000..de4a7528836 --- /dev/null +++ b/src/simplehmm/simple-hmm-acc-stats-fsts.cc @@ -0,0 +1,173 @@ +// simplehmmbin/simple-hmm-acc-stats-fsts.cc + +// Copyright 2016 Vimal Manohar (Johns Hopkins University) + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "base/kaldi-common.h" +#include "util/common-utils.h" +#include "simplehmm/simple-hmm.h" +#include "hmm/hmm-utils.h" +#include "fstext/fstext-lib.h" +#include "decoder/decoder-wrappers.h" + +int main(int argc, char *argv[]) { + try { + using namespace kaldi; + typedef kaldi::int32 int32; + using fst::SymbolTable; + using fst::VectorFst; + using fst::StdArc; + + const char *usage = + "Accumulate stats for simple HMM models from FSTs directly.\n" + "Usage: simple-hmm-acc-stats-fsts [options] " + " \n" + "e.g.: \n" + " simple-hmm-acc-stats-fsts 1.mdl ark:graphs.fsts scp:likes.scp pdf2class_map 1.stats\n"; + + ParseOptions po(usage); + + BaseFloat acoustic_scale = 1.0; + BaseFloat transition_scale = 1.0; + BaseFloat self_loop_scale = 1.0; + + po.Register("transition-scale", &transition_scale, + "Transition-probability scale [relative to acoustics]"); + po.Register("acoustic-scale", &acoustic_scale, + "Scaling factor for acoustic likelihoods"); + po.Register("self-loop-scale", &self_loop_scale, + "Scale of self-loop versus non-self-loop log probs [relative to acoustics]"); + po.Read(argc, argv); + + + if (po.NumArgs() != 5) { + po.PrintUsage(); + exit(1); + } + + std::string model_in_filename = po.GetArg(1), + fst_rspecifier = po.GetArg(2), + likes_rspecifier = po.GetArg(3), + pdf2class_map_rxfilename = po.GetArg(4), + accs_wxfilename = po.GetArg(5); + + simple_hmm::SimpleHmm model; + ReadKaldiObject(model_in_filename, &model); + + SequentialTableReader fst_reader(fst_rspecifier); + RandomAccessBaseFloatMatrixReader likes_reader(likes_rspecifier); + + std::vector pdf2class; + { + Input ki(pdf2class_map_rxfilename); + std::string line; + while (std::getline(ki.Stream(), line)) { + std::vector parts; + SplitStringToVector(line, " ", true, &parts); + if (parts.size() != 2) { + KALDI_ERR << "Invalid line " << line + << " in pdf2class-map " << pdf2class_map_rxfilename; + } + int32 pdf_id = std::atoi(parts[0].c_str()), + class_id = std::atoi(parts[1].c_str()); + + if (pdf_id != pdf2class.size()) + KALDI_ERR << "pdf2class-map is not sorted or does not contain " + << "pdf " << pdf_id - 1 << " in " + << pdf2class_map_rxfilename; + + if (pdf_id < pdf2class.size()) + KALDI_ERR << "Duplicate pdf " << pdf_id + << " in pdf2class-map " << pdf2class_map_rxfilename; + + pdf2class.push_back(class_id); + } + } + + int32 num_done = 0, num_err = 0; + double tot_like = 0.0, tot_t = 0.0; + int64 frame_count = 0; + + Vector transition_accs; + model.InitStats(&transition_accs); + + SimpleHmmComputation computation(model, pdf2class_map); + + for (; !fst_reader.Done(); fst_reader.Next()) { + const std::string &utt = fst_reader.Key(); + + if (!likes_reader.HasKey(utt)) { + num_err++; + KALDI_WARN << "No likes for utterance " << utt; + continue; + } + + const Matrix &likes = likes_reader.Value(utt); + VectorFst decode_fst(fst_reader.Value()); + fst_reader.FreeCurrent(); // this stops copy-on-write of the fst + // by deleting the fst inside the reader, since we're about to mutate + // the fst by adding transition probs. + + if (likes.NumRows() == 0) { + KALDI_WARN << "Zero-length utterance: " << utt; + num_err++; + continue; + } + + if (likes.NumCols() != pdf2class.size()) { + KALDI_ERR << "Mismatch in pdf dimension in log-likelihood matrix " + << "and pdf2class map; " << likes.NumCols() << " vs " + << pdf2class.size(); + } + + // Add transition-probs to the FST. + AddTransitionProbs(model, transition_scale, self_loop_scale, + &decode_fst); + + BaseFloat tot_like_this_utt = 0.0, tot_weight = 0.0; + if (!computation.Compute(decode_fst, likes, acoustic_scale, + &transition_accs, + &tot_like_this_utt, &tot_weight)) { + KALDI_WARN << "Failed to do computation for utterance " << utt; + num_err++; + } + tot_like += tot_like_this_utt; + tot_t += tot_weight; + frame_count += likes.NumRows(); + + num_done++; + } + + KALDI_LOG << "Done " << num_done << " files, " << num_err + << " with errors."; + + KALDI_LOG << "Overall avg like per frame = " + << (tot_like/tot_t) << " over " << tot_t << " frames."; + + { + Output ko(accs_wxfilename, binary); + transition_accs.Write(ko.Stream(), binary); + } + KALDI_LOG << "Written accs."; + return (num_done != 0 ? 0 : 1); + } catch(const std::exception &e) { + std::cerr << e.what(); + return -1; + } +} + + diff --git a/src/simplehmm/simple-hmm-computation.cc b/src/simplehmm/simple-hmm-computation.cc new file mode 100644 index 00000000000..e20f84169a1 --- /dev/null +++ b/src/simplehmm/simple-hmm-computation.cc @@ -0,0 +1,5 @@ +SimpleHmmComputation::SimpleHmmComputation( + const SimpleHmm &model, + const std::vector &num_pdfs, + VectorFst *decode_fst, + const Matrix &log_likes) diff --git a/src/simplehmm/simple-hmm-test.cc b/src/simplehmm/simple-hmm-test.cc new file mode 100644 index 00000000000..b2de0e05a08 --- /dev/null +++ b/src/simplehmm/simple-hmm-test.cc @@ -0,0 +1,76 @@ +// hmm/simple-hmm-test.cc + +// Copyright 2016 Vimal Manohar + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "simplehmm/simple-hmm.h" +#include "hmm/hmm-test-utils.h" + +namespace kaldi { +namespace simple_hmm { + + +SimpleHmm *GenRandSimpleHmm() { + std::vector phones; + phones.push_back(1); + + std::vector num_pdf_classes; + num_pdf_classes.push_back(rand() + 1); + + HmmTopology topo = GenRandTopology(phones, num_pdf_classes); + + SimpleHmm *model = new SimpleHmm(topo); + + return model; +} + + +void TestSimpleHmm() { + + SimpleHmm *model = GenRandSimpleHmm(); + + bool binary = (rand() % 2 == 0); + + std::ostringstream os; + model->Write(os, binary); + + SimpleHmm model2; + std::istringstream is2(os.str()); + model2.Read(is2, binary); + + { + std::ostringstream os1, os2; + model->Write(os1, false); + model2.Write(os2, false); + KALDI_ASSERT(os1.str() == os2.str()); + KALDI_ASSERT(model->Compatible(model2)); + } + delete model; +} + + +} // end namespace simple_hmm +} // end namespace kaldi + + +int main() { + for (int i = 0; i < 2; i++) + kaldi::TestSimpleHmm(); + KALDI_LOG << "Test OK.\n"; +} + + diff --git a/src/simplehmm/simple-hmm.cc b/src/simplehmm/simple-hmm.cc new file mode 100644 index 00000000000..9af077cedc6 --- /dev/null +++ b/src/simplehmm/simple-hmm.cc @@ -0,0 +1,456 @@ +// hmm/simple-hmm.cc + +// Copyright 2009-2012 Microsoft Corporation Johns Hopkins University (Author: Daniel Povey) +// Johns Hopkins University (author: Guoguo Chen) +// 2016 Vimal Manohar (Johns Hopkins University) + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include +#include "simplehmm/simple-hmm.h" + +namespace kaldi { +namespace simple_hmm { + +void SimpleHmm::Initialize() { + KALDI_ASSERT(topo_.GetPhones().size() == 1); + + const HmmTopology::TopologyEntry &entry = topo_.TopologyForPhone(1); + for (int32 j = 0; j < static_cast(entry.size()); j++) { // for each state... + int32 pdf_class = entry[j].forward_pdf_class; + if (pdf_class != kNoPdf) { + states_.push_back(j); + } + } + + // now states_ is populated with all possible pairs + // (hmm_state, pdf_class). + // sort to enable reverse lookup. + std::sort(states_.begin(), states_.end()); + // this sorting defines the transition-ids. +} + +void SimpleHmm::ComputeDerived() { + state2id_.resize(states_.size()+2); // indexed by transition-state, which + // is one based, but also an entry for one past end of list. + + int32 cur_transition_id = 1; + num_pdfs_ = 0; + for (int32 tstate = 1; + tstate <= static_cast(states_.size()+1); // not a typo. + tstate++) { + state2id_[tstate] = cur_transition_id; + if (static_cast(tstate) <= states_.size()) { + int32 hmm_state = states_[tstate-1]; + const HmmTopology::HmmState &state = topo_.TopologyForPhone(1)[hmm_state]; + int32 pdf_class = state.forward_pdf_class; + num_pdfs_ = std::max(num_pdfs_, pdf_class + 1); + int32 my_num_ids = static_cast(state.transitions.size()); + cur_transition_id += my_num_ids; // # trans out of this state. + } + } + + id2state_.resize(cur_transition_id); // cur_transition_id is #transition-ids+1. + for (int32 tstate = 1; + tstate <= static_cast(states_.size()); tstate++) { + for (int32 tid = state2id_[tstate]; tid < state2id_[tstate+1]; tid++) { + id2state_[tid] = tstate; + } + } +} + +void SimpleHmm::InitializeProbs() { + log_probs_.Resize(NumTransitionIds()+1); // one-based array, zeroth element empty. + for (int32 trans_id = 1; trans_id <= NumTransitionIds(); trans_id++) { + int32 trans_state = id2state_[trans_id]; + int32 trans_index = trans_id - state2id_[trans_state]; + int32 hmm_state = states_[trans_state-1]; + const HmmTopology::TopologyEntry &entry = topo_.TopologyForPhone(1); + KALDI_ASSERT(static_cast(hmm_state) < entry.size()); + BaseFloat prob = entry[hmm_state].transitions[trans_index].second; + if (prob <= 0.0) + KALDI_ERR << "SimpleHmm::InitializeProbs, zero " + "probability [should remove that entry in the topology]"; + if (prob > 1.0) + KALDI_WARN << "SimpleHmm::InitializeProbs, prob greater than one."; + log_probs_(trans_id) = Log(prob); + } + ComputeDerivedOfProbs(); +} + +void SimpleHmm::Check() const { + KALDI_ASSERT(topo_.GetPhones().size() == 1); + + KALDI_ASSERT(NumTransitionIds() != 0 && NumTransitionStates() != 0); + { + int32 sum = 0; + for (int32 ts = 1; ts <= NumTransitionStates(); ts++) sum += NumTransitionIndices(ts); + KALDI_ASSERT(sum == NumTransitionIds()); + } + for (int32 tid = 1; tid <= NumTransitionIds(); tid++) { + int32 tstate = TransitionIdToTransitionState(tid), + index = TransitionIdToTransitionIndex(tid); + KALDI_ASSERT(tstate > 0 && tstate <=NumTransitionStates() && index >= 0); + KALDI_ASSERT(tid == PairToTransitionId(tstate, index)); + int32 hmm_state = TransitionStateToHmmState(tstate); + KALDI_ASSERT(tstate == HmmStateToTransitionState(hmm_state)); + KALDI_ASSERT(log_probs_(tid) <= 0.0 && + log_probs_(tid) - log_probs_(tid) == 0.0); + // checking finite and non-positive (and not out-of-bounds). + } + + KALDI_ASSERT(num_pdfs_ == topo_.NumPdfClasses(1)); +} + +SimpleHmm::SimpleHmm( + const HmmTopology &hmm_topo): topo_(hmm_topo) { + Initialize(); + ComputeDerived(); + InitializeProbs(); + Check(); +} + +int32 SimpleHmm::HmmStateToTransitionState(int32 hmm_state) const { + // Note: if this ever gets too expensive, which is unlikely, we can refactor + // this code to sort first on pdf_class, and then index on pdf_class, so those + // that have the same pdf_class are in a contiguous range. + std::vector::const_iterator iter = + std::lower_bound(states_.begin(), states_.end(), hmm_state); + if (iter == states_.end() || !(*iter == hmm_state)) { + KALDI_ERR << "SimpleHmm::HmmStateToTransitionState; " + << "HmmState " << hmm_state << " not found." + << " (incompatible model?)"; + } + // states_is indexed by transition_state-1, so add one. + return static_cast((iter - states_.begin())) + 1; +} + + +int32 SimpleHmm::NumTransitionIndices(int32 trans_state) const { + KALDI_ASSERT(static_cast(trans_state) <= states_.size()); + return static_cast(state2id_[trans_state+1]-state2id_[trans_state]); +} + +int32 SimpleHmm::TransitionIdToTransitionState(int32 trans_id) const { + KALDI_ASSERT(trans_id != 0 && + static_cast(trans_id) < id2state_.size()); + return id2state_[trans_id]; +} + +int32 SimpleHmm::TransitionIdToTransitionIndex(int32 trans_id) const { + KALDI_ASSERT(trans_id != 0 && + static_cast(trans_id) < id2state_.size()); + return trans_id - state2id_[id2state_[trans_id]]; +} + +int32 SimpleHmm::TransitionStateToPdfClass(int32 trans_state) const { + KALDI_ASSERT(static_cast(trans_state) <= states_.size()); + int32 hmm_state = states_[trans_state-1]; + const HmmTopology::TopologyEntry &entry = topo_.TopologyForPhone(1); + KALDI_ASSERT(static_cast(hmm_state) < entry.size()); + return entry[hmm_state].forward_pdf_class; +} + +int32 SimpleHmm::TransitionStateToHmmState(int32 trans_state) const { + KALDI_ASSERT(static_cast(trans_state) <= states_.size()); + return states_[trans_state-1]; +} + +int32 SimpleHmm::PairToTransitionId(int32 trans_state, + int32 trans_index) const { + KALDI_ASSERT(static_cast(trans_state) <= states_.size()); + KALDI_ASSERT(trans_index < state2id_[trans_state+1] - state2id_[trans_state]); + return state2id_[trans_state] + trans_index; +} + +bool SimpleHmm::IsFinal(int32 trans_id) const { + KALDI_ASSERT(static_cast(trans_id) < id2state_.size()); + int32 trans_state = id2state_[trans_id]; + int32 trans_index = trans_id - state2id_[trans_state]; + int32 hmm_state = states_[trans_state-1]; + const HmmTopology::TopologyEntry &entry = topo_.TopologyForPhone(1); + KALDI_ASSERT(static_cast(hmm_state) < entry.size()); + KALDI_ASSERT(static_cast(trans_index) < + entry[hmm_state].transitions.size()); + // return true if the transition goes to the final state of the + // topology entry. + return (entry[hmm_state].transitions[trans_index].first + 1 == + static_cast(entry.size())); +} + +// returns the self-loop transition-id, +// or zero if does not exist. +int32 SimpleHmm::SelfLoopOf(int32 trans_state) const { + KALDI_ASSERT(static_cast(trans_state-1) < states_.size()); + int32 hmm_state = states_[trans_state-1]; + const HmmTopology::TopologyEntry &entry = topo_.TopologyForPhone(1); + KALDI_ASSERT(static_cast(hmm_state) < entry.size()); + for (int32 trans_index = 0; + trans_index < static_cast(entry[hmm_state].transitions.size()); + trans_index++) + if (entry[hmm_state].transitions[trans_index].first == hmm_state) + return PairToTransitionId(trans_state, trans_index); + return 0; // invalid transition id. +} + +void SimpleHmm::ComputeDerivedOfProbs() { + // this array indexed by transition-state with nothing in zeroth element. + non_self_loop_log_probs_.Resize(NumTransitionStates()+1); + for (int32 tstate = 1; tstate <= NumTransitionStates(); tstate++) { + int32 tid = SelfLoopOf(tstate); + if (tid == 0) { // no self-loop + non_self_loop_log_probs_(tstate) = 0.0; // log(1.0) + } else { + BaseFloat self_loop_prob = Exp(GetTransitionLogProb(tid)), + non_self_loop_prob = 1.0 - self_loop_prob; + if (non_self_loop_prob <= 0.0) { + KALDI_WARN << "ComputeDerivedOfProbs(): non-self-loop prob is " << non_self_loop_prob; + non_self_loop_prob = 1.0e-10; // just so we can continue... + } + non_self_loop_log_probs_(tstate) = Log(non_self_loop_prob); // will be negative. + } + } +} + +void SimpleHmm::Read(std::istream &is, bool binary) { + ExpectToken(is, binary, ""); + topo_.Read(is, binary); + Initialize(); + ComputeDerived(); + ExpectToken(is, binary, ""); + log_probs_.Read(is, binary); + ExpectToken(is, binary, ""); + ExpectToken(is, binary, ""); + ComputeDerivedOfProbs(); + Check(); +} + +void SimpleHmm::Write(std::ostream &os, bool binary) const { + WriteToken(os, binary, ""); + if (!binary) os << "\n"; + topo_.Write(os, binary); + if (!binary) os << "\n"; + WriteToken(os, binary, ""); + if (!binary) os << "\n"; + log_probs_.Write(os, binary); + WriteToken(os, binary, ""); + if (!binary) os << "\n"; + WriteToken(os, binary, ""); + if (!binary) os << "\n"; +} + +BaseFloat SimpleHmm::GetTransitionProb(int32 trans_id) const { + return Exp(log_probs_(trans_id)); +} + +BaseFloat SimpleHmm::GetTransitionLogProb(int32 trans_id) const { + return log_probs_(trans_id); +} + +BaseFloat SimpleHmm::GetNonSelfLoopLogProb(int32 trans_state) const { + KALDI_ASSERT(trans_state != 0); + return non_self_loop_log_probs_(trans_state); +} + +BaseFloat SimpleHmm::GetTransitionLogProbIgnoringSelfLoops( + int32 trans_id) const { + KALDI_ASSERT(trans_id != 0); + KALDI_PARANOID_ASSERT(!IsSelfLoop(trans_id)); + return log_probs_(trans_id) - GetNonSelfLoopLogProb(TransitionIdToTransitionState(trans_id)); +} + +// stats are counts/weights, indexed by transition-id. +void SimpleHmm::MleUpdate(const Vector &stats, + const MleSimpleHmmUpdateConfig &cfg, + BaseFloat *objf_impr_out, + BaseFloat *count_out) { + BaseFloat count_sum = 0.0, objf_impr_sum = 0.0; + int32 num_skipped = 0, num_floored = 0; + KALDI_ASSERT(stats.Dim() == NumTransitionIds()+1); + for (int32 tstate = 1; tstate <= NumTransitionStates(); tstate++) { + int32 n = NumTransitionIndices(tstate); + KALDI_ASSERT(n>=1); + if (n > 1) { // no point updating if only one transition... + Vector counts(n); + for (int32 tidx = 0; tidx < n; tidx++) { + int32 tid = PairToTransitionId(tstate, tidx); + counts(tidx) = stats(tid); + } + double tstate_tot = counts.Sum(); + count_sum += tstate_tot; + if (tstate_tot < cfg.mincount) { num_skipped++; } + else { + Vector old_probs(n), new_probs(n); + for (int32 tidx = 0; tidx < n; tidx++) { + int32 tid = PairToTransitionId(tstate, tidx); + old_probs(tidx) = new_probs(tidx) = GetTransitionProb(tid); + } + for (int32 tidx = 0; tidx < n; tidx++) + new_probs(tidx) = counts(tidx) / tstate_tot; + for (int32 i = 0; i < 3; i++) { // keep flooring+renormalizing for 3 times.. + new_probs.Scale(1.0 / new_probs.Sum()); + for (int32 tidx = 0; tidx < n; tidx++) + new_probs(tidx) = std::max(new_probs(tidx), cfg.floor); + } + // Compute objf change + for (int32 tidx = 0; tidx < n; tidx++) { + if (new_probs(tidx) == cfg.floor) num_floored++; + double objf_change = counts(tidx) * (Log(new_probs(tidx)) + - Log(old_probs(tidx))); + objf_impr_sum += objf_change; + } + // Commit updated values. + for (int32 tidx = 0; tidx < n; tidx++) { + int32 tid = PairToTransitionId(tstate, tidx); + log_probs_(tid) = Log(new_probs(tidx)); + if (log_probs_(tid) - log_probs_(tid) != 0.0) + KALDI_ERR << "Log probs is inf or NaN: error in update or bad stats?"; + } + } + } + } + KALDI_LOG << "SimpleHmm::Update, objf change is " + << (objf_impr_sum / count_sum) << " per frame over " << count_sum + << " frames. "; + KALDI_LOG << num_floored << " probabilities floored, " << num_skipped + << " out of " << NumTransitionStates() << " transition-states " + "skipped due to insuffient data (it is normal to have some skipped.)"; + if (objf_impr_out) *objf_impr_out = objf_impr_sum; + if (count_out) *count_out = count_sum; + ComputeDerivedOfProbs(); +} + + +// stats are counts/weights, indexed by transition-id. +void SimpleHmm::MapUpdate(const Vector &stats, + const MapSimpleHmmUpdateConfig &cfg, + BaseFloat *objf_impr_out, + BaseFloat *count_out) { + KALDI_ASSERT(cfg.tau > 0.0); + BaseFloat count_sum = 0.0, objf_impr_sum = 0.0; + KALDI_ASSERT(stats.Dim() == NumTransitionIds()+1); + for (int32 tstate = 1; tstate <= NumTransitionStates(); tstate++) { + int32 n = NumTransitionIndices(tstate); + KALDI_ASSERT(n>=1); + if (n > 1) { // no point updating if only one transition... + Vector counts(n); + for (int32 tidx = 0; tidx < n; tidx++) { + int32 tid = PairToTransitionId(tstate, tidx); + counts(tidx) = stats(tid); + } + double tstate_tot = counts.Sum(); + count_sum += tstate_tot; + Vector old_probs(n), new_probs(n); + for (int32 tidx = 0; tidx < n; tidx++) { + int32 tid = PairToTransitionId(tstate, tidx); + old_probs(tidx) = new_probs(tidx) = GetTransitionProb(tid); + } + for (int32 tidx = 0; tidx < n; tidx++) + new_probs(tidx) = (counts(tidx) + cfg.tau * old_probs(tidx)) / + (cfg.tau + tstate_tot); + // Compute objf change + for (int32 tidx = 0; tidx < n; tidx++) { + double objf_change = counts(tidx) * (Log(new_probs(tidx)) + - Log(old_probs(tidx))); + objf_impr_sum += objf_change; + } + // Commit updated values. + for (int32 tidx = 0; tidx < n; tidx++) { + int32 tid = PairToTransitionId(tstate, tidx); + log_probs_(tid) = Log(new_probs(tidx)); + if (log_probs_(tid) - log_probs_(tid) != 0.0) + KALDI_ERR << "Log probs is inf or NaN: error in update or bad stats?"; + } + } + } + KALDI_LOG << "Objf change is " << (objf_impr_sum / count_sum) + << " per frame over " << count_sum + << " frames."; + if (objf_impr_out) *objf_impr_out = objf_impr_sum; + if (count_out) *count_out = count_sum; + ComputeDerivedOfProbs(); +} + + +int32 SimpleHmm::TransitionIdToPdfClass(int32 trans_id) const { + KALDI_ASSERT(trans_id != 0 && + static_cast(trans_id) < id2state_.size()); + int32 trans_state = id2state_[trans_id]; + + int32 hmm_state = states_[trans_state-1]; + const HmmTopology::TopologyEntry &entry = topo_.TopologyForPhone(1); + KALDI_ASSERT(static_cast(hmm_state) < entry.size()); + return entry[hmm_state].forward_pdf_class; +} + +int32 SimpleHmm::TransitionIdToHmmState(int32 trans_id) const { + KALDI_ASSERT(trans_id != 0 && + static_cast(trans_id) < id2state_.size()); + int32 trans_state = id2state_[trans_id]; + return states_[trans_state-1]; +} + +void SimpleHmm::Print(std::ostream &os, + const Vector *occs) { + if (occs != NULL) + KALDI_ASSERT(occs->Dim() == NumPdfs()); + for (int32 tstate = 1; tstate <= NumTransitionStates(); tstate++) { + int32 hmm_state = TransitionStateToHmmState(tstate); + int32 pdf_class = TransitionStateToPdfClass(tstate); + + os << " hmm-state = " << hmm_state; + os << " pdf-class = " << pdf_class << '\n'; + for (int32 tidx = 0; tidx < NumTransitionIndices(tstate); tidx++) { + int32 tid = PairToTransitionId(tstate, tidx); + BaseFloat p = GetTransitionProb(tid); + os << " Transition-id = " << tid << " p = " << p; + if (occs) { + os << " count of pdf-class = " << (*occs)(pdf_class); + } + // now describe what it's a transition to. + if (IsSelfLoop(tid)) { + os << " [self-loop]\n"; + } else { + const HmmTopology::TopologyEntry &entry = topo_.TopologyForPhone(1); + KALDI_ASSERT(static_cast(hmm_state) < entry.size()); + int32 next_hmm_state = entry[hmm_state].transitions[tidx].first; + KALDI_ASSERT(next_hmm_state != hmm_state); + os << " [" << hmm_state << " -> " << next_hmm_state << "]\n"; + } + } + } +} + +bool SimpleHmm::Compatible(const SimpleHmm &other) const { + return (topo_ == other.topo_ && states_ == other.states_ && + state2id_ == other.state2id_ && id2state_ == other.id2state_ + && NumPdfs() == other.NumPdfs()); +} + +bool SimpleHmm::IsSelfLoop(int32 trans_id) const { + KALDI_ASSERT(static_cast(trans_id) < id2state_.size()); + int32 trans_state = id2state_[trans_id]; + int32 trans_index = trans_id - state2id_[trans_state]; + int32 hmm_state = states_[trans_state-1]; + const HmmTopology::TopologyEntry &entry = topo_.TopologyForPhone(1); + KALDI_ASSERT(static_cast(hmm_state) < entry.size()); + return (static_cast(trans_index) < entry[hmm_state].transitions.size() + && entry[hmm_state].transitions[trans_index].first == hmm_state); +} + +} // end namespace simple_hmm +} // end namespace kaldi + diff --git a/src/simplehmm/simple-hmm.h b/src/simplehmm/simple-hmm.h new file mode 100644 index 00000000000..ef3a5b9abde --- /dev/null +++ b/src/simplehmm/simple-hmm.h @@ -0,0 +1,274 @@ +// hmm/simple-hmm.h + +// Copyright 2009-2012 Microsoft Corporation +// Johns Hopkins University (author: Guoguo Chen) +// 2016 Vimal Manohar (Johns Hopkins University) + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#ifndef KALDI_HMM_SIMPLE_HMM_H +#define KALDI_HMM_SIMPLE_HMM_H + +#include "base/kaldi-common.h" +#include "util/const-integer-set.h" +#include "fst/fst-decl.h" // forward declarations. +#include "hmm/hmm-topology.h" +#include "itf/options-itf.h" + +namespace kaldi { +namespace simple_hmm { + +/// \addtogroup hmm_group +/// @{ + +// The class SimpleHmm is a repository for the transition probabilities. +// The model is exactly like a single phone. It has a HMM topology defined in +// hmm-topology.h. Each HMM-state has a number of +// transitions (and final-probs) out of it. Each emitting HMM-state defined in +// the HmmTopology class has an associated class-id. +// The transition model associates the +// transition probs with the (HMM-state, class-id). We associate with +// each such pair a transition-state. Each +// transition-state has a number of associated probabilities to estimate; +// this depends on the number of transitions/final-probs in the topology for +// that HMM-state. Each probability has an associated transition-index. +// We associate with each (transition-state, transition-index) a unique transition-id. +// Each individual probability estimated by the transition-model is asociated with a +// transition-id. +// +// List of the various types of quantity referred to here and what they mean: +// HMM-state: a number (0, 1, 2...) that indexes TopologyEntry (see hmm-topology.h) +// transition-state: the states for which we estimate transition probabilities for transitions +// out of them. In some topologies, will map one-to-one with pdf-ids. +// One-based, since it appears on FSTs. +// transition-index: identifier of a transition (or final-prob) in the HMM. Indexes the +// "transitions" vector in HmmTopology::HmmState. [if it is out of range, +// equal to transitions.size(), it refers to the final-prob.] +// Zero-based. +// transition-id: identifier of a unique parameter of the +// SimpleHmm. +// Associated with a (transition-state, transition-index) pair. +// One-based, since it appears on FSTs. +// +// List of the possible mappings SimpleHmm can do: +// (HMM-state, class-id) -> transition-state +// (transition-state, transition-index) -> transition-id +// Reverse mappings: +// transition-id -> transition-state +// transition-id -> transition-index +// transition-state -> HMM-state +// transition-state -> class-id +// +// The main things the SimpleHmm object can do are: +// Get initialized (need HmmTopology objects). +// Read/write. +// Update [given a vector of counts indexed by transition-id]. +// Do the various integer mappings mentioned above. +// Get the probability (or log-probability) associated with a particular transition-id. + + +struct MleSimpleHmmUpdateConfig { + BaseFloat floor; + BaseFloat mincount; + MleSimpleHmmUpdateConfig(BaseFloat floor = 0.01, + BaseFloat mincount = 5.0): + floor(floor), mincount(mincount) { } + + void Register (OptionsItf *opts) { + opts->Register("transition-floor", &floor, + "Floor for transition probabilities"); + opts->Register("transition-min-count", &mincount, + "Minimum count required to update transitions from a state"); + } +}; + +struct MapSimpleHmmUpdateConfig { + BaseFloat tau; + MapSimpleHmmUpdateConfig(): tau(5.0) { } + + void Register (OptionsItf *opts) { + opts->Register("transition-tau", &tau, "Tau value for MAP estimation of transition " + "probabilities."); + } +}; + +class SimpleHmm { + + public: + /// Initialize the object [e.g. at the start of training]. + /// The class keeps a copy of the HmmTopology object. + SimpleHmm(const HmmTopology &hmm_topo); + + /// Constructor that takes no arguments: typically used prior to calling Read. + SimpleHmm() { } + + void Read(std::istream &is, bool binary); // note, no symbol table: topo object always read/written w/o symbols. + void Write(std::ostream &os, bool binary) const; + + + /// return reference to HMM-topology object. + const HmmTopology &GetTopo() const { return topo_; } + + /// \name Integer mapping functions + /// @{ + + int32 HmmStateToTransitionState(int32 hmm_state) const; + int32 PairToTransitionId(int32 trans_state, int32 trans_index) const; + int32 TransitionIdToTransitionState(int32 trans_id) const; + int32 TransitionIdToTransitionIndex(int32 trans_id) const; + int32 TransitionStateToHmmState(int32 trans_state) const; + int32 TransitionStateToPdfClass(int32 trans_state) const; + // returns the self-loop transition-id, or zero if + // this state doesn't have a self-loop. + int32 SelfLoopOf(int32 trans_state) const; + + int32 TransitionIdToPdfClass(int32 trans_id) const; + int32 TransitionIdToHmmState(int32 trans_id) const; + + /// @} + + bool IsFinal(int32 trans_id) const; // returns true if this trans_id goes to the final state + // (which is bound to be nonemitting). + bool IsSelfLoop(int32 trans_id) const; // return true if this trans_id corresponds to a self-loop. + + /// Returns the total number of transition-ids (note, these are one-based). + inline int32 NumTransitionIds() const { return id2state_.size()-1; } + + /// Returns the number of transition-indices for a particular transition-state. + /// Note: "Indices" is the plural of "index". Index is not the same as "id", + /// here. A transition-index is a zero-based offset into the transitions + /// out of a particular transition state. + int32 NumTransitionIndices(int32 trans_state) const; + + /// Returns the total number of transition-states (note, these are one-based). + int32 NumTransitionStates() const { return states_.size(); } + + // NumPdfs() in the model. + int32 NumPdfs() const { return num_pdfs_; } + + // Transition-parameter-getting functions: + BaseFloat GetTransitionProb(int32 trans_id) const; + BaseFloat GetTransitionLogProb(int32 trans_id) const; + + // The following functions are more specialized functions for getting + // transition probabilities, that are provided for convenience. + + /// Returns the log-probability of a particular non-self-loop transition + /// after subtracting the probability mass of the self-loop and renormalizing; + /// will crash if called on a self-loop. Specifically: + /// for non-self-loops it returns the log of (that prob divided by (1 minus + /// self-loop-prob-for-that-state)). + BaseFloat GetTransitionLogProbIgnoringSelfLoops(int32 trans_id) const; + + /// Returns the log-prob of the non-self-loop probability + /// mass for this transition state. (you can get the self-loop prob, if a self-loop + /// exists, by calling GetTransitionLogProb(SelfLoopOf(trans_state)). + BaseFloat GetNonSelfLoopLogProb(int32 trans_state) const; + + /// Does Maximum Likelihood estimation. The stats are counts/weights, indexed + /// by transition-id. This was previously called Update(). + void MleUpdate(const Vector &stats, + const MleSimpleHmmUpdateConfig &cfg, + BaseFloat *objf_impr_out, + BaseFloat *count_out); + + /// Does Maximum A Posteriori (MAP) estimation. The stats are counts/weights, + /// indexed by transition-id. + void MapUpdate(const Vector &stats, + const MapSimpleHmmUpdateConfig &cfg, + BaseFloat *objf_impr_out, + BaseFloat *count_out); + + /// Print will print the simple HMM in a human-readable way, + /// for purposes of human + /// inspection. + /// The "occs" are optional (they are indexed by pdf-classes). + void Print(std::ostream &os, + const Vector *occs = NULL); + + + void InitStats(Vector *stats) const { stats->Resize(NumTransitionIds()+1); } + + void Accumulate(BaseFloat prob, int32 trans_id, Vector *stats) const { + KALDI_ASSERT(trans_id <= NumTransitionIds()); + (*stats)(trans_id) += prob; + // This is trivial and doesn't require class members, but leaves us more open + // to design changes than doing it manually. + } + + /// returns true if all the integer class members are identical (but does not + /// compare the transition probabilities. + bool Compatible(const SimpleHmm &other) const; + + private: + void MleUpdateShared(const Vector &stats, + const MleSimpleHmmUpdateConfig &cfg, + BaseFloat *objf_impr_out, BaseFloat *count_out); + void MapUpdateShared(const Vector &stats, + const MapSimpleHmmUpdateConfig &cfg, + BaseFloat *objf_impr_out, BaseFloat *count_out); + + // called from constructor and Read(): initializes states_ + void Initialize(); + // called from constructor and Read(): computes state2id_ and id2state_ + void ComputeDerived(); + // computes quantities derived from log-probs (currently just + // non_self_loop_log_probs_; called whenever log-probs change. + void ComputeDerivedOfProbs(); + void InitializeProbs(); // called from constructor. + void Check() const; + + HmmTopology topo_; + + /// States indexed by transition state minus one; + /// the states are in sorted order which allows us to do the reverse mapping + /// from state to transition state + std::vector states_; + + /// Gives the first transition_id of each transition-state; indexed by + /// the transition-state. Array indexed 1..num-transition-states+1 + /// (the last one is needed so we can know the num-transitions of the last + /// transition-state. + std::vector state2id_; + + /// For each transition-id, the corresponding transition + /// state (indexed by transition-id). + std::vector id2state_; + + /// For each transition-id, the corresponding log-prob. + /// Indexed by transition-id. + Vector log_probs_; + + /// For each transition-state, the log of (1 - self-loop-prob). Indexed by + /// transition-state. + Vector non_self_loop_log_probs_; + + /// This is equal to the one + highest-numbered pdf class. + int32 num_pdfs_; + + + DISALLOW_COPY_AND_ASSIGN(SimpleHmm); + +}; + +/// @} + + +} // end namespace simple_hmm +} // end namespace kaldi + + +#endif From b05406d80a142a929ad9ca2c7683c74d8784c39d Mon Sep 17 00:00:00 2001 From: Vimal Manohar Date: Wed, 25 Jan 2017 14:51:38 -0500 Subject: [PATCH 142/213] asr_diarization: Add SimpleHmm --- src/hmm/simple-hmm-utils.cc | 146 ++++++++++++++++++++++++++++++++++++ src/hmm/simple-hmm-utils.h | 51 +++++++++++++ src/hmm/simple-hmm.cc | 79 +++++++++++++++++++ src/hmm/simple-hmm.h | 95 +++++++++++++++++++++++ 4 files changed, 371 insertions(+) create mode 100644 src/hmm/simple-hmm-utils.cc create mode 100644 src/hmm/simple-hmm-utils.h create mode 100644 src/hmm/simple-hmm.cc create mode 100644 src/hmm/simple-hmm.h diff --git a/src/hmm/simple-hmm-utils.cc b/src/hmm/simple-hmm-utils.cc new file mode 100644 index 00000000000..3406b7b56f8 --- /dev/null +++ b/src/hmm/simple-hmm-utils.cc @@ -0,0 +1,146 @@ +// hmm/simple-hmm-utils.cc + +// Copyright 2009-2011 Microsoft Corporation +// 2016 Vimal Manohar + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "hmm/simple-hmm-utils.h" +#include "fst/fstlib.h" +#include "fstext/fstext-lib.h" + +namespace kaldi { + +fst::VectorFst* GetHTransducer( + const SimpleHmm &model, + BaseFloat transition_scale, BaseFloat self_loop_scale) { + using namespace fst; + typedef StdArc Arc; + typedef Arc::Weight Weight; + typedef Arc::StateId StateId; + typedef Arc::Label Label; + + VectorFst *fst = GetSimpleHmmAsFst(model, transition_scale, + self_loop_scale); + + for (StateIterator > siter(*fst); + !siter.Done(); siter.Next()) { + Arc::StateId s = siter.Value(); + for (MutableArcIterator > aiter(fst, s); + !aiter.Done(); aiter.Next()) { + Arc arc = aiter.Value(); + if (arc.ilabel == 0) { + KALDI_ASSERT(arc.olabel == 0); + continue; + } + + KALDI_ASSERT(arc.ilabel == arc.olabel && + arc.ilabel <= model.NumTransitionIds()); + + arc.olabel = model.TransitionIdToPdf(arc.ilabel) + 1; + aiter.SetValue(arc); + } + } + + return fst; +} + +fst::VectorFst *GetSimpleHmmAsFst( + const SimpleHmm &model, + BaseFloat transition_scale, BaseFloat self_loop_scale) { + using namespace fst; + typedef StdArc Arc; + typedef Arc::Weight Weight; + typedef Arc::StateId StateId; + typedef Arc::Label Label; + + KALDI_ASSERT(model.NumPdfs() > 0); + const HmmTopology &topo = model.GetTopo(); + // This special Hmm has only one phone + const HmmTopology::TopologyEntry &entry = topo.TopologyForPhone(1); + + VectorFst *ans = new VectorFst; + + // Create a mini-FST with a superfinal state [in case we have emitting + // final-states, which we usually will.] + + std::vector state_ids; + for (size_t i = 0; i < entry.size(); i++) + state_ids.push_back(ans->AddState()); + KALDI_ASSERT(state_ids.size() > 1); // Or invalid topology entry. + ans->SetStart(state_ids[0]); + StateId final_state = state_ids.back(); + ans->SetFinal(final_state, Weight::One()); + + for (int32 hmm_state = 0; + hmm_state < static_cast(entry.size()); + hmm_state++) { + int32 pdf_class = entry[hmm_state].forward_pdf_class; + int32 self_loop_pdf_class = entry[hmm_state].self_loop_pdf_class; + KALDI_ASSERT(self_loop_pdf_class == pdf_class); + + if (pdf_class != kNoPdf) { + KALDI_ASSERT(pdf_class < model.NumPdfs()); + } + + int32 trans_idx; + for (trans_idx = 0; + trans_idx < static_cast(entry[hmm_state].transitions.size()); + trans_idx++) { + BaseFloat log_prob; + Label label; + int32 dest_state = entry[hmm_state].transitions[trans_idx].first; + + if (pdf_class == kNoPdf) { + // no pdf, hence non-estimated probability. very unusual case. [would + // not happen with normal topology] . There is no transition-state + // involved in this case. + KALDI_ASSERT(hmm_state != dest_state); + log_prob = transition_scale + * Log(entry[hmm_state].transitions[trans_idx].second); + label = 0; + } else { // normal probability. + int32 trans_state = + model.TupleToTransitionState(1, hmm_state, pdf_class, pdf_class); + int32 trans_id = + model.PairToTransitionId(trans_state, trans_idx); + + log_prob = model.GetTransitionLogProb(trans_id); + + if (hmm_state == dest_state) + log_prob *= self_loop_scale; + else + log_prob *= transition_scale; + // log_prob is a negative number (or zero)... + label = trans_id; + } + ans->AddArc(state_ids[hmm_state], + Arc(label, label, Weight(-log_prob), + state_ids[dest_state])); + } + } + + fst::RemoveEpsLocal(ans); // this is safe and will not blow up. + // Now apply probability scale. + // We waited till after the possible weight-pushing steps, + // because weight-pushing needs "real" weights in order to work. + // ApplyProbabilityScale(config.transition_scale, ans); + return ans; +} + +} // end namespace kaldi diff --git a/src/hmm/simple-hmm-utils.h b/src/hmm/simple-hmm-utils.h new file mode 100644 index 00000000000..bd0a3a15702 --- /dev/null +++ b/src/hmm/simple-hmm-utils.h @@ -0,0 +1,51 @@ +// hmm/simple-hmm-utils.h + +// Copyright 2009-2011 Microsoft Corporation +// 2016 Vimal Manohar (Johns Hopkins University) + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#ifndef KALDI_HMM_SIMPLE_HMM_UTILS_H_ +#define KALDI_HMM_SIMPLE_HMM_UTILS_H_ + +#include "hmm/hmm-utils.h" +#include "hmm/simple-hmm.h" +#include "fst/fstlib.h" + +namespace kaldi { + +fst::VectorFst* GetHTransducer( + const SimpleHmm &model, + BaseFloat transition_scale = 1.0, BaseFloat self_loop_scale = 1.0); + +/** + * Converts the SimpleHmm into H tranducer; result owned by caller. + * Caution: our version of + * the H transducer does not include self-loops; you have to add those later. + * See \ref hmm_graph_get_h_transducer. The H transducer has on the + * input transition-ids. + * The output side contains the one-indexed mappings of pdf_ids, typically + * just pdf_id + 1. + */ +fst::VectorFst* +GetSimpleHmmAsFst (const SimpleHmm &model, + BaseFloat transition_scale = 1.0, + BaseFloat self_loop_scale = 1.0); + + +} // end namespace kaldi + +#endif diff --git a/src/hmm/simple-hmm.cc b/src/hmm/simple-hmm.cc new file mode 100644 index 00000000000..2db6bfbf297 --- /dev/null +++ b/src/hmm/simple-hmm.cc @@ -0,0 +1,79 @@ +// hmm/simple-hmm.cc + +// Copyright 2016 Vimal Manohar + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "hmm/simple-hmm.h" + +namespace kaldi { + +void SimpleHmm::FakeContextDependency::GetPdfInfo( + const std::vector &phones, // list of phones + const std::vector &num_pdf_classes, // indexed by phone, + std::vector > > *pdf_info) const { + KALDI_ASSERT(phones.size() == 1 && phones[0] == 1); + KALDI_ASSERT(num_pdf_classes.size() == 2 && + num_pdf_classes[1] == NumPdfs()); + KALDI_ASSERT(pdf_info); + pdf_info->resize(NumPdfs(), + std::vector >()); + + for (int32 pdf = 0; pdf < NumPdfs(); pdf++) { + (*pdf_info)[pdf].push_back(std::make_pair(1, pdf)); + } +} + +void SimpleHmm::FakeContextDependency::GetPdfInfo( + const std::vector &phones, + const std::vector > > &pdf_class_pairs, + std::vector > > > *pdf_info) const { + KALDI_ASSERT(pdf_info); + KALDI_ASSERT(phones.size() == 1 && phones[0] == 1); + KALDI_ASSERT(pdf_class_pairs.size() == 2); + + pdf_info->resize(2); + (*pdf_info)[1].resize(pdf_class_pairs[1].size()); + + for (size_t j = 0; j < pdf_class_pairs[1].size(); j++) { + int32 pdf_class = pdf_class_pairs[1][j].first, + self_loop_pdf_class = pdf_class_pairs[1][j].second; + KALDI_ASSERT(pdf_class == self_loop_pdf_class && + pdf_class < NumPdfs()); + + (*pdf_info)[1][j].push_back(std::make_pair(pdf_class, pdf_class)); + } +} + +void SimpleHmm::Read(std::istream &is, bool binary) { + TransitionModel::Read(is, binary); + ctx_dep_.Init(NumPdfs()); + CheckSimpleHmm(); +} + +void SimpleHmm::CheckSimpleHmm() const { + KALDI_ASSERT(NumPhones() == 1); + KALDI_ASSERT(GetPhones()[0] == 1); + const HmmTopology::TopologyEntry &entry = GetTopo().TopologyForPhone(1); + for (int32 j = 0; j < static_cast(entry.size()); j++) { // for each state... + int32 forward_pdf_class = entry[j].forward_pdf_class, + self_loop_pdf_class = entry[j].self_loop_pdf_class; + KALDI_ASSERT(forward_pdf_class == self_loop_pdf_class && + forward_pdf_class < NumPdfs()); + } +} + +} // end namespace kaldi diff --git a/src/hmm/simple-hmm.h b/src/hmm/simple-hmm.h new file mode 100644 index 00000000000..4b40f212401 --- /dev/null +++ b/src/hmm/simple-hmm.h @@ -0,0 +1,95 @@ +// hmm/simple-hmm.h + +// Copyright 2016 Vimal Manohar + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#ifndef KALDI_HMM_SIMPLE_HMM_H_ +#define KALDI_HMM_SIMPLE_HMM_H_ + +#include "base/kaldi-common.h" +#include "hmm/transition-model.h" +#include "itf/context-dep-itf.h" + +namespace kaldi { + +class SimpleHmm: public TransitionModel { + public: + SimpleHmm(const HmmTopology &hmm_topo): + ctx_dep_(hmm_topo) { + Init(ctx_dep_, hmm_topo); + CheckSimpleHmm(); + } + + SimpleHmm(): TransitionModel() { } + + void Read(std::istream &is, bool binary); // note, no symbol table: topo object always read/written w/o symbols. + + private: + void CheckSimpleHmm() const; + + class FakeContextDependency: public ContextDependencyInterface { + public: + int ContextWidth() const { return 1; } + int CentralPosition() const { return 0; } + + bool Compute(const std::vector &phoneseq, int32 pdf_class, + int32 *pdf_id) const { + if (phoneseq.size() == 1 && phoneseq[0] == 1) { + *pdf_id = pdf_class; + return true; + } + return false; + } + + void GetPdfInfo( + const std::vector &phones, // list of phones + const std::vector &num_pdf_classes, // indexed by phone, + std::vector > > *pdf_info) const; + + void GetPdfInfo( + const std::vector &phones, + const std::vector > > &pdf_class_pairs, + std::vector > > > *pdf_info) + const; + + void Init(int32 num_pdfs) { num_pdfs_ = num_pdfs; } + + int32 NumPdfs() const { return num_pdfs_; } + + FakeContextDependency(const HmmTopology &topo) { + KALDI_ASSERT(topo.GetPhones().size() == 1); + num_pdfs_ = topo.NumPdfClasses(1); + } + + FakeContextDependency(): num_pdfs_(0) { } + + ContextDependencyInterface* Copy() const { + FakeContextDependency *copy = new FakeContextDependency(); + copy->Init(num_pdfs_); + return copy; + } + + private: + int32 num_pdfs_; + } ctx_dep_; + + DISALLOW_COPY_AND_ASSIGN(SimpleHmm); +}; + +} // end namespace kaldi + +#endif // KALDI_HMM_SIMPLE_HMM_H_ From 3d4cba868d2b707217e571f5bfc99efb2402d064 Mon Sep 17 00:00:00 2001 From: Vimal Manohar Date: Wed, 25 Jan 2017 14:53:11 -0500 Subject: [PATCH 143/213] asr_diarization: Moving SimpleHmm --- src/Makefile | 17 +- src/decoder/Makefile | 3 +- src/decoder/simple-hmm-graph-compiler.cc | 128 ++++++ src/decoder/simple-hmm-graph-compiler.h | 100 +++++ src/hmm/simple-hmm.cc | 79 ---- src/hmm/simple-hmm.h | 95 ----- src/{hmm => simplehmm}/simple-hmm-utils.cc | 0 src/{hmm => simplehmm}/simple-hmm-utils.h | 0 src/simplehmm/simple-hmm.cc | 467 ++------------------- src/simplehmm/simple-hmm.h | 291 +++---------- 10 files changed, 340 insertions(+), 840 deletions(-) create mode 100644 src/decoder/simple-hmm-graph-compiler.cc create mode 100644 src/decoder/simple-hmm-graph-compiler.h delete mode 100644 src/hmm/simple-hmm.cc delete mode 100644 src/hmm/simple-hmm.h rename src/{hmm => simplehmm}/simple-hmm-utils.cc (100%) rename src/{hmm => simplehmm}/simple-hmm-utils.h (100%) diff --git a/src/Makefile b/src/Makefile index a42f78f4742..7a7b672e607 100644 --- a/src/Makefile +++ b/src/Makefile @@ -6,16 +6,16 @@ SHELL := /bin/bash SUBDIRS = base matrix util feat tree thread gmm transform sgmm \ - fstext hmm lm decoder lat kws cudamatrix nnet segmenter \ + fstext hmm lm decoder lat kws cudamatrix nnet segmenter simplehmm \ bin fstbin gmmbin fgmmbin sgmmbin featbin \ nnetbin latbin sgmm2 sgmm2bin nnet2 nnet3 chain nnet3bin nnet2bin kwsbin \ - ivector ivectorbin online2 online2bin lmbin chainbin segmenterbin + ivector ivectorbin online2 online2bin lmbin chainbin segmenterbin simplehmmbin MEMTESTDIRS = base matrix util feat tree thread gmm transform sgmm \ - fstext hmm lm decoder lat nnet kws chain segmenter \ + fstext hmm lm decoder lat nnet kws chain segmenter simplehmm \ bin fstbin gmmbin fgmmbin sgmmbin featbin \ nnetbin latbin sgmm2 nnet2 nnet3 nnet2bin nnet3bin sgmm2bin kwsbin \ - ivector ivectorbin online2 online2bin lmbin segmenterbin + ivector ivectorbin online2 online2bin lmbin segmenterbin simplehmmbin CUDAMEMTESTDIR = cudamatrix @@ -153,9 +153,9 @@ $(EXT_SUBDIRS) : mklibdir # this is necessary for correct parallel compilation #1)The tools depend on all the libraries -bin fstbin gmmbin fgmmbin sgmmbin sgmm2bin featbin nnetbin nnet2bin nnet3bin chainbin latbin ivectorbin lmbin kwsbin online2bin: \ +bin fstbin gmmbin fgmmbin sgmmbin sgmm2bin featbin nnetbin nnet2bin nnet3bin chainbin latbin ivectorbin lmbin kwsbin online2bin segmenterbin simplehmmbin: \ base matrix util feat tree thread gmm transform sgmm sgmm2 fstext hmm \ - lm decoder lat cudamatrix nnet nnet2 nnet3 ivector chain kws online2 segmenter + lm decoder lat cudamatrix nnet nnet2 nnet3 ivector chain kws online2 segmenter simplehmm #2)The libraries have inter-dependencies base: base/.depend.mk @@ -171,7 +171,7 @@ sgmm2: base util matrix gmm tree transform thread hmm fstext: base util thread matrix tree hmm: base tree matrix util thread lm: base util thread matrix fstext -decoder: base util thread matrix gmm sgmm hmm tree transform lat +decoder: base util thread matrix gmm sgmm hmm simplehmm tree transform lat lat: base util thread hmm tree matrix cudamatrix: base util thread matrix nnet: base util hmm tree thread matrix cudamatrix @@ -179,7 +179,8 @@ nnet2: base util matrix thread lat gmm hmm tree transform cudamatrix nnet3: base util matrix thread lat gmm hmm tree transform cudamatrix chain fstext chain: lat hmm tree fstext matrix cudamatrix util thread base ivector: base util matrix thread transform tree gmm -segmenter: base matrix util gmm thread +segmenter: base matrix util gmm thread tree +simplehmm: base tree matrix util thread hmm #3)Dependencies for optional parts of Kaldi onlinebin: base matrix util feat tree gmm transform sgmm sgmm2 fstext hmm lm decoder lat cudamatrix nnet nnet2 online thread # python-kaldi-decoding: base matrix util feat tree thread gmm transform sgmm sgmm2 fstext hmm decoder lat online diff --git a/src/decoder/Makefile b/src/decoder/Makefile index fe489d1cb3f..3d2112629a2 100644 --- a/src/decoder/Makefile +++ b/src/decoder/Makefile @@ -7,11 +7,12 @@ TESTFILES = OBJFILES = training-graph-compiler.o lattice-simple-decoder.o lattice-faster-decoder.o \ lattice-faster-online-decoder.o simple-decoder.o faster-decoder.o \ - decoder-wrappers.o + decoder-wrappers.o simple-hmm-graph-compiler.o LIBNAME = kaldi-decoder ADDLIBS = ../lat/kaldi-lat.a ../sgmm/kaldi-sgmm.a ../hmm/kaldi-hmm.a \ + ../simplehmm/kaldi-simplehmm.a \ ../transform/kaldi-transform.a ../gmm/kaldi-gmm.a \ ../tree/kaldi-tree.a ../util/kaldi-util.a ../thread/kaldi-thread.a \ ../matrix/kaldi-matrix.a ../base/kaldi-base.a diff --git a/src/decoder/simple-hmm-graph-compiler.cc b/src/decoder/simple-hmm-graph-compiler.cc new file mode 100644 index 00000000000..5f91380ca06 --- /dev/null +++ b/src/decoder/simple-hmm-graph-compiler.cc @@ -0,0 +1,128 @@ +// decoder/simple-hmm-graph-compiler.cc + +// Copyright 2009-2011 Microsoft Corporation +// 2016 Vimal Manohar + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at + +// http://www.apache.org/licenses/LICENSE-2.0 + +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "decoder/simple-hmm-graph-compiler.h" +#include "simplehmm/simple-hmm-utils.h" // for GetHTransducer + +namespace kaldi { + +bool SimpleHmmGraphCompiler::CompileGraphFromAlignment( + const std::vector &alignment, + fst::VectorFst *out_fst) { + using namespace fst; + VectorFst class_fst; + MakeLinearAcceptor(alignment, &class_fst); + return CompileGraph(class_fst, out_fst); +} + +bool SimpleHmmGraphCompiler::CompileGraph( + const fst::VectorFst &class_fst, + fst::VectorFst *out_fst) { + using namespace fst; + KALDI_ASSERT(out_fst); + KALDI_ASSERT(class_fst.Start() != kNoStateId); + + if (GetVerboseLevel() >= 4) { + KALDI_VLOG(4) << "Classes FST: "; + WriteFstKaldi(KALDI_LOG, false, class_fst); + } + + VectorFst *H = GetHTransducer(model_, opts_.transition_scale, + opts_.self_loop_scale); + + if (GetVerboseLevel() >= 4) { + KALDI_VLOG(4) << "HTransducer:"; + WriteFstKaldi(KALDI_LOG, false, *H); + } + + // Epsilon-removal and determinization combined. + // This will fail if not determinizable. + DeterminizeStarInLog(H); + + if (GetVerboseLevel() >= 4) { + KALDI_VLOG(4) << "HTransducer determinized:"; + WriteFstKaldi(KALDI_LOG, false, *H); + } + + VectorFst &trans2class_fst = *out_fst; // transition-id to class. + TableCompose(*H, class_fst, &trans2class_fst); + + KALDI_ASSERT(trans2class_fst.Start() != kNoStateId); + + if (GetVerboseLevel() >= 4) { + KALDI_VLOG(4) << "trans2class_fst:"; + WriteFstKaldi(KALDI_LOG, false, trans2class_fst); + } + + // Epsilon-removal and determinization combined. + // This will fail if not determinizable. + DeterminizeStarInLog(&trans2class_fst); + + // we elect not to remove epsilons after this phase, as it is + // a little slow. + if (opts_.rm_eps) + RemoveEpsLocal(&trans2class_fst); + + // Encoded minimization. + MinimizeEncoded(&trans2class_fst); + + delete H; + return true; +} + +bool SimpleHmmGraphCompiler::CompileGraphsFromAlignments( + const std::vector > &alignments, + std::vector*> *out_fsts) { + using namespace fst; + std::vector* > class_fsts(alignments.size()); + for (size_t i = 0; i < alignments.size(); i++) { + VectorFst *class_fst = new VectorFst(); + MakeLinearAcceptor(alignments[i], class_fst); + class_fsts[i] = class_fst; + } + bool ans = CompileGraphs(class_fsts, out_fsts); + for (size_t i = 0; i < alignments.size(); i++) + delete class_fsts[i]; + return ans; +} + +bool SimpleHmmGraphCompiler::CompileGraphs( + const std::vector* > &class_fsts, + std::vector* > *out_fsts) { + + using namespace fst; + KALDI_ASSERT(out_fsts && out_fsts->empty()); + out_fsts->resize(class_fsts.size(), NULL); + if (class_fsts.empty()) return true; + + for (size_t i = 0; i < class_fsts.size(); i++) { + const VectorFst *class_fst = class_fsts[i]; + VectorFst out_fst; + + CompileGraph(*class_fst, &out_fst); + + (*out_fsts)[i] = out_fst.Copy(); + } + + return true; +} + + +} // end namespace kaldi diff --git a/src/decoder/simple-hmm-graph-compiler.h b/src/decoder/simple-hmm-graph-compiler.h new file mode 100644 index 00000000000..dcc8f8fd2ba --- /dev/null +++ b/src/decoder/simple-hmm-graph-compiler.h @@ -0,0 +1,100 @@ +// decoder/simple-hmm-graph-compiler.h + +// Copyright 2009-2011 Microsoft Corporation +// 2016 Vimal Manohar + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at + +// http://www.apache.org/licenses/LICENSE-2.0 + +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#ifndef KALDI_DECODER_SIMPLE_HMM_GRAPH_COMPILER_H_ +#define KALDI_DECODER_SIMPLE_HMM_GRAPH_COMPILER_H_ + +#include "base/kaldi-common.h" +#include "simplehmm/simple-hmm.h" +#include "fst/fstlib.h" +#include "fstext/fstext-lib.h" + + +// This header provides functionality to compile a graph directly from the +// alignment where the alignment is of classes that are simple mappings +// of 'pdf-ids' (same as pdf classes for SimpleHmm). + +namespace kaldi { + +struct SimpleHmmGraphCompilerOptions { + BaseFloat transition_scale; + BaseFloat self_loop_scale; + bool rm_eps; + + explicit SimpleHmmGraphCompilerOptions(BaseFloat transition_scale = 1.0, + BaseFloat self_loop_scale = 1.0): + transition_scale(transition_scale), + self_loop_scale(self_loop_scale), + rm_eps(true) { } + + void Register(OptionsItf *opts) { + opts->Register("transition-scale", &transition_scale, "Scale of transition " + "probabilities (excluding self-loops)"); + opts->Register("self-loop-scale", &self_loop_scale, "Scale of self-loop vs. " + "non-self-loop probability mass "); + opts->Register("rm-eps", &rm_eps, "Remove [most] epsilons before minimization (only applicable " + "if disambig symbols present)"); + } +}; + + +class SimpleHmmGraphCompiler { + public: + SimpleHmmGraphCompiler(const SimpleHmm &model, // Maintains reference to this object. + const SimpleHmmGraphCompilerOptions &opts): + model_(model), opts_(opts) { } + + + /// CompileGraph compiles a single training graph its input is a + /// weighted acceptor (G) at the class level, its output is HCLG-type graph. + /// Note: G could actually be an acceptor, it would also work. + /// This function is not const for technical reasons involving the cache. + /// if not for "table_compose" we could make it const. + bool CompileGraph(const fst::VectorFst &class_fst, + fst::VectorFst *out_fst); + + // CompileGraphs allows you to compile a number of graphs at the same + // time. This consumes more memory but is faster. + bool CompileGraphs( + const std::vector *> &class_fsts, + std::vector *> *out_fsts); + + // This version creates an FST from the per-frame alignment and calls + // CompileGraph. + bool CompileGraphFromAlignment(const std::vector &alignment, + fst::VectorFst *out_fst); + + // This function creates FSTs from the per-frame alignment and calls + // CompileGraphs. + bool CompileGraphsFromAlignments( + const std::vector > &alignments, + std::vector *> *out_fsts); + + ~SimpleHmmGraphCompiler() { } + private: + const SimpleHmm &model_; + + SimpleHmmGraphCompilerOptions opts_; +}; + + +} // end namespace kaldi. + +#endif // KALDI_DECODER_SIMPLE_HMM_GRAPH_COMPILER_H_ diff --git a/src/hmm/simple-hmm.cc b/src/hmm/simple-hmm.cc deleted file mode 100644 index 2db6bfbf297..00000000000 --- a/src/hmm/simple-hmm.cc +++ /dev/null @@ -1,79 +0,0 @@ -// hmm/simple-hmm.cc - -// Copyright 2016 Vimal Manohar - -// See ../../COPYING for clarification regarding multiple authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED -// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, -// MERCHANTABLITY OR NON-INFRINGEMENT. -// See the Apache 2 License for the specific language governing permissions and -// limitations under the License. - -#include "hmm/simple-hmm.h" - -namespace kaldi { - -void SimpleHmm::FakeContextDependency::GetPdfInfo( - const std::vector &phones, // list of phones - const std::vector &num_pdf_classes, // indexed by phone, - std::vector > > *pdf_info) const { - KALDI_ASSERT(phones.size() == 1 && phones[0] == 1); - KALDI_ASSERT(num_pdf_classes.size() == 2 && - num_pdf_classes[1] == NumPdfs()); - KALDI_ASSERT(pdf_info); - pdf_info->resize(NumPdfs(), - std::vector >()); - - for (int32 pdf = 0; pdf < NumPdfs(); pdf++) { - (*pdf_info)[pdf].push_back(std::make_pair(1, pdf)); - } -} - -void SimpleHmm::FakeContextDependency::GetPdfInfo( - const std::vector &phones, - const std::vector > > &pdf_class_pairs, - std::vector > > > *pdf_info) const { - KALDI_ASSERT(pdf_info); - KALDI_ASSERT(phones.size() == 1 && phones[0] == 1); - KALDI_ASSERT(pdf_class_pairs.size() == 2); - - pdf_info->resize(2); - (*pdf_info)[1].resize(pdf_class_pairs[1].size()); - - for (size_t j = 0; j < pdf_class_pairs[1].size(); j++) { - int32 pdf_class = pdf_class_pairs[1][j].first, - self_loop_pdf_class = pdf_class_pairs[1][j].second; - KALDI_ASSERT(pdf_class == self_loop_pdf_class && - pdf_class < NumPdfs()); - - (*pdf_info)[1][j].push_back(std::make_pair(pdf_class, pdf_class)); - } -} - -void SimpleHmm::Read(std::istream &is, bool binary) { - TransitionModel::Read(is, binary); - ctx_dep_.Init(NumPdfs()); - CheckSimpleHmm(); -} - -void SimpleHmm::CheckSimpleHmm() const { - KALDI_ASSERT(NumPhones() == 1); - KALDI_ASSERT(GetPhones()[0] == 1); - const HmmTopology::TopologyEntry &entry = GetTopo().TopologyForPhone(1); - for (int32 j = 0; j < static_cast(entry.size()); j++) { // for each state... - int32 forward_pdf_class = entry[j].forward_pdf_class, - self_loop_pdf_class = entry[j].self_loop_pdf_class; - KALDI_ASSERT(forward_pdf_class == self_loop_pdf_class && - forward_pdf_class < NumPdfs()); - } -} - -} // end namespace kaldi diff --git a/src/hmm/simple-hmm.h b/src/hmm/simple-hmm.h deleted file mode 100644 index 4b40f212401..00000000000 --- a/src/hmm/simple-hmm.h +++ /dev/null @@ -1,95 +0,0 @@ -// hmm/simple-hmm.h - -// Copyright 2016 Vimal Manohar - -// See ../../COPYING for clarification regarding multiple authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED -// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, -// MERCHANTABLITY OR NON-INFRINGEMENT. -// See the Apache 2 License for the specific language governing permissions and -// limitations under the License. - -#ifndef KALDI_HMM_SIMPLE_HMM_H_ -#define KALDI_HMM_SIMPLE_HMM_H_ - -#include "base/kaldi-common.h" -#include "hmm/transition-model.h" -#include "itf/context-dep-itf.h" - -namespace kaldi { - -class SimpleHmm: public TransitionModel { - public: - SimpleHmm(const HmmTopology &hmm_topo): - ctx_dep_(hmm_topo) { - Init(ctx_dep_, hmm_topo); - CheckSimpleHmm(); - } - - SimpleHmm(): TransitionModel() { } - - void Read(std::istream &is, bool binary); // note, no symbol table: topo object always read/written w/o symbols. - - private: - void CheckSimpleHmm() const; - - class FakeContextDependency: public ContextDependencyInterface { - public: - int ContextWidth() const { return 1; } - int CentralPosition() const { return 0; } - - bool Compute(const std::vector &phoneseq, int32 pdf_class, - int32 *pdf_id) const { - if (phoneseq.size() == 1 && phoneseq[0] == 1) { - *pdf_id = pdf_class; - return true; - } - return false; - } - - void GetPdfInfo( - const std::vector &phones, // list of phones - const std::vector &num_pdf_classes, // indexed by phone, - std::vector > > *pdf_info) const; - - void GetPdfInfo( - const std::vector &phones, - const std::vector > > &pdf_class_pairs, - std::vector > > > *pdf_info) - const; - - void Init(int32 num_pdfs) { num_pdfs_ = num_pdfs; } - - int32 NumPdfs() const { return num_pdfs_; } - - FakeContextDependency(const HmmTopology &topo) { - KALDI_ASSERT(topo.GetPhones().size() == 1); - num_pdfs_ = topo.NumPdfClasses(1); - } - - FakeContextDependency(): num_pdfs_(0) { } - - ContextDependencyInterface* Copy() const { - FakeContextDependency *copy = new FakeContextDependency(); - copy->Init(num_pdfs_); - return copy; - } - - private: - int32 num_pdfs_; - } ctx_dep_; - - DISALLOW_COPY_AND_ASSIGN(SimpleHmm); -}; - -} // end namespace kaldi - -#endif // KALDI_HMM_SIMPLE_HMM_H_ diff --git a/src/hmm/simple-hmm-utils.cc b/src/simplehmm/simple-hmm-utils.cc similarity index 100% rename from src/hmm/simple-hmm-utils.cc rename to src/simplehmm/simple-hmm-utils.cc diff --git a/src/hmm/simple-hmm-utils.h b/src/simplehmm/simple-hmm-utils.h similarity index 100% rename from src/hmm/simple-hmm-utils.h rename to src/simplehmm/simple-hmm-utils.h diff --git a/src/simplehmm/simple-hmm.cc b/src/simplehmm/simple-hmm.cc index 9af077cedc6..2db6bfbf297 100644 --- a/src/simplehmm/simple-hmm.cc +++ b/src/simplehmm/simple-hmm.cc @@ -1,8 +1,6 @@ // hmm/simple-hmm.cc -// Copyright 2009-2012 Microsoft Corporation Johns Hopkins University (Author: Daniel Povey) -// Johns Hopkins University (author: Guoguo Chen) -// 2016 Vimal Manohar (Johns Hopkins University) +// Copyright 2016 Vimal Manohar // See ../../COPYING for clarification regarding multiple authors // @@ -19,438 +17,63 @@ // See the Apache 2 License for the specific language governing permissions and // limitations under the License. -#include -#include "simplehmm/simple-hmm.h" +#include "hmm/simple-hmm.h" namespace kaldi { -namespace simple_hmm { -void SimpleHmm::Initialize() { - KALDI_ASSERT(topo_.GetPhones().size() == 1); - - const HmmTopology::TopologyEntry &entry = topo_.TopologyForPhone(1); - for (int32 j = 0; j < static_cast(entry.size()); j++) { // for each state... - int32 pdf_class = entry[j].forward_pdf_class; - if (pdf_class != kNoPdf) { - states_.push_back(j); - } - } - - // now states_ is populated with all possible pairs - // (hmm_state, pdf_class). - // sort to enable reverse lookup. - std::sort(states_.begin(), states_.end()); - // this sorting defines the transition-ids. -} - -void SimpleHmm::ComputeDerived() { - state2id_.resize(states_.size()+2); // indexed by transition-state, which - // is one based, but also an entry for one past end of list. - - int32 cur_transition_id = 1; - num_pdfs_ = 0; - for (int32 tstate = 1; - tstate <= static_cast(states_.size()+1); // not a typo. - tstate++) { - state2id_[tstate] = cur_transition_id; - if (static_cast(tstate) <= states_.size()) { - int32 hmm_state = states_[tstate-1]; - const HmmTopology::HmmState &state = topo_.TopologyForPhone(1)[hmm_state]; - int32 pdf_class = state.forward_pdf_class; - num_pdfs_ = std::max(num_pdfs_, pdf_class + 1); - int32 my_num_ids = static_cast(state.transitions.size()); - cur_transition_id += my_num_ids; // # trans out of this state. - } - } - - id2state_.resize(cur_transition_id); // cur_transition_id is #transition-ids+1. - for (int32 tstate = 1; - tstate <= static_cast(states_.size()); tstate++) { - for (int32 tid = state2id_[tstate]; tid < state2id_[tstate+1]; tid++) { - id2state_[tid] = tstate; - } +void SimpleHmm::FakeContextDependency::GetPdfInfo( + const std::vector &phones, // list of phones + const std::vector &num_pdf_classes, // indexed by phone, + std::vector > > *pdf_info) const { + KALDI_ASSERT(phones.size() == 1 && phones[0] == 1); + KALDI_ASSERT(num_pdf_classes.size() == 2 && + num_pdf_classes[1] == NumPdfs()); + KALDI_ASSERT(pdf_info); + pdf_info->resize(NumPdfs(), + std::vector >()); + + for (int32 pdf = 0; pdf < NumPdfs(); pdf++) { + (*pdf_info)[pdf].push_back(std::make_pair(1, pdf)); } } -void SimpleHmm::InitializeProbs() { - log_probs_.Resize(NumTransitionIds()+1); // one-based array, zeroth element empty. - for (int32 trans_id = 1; trans_id <= NumTransitionIds(); trans_id++) { - int32 trans_state = id2state_[trans_id]; - int32 trans_index = trans_id - state2id_[trans_state]; - int32 hmm_state = states_[trans_state-1]; - const HmmTopology::TopologyEntry &entry = topo_.TopologyForPhone(1); - KALDI_ASSERT(static_cast(hmm_state) < entry.size()); - BaseFloat prob = entry[hmm_state].transitions[trans_index].second; - if (prob <= 0.0) - KALDI_ERR << "SimpleHmm::InitializeProbs, zero " - "probability [should remove that entry in the topology]"; - if (prob > 1.0) - KALDI_WARN << "SimpleHmm::InitializeProbs, prob greater than one."; - log_probs_(trans_id) = Log(prob); - } - ComputeDerivedOfProbs(); -} - -void SimpleHmm::Check() const { - KALDI_ASSERT(topo_.GetPhones().size() == 1); - - KALDI_ASSERT(NumTransitionIds() != 0 && NumTransitionStates() != 0); - { - int32 sum = 0; - for (int32 ts = 1; ts <= NumTransitionStates(); ts++) sum += NumTransitionIndices(ts); - KALDI_ASSERT(sum == NumTransitionIds()); - } - for (int32 tid = 1; tid <= NumTransitionIds(); tid++) { - int32 tstate = TransitionIdToTransitionState(tid), - index = TransitionIdToTransitionIndex(tid); - KALDI_ASSERT(tstate > 0 && tstate <=NumTransitionStates() && index >= 0); - KALDI_ASSERT(tid == PairToTransitionId(tstate, index)); - int32 hmm_state = TransitionStateToHmmState(tstate); - KALDI_ASSERT(tstate == HmmStateToTransitionState(hmm_state)); - KALDI_ASSERT(log_probs_(tid) <= 0.0 && - log_probs_(tid) - log_probs_(tid) == 0.0); - // checking finite and non-positive (and not out-of-bounds). - } - - KALDI_ASSERT(num_pdfs_ == topo_.NumPdfClasses(1)); -} - -SimpleHmm::SimpleHmm( - const HmmTopology &hmm_topo): topo_(hmm_topo) { - Initialize(); - ComputeDerived(); - InitializeProbs(); - Check(); -} - -int32 SimpleHmm::HmmStateToTransitionState(int32 hmm_state) const { - // Note: if this ever gets too expensive, which is unlikely, we can refactor - // this code to sort first on pdf_class, and then index on pdf_class, so those - // that have the same pdf_class are in a contiguous range. - std::vector::const_iterator iter = - std::lower_bound(states_.begin(), states_.end(), hmm_state); - if (iter == states_.end() || !(*iter == hmm_state)) { - KALDI_ERR << "SimpleHmm::HmmStateToTransitionState; " - << "HmmState " << hmm_state << " not found." - << " (incompatible model?)"; - } - // states_is indexed by transition_state-1, so add one. - return static_cast((iter - states_.begin())) + 1; -} - - -int32 SimpleHmm::NumTransitionIndices(int32 trans_state) const { - KALDI_ASSERT(static_cast(trans_state) <= states_.size()); - return static_cast(state2id_[trans_state+1]-state2id_[trans_state]); -} - -int32 SimpleHmm::TransitionIdToTransitionState(int32 trans_id) const { - KALDI_ASSERT(trans_id != 0 && - static_cast(trans_id) < id2state_.size()); - return id2state_[trans_id]; -} - -int32 SimpleHmm::TransitionIdToTransitionIndex(int32 trans_id) const { - KALDI_ASSERT(trans_id != 0 && - static_cast(trans_id) < id2state_.size()); - return trans_id - state2id_[id2state_[trans_id]]; -} - -int32 SimpleHmm::TransitionStateToPdfClass(int32 trans_state) const { - KALDI_ASSERT(static_cast(trans_state) <= states_.size()); - int32 hmm_state = states_[trans_state-1]; - const HmmTopology::TopologyEntry &entry = topo_.TopologyForPhone(1); - KALDI_ASSERT(static_cast(hmm_state) < entry.size()); - return entry[hmm_state].forward_pdf_class; -} - -int32 SimpleHmm::TransitionStateToHmmState(int32 trans_state) const { - KALDI_ASSERT(static_cast(trans_state) <= states_.size()); - return states_[trans_state-1]; -} - -int32 SimpleHmm::PairToTransitionId(int32 trans_state, - int32 trans_index) const { - KALDI_ASSERT(static_cast(trans_state) <= states_.size()); - KALDI_ASSERT(trans_index < state2id_[trans_state+1] - state2id_[trans_state]); - return state2id_[trans_state] + trans_index; -} - -bool SimpleHmm::IsFinal(int32 trans_id) const { - KALDI_ASSERT(static_cast(trans_id) < id2state_.size()); - int32 trans_state = id2state_[trans_id]; - int32 trans_index = trans_id - state2id_[trans_state]; - int32 hmm_state = states_[trans_state-1]; - const HmmTopology::TopologyEntry &entry = topo_.TopologyForPhone(1); - KALDI_ASSERT(static_cast(hmm_state) < entry.size()); - KALDI_ASSERT(static_cast(trans_index) < - entry[hmm_state].transitions.size()); - // return true if the transition goes to the final state of the - // topology entry. - return (entry[hmm_state].transitions[trans_index].first + 1 == - static_cast(entry.size())); -} - -// returns the self-loop transition-id, -// or zero if does not exist. -int32 SimpleHmm::SelfLoopOf(int32 trans_state) const { - KALDI_ASSERT(static_cast(trans_state-1) < states_.size()); - int32 hmm_state = states_[trans_state-1]; - const HmmTopology::TopologyEntry &entry = topo_.TopologyForPhone(1); - KALDI_ASSERT(static_cast(hmm_state) < entry.size()); - for (int32 trans_index = 0; - trans_index < static_cast(entry[hmm_state].transitions.size()); - trans_index++) - if (entry[hmm_state].transitions[trans_index].first == hmm_state) - return PairToTransitionId(trans_state, trans_index); - return 0; // invalid transition id. -} - -void SimpleHmm::ComputeDerivedOfProbs() { - // this array indexed by transition-state with nothing in zeroth element. - non_self_loop_log_probs_.Resize(NumTransitionStates()+1); - for (int32 tstate = 1; tstate <= NumTransitionStates(); tstate++) { - int32 tid = SelfLoopOf(tstate); - if (tid == 0) { // no self-loop - non_self_loop_log_probs_(tstate) = 0.0; // log(1.0) - } else { - BaseFloat self_loop_prob = Exp(GetTransitionLogProb(tid)), - non_self_loop_prob = 1.0 - self_loop_prob; - if (non_self_loop_prob <= 0.0) { - KALDI_WARN << "ComputeDerivedOfProbs(): non-self-loop prob is " << non_self_loop_prob; - non_self_loop_prob = 1.0e-10; // just so we can continue... - } - non_self_loop_log_probs_(tstate) = Log(non_self_loop_prob); // will be negative. - } +void SimpleHmm::FakeContextDependency::GetPdfInfo( + const std::vector &phones, + const std::vector > > &pdf_class_pairs, + std::vector > > > *pdf_info) const { + KALDI_ASSERT(pdf_info); + KALDI_ASSERT(phones.size() == 1 && phones[0] == 1); + KALDI_ASSERT(pdf_class_pairs.size() == 2); + + pdf_info->resize(2); + (*pdf_info)[1].resize(pdf_class_pairs[1].size()); + + for (size_t j = 0; j < pdf_class_pairs[1].size(); j++) { + int32 pdf_class = pdf_class_pairs[1][j].first, + self_loop_pdf_class = pdf_class_pairs[1][j].second; + KALDI_ASSERT(pdf_class == self_loop_pdf_class && + pdf_class < NumPdfs()); + + (*pdf_info)[1][j].push_back(std::make_pair(pdf_class, pdf_class)); } } void SimpleHmm::Read(std::istream &is, bool binary) { - ExpectToken(is, binary, ""); - topo_.Read(is, binary); - Initialize(); - ComputeDerived(); - ExpectToken(is, binary, ""); - log_probs_.Read(is, binary); - ExpectToken(is, binary, ""); - ExpectToken(is, binary, ""); - ComputeDerivedOfProbs(); - Check(); -} - -void SimpleHmm::Write(std::ostream &os, bool binary) const { - WriteToken(os, binary, ""); - if (!binary) os << "\n"; - topo_.Write(os, binary); - if (!binary) os << "\n"; - WriteToken(os, binary, ""); - if (!binary) os << "\n"; - log_probs_.Write(os, binary); - WriteToken(os, binary, ""); - if (!binary) os << "\n"; - WriteToken(os, binary, ""); - if (!binary) os << "\n"; -} - -BaseFloat SimpleHmm::GetTransitionProb(int32 trans_id) const { - return Exp(log_probs_(trans_id)); -} - -BaseFloat SimpleHmm::GetTransitionLogProb(int32 trans_id) const { - return log_probs_(trans_id); -} - -BaseFloat SimpleHmm::GetNonSelfLoopLogProb(int32 trans_state) const { - KALDI_ASSERT(trans_state != 0); - return non_self_loop_log_probs_(trans_state); -} - -BaseFloat SimpleHmm::GetTransitionLogProbIgnoringSelfLoops( - int32 trans_id) const { - KALDI_ASSERT(trans_id != 0); - KALDI_PARANOID_ASSERT(!IsSelfLoop(trans_id)); - return log_probs_(trans_id) - GetNonSelfLoopLogProb(TransitionIdToTransitionState(trans_id)); -} - -// stats are counts/weights, indexed by transition-id. -void SimpleHmm::MleUpdate(const Vector &stats, - const MleSimpleHmmUpdateConfig &cfg, - BaseFloat *objf_impr_out, - BaseFloat *count_out) { - BaseFloat count_sum = 0.0, objf_impr_sum = 0.0; - int32 num_skipped = 0, num_floored = 0; - KALDI_ASSERT(stats.Dim() == NumTransitionIds()+1); - for (int32 tstate = 1; tstate <= NumTransitionStates(); tstate++) { - int32 n = NumTransitionIndices(tstate); - KALDI_ASSERT(n>=1); - if (n > 1) { // no point updating if only one transition... - Vector counts(n); - for (int32 tidx = 0; tidx < n; tidx++) { - int32 tid = PairToTransitionId(tstate, tidx); - counts(tidx) = stats(tid); - } - double tstate_tot = counts.Sum(); - count_sum += tstate_tot; - if (tstate_tot < cfg.mincount) { num_skipped++; } - else { - Vector old_probs(n), new_probs(n); - for (int32 tidx = 0; tidx < n; tidx++) { - int32 tid = PairToTransitionId(tstate, tidx); - old_probs(tidx) = new_probs(tidx) = GetTransitionProb(tid); - } - for (int32 tidx = 0; tidx < n; tidx++) - new_probs(tidx) = counts(tidx) / tstate_tot; - for (int32 i = 0; i < 3; i++) { // keep flooring+renormalizing for 3 times.. - new_probs.Scale(1.0 / new_probs.Sum()); - for (int32 tidx = 0; tidx < n; tidx++) - new_probs(tidx) = std::max(new_probs(tidx), cfg.floor); - } - // Compute objf change - for (int32 tidx = 0; tidx < n; tidx++) { - if (new_probs(tidx) == cfg.floor) num_floored++; - double objf_change = counts(tidx) * (Log(new_probs(tidx)) - - Log(old_probs(tidx))); - objf_impr_sum += objf_change; - } - // Commit updated values. - for (int32 tidx = 0; tidx < n; tidx++) { - int32 tid = PairToTransitionId(tstate, tidx); - log_probs_(tid) = Log(new_probs(tidx)); - if (log_probs_(tid) - log_probs_(tid) != 0.0) - KALDI_ERR << "Log probs is inf or NaN: error in update or bad stats?"; - } - } - } - } - KALDI_LOG << "SimpleHmm::Update, objf change is " - << (objf_impr_sum / count_sum) << " per frame over " << count_sum - << " frames. "; - KALDI_LOG << num_floored << " probabilities floored, " << num_skipped - << " out of " << NumTransitionStates() << " transition-states " - "skipped due to insuffient data (it is normal to have some skipped.)"; - if (objf_impr_out) *objf_impr_out = objf_impr_sum; - if (count_out) *count_out = count_sum; - ComputeDerivedOfProbs(); + TransitionModel::Read(is, binary); + ctx_dep_.Init(NumPdfs()); + CheckSimpleHmm(); } - -// stats are counts/weights, indexed by transition-id. -void SimpleHmm::MapUpdate(const Vector &stats, - const MapSimpleHmmUpdateConfig &cfg, - BaseFloat *objf_impr_out, - BaseFloat *count_out) { - KALDI_ASSERT(cfg.tau > 0.0); - BaseFloat count_sum = 0.0, objf_impr_sum = 0.0; - KALDI_ASSERT(stats.Dim() == NumTransitionIds()+1); - for (int32 tstate = 1; tstate <= NumTransitionStates(); tstate++) { - int32 n = NumTransitionIndices(tstate); - KALDI_ASSERT(n>=1); - if (n > 1) { // no point updating if only one transition... - Vector counts(n); - for (int32 tidx = 0; tidx < n; tidx++) { - int32 tid = PairToTransitionId(tstate, tidx); - counts(tidx) = stats(tid); - } - double tstate_tot = counts.Sum(); - count_sum += tstate_tot; - Vector old_probs(n), new_probs(n); - for (int32 tidx = 0; tidx < n; tidx++) { - int32 tid = PairToTransitionId(tstate, tidx); - old_probs(tidx) = new_probs(tidx) = GetTransitionProb(tid); - } - for (int32 tidx = 0; tidx < n; tidx++) - new_probs(tidx) = (counts(tidx) + cfg.tau * old_probs(tidx)) / - (cfg.tau + tstate_tot); - // Compute objf change - for (int32 tidx = 0; tidx < n; tidx++) { - double objf_change = counts(tidx) * (Log(new_probs(tidx)) - - Log(old_probs(tidx))); - objf_impr_sum += objf_change; - } - // Commit updated values. - for (int32 tidx = 0; tidx < n; tidx++) { - int32 tid = PairToTransitionId(tstate, tidx); - log_probs_(tid) = Log(new_probs(tidx)); - if (log_probs_(tid) - log_probs_(tid) != 0.0) - KALDI_ERR << "Log probs is inf or NaN: error in update or bad stats?"; - } - } - } - KALDI_LOG << "Objf change is " << (objf_impr_sum / count_sum) - << " per frame over " << count_sum - << " frames."; - if (objf_impr_out) *objf_impr_out = objf_impr_sum; - if (count_out) *count_out = count_sum; - ComputeDerivedOfProbs(); -} - - -int32 SimpleHmm::TransitionIdToPdfClass(int32 trans_id) const { - KALDI_ASSERT(trans_id != 0 && - static_cast(trans_id) < id2state_.size()); - int32 trans_state = id2state_[trans_id]; - - int32 hmm_state = states_[trans_state-1]; - const HmmTopology::TopologyEntry &entry = topo_.TopologyForPhone(1); - KALDI_ASSERT(static_cast(hmm_state) < entry.size()); - return entry[hmm_state].forward_pdf_class; -} - -int32 SimpleHmm::TransitionIdToHmmState(int32 trans_id) const { - KALDI_ASSERT(trans_id != 0 && - static_cast(trans_id) < id2state_.size()); - int32 trans_state = id2state_[trans_id]; - return states_[trans_state-1]; -} - -void SimpleHmm::Print(std::ostream &os, - const Vector *occs) { - if (occs != NULL) - KALDI_ASSERT(occs->Dim() == NumPdfs()); - for (int32 tstate = 1; tstate <= NumTransitionStates(); tstate++) { - int32 hmm_state = TransitionStateToHmmState(tstate); - int32 pdf_class = TransitionStateToPdfClass(tstate); - - os << " hmm-state = " << hmm_state; - os << " pdf-class = " << pdf_class << '\n'; - for (int32 tidx = 0; tidx < NumTransitionIndices(tstate); tidx++) { - int32 tid = PairToTransitionId(tstate, tidx); - BaseFloat p = GetTransitionProb(tid); - os << " Transition-id = " << tid << " p = " << p; - if (occs) { - os << " count of pdf-class = " << (*occs)(pdf_class); - } - // now describe what it's a transition to. - if (IsSelfLoop(tid)) { - os << " [self-loop]\n"; - } else { - const HmmTopology::TopologyEntry &entry = topo_.TopologyForPhone(1); - KALDI_ASSERT(static_cast(hmm_state) < entry.size()); - int32 next_hmm_state = entry[hmm_state].transitions[tidx].first; - KALDI_ASSERT(next_hmm_state != hmm_state); - os << " [" << hmm_state << " -> " << next_hmm_state << "]\n"; - } - } +void SimpleHmm::CheckSimpleHmm() const { + KALDI_ASSERT(NumPhones() == 1); + KALDI_ASSERT(GetPhones()[0] == 1); + const HmmTopology::TopologyEntry &entry = GetTopo().TopologyForPhone(1); + for (int32 j = 0; j < static_cast(entry.size()); j++) { // for each state... + int32 forward_pdf_class = entry[j].forward_pdf_class, + self_loop_pdf_class = entry[j].self_loop_pdf_class; + KALDI_ASSERT(forward_pdf_class == self_loop_pdf_class && + forward_pdf_class < NumPdfs()); } } -bool SimpleHmm::Compatible(const SimpleHmm &other) const { - return (topo_ == other.topo_ && states_ == other.states_ && - state2id_ == other.state2id_ && id2state_ == other.id2state_ - && NumPdfs() == other.NumPdfs()); -} - -bool SimpleHmm::IsSelfLoop(int32 trans_id) const { - KALDI_ASSERT(static_cast(trans_id) < id2state_.size()); - int32 trans_state = id2state_[trans_id]; - int32 trans_index = trans_id - state2id_[trans_state]; - int32 hmm_state = states_[trans_state-1]; - const HmmTopology::TopologyEntry &entry = topo_.TopologyForPhone(1); - KALDI_ASSERT(static_cast(hmm_state) < entry.size()); - return (static_cast(trans_index) < entry[hmm_state].transitions.size() - && entry[hmm_state].transitions[trans_index].first == hmm_state); -} - -} // end namespace simple_hmm } // end namespace kaldi - diff --git a/src/simplehmm/simple-hmm.h b/src/simplehmm/simple-hmm.h index ef3a5b9abde..4b40f212401 100644 --- a/src/simplehmm/simple-hmm.h +++ b/src/simplehmm/simple-hmm.h @@ -1,8 +1,6 @@ // hmm/simple-hmm.h -// Copyright 2009-2012 Microsoft Corporation -// Johns Hopkins University (author: Guoguo Chen) -// 2016 Vimal Manohar (Johns Hopkins University) +// Copyright 2016 Vimal Manohar // See ../../COPYING for clarification regarding multiple authors // @@ -19,256 +17,79 @@ // See the Apache 2 License for the specific language governing permissions and // limitations under the License. -#ifndef KALDI_HMM_SIMPLE_HMM_H -#define KALDI_HMM_SIMPLE_HMM_H +#ifndef KALDI_HMM_SIMPLE_HMM_H_ +#define KALDI_HMM_SIMPLE_HMM_H_ #include "base/kaldi-common.h" -#include "util/const-integer-set.h" -#include "fst/fst-decl.h" // forward declarations. -#include "hmm/hmm-topology.h" -#include "itf/options-itf.h" +#include "hmm/transition-model.h" +#include "itf/context-dep-itf.h" namespace kaldi { -namespace simple_hmm { - -/// \addtogroup hmm_group -/// @{ - -// The class SimpleHmm is a repository for the transition probabilities. -// The model is exactly like a single phone. It has a HMM topology defined in -// hmm-topology.h. Each HMM-state has a number of -// transitions (and final-probs) out of it. Each emitting HMM-state defined in -// the HmmTopology class has an associated class-id. -// The transition model associates the -// transition probs with the (HMM-state, class-id). We associate with -// each such pair a transition-state. Each -// transition-state has a number of associated probabilities to estimate; -// this depends on the number of transitions/final-probs in the topology for -// that HMM-state. Each probability has an associated transition-index. -// We associate with each (transition-state, transition-index) a unique transition-id. -// Each individual probability estimated by the transition-model is asociated with a -// transition-id. -// -// List of the various types of quantity referred to here and what they mean: -// HMM-state: a number (0, 1, 2...) that indexes TopologyEntry (see hmm-topology.h) -// transition-state: the states for which we estimate transition probabilities for transitions -// out of them. In some topologies, will map one-to-one with pdf-ids. -// One-based, since it appears on FSTs. -// transition-index: identifier of a transition (or final-prob) in the HMM. Indexes the -// "transitions" vector in HmmTopology::HmmState. [if it is out of range, -// equal to transitions.size(), it refers to the final-prob.] -// Zero-based. -// transition-id: identifier of a unique parameter of the -// SimpleHmm. -// Associated with a (transition-state, transition-index) pair. -// One-based, since it appears on FSTs. -// -// List of the possible mappings SimpleHmm can do: -// (HMM-state, class-id) -> transition-state -// (transition-state, transition-index) -> transition-id -// Reverse mappings: -// transition-id -> transition-state -// transition-id -> transition-index -// transition-state -> HMM-state -// transition-state -> class-id -// -// The main things the SimpleHmm object can do are: -// Get initialized (need HmmTopology objects). -// Read/write. -// Update [given a vector of counts indexed by transition-id]. -// Do the various integer mappings mentioned above. -// Get the probability (or log-probability) associated with a particular transition-id. - - -struct MleSimpleHmmUpdateConfig { - BaseFloat floor; - BaseFloat mincount; - MleSimpleHmmUpdateConfig(BaseFloat floor = 0.01, - BaseFloat mincount = 5.0): - floor(floor), mincount(mincount) { } - - void Register (OptionsItf *opts) { - opts->Register("transition-floor", &floor, - "Floor for transition probabilities"); - opts->Register("transition-min-count", &mincount, - "Minimum count required to update transitions from a state"); - } -}; - -struct MapSimpleHmmUpdateConfig { - BaseFloat tau; - MapSimpleHmmUpdateConfig(): tau(5.0) { } - - void Register (OptionsItf *opts) { - opts->Register("transition-tau", &tau, "Tau value for MAP estimation of transition " - "probabilities."); - } -}; - -class SimpleHmm { +class SimpleHmm: public TransitionModel { public: - /// Initialize the object [e.g. at the start of training]. - /// The class keeps a copy of the HmmTopology object. - SimpleHmm(const HmmTopology &hmm_topo); - - /// Constructor that takes no arguments: typically used prior to calling Read. - SimpleHmm() { } + SimpleHmm(const HmmTopology &hmm_topo): + ctx_dep_(hmm_topo) { + Init(ctx_dep_, hmm_topo); + CheckSimpleHmm(); + } + SimpleHmm(): TransitionModel() { } + void Read(std::istream &is, bool binary); // note, no symbol table: topo object always read/written w/o symbols. - void Write(std::ostream &os, bool binary) const; - - - /// return reference to HMM-topology object. - const HmmTopology &GetTopo() const { return topo_; } - - /// \name Integer mapping functions - /// @{ - - int32 HmmStateToTransitionState(int32 hmm_state) const; - int32 PairToTransitionId(int32 trans_state, int32 trans_index) const; - int32 TransitionIdToTransitionState(int32 trans_id) const; - int32 TransitionIdToTransitionIndex(int32 trans_id) const; - int32 TransitionStateToHmmState(int32 trans_state) const; - int32 TransitionStateToPdfClass(int32 trans_state) const; - // returns the self-loop transition-id, or zero if - // this state doesn't have a self-loop. - int32 SelfLoopOf(int32 trans_state) const; - - int32 TransitionIdToPdfClass(int32 trans_id) const; - int32 TransitionIdToHmmState(int32 trans_id) const; - - /// @} - - bool IsFinal(int32 trans_id) const; // returns true if this trans_id goes to the final state - // (which is bound to be nonemitting). - bool IsSelfLoop(int32 trans_id) const; // return true if this trans_id corresponds to a self-loop. - - /// Returns the total number of transition-ids (note, these are one-based). - inline int32 NumTransitionIds() const { return id2state_.size()-1; } - - /// Returns the number of transition-indices for a particular transition-state. - /// Note: "Indices" is the plural of "index". Index is not the same as "id", - /// here. A transition-index is a zero-based offset into the transitions - /// out of a particular transition state. - int32 NumTransitionIndices(int32 trans_state) const; - - /// Returns the total number of transition-states (note, these are one-based). - int32 NumTransitionStates() const { return states_.size(); } - - // NumPdfs() in the model. - int32 NumPdfs() const { return num_pdfs_; } - - // Transition-parameter-getting functions: - BaseFloat GetTransitionProb(int32 trans_id) const; - BaseFloat GetTransitionLogProb(int32 trans_id) const; - - // The following functions are more specialized functions for getting - // transition probabilities, that are provided for convenience. - - /// Returns the log-probability of a particular non-self-loop transition - /// after subtracting the probability mass of the self-loop and renormalizing; - /// will crash if called on a self-loop. Specifically: - /// for non-self-loops it returns the log of (that prob divided by (1 minus - /// self-loop-prob-for-that-state)). - BaseFloat GetTransitionLogProbIgnoringSelfLoops(int32 trans_id) const; - - /// Returns the log-prob of the non-self-loop probability - /// mass for this transition state. (you can get the self-loop prob, if a self-loop - /// exists, by calling GetTransitionLogProb(SelfLoopOf(trans_state)). - BaseFloat GetNonSelfLoopLogProb(int32 trans_state) const; - - /// Does Maximum Likelihood estimation. The stats are counts/weights, indexed - /// by transition-id. This was previously called Update(). - void MleUpdate(const Vector &stats, - const MleSimpleHmmUpdateConfig &cfg, - BaseFloat *objf_impr_out, - BaseFloat *count_out); - - /// Does Maximum A Posteriori (MAP) estimation. The stats are counts/weights, - /// indexed by transition-id. - void MapUpdate(const Vector &stats, - const MapSimpleHmmUpdateConfig &cfg, - BaseFloat *objf_impr_out, - BaseFloat *count_out); - - /// Print will print the simple HMM in a human-readable way, - /// for purposes of human - /// inspection. - /// The "occs" are optional (they are indexed by pdf-classes). - void Print(std::ostream &os, - const Vector *occs = NULL); - - - void InitStats(Vector *stats) const { stats->Resize(NumTransitionIds()+1); } - - void Accumulate(BaseFloat prob, int32 trans_id, Vector *stats) const { - KALDI_ASSERT(trans_id <= NumTransitionIds()); - (*stats)(trans_id) += prob; - // This is trivial and doesn't require class members, but leaves us more open - // to design changes than doing it manually. - } - - /// returns true if all the integer class members are identical (but does not - /// compare the transition probabilities. - bool Compatible(const SimpleHmm &other) const; private: - void MleUpdateShared(const Vector &stats, - const MleSimpleHmmUpdateConfig &cfg, - BaseFloat *objf_impr_out, BaseFloat *count_out); - void MapUpdateShared(const Vector &stats, - const MapSimpleHmmUpdateConfig &cfg, - BaseFloat *objf_impr_out, BaseFloat *count_out); - - // called from constructor and Read(): initializes states_ - void Initialize(); - // called from constructor and Read(): computes state2id_ and id2state_ - void ComputeDerived(); - // computes quantities derived from log-probs (currently just - // non_self_loop_log_probs_; called whenever log-probs change. - void ComputeDerivedOfProbs(); - void InitializeProbs(); // called from constructor. - void Check() const; + void CheckSimpleHmm() const; - HmmTopology topo_; - - /// States indexed by transition state minus one; - /// the states are in sorted order which allows us to do the reverse mapping - /// from state to transition state - std::vector states_; - - /// Gives the first transition_id of each transition-state; indexed by - /// the transition-state. Array indexed 1..num-transition-states+1 - /// (the last one is needed so we can know the num-transitions of the last - /// transition-state. - std::vector state2id_; + class FakeContextDependency: public ContextDependencyInterface { + public: + int ContextWidth() const { return 1; } + int CentralPosition() const { return 0; } + + bool Compute(const std::vector &phoneseq, int32 pdf_class, + int32 *pdf_id) const { + if (phoneseq.size() == 1 && phoneseq[0] == 1) { + *pdf_id = pdf_class; + return true; + } + return false; + } + + void GetPdfInfo( + const std::vector &phones, // list of phones + const std::vector &num_pdf_classes, // indexed by phone, + std::vector > > *pdf_info) const; + + void GetPdfInfo( + const std::vector &phones, + const std::vector > > &pdf_class_pairs, + std::vector > > > *pdf_info) + const; + + void Init(int32 num_pdfs) { num_pdfs_ = num_pdfs; } - /// For each transition-id, the corresponding transition - /// state (indexed by transition-id). - std::vector id2state_; + int32 NumPdfs() const { return num_pdfs_; } - /// For each transition-id, the corresponding log-prob. - /// Indexed by transition-id. - Vector log_probs_; + FakeContextDependency(const HmmTopology &topo) { + KALDI_ASSERT(topo.GetPhones().size() == 1); + num_pdfs_ = topo.NumPdfClasses(1); + } - /// For each transition-state, the log of (1 - self-loop-prob). Indexed by - /// transition-state. - Vector non_self_loop_log_probs_; + FakeContextDependency(): num_pdfs_(0) { } - /// This is equal to the one + highest-numbered pdf class. - int32 num_pdfs_; + ContextDependencyInterface* Copy() const { + FakeContextDependency *copy = new FakeContextDependency(); + copy->Init(num_pdfs_); + return copy; + } + private: + int32 num_pdfs_; + } ctx_dep_; DISALLOW_COPY_AND_ASSIGN(SimpleHmm); - }; -/// @} - - -} // end namespace simple_hmm -} // end namespace kaldi - +} // end namespace kaldi -#endif +#endif // KALDI_HMM_SIMPLE_HMM_H_ From be892299646d2989561dc928dfbc8d4289c90d95 Mon Sep 17 00:00:00 2001 From: Vimal Manohar Date: Wed, 25 Jan 2017 14:59:04 -0500 Subject: [PATCH 144/213] asr_diarization: Convert GMM posteriors to feats --- src/gmm/diag-gmm.h | 10 +++ src/gmmbin/Makefile | 3 +- src/gmmbin/gmm-global-post-to-feats.cc | 103 +++++++++++++++++++++++++ 3 files changed, 115 insertions(+), 1 deletion(-) create mode 100644 src/gmmbin/gmm-global-post-to-feats.cc diff --git a/src/gmm/diag-gmm.h b/src/gmm/diag-gmm.h index 1243d7a6bfd..32ef4f146d7 100644 --- a/src/gmm/diag-gmm.h +++ b/src/gmm/diag-gmm.h @@ -32,6 +32,8 @@ #include "matrix/matrix-lib.h" #include "tree/cluster-utils.h" #include "tree/clusterable-classes.h" +#include "util/kaldi-table.h" +#include "util/kaldi-holder.h" namespace kaldi { @@ -255,6 +257,14 @@ operator << (std::ostream &os, const kaldi::DiagGmm &gmm); std::istream & operator >> (std::istream &is, kaldi::DiagGmm &gmm); +typedef KaldiObjectHolder DiagGmmHolder; + +typedef TableWriter DiagGmmWriter; +typedef SequentialTableReader SequentialDiagGmmReader; +typedef RandomAccessTableReader RandomAccessDiagGmmReader; +typedef RandomAccessTableReaderMapped +RandomAccessDiagGmmReaderMapped; + } // End namespace kaldi #include "gmm/diag-gmm-inl.h" // templated functions. diff --git a/src/gmmbin/Makefile b/src/gmmbin/Makefile index 7adb8bdc41e..caf4b1f8118 100644 --- a/src/gmmbin/Makefile +++ b/src/gmmbin/Makefile @@ -28,7 +28,8 @@ BINFILES = gmm-init-mono gmm-est gmm-acc-stats-ali gmm-align \ gmm-est-fmllr-raw gmm-est-fmllr-raw-gpost gmm-global-init-from-feats \ gmm-global-info gmm-latgen-faster-regtree-fmllr gmm-est-fmllr-global \ gmm-acc-mllt-global gmm-transform-means-global gmm-global-get-post \ - gmm-global-gselect-to-post gmm-global-est-lvtln-trans + gmm-global-gselect-to-post gmm-global-est-lvtln-trans \ + gmm-global-post-to-feats OBJFILES = diff --git a/src/gmmbin/gmm-global-post-to-feats.cc b/src/gmmbin/gmm-global-post-to-feats.cc new file mode 100644 index 00000000000..fa903b66014 --- /dev/null +++ b/src/gmmbin/gmm-global-post-to-feats.cc @@ -0,0 +1,103 @@ +// gmmbin/gmm-global-post-to-feats.cc + +// Copyright 2016 Brno University of Technology (Author: Karel Vesely) +// 2016 Vimal Manohar + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + + +#include "base/kaldi-common.h" +#include "util/common-utils.h" +#include "matrix/kaldi-matrix.h" +#include "hmm/posterior.h" +#include "gmm/diag-gmm.h" + +int main(int argc, char *argv[]) { + try { + using namespace kaldi; + + const char *usage = + "Convert GMM global posteriors to features\n" + "\n" + "Usage: gmm-global-post-to-feats [options] \n" + "e.g.: gmm-global-post-to-feats ark:1.gmm ark:post.ark ark:feat.ark\n" + "See also: post-to-feats --post-dim, post-to-weights feat-to-post, append-vector-to-feats, append-post-to-feats\n"; + + ParseOptions po(usage); + std::string utt2spk_rspecifier; + + po.Register("utt2spk", &utt2spk_rspecifier, + "rspecifier for utterance to speaker map for reading " + "per-speaker GMM models"); + + po.Read(argc, argv); + + if (po.NumArgs() != 3) { + po.PrintUsage(); + exit(1); + } + + std::string model_in_filename = po.GetArg(1), + post_rspecifier = po.GetArg(2), + feat_wspecifier = po.GetArg(3); + + DiagGmm diag_gmm; + RandomAccessDiagGmmReaderMapped *gmm_reader = NULL; + SequentialPosteriorReader post_reader(post_rspecifier); + BaseFloatMatrixWriter feat_writer(feat_wspecifier); + + if (ClassifyRspecifier(po.GetArg(1), NULL, NULL) + != kNoRspecifier) { // We're operating on tables, e.g. archives. + gmm_reader = new RandomAccessDiagGmmReaderMapped(model_in_filename, + utt2spk_rspecifier); + } else { + ReadKaldiObject(model_in_filename, &diag_gmm); + } + + int32 num_done = 0, num_err = 0; + + for (; !post_reader.Done(); post_reader.Next()) { + const std::string &utt = post_reader.Key(); + + const DiagGmm *gmm = &diag_gmm; + if (gmm_reader) { + if (!gmm_reader->HasKey(utt)) { + KALDI_WARN << "Could not find GMM model for utterance " << utt; + num_err++; + continue; + } + gmm = &(gmm_reader->Value(utt)); + } + + int32 post_dim = gmm->NumGauss(); + + const Posterior &post = post_reader.Value(); + + Matrix output; + PosteriorToMatrix(post, post_dim, &output); + + feat_writer.Write(utt, output); + num_done++; + } + KALDI_LOG << "Done " << num_done << " utts, errors on " + << num_err; + + return (num_done == 0 ? -1 : 0); + } catch(const std::exception &e) { + std::cerr << e.what(); + return -1; + } +} From c23060eb527957ab6c9e6a6064e9de3c28e0b657 Mon Sep 17 00:00:00 2001 From: Vimal Manohar Date: Wed, 25 Jan 2017 14:59:34 -0500 Subject: [PATCH 145/213] asr_diarization: Remove some accidentally added files --- src/simplehmm/simple-hmm-acc-stats-fsts.cc | 173 --------------------- src/simplehmm/simple-hmm-computation.cc | 5 - 2 files changed, 178 deletions(-) delete mode 100644 src/simplehmm/simple-hmm-acc-stats-fsts.cc delete mode 100644 src/simplehmm/simple-hmm-computation.cc diff --git a/src/simplehmm/simple-hmm-acc-stats-fsts.cc b/src/simplehmm/simple-hmm-acc-stats-fsts.cc deleted file mode 100644 index de4a7528836..00000000000 --- a/src/simplehmm/simple-hmm-acc-stats-fsts.cc +++ /dev/null @@ -1,173 +0,0 @@ -// simplehmmbin/simple-hmm-acc-stats-fsts.cc - -// Copyright 2016 Vimal Manohar (Johns Hopkins University) - -// See ../../COPYING for clarification regarding multiple authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED -// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, -// MERCHANTABLITY OR NON-INFRINGEMENT. -// See the Apache 2 License for the specific language governing permissions and -// limitations under the License. - -#include "base/kaldi-common.h" -#include "util/common-utils.h" -#include "simplehmm/simple-hmm.h" -#include "hmm/hmm-utils.h" -#include "fstext/fstext-lib.h" -#include "decoder/decoder-wrappers.h" - -int main(int argc, char *argv[]) { - try { - using namespace kaldi; - typedef kaldi::int32 int32; - using fst::SymbolTable; - using fst::VectorFst; - using fst::StdArc; - - const char *usage = - "Accumulate stats for simple HMM models from FSTs directly.\n" - "Usage: simple-hmm-acc-stats-fsts [options] " - " \n" - "e.g.: \n" - " simple-hmm-acc-stats-fsts 1.mdl ark:graphs.fsts scp:likes.scp pdf2class_map 1.stats\n"; - - ParseOptions po(usage); - - BaseFloat acoustic_scale = 1.0; - BaseFloat transition_scale = 1.0; - BaseFloat self_loop_scale = 1.0; - - po.Register("transition-scale", &transition_scale, - "Transition-probability scale [relative to acoustics]"); - po.Register("acoustic-scale", &acoustic_scale, - "Scaling factor for acoustic likelihoods"); - po.Register("self-loop-scale", &self_loop_scale, - "Scale of self-loop versus non-self-loop log probs [relative to acoustics]"); - po.Read(argc, argv); - - - if (po.NumArgs() != 5) { - po.PrintUsage(); - exit(1); - } - - std::string model_in_filename = po.GetArg(1), - fst_rspecifier = po.GetArg(2), - likes_rspecifier = po.GetArg(3), - pdf2class_map_rxfilename = po.GetArg(4), - accs_wxfilename = po.GetArg(5); - - simple_hmm::SimpleHmm model; - ReadKaldiObject(model_in_filename, &model); - - SequentialTableReader fst_reader(fst_rspecifier); - RandomAccessBaseFloatMatrixReader likes_reader(likes_rspecifier); - - std::vector pdf2class; - { - Input ki(pdf2class_map_rxfilename); - std::string line; - while (std::getline(ki.Stream(), line)) { - std::vector parts; - SplitStringToVector(line, " ", true, &parts); - if (parts.size() != 2) { - KALDI_ERR << "Invalid line " << line - << " in pdf2class-map " << pdf2class_map_rxfilename; - } - int32 pdf_id = std::atoi(parts[0].c_str()), - class_id = std::atoi(parts[1].c_str()); - - if (pdf_id != pdf2class.size()) - KALDI_ERR << "pdf2class-map is not sorted or does not contain " - << "pdf " << pdf_id - 1 << " in " - << pdf2class_map_rxfilename; - - if (pdf_id < pdf2class.size()) - KALDI_ERR << "Duplicate pdf " << pdf_id - << " in pdf2class-map " << pdf2class_map_rxfilename; - - pdf2class.push_back(class_id); - } - } - - int32 num_done = 0, num_err = 0; - double tot_like = 0.0, tot_t = 0.0; - int64 frame_count = 0; - - Vector transition_accs; - model.InitStats(&transition_accs); - - SimpleHmmComputation computation(model, pdf2class_map); - - for (; !fst_reader.Done(); fst_reader.Next()) { - const std::string &utt = fst_reader.Key(); - - if (!likes_reader.HasKey(utt)) { - num_err++; - KALDI_WARN << "No likes for utterance " << utt; - continue; - } - - const Matrix &likes = likes_reader.Value(utt); - VectorFst decode_fst(fst_reader.Value()); - fst_reader.FreeCurrent(); // this stops copy-on-write of the fst - // by deleting the fst inside the reader, since we're about to mutate - // the fst by adding transition probs. - - if (likes.NumRows() == 0) { - KALDI_WARN << "Zero-length utterance: " << utt; - num_err++; - continue; - } - - if (likes.NumCols() != pdf2class.size()) { - KALDI_ERR << "Mismatch in pdf dimension in log-likelihood matrix " - << "and pdf2class map; " << likes.NumCols() << " vs " - << pdf2class.size(); - } - - // Add transition-probs to the FST. - AddTransitionProbs(model, transition_scale, self_loop_scale, - &decode_fst); - - BaseFloat tot_like_this_utt = 0.0, tot_weight = 0.0; - if (!computation.Compute(decode_fst, likes, acoustic_scale, - &transition_accs, - &tot_like_this_utt, &tot_weight)) { - KALDI_WARN << "Failed to do computation for utterance " << utt; - num_err++; - } - tot_like += tot_like_this_utt; - tot_t += tot_weight; - frame_count += likes.NumRows(); - - num_done++; - } - - KALDI_LOG << "Done " << num_done << " files, " << num_err - << " with errors."; - - KALDI_LOG << "Overall avg like per frame = " - << (tot_like/tot_t) << " over " << tot_t << " frames."; - - { - Output ko(accs_wxfilename, binary); - transition_accs.Write(ko.Stream(), binary); - } - KALDI_LOG << "Written accs."; - return (num_done != 0 ? 0 : 1); - } catch(const std::exception &e) { - std::cerr << e.what(); - return -1; - } -} - - diff --git a/src/simplehmm/simple-hmm-computation.cc b/src/simplehmm/simple-hmm-computation.cc deleted file mode 100644 index e20f84169a1..00000000000 --- a/src/simplehmm/simple-hmm-computation.cc +++ /dev/null @@ -1,5 +0,0 @@ -SimpleHmmComputation::SimpleHmmComputation( - const SimpleHmm &model, - const std::vector &num_pdfs, - VectorFst *decode_fst, - const Matrix &log_likes) From 8786deab25721910fb31f0ffcd744fa62e563c12 Mon Sep 17 00:00:00 2001 From: Vimal Manohar Date: Wed, 25 Jan 2017 15:00:23 -0500 Subject: [PATCH 146/213] asr_diarzation: Update do_corruption_data_dir{,_music} --- .../segmentation/do_corruption_data_dir.sh | 16 ++-- .../do_corruption_data_dir_music.sh | 80 +++++++++++++------ 2 files changed, 64 insertions(+), 32 deletions(-) diff --git a/egs/aspire/s5/local/segmentation/do_corruption_data_dir.sh b/egs/aspire/s5/local/segmentation/do_corruption_data_dir.sh index 1bfa08370e7..5d38be87d70 100755 --- a/egs/aspire/s5/local/segmentation/do_corruption_data_dir.sh +++ b/egs/aspire/s5/local/segmentation/do_corruption_data_dir.sh @@ -76,14 +76,14 @@ corrupted_data_dir=data/${corrupted_data_id} if $speed_perturb; then if [ $stage -le 2 ]; then ## Assuming whole data directories - for x in $clean_data_dir $corrupted_data_dir $noise_data_dir; do + for x in $corrupted_data_dir; do cp $x/reco2dur $x/utt2dur - utils/data/perturb_data_dir_speed_3way.sh $x ${x}_sp + utils/data/perturb_data_dir_speed_random.sh $x ${x}_spr done fi - corrupted_data_dir=${corrupted_data_dir}_sp - corrupted_data_id=${corrupted_data_id}_sp + corrupted_data_dir=${corrupted_data_dir}_spr + corrupted_data_id=${corrupted_data_id}_spr if [ $stage -le 3 ]; then utils/data/perturb_data_dir_volume.sh --scale-low 0.03125 --scale-high 2 \ @@ -118,14 +118,14 @@ fi if [ $stage -le 8 ]; then if [ ! -z "$reco_vad_dir" ]; then - if [ ! -f $reco_vad_dir/speech_feat.scp ]; then - echo "$0: Could not find file $reco_vad_dir/speech_feat.scp" + if [ ! -f $reco_vad_dir/speech_labels.scp ]; then + echo "$0: Could not find file $reco_vad_dir/speech_labels.scp" exit 1 fi - cat $reco_vad_dir/speech_feat.scp | \ + cat $reco_vad_dir/speech_labels.scp | \ steps/segmentation/get_reverb_scp.pl -f 1 $num_data_reps | \ - sort -k1,1 > ${corrupted_data_dir}/speech_feat.scp + sort -k1,1 > ${corrupted_data_dir}/speech_labels.scp cat $reco_vad_dir/deriv_weights.scp | \ steps/segmentation/get_reverb_scp.pl -f 1 $num_data_reps | \ diff --git a/egs/aspire/s5/local/segmentation/do_corruption_data_dir_music.sh b/egs/aspire/s5/local/segmentation/do_corruption_data_dir_music.sh index 214cba347da..4fc369234ea 100755 --- a/egs/aspire/s5/local/segmentation/do_corruption_data_dir_music.sh +++ b/egs/aspire/s5/local/segmentation/do_corruption_data_dir_music.sh @@ -71,19 +71,20 @@ if $dry_run; then fi corrupted_data_dir=data/${corrupted_data_id} -orig_corrupted_data_dir=$corrupted_data_dir +# Data dir without speed perturbation +orig_corrupted_data_dir=$corrupted_data_dir if $speed_perturb; then if [ $stage -le 2 ]; then ## Assuming whole data directories for x in $corrupted_data_dir; do cp $x/reco2dur $x/utt2dur - utils/data/perturb_data_dir_speed_3way.sh $x ${x}_sp + utils/data/perturb_data_dir_speed_random.sh $x ${x}_spr done fi - corrupted_data_dir=${corrupted_data_dir}_sp - corrupted_data_id=${corrupted_data_id}_sp + corrupted_data_dir=${corrupted_data_dir}_spr + corrupted_data_id=${corrupted_data_id}_spr if [ $stage -le 3 ]; then utils/data/perturb_data_dir_volume.sh --scale-low 0.03125 --scale-high 2 \ @@ -122,14 +123,14 @@ fi if [ $stage -le 8 ]; then if [ ! -z "$reco_vad_dir" ]; then - if [ ! -f $reco_vad_dir/speech_feat.scp ]; then - echo "$0: Could not find file $reco_vad_dir/speech_feat.scp" + if [ ! -f $reco_vad_dir/speech_labels.scp ]; then + echo "$0: Could not find file $reco_vad_dir/speech_labels.scp" exit 1 fi - cat $reco_vad_dir/speech_feat.scp | \ + cat $reco_vad_dir/speech_labels.scp | \ steps/segmentation/get_reverb_scp.pl -f 1 $num_data_reps "music" | \ - sort -k1,1 > ${corrupted_data_dir}/speech_feat.scp + sort -k1,1 > ${corrupted_data_dir}/speech_labels.scp cat $reco_vad_dir/deriv_weights.scp | \ steps/segmentation/get_reverb_scp.pl -f 1 $num_data_reps "music" | \ @@ -144,38 +145,41 @@ music_data_dir=$music_dir/music_data mkdir -p $music_data_dir if [ $stage -le 10 ]; then - utils/data/get_utt2num_frames.sh $corrupted_data_dir + utils/data/get_reco2num_frames.sh --nj $reco_nj $orig_corrupted_data_dir utils/split_data.sh --per-reco ${orig_corrupted_data_dir} $reco_nj cp $orig_corrupted_data_dir/wav.scp $music_data_dir - - # Combine the VAD from the base recording and the VAD from the overlapping segments - # to create per-frame labels of the number of overlapping speech segments - # Unreliable segments are regions where no VAD labels were available for the - # overlapping segments. These can be later removed by setting deriv weights to 0. + + # The first rspecifier is a dummy required to get the recording-id as key. + # It has no segments in it as they are all removed by --remove-labels. $train_cmd JOB=1:$reco_nj $music_dir/log/get_music_seg.JOB.log \ - segmentation-init-from-additive-signals-info --lengths-rspecifier=ark,t:$corrupted_data_dir/utt2num_frames \ + segmentation-init-from-additive-signals-info --lengths-rspecifier=ark,t:${orig_corrupted_data_dir}/reco2num_frames \ --additive-signals-segmentation-rspecifier="ark:segmentation-init-from-lengths ark:$music_utt2num_frames ark:- |" \ - "ark:utils/filter_scp.pl ${orig_corrupted_data_dir}/split${reco_nj}reco/JOB/utt2spk $corrupted_data_dir/utt2num_frames | segmentation-init-from-lengths --label=1 ark:- ark:- | segmentation-post-process --remove-labels=1 ark:- ark:- |" \ - ark,t:$orig_corrupted_data_dir/additive_signals_info.txt \ + "ark,t:utils/filter_scp.pl ${orig_corrupted_data_dir}/split${reco_nj}reco/JOB/reco2utt $orig_corrupted_data_dir/additive_signals_info.txt |" \ ark:- \| \ segmentation-post-process --merge-adjacent-segments ark:- \ ark:- \| \ segmentation-to-segments ark:- ark:$music_data_dir/utt2spk.JOB \ $music_data_dir/segments.JOB + utils/data/get_reco2utt.sh $corrupted_data_dir for n in `seq $reco_nj`; do cat $music_data_dir/utt2spk.$n; done > $music_data_dir/utt2spk for n in `seq $reco_nj`; do cat $music_data_dir/segments.$n; done > $music_data_dir/segments utils/fix_data_dir.sh $music_data_dir if $speed_perturb; then - utils/data/perturb_data_dir_speed_3way.sh $music_data_dir ${music_data_dir}_sp + utils/data/perturb_data_dir_speed_3way.sh $music_data_dir ${music_data_dir}_spr + mv ${music_data_dir}_spr/segments{,.temp} + cat ${music_data_dir}_spr/segments.temp | \ + utils/filter_scp.pl -f 2 ${corrupted_data_dir}/reco2utt > ${music_data_dir}_spr/segments + utils/fix_data_dir.sh ${music_data_dir}_spr + rm ${music_data_dir}_spr/segments.temp fi fi if $speed_perturb; then - music_data_dir=${music_data_dir}_sp + music_data_dir=${music_data_dir}_spr fi label_dir=music_labels @@ -184,13 +188,20 @@ mkdir -p $label_dir label_dir=`perl -e '($dir,$pwd)= @ARGV; if($dir!~m:^/:) { $dir = "$pwd/$dir"; } print $dir; ' $label_dir ${PWD}` if [ $stage -le 11 ]; then - utils/split_data.sh --per-reco ${music_data_dir} $reco_nj + utils/split_data.sh --per-reco ${corrupted_data_dir} $reco_nj + # TODO: Don't assume that its whole data directory. + nj=$reco_nj + if [ $nj -gt 4 ]; then + nj=4 + fi + utils/data/get_utt2num_frames.sh --cmd "$train_cmd" --nj $nj ${corrupted_data_dir} + utils/data/get_reco2utt.sh $music_data_dir/ $train_cmd JOB=1:$reco_nj $music_dir/log/get_music_labels.JOB.log \ - utils/data/get_reco2utt.sh ${music_data_dir}/split${reco_nj}reco/JOB '&&' \ segmentation-init-from-segments --shift-to-zero=false \ - ${music_data_dir}/split${reco_nj}reco/JOB/segments ark:- \| \ - segmentation-combine-segments-to-recordings ark:- ark,t:${music_data_dir}/split${reco_nj}reco/JOB/reco2utt \ + "utils/filter_scp.pl -f 2 ${corrupted_data_dir}/split${reco_nj}reco/JOB/reco2utt ${music_data_dir}/segments |" ark:- \| \ + segmentation-combine-segments-to-recordings ark:- \ + "ark,t:utils/filter_scp.pl ${corrupted_data_dir}/split${reco_nj}reco/JOB/reco2utt ${music_data_dir}/reco2utt |" \ ark:- \| \ segmentation-to-ali --lengths-rspecifier=ark,t:${corrupted_data_dir}/utt2num_frames ark:- \ ark,scp:$label_dir/music_labels_${corrupted_data_id}.JOB.ark,$label_dir/music_labels_${corrupted_data_id}.JOB.scp @@ -198,6 +209,27 @@ fi for n in `seq $reco_nj`; do cat $label_dir/music_labels_${corrupted_data_id}.$n.scp -done > ${corrupted_data_dir}/music_labels.scp +done | utils/filter_scp.pl ${corrupted_data_dir}/utt2spk > ${corrupted_data_dir}/music_labels.scp + +if [ $stage -le 12 ]; then + utils/split_data.sh --per-reco ${corrupted_data_dir} $reco_nj + + cat < $music_dir/speech_music_map +0 0 0 +0 1 3 +1 0 1 +1 1 2 +EOF + + $train_cmd JOB=1:$reco_nj $music_dir/log/get_speech_music_labels.JOB.log \ + intersect-int-vectors --mapping-in=$music_dir/speech_music_map \ + "scp:utils/filter_scp.pl ${corrupted_data_dir}/split${reco_nj}reco/JOB/reco2utt ${corrupted_data_dir}/speech_labels.scp |" \ + "scp:utils/filter_scp.pl ${corrupted_data_dir}/split${reco_nj}reco/JOB/reco2utt ${corrupted_data_dir}/music_labels.scp |" \ + ark,scp:$label_dir/speech_music_labels_${corrupted_data_id}.JOB.ark,$label_dir/speech_music_labels_${corrupted_data_id}.JOB.scp + + for n in `seq $reco_nj`; do + cat $label_dir/speech_music_labels_${corrupted_data_id}.$n.scp + done > $corrupted_data_dir/speech_music_labels.scp +fi exit 0 From eb5432282b54e897d14ea65e7fc8ce8cac1c3420 Mon Sep 17 00:00:00 2001 From: Vimal Manohar Date: Wed, 25 Jan 2017 15:00:58 -0500 Subject: [PATCH 147/213] asr_diarization: Prepare unsad data fisher and babel --- .../local/segmentation/prepare_babel_data.sh | 105 ++++++++++++++++++ .../local/segmentation/prepare_fisher_data.sh | 18 +-- .../local/segmentation/prepare_unsad_data.sh | 28 ++--- 3 files changed, 128 insertions(+), 23 deletions(-) create mode 100644 egs/aspire/s5/local/segmentation/prepare_babel_data.sh diff --git a/egs/aspire/s5/local/segmentation/prepare_babel_data.sh b/egs/aspire/s5/local/segmentation/prepare_babel_data.sh new file mode 100644 index 00000000000..24a61eca772 --- /dev/null +++ b/egs/aspire/s5/local/segmentation/prepare_babel_data.sh @@ -0,0 +1,105 @@ +#! /bin/bash + +# Copyright 2016 Vimal Manohar +# Apache 2.0. + +# This script prepares Babel data for training speech activity detection, +# music detection. + +. path.sh +. cmd.sh + +set -e +set -o pipefail +set -u + +lang_id=assamese +subset= # Number of recordings to keep before speed perturbation and corruption. + # In limitedLP, this is about 120. So subset, if specified, must be lower that that. + +# All the paths below can be modified to any absolute path. +ROOT_DIR=/home/vimal/workspace_waveform/egs/babel/s5c_assamese/ + +stage=-1 + +. utils/parse_options.sh + +if [ $# -ne 0 ]; then + echo "Usage: $0" + echo "This script is to serve as an example recipe." + echo "Edit the script to change variables if needed." + exit 1 +fi + +dir=exp/unsad/make_unsad_babel_${lang_id}_train # Work dir + +model_dir=$ROOT_DIR/exp/tri4 # Model directory used for decoding +sat_model_dir=$ROOT_DIR/exp/tri5 # Model directory used for getting alignments +lang=$ROOT_DIR/data/lang # Language directory +lang_test=$ROOT_DIR/data/lang # Language directory used to build graph + +mkdir -p $dir + +# Hard code the mapping from phones to SAD labels +# 0 for silence, 1 for speech, 2 for noise, 3 for unk +cat < $dir/babel_sad.map + 3 +_B 3 +_E 3 +_I 3 +_S 3 + 2 +_B 2 +_E 2 +_I 2 +_S 2 + 2 +_B 2 +_E 2 +_I 2 +_S 2 +SIL 0 +SIL_B 0 +SIL_E 0 +SIL_I 0 +SIL_S 0 +EOF + +# The original data directory which will be converted to a whole (recording-level) directory. +utils/copy_data_dir.sh $ROOT_DIR/data/train data/babel_${lang_id}_train +train_data_dir=data/babel_${lang_id}_train + +# Expecting the user to have done run.sh to have $model_dir, +# $sat_model_dir, $lang, $lang_test, $train_data_dir +local/segmentation/prepare_unsad_data.sh --stage 14 \ + --sad-map $dir/babel_sad.map \ + --config-dir $ROOT_DIR/conf --feat-type plp --add-pitch true \ + --reco-nj 40 --nj 100 --cmd "$train_cmd" \ + --sat-model-dir $sat_model_dir \ + --lang-test $lang_test \ + $train_data_dir $lang $model_dir $dir + +orig_data_dir=${train_data_dir}_sp + +data_dir=${train_data_dir}_whole + +if [ ! -z $subset ]; then + # Work on a subset + utils/subset_data_dir.sh ${data_dir} $subset \ + ${data_dir}_$subset + data_dir=${data_dir}_$subset +fi + +reco_vad_dir=$dir/`basename $model_dir`_reco_vad_`basename $train_data_dir`_sp + +# Add noise from MUSAN corpus to data directory and create a new data directory +local/segmentation/do_corruption_data_dir.sh \ + --data-dir $data_dir \ + --reco-vad-dir $reco_vad_dir \ + --feat-suffix hires_bp --mfcc-config conf/mfcc_hires_bp.conf + +# Add music from MUSAN corpus to data directory and create a new data directory +local/segmentation/do_corruption_data_dir_music.sh \ + --data-dir $data_dir \ + --reco-vad-dir $reco_vad_dir \ + --feat-suffix hires_bp --mfcc-config conf/mfcc_hires_bp.conf diff --git a/egs/aspire/s5/local/segmentation/prepare_fisher_data.sh b/egs/aspire/s5/local/segmentation/prepare_fisher_data.sh index 1344e185a02..4749ff7da8a 100644 --- a/egs/aspire/s5/local/segmentation/prepare_fisher_data.sh +++ b/egs/aspire/s5/local/segmentation/prepare_fisher_data.sh @@ -9,6 +9,8 @@ . path.sh . cmd.sh +set -e -o pipefail + if [ $# -ne 0 ]; then echo "Usage: $0" echo "This script is to serve as an example recipe." @@ -17,7 +19,7 @@ if [ $# -ne 0 ]; then fi dir=exp/unsad/make_unsad_fisher_train_100k # Work dir -subset=150 +subset=900 # All the paths below can be modified to any absolute path. @@ -54,21 +56,23 @@ oov_I 3 oov_S 3 EOF +false && { # Expecting the user to have done run.sh to have $model_dir, # $sat_model_dir, $lang, $lang_test, $train_data_dir local/segmentation/prepare_unsad_data.sh \ --sad-map $dir/fisher_sad.map \ --config-dir conf \ --reco-nj 40 --nj 100 --cmd "$train_cmd" \ - --sat-model $sat_model_dir \ + --sat-model-dir $sat_model_dir \ --lang-test $lang_test \ $train_data_dir $lang $model_dir $dir +} data_dir=${train_data_dir}_whole if [ ! -z $subset ]; then # Work on a subset - utils/subset_data_dir.sh ${data_dir} $subset \ + false && utils/subset_data_dir.sh ${data_dir} $subset \ ${data_dir}_$subset data_dir=${data_dir}_$subset fi @@ -76,13 +80,13 @@ fi reco_vad_dir=$dir/`basename $model_dir`_reco_vad_`basename $train_data_dir`_sp # Add noise from MUSAN corpus to data directory and create a new data directory -local/segmentation/do_corruption_data_dir.sh +false && local/segmentation/do_corruption_data_dir.sh \ --data-dir $data_dir \ - --reco-vad-dir $reco_vad_dir + --reco-vad-dir $reco_vad_dir \ --feat-suffix hires_bp --mfcc-config conf/mfcc_hires_bp.conf # Add music from MUSAN corpus to data directory and create a new data directory -local/segmentation/do_corruption_data_dir_music.sh +local/segmentation/do_corruption_data_dir_music.sh --stage 10 \ --data-dir $data_dir \ - --reco-vad-dir $reco_vad_dir + --reco-vad-dir $reco_vad_dir \ --feat-suffix hires_bp --mfcc-config conf/mfcc_hires_bp.conf diff --git a/egs/aspire/s5/local/segmentation/prepare_unsad_data.sh b/egs/aspire/s5/local/segmentation/prepare_unsad_data.sh index 12097811ec9..7385e309f5f 100755 --- a/egs/aspire/s5/local/segmentation/prepare_unsad_data.sh +++ b/egs/aspire/s5/local/segmentation/prepare_unsad_data.sh @@ -40,7 +40,7 @@ lang_test= # Language directory used to build graph. . utils/parse_options.sh -if [ $# -ne 5 ]; then +if [ $# -ne 4 ]; then echo "This script takes a data directory and creates a new data directory " echo "and speech activity labels" echo "for the purpose of training a Universal Speech Activity Detector." @@ -241,12 +241,12 @@ fi utils/data/get_reco2utt.sh $data_dir if [ $stage -le 0 ]; then - steps/segmentation/get_utt2num_frames.sh \ + utils/data/get_utt2num_frames.sh \ --frame-shift $frame_shift --frame-overlap $frame_overlap \ --cmd "$cmd" --nj $reco_nj $whole_data_dir awk '{print $1" "$2}' ${data_dir}/segments | utils/apply_map.pl -f 2 ${whole_data_dir}/utt2num_frames > $data_dir/utt2max_frames - utils/data/subsegment_feats.sh ${whole_data_dir}/feats.scp \ + utils/data/get_subsegmented_feats.sh ${whole_data_dir}/feats.scp \ $frame_shift $frame_overlap ${data_dir}/segments | \ utils/data/fix_subsegmented_feats.pl $data_dir/utt2max_frames \ > ${data_dir}/feats.scp @@ -289,8 +289,7 @@ utils/split_data.sh $data_dir $nj vad_dir=$dir/`basename ${ali_dir}`_vad_${data_id} if [ $stage -le 3 ]; then steps/segmentation/internal/convert_ali_to_vad.sh --cmd "$cmd" \ - $data_dir $ali_dir \ - $dir/sad_map $vad_dir + $ali_dir $dir/sad_map $vad_dir fi [ ! -s $vad_dir/sad_seg.scp ] && echo "$0: $vad_dir/vad.scp is empty" && exit 1 @@ -381,9 +380,9 @@ if [ $stage -le 6 ]; then utils/data/get_reco2utt.sh $outside_data_dir awk '{print $1" "$2}' $outside_data_dir/segments | utils/apply_map.pl -f 2 $whole_data_dir/utt2num_frames > $outside_data_dir/utt2max_frames - utils/data/subsegment_feats.sh ${whole_data_dir}/feats.scp \ + utils/data/get_subsegmented_feats.sh ${whole_data_dir}/feats.scp \ $frame_shift $frame_overlap ${outside_data_dir}/segments | \ - utils/data/fix_subsegmented_feats.pl $outside_data_dir/utt2max_framres \ + utils/data/fix_subsegmented_feats.pl $outside_data_dir/utt2max_frames \ > ${outside_data_dir}/feats.scp fi @@ -432,8 +431,7 @@ model_id=`basename $model_dir` decode_vad_dir=$dir/${model_id}_decode_vad_${data_id} if [ $stage -le 9 ]; then steps/segmentation/internal/convert_ali_to_vad.sh --cmd "$cmd" \ - $extended_data_dir ${model_dir}/decode_${data_id}_extended \ - $dir/sad_map $decode_vad_dir + ${model_dir}/decode_${data_id}_extended $dir/sad_map $decode_vad_dir fi [ ! -s $decode_vad_dir/sad_seg.scp ] && echo "$0: $decode_vad_dir/vad.scp is empty" && exit 1 @@ -477,7 +475,7 @@ set +e for n in `seq $reco_nj`; do utils/create_data_link.pl $reco_vad_dir/deriv_weights.$n.ark utils/create_data_link.pl $reco_vad_dir/deriv_weights_for_uncorrupted.$n.ark - utils/create_data_link.pl $reco_vad_dir/speech_feat.$n.ark + utils/create_data_link.pl $reco_vad_dir/speech_labels.$n.ark done set -e @@ -508,14 +506,12 @@ fi if [ $stage -le 14 ]; then $cmd JOB=1:$reco_nj $reco_vad_dir/log/get_speech_labels.JOB.log \ - segmentation-post-process --keep-label=1 scp:$reco_vad_dir/sad_seg.JOB.scp ark:- \| \ + segmentation-copy --keep-label=1 scp:$reco_vad_dir/sad_seg.JOB.scp ark:- \| \ segmentation-to-ali --lengths-rspecifier=ark,t:${whole_data_dir}/utt2num_frames \ - ark:- ark,t:- \| \ - steps/segmentation/convert_ali_to_vec.pl \| vector-to-feat ark:- ark:- \| copy-feats --compress \ - ark:- ark,scp:$reco_vad_dir/speech_feat.JOB.ark,$reco_vad_dir/speech_feat.JOB.scp + ark:- ark,scp:$reco_vad_dir/speech_labels.JOB.ark,$reco_vad_dir/speech_labels.JOB.scp for n in `seq $reco_nj`; do - cat $reco_vad_dir/speech_feat.$n.scp - done > $reco_vad_dir/speech_feat.scp + cat $reco_vad_dir/speech_labels.$n.scp + done > $reco_vad_dir/speech_labels.scp fi if [ $stage -le 15 ]; then From e52f0324d7e91311e5228bac3a30c41ac26797fc Mon Sep 17 00:00:00 2001 From: Vimal Manohar Date: Wed, 25 Jan 2017 15:01:28 -0500 Subject: [PATCH 148/213] asr_diarization: Bug fix in reverberate_data_dir.py --- egs/wsj/s5/steps/data/reverberate_data_dir.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/egs/wsj/s5/steps/data/reverberate_data_dir.py b/egs/wsj/s5/steps/data/reverberate_data_dir.py index 9a71126dde3..c9a4d918c91 100755 --- a/egs/wsj/s5/steps/data/reverberate_data_dir.py +++ b/egs/wsj/s5/steps/data/reverberate_data_dir.py @@ -8,6 +8,8 @@ import argparse, glob, math, os, random, sys, warnings, copy, imp, ast import data_dir_manipulation_lib as data_lib +sys.path.insert(0, 'steps') +import libs.common as common_lib def GetArgs(): # we add required arguments as named arguments for readability From b7fba13cb42f6d5f2957021977069dfb756b06c1 Mon Sep 17 00:00:00 2001 From: Vimal Manohar Date: Wed, 25 Jan 2017 15:02:01 -0500 Subject: [PATCH 149/213] asr_diarization: Updated compute_output.sh to compute from Am --- egs/wsj/s5/steps/nnet3/compute_output.sh | 36 ++++++++++++++---------- 1 file changed, 21 insertions(+), 15 deletions(-) diff --git a/egs/wsj/s5/steps/nnet3/compute_output.sh b/egs/wsj/s5/steps/nnet3/compute_output.sh index f49790bc578..4c32b5cb0ea 100755 --- a/egs/wsj/s5/steps/nnet3/compute_output.sh +++ b/egs/wsj/s5/steps/nnet3/compute_output.sh @@ -27,7 +27,7 @@ compress=false online_ivector_dir= post_vec= output_name= -get_raw_nnet_from_am=true +use_raw_nnet=true # End configuration section. echo "$0 $@" # Print the command line for logging @@ -54,11 +54,13 @@ data=$1 srcdir=$2 dir=$3 -if $get_raw_nnet_from_am; then +if ! $use_raw_nnet; then [ ! -f $srcdir/$iter.mdl ] && echo "$0: no such file $srcdir/$iter.mdl" && exit 1 - model="nnet3-am-copy --raw=true $srcdir/$iter.mdl - |" + prog=nnet3-am-compute + model="$srcdir/$iter.mdl" else [ ! -f $srcdir/$iter.raw ] && echo "$0: no such file $srcdir/$iter.raw" && exit 1 + prog=nnet3-compute model="nnet3-copy $srcdir/$iter.raw - |" fi @@ -142,18 +144,22 @@ if [ $frame_subsampling_factor -ne 1 ]; then frame_subsampling_opt="--frame-subsampling-factor=$frame_subsampling_factor" fi -output_wspecifier="ark:| copy-feats --compress=$compress ark:- ark:- | gzip -c > $dir/nnet_output.JOB.gz" - -if [ ! -z $post_vec ]; then - if [ $stage -le 1 ]; then - copy-vector --binary=false $post_vec - | \ - awk '{for (i = 2; i < NF; i++) { sum += i; }; - printf ("["); - for (i = 2; i < NF; i++) { printf " "log(i/sum); }; - print (" ]");}' > $dir/log_priors.vec +if ! $use_raw_nnet; then + output_wspecifier="ark:| copy-feats --compress=$compress ark:- ark:- | gzip -c > $dir/log_likes.JOB.gz" +else + output_wspecifier="ark:| copy-feats --compress=$compress ark:- ark:- | gzip -c > $dir/nnet_output.JOB.gz" + + if [ ! -z $post_vec ]; then + if [ $stage -le 1 ]; then + copy-vector --binary=false $post_vec - | \ + awk '{for (i = 2; i < NF; i++) { sum += i; }; + printf ("["); + for (i = 2; i < NF; i++) { printf " "log(i/sum); }; + print (" ]");}' > $dir/log_priors.vec + fi + + output_wspecifier="ark:| matrix-add-offset ark:- 'vector-scale --scale=-1.0 $dir/log_priors.vec - |' ark:- | copy-feats --compress=$compress ark:- ark:- | gzip -c > $dir/log_likes.JOB.gz" fi - - output_wspecifier="ark:| matrix-add-offset ark:- 'vector-scale --scale=-1.0 $dir/log_priors.vec - |' ark:- | copy-feats --compress=$compress ark:- ark:- | gzip -c > $dir/log_likes.JOB.gz" fi gpu_opt="--use-gpu=no" @@ -166,7 +172,7 @@ fi if [ $stage -le 2 ]; then $cmd $gpu_queue_opt JOB=1:$nj $dir/log/compute_output.JOB.log \ - nnet3-compute $gpu_opt $ivector_opts $frame_subsampling_opt \ + $prog $gpu_opt $ivector_opts $frame_subsampling_opt \ --frames-per-chunk=$frames_per_chunk \ --extra-left-context=$extra_left_context \ --extra-right-context=$extra_right_context \ From 84889b61d1fc74bf578886c823d30d8f346acd2c Mon Sep 17 00:00:00 2001 From: Vimal Manohar Date: Wed, 25 Jan 2017 15:02:25 -0500 Subject: [PATCH 150/213] asr_diarization: Bug fix in get_egs_multiple_targets --- .../s5/steps/nnet3/get_egs_multiple_targets.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/egs/wsj/s5/steps/nnet3/get_egs_multiple_targets.py b/egs/wsj/s5/steps/nnet3/get_egs_multiple_targets.py index 8e6f1442c7a..30449c81e81 100755 --- a/egs/wsj/s5/steps/nnet3/get_egs_multiple_targets.py +++ b/egs/wsj/s5/steps/nnet3/get_egs_multiple_targets.py @@ -675,7 +675,14 @@ def generate_training_examples_internal(dir, targets_parameters, feat_dir, fpi=samples_per_iter)) if dry_run: - cleanup(dir, archives_multiple) + if generate_egs_scp: + for i in range(1, num_archives_intermediate + 1): + for j in range(1, archives_multiple + 1): + archive_index = (i-1) * archives_multiple + j + common_lib.force_symlink( + "egs.{0}.ark".format(archive_index), + "{dir}/egs.{i}.{j}.ark".format(dir=dir, i=i, j=j)) + cleanup(dir, archives_multiple, generate_egs_scp) return {'num_frames': num_frames, 'num_archives': num_archives, 'egs_per_archive': egs_per_archive} @@ -763,7 +770,7 @@ def generate_training_examples_internal(dir, targets_parameters, feat_dir, for i in range(1, num_archives_intermediate + 1): for j in range(1, archives_multiple + 1): archive_index = (i-1) * archives_multiple + j - common_lib.force_sym_link( + common_lib.force_symlink( "egs.{0}.ark".format(archive_index), "{dir}/egs.{i}.{j}.ark".format(dir=dir, i=i, j=j)) @@ -785,20 +792,20 @@ def generate_training_examples_internal(dir, targets_parameters, feat_dir, print (line.strip(), file=out_egs_handle) out_egs_handle.close() - cleanup(dir, archives_multiple) + cleanup(dir, archives_multiple, generate_egs_scp) return {'num_frames': num_frames, 'num_archives': num_archives, 'egs_per_archive': egs_per_archive} -def cleanup(dir, archives_multiple): +def cleanup(dir, archives_multiple, generate_egs_scp=False): logger.info("Removing temporary archives in {0}.".format(dir)) for file_name in glob.glob("{0}/egs_orig*".format(dir)): real_path = os.path.realpath(file_name) data_lib.try_to_delete(real_path) data_lib.try_to_delete(file_name) - if archives_multiple > 1: + if archives_multiple > 1 and not generate_egs_scp: # there will be some extra soft links we want to delete for file_name in glob.glob('{0}/egs.*.*.ark'.format(dir)): os.remove(file_name) From 911d1d06b7aa230ffeb5f936502c37e4eaeb68e6 Mon Sep 17 00:00:00 2001 From: Vimal Manohar Date: Wed, 25 Jan 2017 15:02:50 -0500 Subject: [PATCH 151/213] asr_diariztion: Add compute-per-dim-accuracy --- egs/wsj/s5/steps/nnet3/train_raw_rnn.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/egs/wsj/s5/steps/nnet3/train_raw_rnn.py b/egs/wsj/s5/steps/nnet3/train_raw_rnn.py index 2bea66dbcbf..d43406e7f3e 100755 --- a/egs/wsj/s5/steps/nnet3/train_raw_rnn.py +++ b/egs/wsj/s5/steps/nnet3/train_raw_rnn.py @@ -468,7 +468,8 @@ def learning_rate(iter, current_num_jobs, num_archives_processed): background_process_handler=background_process_handler, extra_egs_copy_cmd=args.extra_egs_copy_cmd, use_multitask_egs=args.use_multitask_egs, - rename_multitask_outputs=args.rename_multitask_outputs) + rename_multitask_outputs=args.rename_multitask_outputs, + compute_per_dim_accuracy=args.compute_per_dim_accuracy) if args.cleanup: # do a clean up everythin but the last 2 models, under certain @@ -493,6 +494,9 @@ def learning_rate(iter, current_num_jobs, num_archives_processed): if args.stage <= num_iters: logger.info("Doing final combination to produce final.raw") + common_lib.run_kaldi_command( + "cp {dir}/{num_iters}.raw {dir}/pre_combine.raw" + "".format(dir=args.dir, num_iters=num_iters)) train_lib.common.combine_models( dir=args.dir, num_iters=num_iters, models_to_combine=models_to_combine, egs_dir=egs_dir, @@ -500,7 +504,8 @@ def learning_rate(iter, current_num_jobs, num_archives_processed): run_opts=run_opts, chunk_width=args.chunk_width, background_process_handler=background_process_handler, get_raw_nnet_from_am=False, - extra_egs_copy_cmd=args.extra_egs_copy_cmd) + extra_egs_copy_cmd=args.extra_egs_copy_cmd, + compute_per_dim_accuracy=args.compute_per_dim_accuracy) if include_log_softmax and args.stage <= num_iters + 1: logger.info("Getting average posterior for purposes of " From abd45fe1e28f59745cfb64ddff61c593e3f32e08 Mon Sep 17 00:00:00 2001 From: Vimal Manohar Date: Wed, 25 Jan 2017 15:04:30 -0500 Subject: [PATCH 152/213] asr_diarization: Update some segmentation scripts --- egs/wsj/s5/steps/segmentation/decode_sad.sh | 18 +++-- .../segmentation/decode_sad_to_segments.sh | 25 ++++-- .../internal/post_process_segments.sh | 8 +- .../segmentation/internal/prepare_sad_lang.py | 79 +++++++++++++------ 4 files changed, 91 insertions(+), 39 deletions(-) diff --git a/egs/wsj/s5/steps/segmentation/decode_sad.sh b/egs/wsj/s5/steps/segmentation/decode_sad.sh index 9758d36e24e..2f2e5ae2586 100755 --- a/egs/wsj/s5/steps/segmentation/decode_sad.sh +++ b/egs/wsj/s5/steps/segmentation/decode_sad.sh @@ -7,6 +7,8 @@ cmd=run.pl acwt=0.1 beam=8 max_active=1000 +get_pdfs=false +iter=final . path.sh @@ -22,21 +24,27 @@ graph_dir=$1 log_likes_dir=$2 dir=$3 +mkdir -p $dir nj=`cat $log_likes_dir/num_jobs` echo $nj > $dir/num_jobs -for f in $dir/trans.mdl $log_likes_dir/log_likes.1.gz $graph_dir/HCLG.fst; do +for f in $graph_dir/$iter.mdl $log_likes_dir/log_likes.1.gz $graph_dir/HCLG.fst; do if [ ! -f $f ]; then echo "$0: Could not find file $f" + exit 1 fi done decoder_opts+=(--acoustic-scale=$acwt --beam=$beam --max-active=$max_active) +ali="ark:| ali-to-phones --per-frame $graph_dir/$iter.mdl ark:- ark:- | gzip -c > $dir/ali.JOB.gz" + +if $get_pdfs; then + ali="ark:| ali-to-pdf $graph_dir/$iter.mdl ark:- ark:- | gzip -c > $dir/ali.JOB.gz" +fi + $cmd JOB=1:$nj $dir/log/decode.JOB.log \ decode-faster-mapped ${decoder_opts[@]} \ - $dir/trans.mdl \ + $graph_dir/$iter.mdl \ $graph_dir/HCLG.fst "ark:gunzip -c $log_likes_dir/log_likes.JOB.gz |" \ - ark:/dev/null ark:- \| \ - ali-to-phones --per-frame $dir/trans.mdl ark:- \ - "ark:|gzip -c > $dir/ali.JOB.gz" + ark:/dev/null "$ali" diff --git a/egs/wsj/s5/steps/segmentation/decode_sad_to_segments.sh b/egs/wsj/s5/steps/segmentation/decode_sad_to_segments.sh index de8ab0d90e8..84287230fba 100755 --- a/egs/wsj/s5/steps/segmentation/decode_sad_to_segments.sh +++ b/egs/wsj/s5/steps/segmentation/decode_sad_to_segments.sh @@ -16,6 +16,7 @@ nonsil_transition_probability=0.1 sil_transition_probability=0.1 sil_prior=0.5 speech_prior=0.5 +use_unigram_lm=true # Decoding options acwt=1 @@ -59,14 +60,25 @@ if [ $stage -le 2 ]; then fi if [ $stage -le 3 ]; then - cat > $lang/word2prior < $lang/word2prior < $lang/G.fst + steps/segmentation/internal/make_G_fst.py --word2prior-map $lang/word2prior | \ + fstcompile --isymbols=$lang/words.txt --osymbols=$lang/words.txt \ + --keep_isymbols=false --keep_osymbols=false \ + > $lang/G.fst + else + { + echo "1 0.99 1:0.6 2:0.39"; + echo "2 0.01 1:0.5 2:0.49"; + } | \ + steps/segmentation/internal/make_bigram_G_fst.py - - | \ + fstcompile --isymbols=$lang/words.txt --osymbols=$lang/words.txt \ + --keep_isymbols=false --keep_osymbols=false \ + > $lang/G.fst + fi fi graph_dir=$dir/graph_test_${t} @@ -75,11 +87,12 @@ if [ $stage -le 4 ]; then $cmd $dir/log/make_vad_graph.log \ steps/segmentation/internal/make_sad_graph.sh --iter trans \ $lang $dir $dir/graph_test_${t} || exit 1 + cp $dir/trans.mdl $graph_dir fi if [ $stage -le 5 ]; then steps/segmentation/decode_sad.sh \ - --acwt $acwt --beam $beam --max-active $max_active \ + --acwt $acwt --beam $beam --max-active $max_active --iter trans \ $graph_dir $sad_likes_dir $dir fi diff --git a/egs/wsj/s5/steps/segmentation/internal/post_process_segments.sh b/egs/wsj/s5/steps/segmentation/internal/post_process_segments.sh index e37d5dc2f62..31f0d09f351 100755 --- a/egs/wsj/s5/steps/segmentation/internal/post_process_segments.sh +++ b/egs/wsj/s5/steps/segmentation/internal/post_process_segments.sh @@ -26,8 +26,10 @@ max_segment_length=1000 # Segments that are longer than this are split into overlap_length=100 # Overlapping frames when segments are split. # See the above option. min_silence_length=30 # Min silence length at which to split very long segments +min_segment_length=20 frame_shift=0.01 +frame_overlap=0.016 . utils/parse_options.sh @@ -44,7 +46,7 @@ data_dir=$1 dir=$2 segmented_data_dir=$3 -for f in $dir/orig_segmentation.1.gz $data_dir/segments; do +for f in $dir/orig_segmentation.1.gz; do if [ ! -f $f ]; then echo "$0: Could not find $f" exit 1 @@ -80,9 +82,11 @@ if [ $stage -le 2 ]; then segmentation-post-process ${post_pad_length:+--pad-label=1 --pad-length=$post_pad_length} ark:- ark:- \| \ segmentation-split-segments --alignments="ark,s,cs:gunzip -c $dir/orig_segmentation.JOB.gz | segmentation-to-ali ark:- ark:- |" \ --max-segment-length=$max_segment_length --min-alignment-chunk-length=$min_silence_length --ali-label=0 ark:- ark:- \| \ + segmentation-post-process --remove-labels=1 --max-remove-length=$min_segment_length ark:- ark:- \| \ segmentation-split-segments \ --max-segment-length=$max_segment_length --overlap-length=$overlap_length ark:- ark:- \| \ - segmentation-to-segments --frame-shift=$frame_shift ark:- \ + segmentation-to-segments --frame-shift=$frame_shift \ + --frame-overlap=$frame_overlap ark:- \ ark,t:$dir/utt2spk.JOB $dir/segments.JOB || exit 1 fi diff --git a/egs/wsj/s5/steps/segmentation/internal/prepare_sad_lang.py b/egs/wsj/s5/steps/segmentation/internal/prepare_sad_lang.py index 17b039015d2..b539286a85b 100755 --- a/egs/wsj/s5/steps/segmentation/internal/prepare_sad_lang.py +++ b/egs/wsj/s5/steps/segmentation/internal/prepare_sad_lang.py @@ -1,7 +1,12 @@ #! /usr/bin/env python from __future__ import print_function -import argparse, shlex +import argparse +import sys +import shlex + +sys.path.insert(0, 'steps') +import libs.common as common_lib def GetArgs(): parser = argparse.ArgumentParser(description="""This script generates a lang @@ -9,13 +14,13 @@ def GetArgs(): the corresponding min durations and end transition probability.""") parser.add_argument("--phone-transition-parameters", dest='phone_transition_para_array', - type=str, action='append', required = True, - help = "Options to build topology. \n" + type=str, action='append', required=True, + help="Options to build topology. \n" "--phone-list= # Colon-separated list of phones\n" "--min-duration= # Min duration for the phones\n" "--end-transition-probability= # Probability of the end transition after the minimum duration\n") parser.add_argument("dir", type=str, - help = "Output lang directory") + help="Output lang directory") args = parser.parse_args() return args @@ -47,7 +52,8 @@ def ParsePhoneTransitionParameters(para_array): return phone_transition_parameters -def GetPhoneMap(phone_transition_parameters): + +def get_phone_map(phone_transition_parameters): phone2int = {} n = 1 for t in phone_transition_parameters: @@ -59,36 +65,57 @@ def GetPhoneMap(phone_transition_parameters): return phone2int -def Main(): + +def print_duration_constraint_states(min_duration, topo): + for state in range(0, min_duration - 1): + print(" {state} 0" + " {dest_state} 1.0 ".format( + state=state, dest_state=state + 1), + file=topo) + + +def print_topology(phone_transition_parameters, phone2int, args, topo): + for t in phone_transition_parameters: + print ("", file=topo) + print ("", file=topo) + print ("{0}".format(" ".join([str(phone2int[p]) + for p in t.phone_list])), file=topo) + print ("", file=topo) + + print_duration_constraint_states(t.min_duration, topo) + + print(" {state} 0 " + " {state} {self_prob} " + " {next_state} {next_prob} ".format( + state=t.min_duration - 1, next_state=t.min_duration, + self_prob=1 - t.end_transition_probability, + next_prob=t.end_transition_probability), file=topo) + + print(" {state} ".format(state=t.min_duration), + file=topo) # Final state + print ("", file=topo) + + +def main(): args = GetArgs() phone_transition_parameters = ParsePhoneTransitionParameters(args.phone_transition_para_array) - phone2int = GetPhoneMap(phone_transition_parameters) + phone2int = get_phone_map(phone_transition_parameters) topo = open("{0}/topo".format(args.dir), 'w') - print ("", file = topo) + print ("", file=topo) - for t in phone_transition_parameters: - print ("", file = topo) - print ("", file = topo) - print ("{0}".format(" ".join([str(phone2int[p]) for p in t.phone_list])), file = topo) - print ("", file = topo) - - for state in range(0, t.min_duration-1): - print(" {0} 0 {1} 1.0 ".format(state, state + 1), file = topo) - print(" {state} 0 {state} {self_prob} {next_state} {next_prob} ".format( - state = t.min_duration - 1, next_state = t.min_duration, - self_prob = 1 - t.end_transition_probability, - next_prob = t.end_transition_probability), file = topo) - print(" {state} ".format(state = t.min_duration), file = topo) # Final state - print ("", file = topo) - print ("", file = topo) + print_topology(phone_transition_parameters, phone2int, args, topo) + + print ("", file=topo) phones_file = open("{0}/phones.txt".format(args.dir), 'w') - for p,n in sorted(list(phone2int.items()), key = lambda x:x[1]): - print ("{0} {1}".format(p, n), file = phones_file) + print (" 0", file=phones_file) + + for p,n in sorted(list(phone2int.items()), key=lambda x:x[1]): + print ("{0} {1}".format(p, n), file=phones_file) if __name__ == '__main__': - Main() + main() From 0cd44c87162eaaf5632801ce01f16085e442f80e Mon Sep 17 00:00:00 2001 From: Vimal Manohar Date: Wed, 25 Jan 2017 15:05:00 -0500 Subject: [PATCH 153/213] asr_diarization: SimpleHmm version of segmentation --- .../do_segmentation_data_dir_simple.sh | 239 ++++++++++++++++++ .../internal/prepare_simple_hmm_lang.py | 202 +++++++++++++++ 2 files changed, 441 insertions(+) create mode 100755 egs/wsj/s5/steps/segmentation/do_segmentation_data_dir_simple.sh create mode 100755 egs/wsj/s5/steps/segmentation/internal/prepare_simple_hmm_lang.py diff --git a/egs/wsj/s5/steps/segmentation/do_segmentation_data_dir_simple.sh b/egs/wsj/s5/steps/segmentation/do_segmentation_data_dir_simple.sh new file mode 100755 index 00000000000..0da130ee3ab --- /dev/null +++ b/egs/wsj/s5/steps/segmentation/do_segmentation_data_dir_simple.sh @@ -0,0 +1,239 @@ +#!/bin/bash + +set -e +set -o pipefail +set -u + +. path.sh +. cmd.sh + +affix= # Affix for the segmentation +nj=32 # works on recordings as against on speakers + +# Feature options (Must match training) +mfcc_config=conf/mfcc_hires_bp.conf +feat_affix=bp # Affix for the type of feature used + +skip_output_computation=false + +stage=-1 +sad_stage=-1 +output_name=output-speech # The output node in the network +sad_name=sad # Base name for the directory storing the computed loglikes +segmentation_name=segmentation # Base name for the directory doing segmentation + +# SAD network config +iter=final # Model iteration to use + +# Contexts must ideally match training for LSTM models, but +# may not necessarily for stats components +extra_left_context=0 # Set to some large value, typically 40 for LSTM (must match training) +extra_right_context=0 + +frame_subsampling_factor=1 # Subsampling at the output + +transition_scale=10.0 +loopscale=1.0 + +# Set to true if the test data has > 8kHz sampling frequency. +do_downsampling=false + +# Segmentation configs +segmentation_config=conf/segmentation_speech.conf +convert_data_dir_to_whole=true + +echo $* + +. utils/parse_options.sh + +if [ $# -ne 5 ]; then + echo "Usage: $0 " + echo " e.g.: $0 ~/workspace/egs/ami/s5b/data/sdm1/dev exp/nnet3_sad_snr/nnet_tdnn_j_n4 mfcc_hires_bp data/ami_sdm1_dev" + exit 1 +fi + +src_data_dir=$1 # The input data directory that needs to be segmented. + # Any segments in that will be ignored. +sad_nnet_dir=$2 # The SAD neural network +lang=$3 +mfcc_dir=$4 # The directory to store the features +data_dir=$5 # The output data directory will be ${data_dir}_seg + +affix=${affix:+_$affix} +feat_affix=${feat_affix:+_$feat_affix} + +data_id=`basename $data_dir` +sad_dir=${sad_nnet_dir}/${sad_name}${affix}_${data_id}_whole${feat_affix} +seg_dir=${sad_nnet_dir}/${segmentation_name}${affix}_${data_id}_whole${feat_affix} + +export PATH="$KALDI_ROOT/tools/sph2pipe_v2.5/:$PATH" +[ ! -z `which sph2pipe` ] + +test_data_dir=data/${data_id}${feat_affix}_hires + +if $convert_data_dir_to_whole; then + if [ $stage -le 0 ]; then + whole_data_dir=${sad_dir}/${data_id}_whole + utils/data/convert_data_dir_to_whole.sh $src_data_dir ${whole_data_dir} + + if $do_downsampling; then + freq=`cat $mfcc_config | perl -pe 's/\s*#.*//g' | grep "sample-frequency=" | awk -F'=' '{if (NF == 0) print 16000; else print $2}'` + utils/data/downsample_data_dir.sh $freq $whole_data_dir + fi + + utils/copy_data_dir.sh ${whole_data_dir} $test_data_dir + fi +else + if [ $stage -le 0 ]; then + utils/copy_data_dir.sh $src_data_dir $test_data_dir + + if $do_downsampling; then + freq=`cat $mfcc_config | perl -pe 's/\s*#.*//g' | grep "sample-frequency=" | awk -F'=' '{if (NF == 0) print 16000; else print $2}'` + utils/data/downsample_data_dir.sh $freq $test_data_dir + fi + fi +fi + +if [ $stage -le 1 ]; then + steps/make_mfcc.sh --mfcc-config $mfcc_config --nj $nj --cmd "$train_cmd" \ + ${test_data_dir} exp/make_hires/${data_id}${feat_affix} $mfcc_dir + steps/compute_cmvn_stats.sh ${test_data_dir} exp/make_hires/${data_id}${feat_affix} $mfcc_dir + utils/fix_data_dir.sh ${test_data_dir} +fi + +post_vec=$sad_nnet_dir/post_${output_name}.vec +if [ ! -f $sad_nnet_dir/post_${output_name}.vec ]; then + echo "$0: Could not find $sad_nnet_dir/post_${output_name}.vec. See the last stage of local/segmentation/run_train_sad.sh" + exit 1 +fi + +create_topo=true +if $create_topo; then + if [ ! -f $lang/classes_info.txt ]; then + echo "$0: Could not find $lang/topo or $lang/classes_info.txt" + exit 1 + else + steps/segmentation/internal/prepare_simple_hmm_lang.py \ + $lang/classes_info.txt $lang + fi +fi + +if [ $stage -le 3 ]; then + simple-hmm-init $lang/topo $lang/init.mdl + + $train_cmd $sad_nnet_dir/log/get_final_${output_name}_model.log \ + nnet3-am-init $lang/init.mdl \ + "nnet3-copy --edits='rename-node old-name=$output_name new-name=output' $sad_nnet_dir/$iter.raw - |" - \| \ + nnet3-am-adjust-priors - $sad_nnet_dir/post_${output_name}.vec \ + $sad_nnet_dir/${iter}_${output_name}.mdl +fi +iter=${iter}_${output_name} + +if [ $stage -le 4 ]; then + steps/nnet3/compute_output.sh --nj $nj --cmd "$train_cmd" \ + --iter $iter --use-raw-nnet false \ + --extra-left-context $extra_left_context \ + --extra-right-context $extra_right_context \ + --frames-per-chunk 150 \ + --stage $sad_stage \ + --frame-subsampling-factor $frame_subsampling_factor \ + ${test_data_dir} $sad_nnet_dir $sad_dir +fi + +graph_dir=${sad_nnet_dir}/graph_${output_name} + +if [ $stage -le 5 ]; then + cp -r $lang $graph_dir + + if [ ! -f $lang/final.mdl ]; then + echo "$0: Could not find $lang/final.mdl!" + echo "$0: Using $lang/init.mdl instead" + cp $lang/init.mdl $graph_dir/final.mdl + else + cp $lang/final.mdl $graph_dir + fi + + $train_cmd $lang/log/make_graph.log \ + make-simple-hmm-graph --transition-scale=$transition_scale \ + --self-loop-scale=$loopscale \ + $graph_dir/final.mdl \| \ + fstdeterminizestar --use-log=true \| \ + fstrmepslocal \| \ + fstminimizeencoded '>' $graph_dir/HCLG.fst +fi + +if [ $stage -le 6 ]; then + # 'final' here refers to $lang/final.mdl + steps/segmentation/decode_sad.sh --acwt 1.0 --cmd "$decode_cmd" \ + --iter final --get-pdfs true $graph_dir $sad_dir $seg_dir +fi + +if [ $stage -le 7 ]; then + steps/segmentation/post_process_sad_to_subsegments.sh \ + --cmd "$train_cmd" --segmentation-config $segmentation_config \ + --frame-subsampling-factor $frame_subsampling_factor \ + ${test_data_dir} $lang/phone2sad_map ${seg_dir} \ + ${seg_dir} ${data_dir}_seg + + cp $src_data_dir/wav.scp ${data_dir}_seg +fi + +exit 0 + +segments_opts="--single-speaker" + +if false; then + mkdir -p ${seg_dir}/post_process_${data_id} + echo $nj > ${seg_dir}/post_process_${data_id}/num_jobs + + $train_cmd JOB=1:$nj $seg_dir/log/convert_to_segments.JOB.log \ + segmentation-init-from-ali "ark:gunzip -c $seg_dir/ali.JOB.gz |" ark:- \| \ + segmentation-copy --label-map=$lang/phone2sad_map --frame-subsampling-factor=$frame_subsampling_factor ark:- ark:- \| \ + segmentation-to-segments --frame-overlap=0.02 $segments_opts ark:- \ + ark,t:${seg_dir}/post_process_${data_id}/utt2spk.JOB \ + ${seg_dir}/post_process_${data_id}/segments.JOB + + for n in `seq $nj`; do + cat ${seg_dir}/post_process_${data_id}/segments.$n + done > ${seg_dir}/post_process_${data_id}/segments + + for n in `seq $nj`; do + cat ${seg_dir}/post_process_${data_id}/utt2spk.$n + done > ${seg_dir}/post_process_${data_id}/utt2spk + + rm -r ${data_dir}_seg || true + mkdir -p ${data_dir}_seg + + utils/data/subsegment_data_dir.sh ${test_data_dir} \ + ${seg_dir}/post_process_${data_id}/segments ${data_dir}_seg + + cp ${src_data_dir}/wav.scp ${data_dir}_seg + cp ${seg_dir}/post_process_${data_id}/utt2spk ${data_dir}_seg + for f in stm glm reco2file_and_channel; do + [ -f $src_data_dir/$f ] && cp ${src_data_dir}/$f ${data_dir}_seg + done + + rm ${data_dir}/{cmvn.scp,spk2utt} || true + utils/fix_data_dir.sh ${data_dir}_seg +fi + +exit 0 + +# Subsegment data directory +if [ $stage -le 8 ]; then + utils/data/get_reco2num_frames.sh ${test_data_dir} + awk '{print $1" "$2}' ${data_dir}_seg/segments | \ + utils/apply_map.pl -f 2 ${test_data_dir}/reco2num_frames > \ + ${data_dir}_seg/utt2max_frames + + frame_shift_info=`cat $mfcc_config | steps/segmentation/get_frame_shift_info_from_config.pl` + utils/data/get_subsegment_feats.sh ${test_data_dir}/feats.scp \ + $frame_shift_info ${data_dir}_seg/segments | \ + utils/data/fix_subsegmented_feats.pl ${data_dir}_seg/utt2max_frames > \ + ${data_dir}_seg/feats.scp + steps/compute_cmvn_stats.sh --fake ${data_dir}_seg + + utils/fix_data_dir.sh ${data_dir}_seg +fi + + diff --git a/egs/wsj/s5/steps/segmentation/internal/prepare_simple_hmm_lang.py b/egs/wsj/s5/steps/segmentation/internal/prepare_simple_hmm_lang.py new file mode 100755 index 00000000000..eae0f142668 --- /dev/null +++ b/egs/wsj/s5/steps/segmentation/internal/prepare_simple_hmm_lang.py @@ -0,0 +1,202 @@ +#! /usr/bin/env python + +from __future__ import print_function +import argparse +import logging +import os +import sys + +sys.path.insert(0, 'steps') +import libs.common as common_lib + + +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) +handler = logging.StreamHandler() +handler.setLevel(logging.INFO) +formatter = logging.Formatter("%(asctime)s [%(filename)s:%(lineno)s - " + "%(funcName)s - %(levelname)s ] %(message)s") +handler.setFormatter(formatter) +logger.addHandler(handler) + + +def get_args(): + parser = argparse.ArgumentParser( + description="""This script generates a lang directory for decoding with + simple HMM model. + It needs as an input classes_info file with the + format: + , + where each pair is :. + destination-class -1 is used to represent final probabilitiy.""") + + parser.add_argument("classes_info", type=argparse.FileType('r'), + help="File with classes_info") + parser.add_argument("dir", type=str, + help="Output lang directory") + args = parser.parse_args() + return args + + +class ClassInfo(object): + def __init__(self, class_id): + self.class_id = class_id + self.start_state = -1 + self.num_states = 0 + self.initial_prob = 0 + self.self_loop_prob= 0 + self.transitions = {} + + def __str__(self): + return ("class-id={0},start-state={1},num-states={2}," + "initial-prob={3:.2f},transitions={4}".format( + self.class_id, self.start_state, self.num_states, + self.initial_prob, ' '.join( + ['{0}:{1}'.format(x,y) + for x,y in self.transitions.iteritems()]))) + + +def read_classes_info(file_handle): + classes_info = {} + + num_states = 1 + num_classes = 0 + + for line in file_handle.readlines(): + try: + parts = line.split() + class_id = int(parts[0]) + assert class_id > 0, class_id + if class_id in classes_info: + raise RuntimeError( + "Duplicate class-id {0} in file {1}".format( + class_id, file_handle.name)) + classes_info[class_id] = ClassInfo(class_id) + class_info = classes_info[class_id] + class_info.initial_prob = float(parts[1]) + class_info.self_loop_prob = float(parts[2]) + class_info.num_states = int(parts[3]) + class_info.start_state = num_states + num_states += class_info.num_states + num_classes += 1 + + if len(parts) > 4: + for part in parts[4:]: + dest_class, transition_prob = part.split(':') + dest_class = int(dest_class) + if dest_class in class_info.transitions: + logger.error( + "Duplicate transition to class-id {0}" + "in transitions".format(dest_class)) + raise RuntimeError + class_info.transitions[dest_class] = float(transition_prob) + else: + raise RuntimeError( + "No transitions out of class {0}".format(class_id)) + except Exception: + logger.error("Error processing line %s in file %s", + line, file_handle.name) + raise + + # Final state + classes_info[-1] = ClassInfo(-1) + class_info = classes_info[-1] + class_info.num_states = 1 + class_info.start_state = num_states + + for class_id, class_info in classes_info.iteritems(): + logger.info("For class %d, dot class-info %s", class_id, class_info) + + return classes_info, num_classes + + +def print_states_for_class(class_id, classes_info, topo): + class_info = classes_info[class_id] + + assert class_info.num_states > 1, class_info + + for state in range(class_info.start_state, + class_info.start_state + class_info.num_states - 1): + print(" {state} {pdf}" + " {dest_state} 1.0 ".format( + state=state, dest_state=state + 1, + pdf=class_info.class_id - 1), + file=topo) + + state = class_info.start_state + class_info.num_states - 1 + + transitions = [] + + transitions.append(" {next_state} {next_prob}".format( + next_state=state, next_prob=class_info.self_loop_prob)) + + for dest_class, prob in class_info.transitions.iteritems(): + try: + next_state = classes_info[dest_class].start_state + + transitions.append(" {next_state} {next_prob}".format( + next_state=next_state, next_prob=prob)) + except Exception: + logger.error("Failed to add transition (%d->%d).\n" + "classes_info = %s", class_id, dest_class, + class_info) + + print(" {state} {pdf} " + "{transitions} ".format( + state=state, pdf=class_id - 1, + transitions=' '.join(transitions)), file=topo) + + +def main(): + try: + args = get_args() + run(args) + except Exception: + logger.error("Failed preparing lang directory") + raise + + +def run(args): + if not os.path.exists(args.dir): + os.makedirs(args.dir) + + classes_info, num_classes = read_classes_info(args.classes_info) + + topo = open("{0}/topo".format(args.dir), 'w') + + print ("", file=topo) + print ("", file=topo) + print ("", file=topo) + print ("1", file=topo) + print ("", file=topo) + + # Print transitions from initial state (initial probs) + transitions = [] + for class_id in range(1, num_classes + 1): + class_info = classes_info[class_id] + transitions.append(" {next_state} {next_prob}".format( + next_state=class_info.start_state, + next_prob=class_info.initial_prob)) + print(" 0 {transitions} ".format( + transitions=' '.join(transitions)), file=topo) + + for class_id in range(1, num_classes + 1): + print_states_for_class(class_id, classes_info, topo) + + print(" {state} ".format( + state=classes_info[-1].start_state), file=topo) + + print ("", file=topo) + print ("", file=topo) + topo.close() + + with open('{0}/phones.txt'.format(args.dir), 'w') as phones_f: + for class_id in range(1, num_classes + 1): + print ("{0} {1}".format(class_id - 1, class_id), file=phones_f) + + common_lib.force_symlink('{0}/phones.txt'.format(args.dir), + '{0}/words.txt'.format(args.dir)) + + +if __name__ == '__main__': + main() From a4b823c7c145ceae70a026ebf3a226277397a087 Mon Sep 17 00:00:00 2001 From: Vimal Manohar Date: Wed, 25 Jan 2017 15:05:30 -0500 Subject: [PATCH 154/213] More segmentation script updated --- .../post_process_sad_to_subsegments.sh | 24 ++++++++++++++----- 1 file changed, 18 insertions(+), 6 deletions(-) diff --git a/egs/wsj/s5/steps/segmentation/post_process_sad_to_subsegments.sh b/egs/wsj/s5/steps/segmentation/post_process_sad_to_subsegments.sh index 0ca6b3dd126..d5ad48a492f 100755 --- a/egs/wsj/s5/steps/segmentation/post_process_sad_to_subsegments.sh +++ b/egs/wsj/s5/steps/segmentation/post_process_sad_to_subsegments.sh @@ -14,6 +14,7 @@ nj=18 frame_subsampling_factor=1 frame_shift=0.01 +frame_overlap=0.015 . utils/parse_options.sh @@ -56,21 +57,32 @@ if [ $stage -le 1 ]; then fi if [ $stage -le 2 ]; then + # --frame-overlap is set to 0 to not do any additional padding when writing + # segments. This padding will be done later by the option + # --segment-end-padding to utils/data/subsegment_data_dir.sh. steps/segmentation/internal/post_process_segments.sh \ --stage $stage --cmd "$cmd" \ --config $segmentation_config --frame-shift $frame_shift \ + --frame-overlap 0 \ $data_dir $dir $segmented_data_dir fi mv $segmented_data_dir/segments $segmented_data_dir/sub_segments -utils/data/subsegment_data_dir.sh $data_dir $segmented_data_dir/sub_segments $segmented_data_dir +utils/data/subsegment_data_dir.sh --segment-end-padding `perl -e "print $frame_overlap"` \ + $data_dir $segmented_data_dir/sub_segments $segmented_data_dir +utils/fix_data_dir.sh $segmented_data_dir -utils/data/get_reco2num_frames.sh ${data_dir} +utils/data/get_reco2num_frames.sh --nj $nj --cmd "$cmd" ${data_dir} mv $segmented_data_dir/feats.scp $segmented_data_dir/feats.scp.tmp -cat $segmented_data_dir/segments | utils/apply_map.pl -f 2 $data_dir/reco2num_frames > $segmetned_data_dir/utt2max_frames -cat $segmented_data_dir/feats.scp.tmp | utils/data/fix_subsegmented_feats.pl $dsegmented_data_dir/utt2max_frames > $segmented_data_dir/feats.scp - -utils/utt2spk_to_spk2utt.pl $segmented_data_dir/utt2spk > $segmented_data_dir/spk2utt || exit 1 +cat $segmented_data_dir/segments | awk '{print $1" "$2}' | \ + utils/apply_map.pl -f 2 $data_dir/reco2num_frames > \ + $segmented_data_dir/utt2max_frames +cat $segmented_data_dir/feats.scp.tmp | \ + utils/data/fix_subsegmented_feats.pl $segmented_data_dir/utt2max_frames > \ + $segmented_data_dir/feats.scp + +utils/utt2spk_to_spk2utt.pl $segmented_data_dir/utt2spk > \ + $segmented_data_dir/spk2utt || exit 1 utils/fix_data_dir.sh $segmented_data_dir if [ ! -s $segmented_data_dir/utt2spk ] || [ ! -s $segmented_data_dir/segments ]; then From dd51f1ce9ee2c6f5f0467a14e38b6320940fd1b5 Mon Sep 17 00:00:00 2001 From: Vimal Manohar Date: Wed, 25 Jan 2017 15:06:45 -0500 Subject: [PATCH 155/213] subsegment_data_dir fix --- egs/wsj/s5/utils/data/fix_subsegmented_feats.pl | 2 +- egs/wsj/s5/utils/data/get_subsegment_feats.sh | 1 - egs/wsj/s5/utils/data/subsegment_data_dir.sh | 1 + 3 files changed, 2 insertions(+), 2 deletions(-) delete mode 120000 egs/wsj/s5/utils/data/get_subsegment_feats.sh diff --git a/egs/wsj/s5/utils/data/fix_subsegmented_feats.pl b/egs/wsj/s5/utils/data/fix_subsegmented_feats.pl index bd8aeb8e409..b0cece46ca8 100755 --- a/egs/wsj/s5/utils/data/fix_subsegmented_feats.pl +++ b/egs/wsj/s5/utils/data/fix_subsegmented_feats.pl @@ -49,7 +49,7 @@ my @F = split(/ /, $before_range); my $utt = shift @F; - defined $utt2max_frames{$utt} or die "fix_subsegmented_feats.pl: Could not find key $utt in $utt2num_frames_file.\nError with line $line"; + defined $utt2max_frames{$utt} or die "fix_subsegmented_feats.pl: Could not find key $utt in $utt2max_frames_file.\nError with line $line"; if ($range !~ m/^(\d*):(\d*)([,]?.*)$/) { print STDERR "fix_subsegmented_feats.pl: could not make sense of input line $_"; diff --git a/egs/wsj/s5/utils/data/get_subsegment_feats.sh b/egs/wsj/s5/utils/data/get_subsegment_feats.sh deleted file mode 120000 index c1495ea63ff..00000000000 --- a/egs/wsj/s5/utils/data/get_subsegment_feats.sh +++ /dev/null @@ -1 +0,0 @@ -get_subsegmented_feats.sh \ No newline at end of file diff --git a/egs/wsj/s5/utils/data/subsegment_data_dir.sh b/egs/wsj/s5/utils/data/subsegment_data_dir.sh index b018d5ec94a..10a8a9cb264 100755 --- a/egs/wsj/s5/utils/data/subsegment_data_dir.sh +++ b/egs/wsj/s5/utils/data/subsegment_data_dir.sh @@ -202,6 +202,7 @@ utils/data/fix_data_dir.sh $dir validate_opts= [ ! -f $srcdir/feats.scp ] && validate_opts="$validate_opts --no-feats" [ ! -f $srcdir/wav.scp ] && validate_opts="$validate_opts --no-wav" +$no_text && validate_opts="$validate_opts --no-text" utils/data/validate_data_dir.sh $validate_opts $dir From 310f42e71973e1624f848dd7584e540ac5c33097 Mon Sep 17 00:00:00 2001 From: Vimal Manohar Date: Wed, 25 Jan 2017 15:07:15 -0500 Subject: [PATCH 156/213] asr_diarization: Update get_sad_map --- egs/wsj/s5/steps/segmentation/get_sad_map.py | 42 +++++--------------- 1 file changed, 9 insertions(+), 33 deletions(-) diff --git a/egs/wsj/s5/steps/segmentation/get_sad_map.py b/egs/wsj/s5/steps/segmentation/get_sad_map.py index 9160503c7ad..222e6c1a512 100755 --- a/egs/wsj/s5/steps/segmentation/get_sad_map.py +++ b/egs/wsj/s5/steps/segmentation/get_sad_map.py @@ -20,34 +20,10 @@ """ import argparse +import sys - -class StrToBoolAction(argparse.Action): - """ A custom action to convert bools from shell format i.e., true/false - to python format i.e., True/False """ - def __call__(self, parser, namespace, values, option_string=None): - try: - if values == "true": - setattr(namespace, self.dest, True) - elif values == "true": - setattr(namespace, self.dest, False) - else: - raise ValueError - except ValueError: - raise Exception("Unknown value {0} for --{1}".format(values, - self.dest)) - - -class NullstrToNoneAction(argparse.Action): - """ A custom action to convert empty strings passed by shell - to None in python. This is necessary as shell scripts print null - strings when a variable is not specified. We could use the more apt - None in python. """ - def __call__(self, parser, namespace, values, option_string=None): - if values.strip() == "": - setattr(namespace, self.dest, None) - else: - setattr(namespace, self.dest, values) +sys.path.insert(0, 'steps') +import libs.common as common_lib def get_args(): @@ -71,7 +47,7 @@ def get_args(): or noise phones to separate SAD labels. """) - parser.add_argument("--init-sad-map", type=str, action=NullstrToNoneAction, + parser.add_argument("--init-sad-map", type=str, action=common_lib.NullstrToNoneAction, help="""Initial SAD map that will be used to override the default mapping using phones/silence.txt and phones/nonsilence.txt. Does not need to specify labels @@ -82,24 +58,24 @@ def get_args(): noise_group = parser.add_mutually_exclusive_group() noise_group.add_argument("--noise-phones-file", type=str, - action=NullstrToNoneAction, + action=common_lib.NullstrToNoneAction, help="Map noise phones from file to label 2") noise_group.add_argument("--noise-phones-list", type=str, - action=NullstrToNoneAction, + action=common_lib.NullstrToNoneAction, help="A colon-separated list of noise phones to " "map to label 2") - parser.add_argument("--unk", type=str, action=NullstrToNoneAction, + parser.add_argument("--unk", type=str, action=common_lib.NullstrToNoneAction, help="""UNK phone, if provided will be mapped to label 3""") parser.add_argument("--map-noise-to-sil", type=str, - action=StrToBoolAction, + action=common_lib.StrToBoolAction, choices=["true", "false"], default=False, help="""Map noise phones to silence before writing the map. i.e. anything with label 2 is mapped to label 0.""") parser.add_argument("--map-unk-to-speech", type=str, - action=StrToBoolAction, + action=common_lib.StrToBoolAction, choices=["true", "false"], default=False, help="""Map UNK phone to speech before writing the map i.e. anything with label 3 is mapped to label 1.""") From c9a44e0332089365f4660dec77735b49f0c2a62f Mon Sep 17 00:00:00 2001 From: Vimal Manohar Date: Wed, 25 Jan 2017 15:07:35 -0500 Subject: [PATCH 157/213] asr_diarization: downsample_data_dir.sh perturb_data_dir_speed_random.sh --- egs/wsj/s5/utils/data/downsample_data_dir.sh | 34 +++++++++++++ .../data/perturb_data_dir_speed_random.sh | 51 +++++++++++++++++++ 2 files changed, 85 insertions(+) create mode 100755 egs/wsj/s5/utils/data/downsample_data_dir.sh create mode 100755 egs/wsj/s5/utils/data/perturb_data_dir_speed_random.sh diff --git a/egs/wsj/s5/utils/data/downsample_data_dir.sh b/egs/wsj/s5/utils/data/downsample_data_dir.sh new file mode 100755 index 00000000000..022af67d265 --- /dev/null +++ b/egs/wsj/s5/utils/data/downsample_data_dir.sh @@ -0,0 +1,34 @@ +#! /bin/bash + +# Copyright 2016 Vimal Manohar +# Apache 2.0. + +if [ $# -ne 2 ]; then + echo "Usage: $0 " + exit 1 +fi + +freq=$1 +dir=$2 + +sox=`which sox` || { echo "Could not find sox in PATH"; exit 1; } + +if [ -f $dir/feats.scp ]; then + mkdir -p $dir/.backup + mv $dir/feats.scp $dir/.backup/ + if [ -f $dir/cmvn.scp ]; then + mv $dir/cmvn.scp $dir/.backup/ + fi + echo "$0: feats.scp already exists. Moving it to $dir/.backup" +fi + +mv $dir/wav.scp $dir/wav.scp.tmp +cat $dir/wav.scp.tmp | python -c "import sys +for line in sys.stdin.readlines(): + splits = line.strip().split() + if splits[-1] == '|': + out_line = line.strip() + ' $sox -t wav - -r $freq -c 1 -b 16 -t wav - downsample |' + else: + out_line = 'cat {0} {1} | $sox -t wav - -r $freq -c 1 -b 16 -t wav - downsample |'.format(splits[0], ' '.join(splits[1:])) + print (out_line)" > ${dir}/wav.scp +rm $dir/wav.scp.tmp diff --git a/egs/wsj/s5/utils/data/perturb_data_dir_speed_random.sh b/egs/wsj/s5/utils/data/perturb_data_dir_speed_random.sh new file mode 100755 index 00000000000..d9d027b77a3 --- /dev/null +++ b/egs/wsj/s5/utils/data/perturb_data_dir_speed_random.sh @@ -0,0 +1,51 @@ +#!/bin/bash + +# Copyright 2017 Vimal Manohar + +# Apache 2.0 + +. utils/parse_options.sh + +if [ $# != 2 ]; then + echo "Usage: perturb_data_dir_speed_random.sh " + echo "Applies 3-way speed perturbation using factors of 0.9, 1.0 and 1.1 on random subsets." + echo "e.g.:" + echo " $0 data/train data/train_spr" + echo "Note: if /feats.scp already exists, this will refuse to run." + exit 1 +fi + +srcdir=$1 +destdir=$2 + +if [ ! -f $srcdir/wav.scp ]; then + echo "$0: expected $srcdir/wav.scp to exist" + exit 1 +fi + +if [ -f $destdir/feats.scp ]; then + echo "$0: $destdir/feats.scp already exists: refusing to run this (please delete $destdir/feats.scp if you want this to run)" + exit 1 +fi + +echo "$0: making sure the utt2dur file is present in ${srcdir}, because " +echo "... obtaining it after speed-perturbing would be very slow, and" +echo "... you might need it." +utils/data/get_utt2dur.sh ${srcdir} + +utils/split_data.sh --per-reco $srcdir 3 + +utils/data/perturb_data_dir_speed.sh 0.9 ${srcdir}/split3reco/1 ${destdir}_speed0.9 || exit 1 +utils/data/perturb_data_dir_speed.sh 1.1 ${srcdir}/split3reco/3 ${destdir}_speed1.1 || exit 1 +utils/data/combine_data.sh $destdir ${srcdir}/split3reco/2 ${destdir}_speed0.9 ${destdir}_speed1.1 || exit 1 + +rm -r ${destdir}_speed0.9 ${destdir}_speed1.1 + +echo "$0: generated 3-way speed-perturbed version of random subsets of data in $srcdir, in $destdir" +if [ -f $srcdir/text ]; then + utils/validate_data_dir.sh --no-feats $destdir +else + utils/validate_data_dir.sh --no-feats --no-text $destdir +fi + + From 0e276b309dde3efb1c9a2657948ba4dde294e77d Mon Sep 17 00:00:00 2001 From: Vimal Manohar Date: Wed, 25 Jan 2017 15:08:02 -0500 Subject: [PATCH 158/213] asr_diarization: normalize_data_range.pl --- egs/wsj/s5/utils/data/normalize_data_range.pl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/egs/wsj/s5/utils/data/normalize_data_range.pl b/egs/wsj/s5/utils/data/normalize_data_range.pl index a7a144fd82e..61ccfd593f7 100755 --- a/egs/wsj/s5/utils/data/normalize_data_range.pl +++ b/egs/wsj/s5/utils/data/normalize_data_range.pl @@ -51,7 +51,7 @@ sub combine_ranges { if ($start1 + $end2 > $end1) { chop $line; print STDERR ("normalize_data_range.pl: could not make sense of line $line " . - "[second $row_or_column range too large vs first range, $start1 + $end2 > $end1]\n"); + "[second $row_or_column range too large vs first range, $start1 + $end2 > $end1]; adjusting end.\n"); } return ($start2+$start1, $end2+$start1); } @@ -75,7 +75,7 @@ sub combine_ranges { "if concat-feats was in the input data\n"; exit(1); } - print STDERR "matched: $before_range $first_range $second_range\n"; + # print STDERR "matched: $before_range $first_range $second_range\n"; if ($first_range !~ m/^((\d*):(\d*)|)(,(\d*):(\d*)|)$/) { print STDERR "normalize_data_range.pl: could not make sense of input line $_"; exit(1); From 4637f02ad5733f572879cb12c252e4b166dd276c Mon Sep 17 00:00:00 2001 From: Vimal Manohar Date: Wed, 25 Jan 2017 15:08:23 -0500 Subject: [PATCH 159/213] asr_diarization: Add reco2utt to split_data.sh --- egs/wsj/s5/utils/split_data.sh | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/egs/wsj/s5/utils/split_data.sh b/egs/wsj/s5/utils/split_data.sh index 646830481db..f90ce9e6759 100755 --- a/egs/wsj/s5/utils/split_data.sh +++ b/egs/wsj/s5/utils/split_data.sh @@ -27,14 +27,17 @@ elif [ "$1" == "--per-reco" ]; then fi if [ $# != 2 ]; then - echo "Usage: $0 [--per-utt] " + echo "Usage: $0 [--per-utt|--per-reco] " echo "E.g.: $0 data/train 50" echo "It creates its output in e.g. data/train/split50/{1,2,3,...50}, or if the " echo "--per-utt option was given, in e.g. data/train/split50utt/{1,2,3,...50}." + echo "If the --per-reco option was given, in e.g. data/train/split50reco/{1,2,3,...50}." echo "" echo "This script will not split the data-dir if it detects that the output is newer than the input." echo "By default it splits per speaker (so each speaker is in only one split dir)," echo "but with the --per-utt option it will ignore the speaker information while splitting." + echo "But if --per-reco option is given, it splits per recording " + echo "(so each recording is in only one split dir)" exit 1 fi @@ -133,7 +136,7 @@ if [ ! -f $data/segments ]; then fi # split some things that are indexed by utterance. -for f in feats.scp text vad.scp utt2lang $maybe_wav_scp; do +for f in feats.scp text vad.scp utt2lang $maybe_wav_scp utt2dur utt2num_frames; do if [ -f $data/$f ]; then utils/filter_scps.pl JOB=1:$numsplit \ $data/split${numsplit}${utt}/JOB/utt2spk $data/$f $data/split${numsplit}${utt}/JOB/$f || exit 1; @@ -168,6 +171,12 @@ if [ -f $data/segments ]; then $data/split${numsplit}${utt}/JOB/tmp.reco $data/wav.scp \ $data/split${numsplit}${utt}/JOB/wav.scp || exit 1 fi + if [ -f $data/reco2utt ]; then + utils/filter_scps.pl JOB=1:$numsplit \ + $data/split${numsplit}${utt}/JOB/tmp.reco $data/reco2utt \ + $data/split${numsplit}${utt}/JOB/reco2utt || exit 1 + fi + for f in $data/split${numsplit}${utt}/*/tmp.reco; do rm $f; done fi From bf1647b3bb0194b3a770074621b1fc14d0898bb2 Mon Sep 17 00:00:00 2001 From: Vimal Manohar Date: Wed, 25 Jan 2017 15:08:55 -0500 Subject: [PATCH 160/213] asr_diarization: Possibly deprecated update to do_segmentation_data_dir.sh --- egs/wsj/s5/steps/segmentation/do_segmentation_data_dir.sh | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/egs/wsj/s5/steps/segmentation/do_segmentation_data_dir.sh b/egs/wsj/s5/steps/segmentation/do_segmentation_data_dir.sh index c1e690af366..9e95cca9cc0 100755 --- a/egs/wsj/s5/steps/segmentation/do_segmentation_data_dir.sh +++ b/egs/wsj/s5/steps/segmentation/do_segmentation_data_dir.sh @@ -107,11 +107,12 @@ if [ $stage -le 2 ]; then --frames-per-chunk 150 \ --stage $sad_stage --output-name $output_name \ --frame-subsampling-factor $frame_subsampling_factor \ - --get-raw-nnet-from-am false ${test_data_dir} $sad_nnet_dir $sad_dir + --use-raw-nnet true ${test_data_dir} $sad_nnet_dir $sad_dir fi if [ $stage -le 3 ]; then steps/segmentation/decode_sad_to_segments.sh \ + --use-unigram-lm false \ --frame-subsampling-factor $frame_subsampling_factor \ --min-silence-duration $min_silence_duration \ --min-speech-duration $min_speech_duration \ From b63787afef3b8b4c4fbd5621164e5d54df7a8ae9 Mon Sep 17 00:00:00 2001 From: Vimal Manohar Date: Wed, 25 Jan 2017 15:09:34 -0500 Subject: [PATCH 161/213] asr_diarization: Minor logging to nnet3-copy-egs --- src/nnet3bin/nnet3-copy-egs.cc | 1 + 1 file changed, 1 insertion(+) diff --git a/src/nnet3bin/nnet3-copy-egs.cc b/src/nnet3bin/nnet3-copy-egs.cc index 5189ee4046f..13d9a0d6a15 100644 --- a/src/nnet3bin/nnet3-copy-egs.cc +++ b/src/nnet3bin/nnet3-copy-egs.cc @@ -429,6 +429,7 @@ int main(int argc, char *argv[]) { // count is normally 1; could be 0, or possibly >1. int32 count = GetCount(keep_proportion); std::string key = example_reader.Key(); + KALDI_VLOG(2) << "Copying eg " << key; NnetExample eg(example_reader.Value()); if (!keep_outputs_str.empty()) { From ea5004233def84a1f0a6175fc9aea5e25b6dbb9a Mon Sep 17 00:00:00 2001 From: Vimal Manohar Date: Wed, 25 Jan 2017 15:10:36 -0500 Subject: [PATCH 162/213] asr_diarization: Partial update to aspire segmentation --- .../nnet3/prep_test_aspire_segmentation.sh | 78 +------------------ 1 file changed, 3 insertions(+), 75 deletions(-) diff --git a/egs/aspire/s5/local/nnet3/prep_test_aspire_segmentation.sh b/egs/aspire/s5/local/nnet3/prep_test_aspire_segmentation.sh index e7f70c0c07f..266781fc84d 100755 --- a/egs/aspire/s5/local/nnet3/prep_test_aspire_segmentation.sh +++ b/egs/aspire/s5/local/nnet3/prep_test_aspire_segmentation.sh @@ -78,13 +78,13 @@ if [ $stage -le 1 ]; then --mfcc-config conf/mfcc_hires_bp.conf --feat-affix bp --iter $sad_iter \ --do-downsampling false --extra-left-context 100 --extra-right-context 20 \ --output-name output-speech --frame-subsampling-factor 6 \ - data/${data_set} $sad_nnet_dir mfcc_hires_bp data/${data_set} + data/${data_set} $sad_nnet_dir mfcc_hires_bp data/${data_set}${affix} # Output will be in data/${data_set}_seg fi # uniform segmentation script would have created this dataset # so update that script if you plan to change this variable -segmented_data_set=${data_set}_seg +segmented_data_set=${data_set}${affix}_seg if [ $stage -le 2 ]; then mfccdir=mfcc_reverb @@ -103,79 +103,7 @@ if [ $stage -le 2 ]; then utils/validate_data_dir.sh --no-text data/${segmented_data_set}_hires fi -decode_dir=$dir/decode_${segmented_data_set}${affix}_pp -false && { -if [ $stage -le 2 ]; then - echo "Extracting i-vectors, stage 1" - steps/online/nnet2/extract_ivectors_online.sh --cmd "$train_cmd" --nj 20 \ - --max-count $max_count \ - data/${segmented_data_set}_hires $ivector_dir/extractor \ - $ivector_dir/ivectors_${segmented_data_set}${ivector_affix}_stage1; - # float comparisons are hard in bash - if [ `bc <<< "$ivector_scale != 1"` -eq 1 ]; then - ivector_scale_affix=_scale$ivector_scale - else - ivector_scale_affix= - fi - - if [ ! -z "$ivector_scale_affix" ]; then - echo "$0: Scaling iVectors, stage 1" - srcdir=$ivector_dir/ivectors_${segmented_data_set}${ivector_affix}_stage1 - outdir=$ivector_dir/ivectors_${segmented_data_set}${ivector_affix}${ivector_scale_affix}_stage1 - mkdir -p $outdir - copy-matrix --scale=$ivector_scale scp:$srcdir/ivector_online.scp ark:- | \ - copy-feats --compress=true ark:- ark,scp:$outdir/ivector_online.ark,$outdir/ivector_online.scp; - cp $srcdir/ivector_period $outdir/ivector_period - fi -fi - -# generate the lattices -if [ $stage -le 3 ]; then - echo "Generating lattices, stage 1" - steps/nnet3/decode.sh --nj $decode_num_jobs --cmd "$decode_cmd" --config conf/decode.config \ - --acwt $acwt --post-decode-acwt $post_decode_acwt \ - --extra-left-context $extra_left_context \ - --extra-right-context $extra_right_context \ - --frames-per-chunk "$frames_per_chunk" \ - --online-ivector-dir $ivector_dir/ivectors_${segmented_data_set}${ivector_affix}${ivector_scale_affix}_stage1 \ - --skip-scoring true --iter $iter \ - $graph data/${segmented_data_set}_hires ${decode_dir}_stage1; -fi - -if [ $stage -le 4 ]; then - if $filter_ctm; then - if [ ! -z $weights_file ]; then - echo "$0: Using provided vad weights file $weights_file" - ivector_extractor_input=$weights_file - else - echo "$0 : Generating vad weights file" - ivector_extractor_input=${decode_dir}_stage1/weights${affix}.gz - local/extract_vad_weights.sh --cmd "$decode_cmd" --iter $iter \ - data/${segmented_data_set}_hires $lang \ - ${decode_dir}_stage1 $ivector_extractor_input - fi - else - # just use all the frames - ivector_extractor_input=${decode_dir}_stage1 - fi -fi - -if [ $stage -le 5 ]; then - echo "Extracting i-vectors, stage 2 with input $ivector_extractor_input" - # this does offline decoding, except we estimate the iVectors per - # speaker, excluding silence (based on alignments from a DNN decoding), with a - # different script. This is just to demonstrate that script. - # the --sub-speaker-frames is optional; if provided, it will divide each speaker - # up into "sub-speakers" of at least that many frames... can be useful if - # acoustic conditions drift over time within the speaker's data. - steps/online/nnet2/extract_ivectors.sh --cmd "$train_cmd" --nj 20 \ - --silence-weight $silence_weight \ - --sub-speaker-frames $sub_speaker_frames --max-count $max_count \ - data/${segmented_data_set}_hires $lang $ivector_dir/extractor \ - $ivector_extractor_input $ivector_dir/ivectors_${segmented_data_set}${ivector_affix}; -fi -} - +decode_dir=$dir/decode_${segmented_data_set}_pp if [ $stage -le 5 ]; then echo "Extracting i-vectors, stage 2" # this does offline decoding, except we estimate the iVectors per From 6a0fca91e183c5ae0b1365a3374cd5376ff64430 Mon Sep 17 00:00:00 2001 From: Vimal Manohar Date: Wed, 25 Jan 2017 15:11:09 -0500 Subject: [PATCH 163/213] asr_diarization: Update overlapping speech detection in ami --- .../s5/local/segmentation/prepare_ami.sh | 136 ++++++++++-------- .../prepare_babel_data_overlapped_speech.sh | 10 +- ...are_unsad_overlapped_speech_data_simple.sh | 5 +- 3 files changed, 81 insertions(+), 70 deletions(-) diff --git a/egs/aspire/s5/local/segmentation/prepare_ami.sh b/egs/aspire/s5/local/segmentation/prepare_ami.sh index 38ed9559c89..7147a3004cb 100755 --- a/egs/aspire/s5/local/segmentation/prepare_ami.sh +++ b/egs/aspire/s5/local/segmentation/prepare_ami.sh @@ -112,87 +112,97 @@ if [ $stage -le 6 ]; then --cmd queue.pl --nj $nj \ $src_dir/data/sdm1/${dataset} - # Get a filter that selects only regions within the manual segments. - $train_cmd $dir/log/get_manual_segments_regions.log \ - segmentation-init-from-segments --shift-to-zero=false $src_dir/data/sdm1/${dataset}/segments ark:- \| \ - segmentation-combine-segments-to-recordings ark:- ark,t:$src_dir/data/sdm1/${dataset}/reco2utt ark:- \| \ - segmentation-create-subsegments --filter-label=1 --subsegment-label=1 \ - "ark:segmentation-init-from-lengths --label=0 ark,t:$src_dir/data/sdm1/${dataset}/reco2num_frames ark:- |" ark:- ark,t:- \| \ - perl -ane '$F[3] = 10000; $F[$#F-1] = 10000; print join(" ", @F) . "\n";' \| \ - segmentation-create-subsegments --filter-label=10000 --subsegment-label=10000 \ - ark,t:- "ark:gunzip -c $dir/ref_spk_seg.gz |" ark:- \| \ - segmentation-post-process --merge-labels=0:1 --merge-dst-label=1 ark:- ark:- \| \ - segmentation-post-process --merge-labels=10000 --merge-dst-label=0 --merge-adjacent-segments \ - --max-intersegment-length=10000 ark,t:- \ - "ark:| gzip -c > $dir/manual_segments_regions.seg.gz" + ## Get a filter that selects only regions within the manual segments. + #$train_cmd $dir/log/get_manual_segments_regions.log \ + # segmentation-init-from-segments --shift-to-zero=false $src_dir/data/sdm1/${dataset}/segments ark:- \| \ + # segmentation-combine-segments-to-recordings ark:- ark,t:$src_dir/data/sdm1/${dataset}/reco2utt ark:- \| \ + # segmentation-create-subsegments --filter-label=1 --subsegment-label=1 \ + # "ark:segmentation-init-from-lengths --label=0 ark,t:$src_dir/data/sdm1/${dataset}/reco2num_frames ark:- |" ark:- ark,t:- \| \ + # perl -ane '$F[3] = 10000; $F[$#F-1] = 10000; print join(" ", @F) . "\n";' \| \ + # segmentation-create-subsegments --filter-label=10000 --subsegment-label=10000 \ + # ark,t:- "ark:gunzip -c $dir/ref_spk_seg.gz |" ark:- \| \ + # segmentation-post-process --merge-labels=0:1 --merge-dst-label=1 ark:- ark:- \| \ + # segmentation-post-process --merge-labels=10000 --merge-dst-label=0 --merge-adjacent-segments \ + # --max-intersegment-length=10000 ark,t:- \ + # "ark:| gzip -c > $dir/manual_segments_regions.seg.gz" fi if [ $stage -le 7 ]; then - # To get the actual RTTM, we need to add no-score - $train_cmd $dir/log/get_ref_rttm.log \ + $train_cmd $dir/log/get_overlap_sad_seg.log \ segmentation-get-stats --lengths-rspecifier=ark,t:$src_dir/data/sdm1/${dataset}/reco2num_frames \ - "ark:gunzip -c $dir/ref_spk_seg.gz | segmentation-post-process --remove-labels=0:10000 ark:- ark:- |" \ - ark:/dev/null ark:- \| \ - segmentation-init-from-ali ark:- ark:- \| \ - segmentation-post-process --merge-labels=1:2:3:4:5:6:7:8:9:10 --merge-dst-label=1 \ - --merge-adjacent-segments --max-intersegment-length=10000 ark:- ark:- \| \ - segmentation-create-subsegments --filter-label=0 --subsegment-label=10000 \ - ark:- "ark:gunzip -c $dir/manual_segments_regions.seg.gz |" ark:- \| \ - segmentation-post-process --merge-adjacent-segments --max-intersegment-length=10000 ark:- ark:- \| \ - segmentation-to-rttm --reco2file-and-channel=$dir/reco2file_and_channel \ - --no-score-label=10000 ark:- $dir/ref.rttm + "ark:gunzip -c $dir/ref_spk_seg.gz |" \ + ark:/dev/null ark:/dev/null ark:- \| \ + classes-per-frame-to-labels --junk-label=10000 ark:- ark:- \| \ + segmentation-init-from-ali ark:- \ + "ark:| gzip -c > $dir/overlap_sad_seg.gz" fi - if [ $stage -le 8 ]; then + # To get the actual RTTM, we need to add no-score + $train_cmd $dir/log/get_ref_rttm.log \ + gunzip -c $dir/overlap_sad_seg.gz \| \ + segmentation-post-process --merge-labels=1:2 --merge-dst-label=1 \ + ark:- ark:- \| \ + segmentation-to-rttm --reco2file-and-channel=$dir/reco2file_and_channel \ + --no-score-label=10000 ark:- $dir/ref.rttm + # Get RTTM for overlapped speech detection with 3 classes # 0 -> SILENCE, 1 -> SINGLE_SPEAKER, 2 -> OVERLAP - $train_cmd $dir/log/get_overlapping_rttm.log \ - segmentation-get-stats --lengths-rspecifier=ark,t:$src_dir/data/sdm1/${dataset}/reco2num_frames \ - "ark:gunzip -c $dir/ref_spk_seg.gz | segmentation-post-process --remove-labels=0:10000 ark:- ark:- |" \ - ark:/dev/null ark:- \| \ - segmentation-init-from-ali ark:- ark:- \| \ - segmentation-post-process --merge-labels=2:3:4:5:6:7:8:9:10 --merge-dst-label=2 \ - --merge-adjacent-segments --max-intersegment-length=10000 ark:- ark:- \| \ - segmentation-create-subsegments --filter-label=0 --subsegment-label=10000 \ - ark:- "ark:gunzip -c $dir/manual_segments_regions.seg.gz |" ark:- \| \ - segmentation-post-process --merge-adjacent-segments --max-intersegment-length=10000 ark:- ark:- \| \ - segmentation-to-rttm --map-to-speech-and-sil=false --reco2file-and-channel=$dir/reco2file_and_channel \ - --no-score-label=10000 ark:- $dir/overlapping_speech_ref.rttm + $train_cmd $dir/log/get_ref_rttm.log \ + gunzip -c $dir/overlap_sad_seg.gz \| \ + segmentation-to-rttm --reco2file-and-channel=$dir/reco2file_and_channel \ + --no-score-label=10000 --map-to-speech-and-sil=false ark:- $dir/overlapping_speech_ref.rttm fi + +#if [ $stage -le 8 ]; then +# # Get RTTM for overlapped speech detection with 3 classes +# # 0 -> SILENCE, 1 -> SINGLE_SPEAKER, 2 -> OVERLAP +# $train_cmd $dir/log/get_overlapping_rttm.log \ +# segmentation-get-stats --lengths-rspecifier=ark,t:$src_dir/data/sdm1/${dataset}/reco2num_frames \ +# "ark:gunzip -c $dir/ref_spk_seg.gz | segmentation-post-process --remove-labels=0:10000 ark:- ark:- |" \ +# ark:/dev/null ark:- \| \ +# segmentation-init-from-ali ark:- ark:- \| \ +# segmentation-post-process --merge-labels=2:3:4:5:6:7:8:9:10 --merge-dst-label=2 \ +# --merge-adjacent-segments --max-intersegment-length=10000 ark:- ark:- \| \ +# segmentation-create-subsegments --filter-label=0 --subsegment-label=10000 \ +# ark:- "ark:gunzip -c $dir/manual_segments_regions.seg.gz |" ark:- \| \ +# segmentation-post-process --merge-adjacent-segments --max-intersegment-length=10000 ark:- ark:- \| \ +# segmentation-to-rttm --map-to-speech-and-sil=false --reco2file-and-channel=$dir/reco2file_and_channel \ +# --no-score-label=10000 ark:- $dir/overlapping_speech_ref.rttm +#fi + +# make $dir an absolute pathname. +dir=`perl -e '($dir,$pwd)= @ARGV; if($dir!~m:^/:) { $dir = "$pwd/$dir"; } print $dir; ' $dir ${PWD}` + if [ $stage -le 9 ]; then # Get a filter that selects only regions of speech $train_cmd $dir/log/get_speech_filter.log \ - segmentation-get-stats --lengths-rspecifier=ark,t:$src_dir/data/sdm1/${dataset}/reco2num_frames \ - "ark:gunzip -c $dir/ref_spk_seg.gz | segmentation-post-process --remove-labels=0:10000 ark:- ark:- |" \ - ark:/dev/null ark:- \| \ - segmentation-init-from-ali ark:- ark:- \| \ - segmentation-post-process --merge-labels=1:2:3:4:5:6:7:8:9:10 --merge-dst-label=1 ark:- ark:- \| \ - segmentation-create-subsegments --filter-label=0 --subsegment-label=0 \ - ark:- "ark:gunzip -c $dir/manual_segments_regions.seg.gz |" ark:- \| \ - segmentation-post-process --merge-adjacent-segments --max-intersegment-length=10000 \ - ark:- "ark:| gzip -c > $dir/manual_segments_speech_regions.seg.gz" + gunzip -c $dir/overlap_sad_seg.gz \| \ + segmentation-post-process --merge-labels=1:2 --merge-dst-label=1 ark:- ark:- \| \ + segmentation-post-process --remove-labels=10000 ark:- ark:- \| \ + segmentation-to-ali --lengths-rspecifier=ark,t:$src_dir/data/sdm1/${dataset}/reco2num_frames \ + ark:- ark,t:- \| \ + steps/segmentation/convert_ali_to_vec.pl \| \ + copy-vector ark,t: ark,scp:$dir/deriv_weights_for_overlapping_sad.ark,$dir/deriv_weights_for_overlapping_sad.scp + + # Get deriv weights + $train_cmd $dir/log/get_speech_filter.log \ + gunzip -c $dir/overlap_sad_seg.gz \| \ + segmentation-post-process --merge-labels=0:1:2 --merge-dst-label=1 ark:- ark:- \| \ + segmentation-post-process --remove-labels=10000 ark:- ark:- \| \ + segmentation-to-ali --lengths-rspecifier=ark,t:$src_dir/data/sdm1/${dataset}/reco2num_frames \ + ark:- ark,t:- \| \ + steps/segmentation/convert_ali_to_vec.pl \| \ + copy-vector ark,t: ark,scp:$dir/deriv_weights.ark,$dir/deriv_weights.scp fi -# make $dir an absolute pathname. -dir=`perl -e '($dir,$pwd)= @ARGV; if($dir!~m:^/:) { $dir = "$pwd/$dir"; } print $dir; ' $dir ${PWD}` - if [ $stage -le 10 ]; then $train_cmd $dir/log/get_overlapping_sad.log \ - segmentation-get-stats --lengths-rspecifier=ark,t:$src_dir/data/sdm1/${dataset}/reco2num_frames \ - "ark:gunzip -c $dir/ref_spk_seg.gz | segmentation-post-process --remove-labels=0:10000 ark:- ark:- |" \ - ark:/dev/null ark:- \| \ - segmentation-init-from-ali ark:- ark:- \| \ - segmentation-post-process --merge-labels=2:3:4:5:6:7:8:9:10 --merge-dst-label=2 \ - --merge-adjacent-segments --max-intersegment-length=10000 ark:- ark:- \| \ - segmentation-post-process --merge-adjacent-segments --max-intersegment-length=10000 ark:- ark:- \| \ - segmentation-to-ali ark:- ark,scp:$dir/overlapping_sad_labels.ark,$dir/overlapping_sad_labels.scp - - $train_cmd $dir/log/get_deriv_weights_for_overlapping_sad.log \ - segmentation-to-ali "ark:gunzip -c $dir/manual_segments_regions.seg.gz |" ark,t:- \| \ - steps/segmentation/convert_ali_to_vec.pl \| \ - copy-vector ark,t: ark,scp:$dir/deriv_weights_for_overlapping_sad.ark,$dir/deriv_weights_for_overlapping_sad.scp + gunzip -c $dir/overlap_sad_seg.gz \| \ + segmentation-post-process --remove-labels=10000 ark:- ark:- \| \ + segmentation-to-ali --lengths-rspecifier=ark,t:$src_dir/data/sdm1/${dataset}/reco2num_frames \ + ark:- ark,scp:$dir/overlapping_sad_labels.ark,$dir/overlapping_sad_labels.scp fi if false && [ $stage -le 11 ]; then diff --git a/egs/aspire/s5/local/segmentation/prepare_babel_data_overlapped_speech.sh b/egs/aspire/s5/local/segmentation/prepare_babel_data_overlapped_speech.sh index 2136f42f322..a3e087d95ec 100644 --- a/egs/aspire/s5/local/segmentation/prepare_babel_data_overlapped_speech.sh +++ b/egs/aspire/s5/local/segmentation/prepare_babel_data_overlapped_speech.sh @@ -14,7 +14,7 @@ set -o pipefail set -u lang_id=assamese -subset=25 # Number of recordings to keep before speed perturbation and corruption +subset=150 # Number of recordings to keep before speed perturbation and corruption utt_subset=30000 # Number of utterances to keep after speed perturbation for adding overlapped-speech # All the paths below can be modified to any absolute path. @@ -88,15 +88,15 @@ fi reco_vad_dir=$dir/`basename $model_dir`_reco_vad_`basename $train_data_dir`_sp # Add noise from MUSAN corpus to data directory and create a new data directory -local/segmentation/do_corruption_data_dir.sh +local/segmentation/do_corruption_data_dir.sh \ --data-dir $data_dir \ - --reco-vad-dir $reco_vad_dir + --reco-vad-dir $reco_vad_dir \ --feat-suffix hires_bp --mfcc-config conf/mfcc_hires_bp.conf # Add music from MUSAN corpus to data directory and create a new data directory -local/segmentation/do_corruption_data_dir_music.sh +local/segmentation/do_corruption_data_dir_music.sh \ --data-dir $data_dir \ - --reco-vad-dir $reco_vad_dir + --reco-vad-dir $reco_vad_dir \ --feat-suffix hires_bp --mfcc-config conf/mfcc_hires_bp.conf if [ ! -z $utt_subset ]; then diff --git a/egs/aspire/s5/local/segmentation/prepare_unsad_overlapped_speech_data_simple.sh b/egs/aspire/s5/local/segmentation/prepare_unsad_overlapped_speech_data_simple.sh index 73f2abca566..80810afd619 100755 --- a/egs/aspire/s5/local/segmentation/prepare_unsad_overlapped_speech_data_simple.sh +++ b/egs/aspire/s5/local/segmentation/prepare_unsad_overlapped_speech_data_simple.sh @@ -68,8 +68,9 @@ if [ $stage -le 1 ]; then segmentation-init-from-additive-signals-info --lengths-rspecifier=ark,t:$corrupted_data_dir/utt2num_frames \ --junk-label=10000 \ --additive-signals-segmentation-rspecifier=scp:$utt_vad_dir/sad_seg.scp \ - "scp:utils/filter_scp.pl ${corrupted_data_dir}/split${nj}/JOB/utt2spk $corrupted_data_dir/sad_seg.scp |" \ - ark,t:$orig_corrupted_data_dir/overlapped_segments_info.txt ark:- \| \ + "ark,t:utils/filter_scp.pl ${orig_corrupted_data_dir}/split${reco_nj}reco/JOB/reco2utt $orig_corrupted_data_dir/overlapped_segments_info.txt |" \ + ark:- \| \ + segmentation-merge "scp:utils/filter_scp.pl ${corrupted_data_dir}/split${nj}/JOB/utt2spk $corrupted_data_dir/sad_seg.scp |" ark:- ark:- \| \ segmentation-get-stats --lengths-rspecifier=ark,t:$corrupted_data_dir/utt2num_frames \ ark:- ark:/dev/null ark:/dev/null ark:- \| \ classes-per-frame-to-labels --junk-label=10000 ark:- ark:- \| \ From fd96de7a8867dc455579cdc62ad5f48076c9fdc4 Mon Sep 17 00:00:00 2001 From: Vimal Manohar Date: Wed, 25 Jan 2017 15:11:34 -0500 Subject: [PATCH 164/213] asr_diarization: Add simplehmmbin to common_path --- src/simplehmmbin/Makefile | 23 +++ .../compile-train-simple-hmm-graphs.cc | 151 ++++++++++++++++++ src/simplehmmbin/make-simple-hmm-graph.cc | 87 ++++++++++ src/simplehmmbin/simple-hmm-acc-stats-ali.cc | 88 ++++++++++ src/simplehmmbin/simple-hmm-align-compiled.cc | 131 +++++++++++++++ src/simplehmmbin/simple-hmm-est.cc | 86 ++++++++++ src/simplehmmbin/simple-hmm-init.cc | 70 ++++++++ tools/config/common_path.sh | 3 +- 8 files changed, 638 insertions(+), 1 deletion(-) create mode 100644 src/simplehmmbin/Makefile create mode 100644 src/simplehmmbin/compile-train-simple-hmm-graphs.cc create mode 100644 src/simplehmmbin/make-simple-hmm-graph.cc create mode 100644 src/simplehmmbin/simple-hmm-acc-stats-ali.cc create mode 100644 src/simplehmmbin/simple-hmm-align-compiled.cc create mode 100644 src/simplehmmbin/simple-hmm-est.cc create mode 100644 src/simplehmmbin/simple-hmm-init.cc diff --git a/src/simplehmmbin/Makefile b/src/simplehmmbin/Makefile new file mode 100644 index 00000000000..f382b30277c --- /dev/null +++ b/src/simplehmmbin/Makefile @@ -0,0 +1,23 @@ + +all: +EXTRA_CXXFLAGS = -Wno-sign-compare +include ../kaldi.mk + +BINFILES = simple-hmm-init \ + compile-train-simple-hmm-graphs simple-hmm-align-compiled \ + simple-hmm-acc-stats-ali simple-hmm-est make-simple-hmm-graph + + +OBJFILES = + +ADDLIBS = ../decoder/kaldi-decoder.a ../lat/kaldi-lat.a \ + ../fstext/kaldi-fstext.a ../hmm/kaldi-hmm.a \ + ../simplehmm/kaldi-simplehmm.a\ + ../util/kaldi-util.a ../thread/kaldi-thread.a \ + ../matrix/kaldi-matrix.a ../base/kaldi-base.a + + +TESTFILES = + +include ../makefiles/default_rules.mk + diff --git a/src/simplehmmbin/compile-train-simple-hmm-graphs.cc b/src/simplehmmbin/compile-train-simple-hmm-graphs.cc new file mode 100644 index 00000000000..a1914ed0763 --- /dev/null +++ b/src/simplehmmbin/compile-train-simple-hmm-graphs.cc @@ -0,0 +1,151 @@ +// bin/compile-train-simple-hmm-graphs.cc + +// Copyright 2009-2012 Microsoft Corporation +// 2012-2015 Johns Hopkins University (Author: Daniel Povey) +// 2016 Vimal Manohar + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "base/kaldi-common.h" +#include "util/common-utils.h" +#include "tree/context-dep.h" +#include "simplehmm/simple-hmm.h" +#include "fstext/fstext-lib.h" +#include "decoder/simple-hmm-graph-compiler.h" + + +int main(int argc, char *argv[]) { + try { + using namespace kaldi; + typedef kaldi::int32 int32; + using fst::SymbolTable; + using fst::VectorFst; + using fst::StdArc; + + const char *usage = + "Creates training graphs (without transition-probabilities, by default)\n" + "for training SimpleHmm models using alignments of pdf-ids.\n" + "Usage: compile-train-simple-hmm-graphs [options] " + " \n" + "e.g.: \n" + " compile-train-simple-hmm-graphs 1.mdl ark:train.tra ark:graphs.fsts\n"; + ParseOptions po(usage); + + SimpleHmmGraphCompilerOptions gopts; + int32 batch_size = 250; + gopts.transition_scale = 0.0; // Change the default to 0.0 since we will generally add the + // transition probs in the alignment phase (since they change eacm time) + gopts.self_loop_scale = 0.0; // Ditto for self-loop probs. + std::string disambig_rxfilename; + gopts.Register(&po); + + po.Register("batch-size", &batch_size, + "Number of FSTs to compile at a time (more -> faster but uses " + "more memory. E.g. 500"); + + po.Read(argc, argv); + + if (po.NumArgs() != 3) { + po.PrintUsage(); + exit(1); + } + + std::string model_rxfilename = po.GetArg(1); + std::string alignment_rspecifier = po.GetArg(2); + std::string fsts_wspecifier = po.GetArg(3); + + SimpleHmm model; + ReadKaldiObject(model_rxfilename, &model); + + SimpleHmmGraphCompiler gc(model, gopts); + + SequentialInt32VectorReader alignment_reader(alignment_rspecifier); + TableWriter fst_writer(fsts_wspecifier); + + int32 num_succeed = 0, num_fail = 0; + + if (batch_size == 1) { // We treat batch_size of 1 as a special case in order + // to test more parts of the code. + for (; !alignment_reader.Done(); alignment_reader.Next()) { + const std::string &key = alignment_reader.Key(); + std::vector alignment = alignment_reader.Value(); + + for (std::vector::iterator it = alignment.begin(); + it != alignment.end(); ++it) { + KALDI_ASSERT(*it < model.NumPdfs()); + ++(*it); + } + + VectorFst decode_fst; + + if (!gc.CompileGraphFromAlignment(alignment, &decode_fst)) { + decode_fst.DeleteStates(); // Just make it empty. + } + if (decode_fst.Start() != fst::kNoStateId) { + num_succeed++; + fst_writer.Write(key, decode_fst); + } else { + KALDI_WARN << "Empty decoding graph for utterance " + << key; + num_fail++; + } + } + } else { + std::vector keys; + std::vector > alignments; + while (!alignment_reader.Done()) { + keys.clear(); + alignments.clear(); + for (; !alignment_reader.Done() && + static_cast(alignments.size()) < batch_size; + alignment_reader.Next()) { + keys.push_back(alignment_reader.Key()); + alignments.push_back(alignment_reader.Value()); + + for (std::vector::iterator it = alignments.back().begin(); + it != alignments.back().end(); ++it) { + KALDI_ASSERT(*it < model.NumPdfs()); + ++(*it); + } + } + std::vector* > fsts; + if (!gc.CompileGraphsFromAlignments(alignments, &fsts)) { + KALDI_ERR << "Not expecting CompileGraphs to fail."; + } + KALDI_ASSERT(fsts.size() == keys.size()); + for (size_t i = 0; i < fsts.size(); i++) { + if (fsts[i]->Start() != fst::kNoStateId) { + num_succeed++; + fst_writer.Write(keys[i], *(fsts[i])); + } else { + KALDI_WARN << "Empty decoding graph for utterance " + << keys[i]; + num_fail++; + } + } + DeletePointers(&fsts); + } + } + KALDI_LOG << "compile-train--simple-hmm-graphs: succeeded for " + << num_succeed << " graphs, failed for " << num_fail; + return (num_succeed != 0 ? 0 : 1); + } catch(const std::exception &e) { + std::cerr << e.what(); + return -1; + } +} + + diff --git a/src/simplehmmbin/make-simple-hmm-graph.cc b/src/simplehmmbin/make-simple-hmm-graph.cc new file mode 100644 index 00000000000..088a73e7c50 --- /dev/null +++ b/src/simplehmmbin/make-simple-hmm-graph.cc @@ -0,0 +1,87 @@ +// simplehmmbin/make-simple-hmm-graph.cc + +// Copyright 2016 Vimal Manohar + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "simplehmm/simple-hmm.h" +#include "simplehmm/simple-hmm-utils.h" +#include "util/common-utils.h" +#include "fst/fstlib.h" +#include "fstext/table-matcher.h" +#include "fstext/fstext-utils.h" +#include "fstext/context-fst.h" +#include "decoder/simple-hmm-graph-compiler.h" + + +int main(int argc, char *argv[]) { + try { + using namespace kaldi; + typedef kaldi::int32 int32; + using fst::SymbolTable; + using fst::VectorFst; + using fst::StdArc; + + const char *usage = + "Make graph to decode with simple HMM. It is an FST from " + "transition-ids to pdf-ids + 1, \n" + "Usage: make-simple-hmm-graph []\n" + "e.g.: \n" + " make-simple-hmm-graph 1.mdl > HCLG.fst\n"; + ParseOptions po(usage); + + SimpleHmmGraphCompilerOptions gopts; + gopts.Register(&po); + + po.Read(argc, argv); + + if (po.NumArgs() < 1 || po.NumArgs() > 2) { + po.PrintUsage(); + exit(1); + } + + std::string model_filename = po.GetArg(1); + std::string fst_out_filename; + if (po.NumArgs() >= 2) fst_out_filename = po.GetArg(2); + if (fst_out_filename == "-") fst_out_filename = ""; + + SimpleHmm trans_model; + ReadKaldiObject(model_filename, &trans_model); + + // The work gets done here. + fst::VectorFst *H = GetHTransducer (trans_model, + gopts.transition_scale, + gopts.self_loop_scale); + +#if _MSC_VER + if (fst_out_filename == "") + _setmode(_fileno(stdout), _O_BINARY); +#endif + + if (! H->Write(fst_out_filename) ) + KALDI_ERR << "make-simple-hmm-graph: error writing FST to " + << (fst_out_filename == "" ? + "standard output" : fst_out_filename); + + delete H; + return 0; + } catch(const std::exception &e) { + std::cerr << e.what(); + return -1; + } +} + + diff --git a/src/simplehmmbin/simple-hmm-acc-stats-ali.cc b/src/simplehmmbin/simple-hmm-acc-stats-ali.cc new file mode 100644 index 00000000000..5bcf8239311 --- /dev/null +++ b/src/simplehmmbin/simple-hmm-acc-stats-ali.cc @@ -0,0 +1,88 @@ +// simplehmmbin/simple-hmm-acc-stats-ali.cc + +// Copyright 2009-2012 Microsoft Corporation Johns Hopkins University (Author: Daniel Povey) +// 2016 Vimal Manohar + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + + +#include "base/kaldi-common.h" +#include "util/common-utils.h" +#include "simplehmm/simple-hmm.h" + +int main(int argc, char *argv[]) { + using namespace kaldi; + typedef kaldi::int32 int32; + try { + const char *usage = + "Accumulate stats for simple HMM training.\n" + "Usage: simple-hmm-acc-stats-ali [options] " + " \n" + "e.g.:\n simple-hmm-acc-stats-ali 1.mdl ark:1.ali 1.acc\n"; + + ParseOptions po(usage); + bool binary = true; + po.Register("binary", &binary, "Write output in binary mode"); + po.Read(argc, argv); + + if (po.NumArgs() != 3) { + po.PrintUsage(); + exit(1); + } + + std::string model_filename = po.GetArg(1), + alignments_rspecifier = po.GetArg(2), + accs_wxfilename = po.GetArg(3); + + SimpleHmm model; + ReadKaldiObject(model_filename, &model); + + Vector transition_accs; + model.InitStats(&transition_accs); + + SequentialInt32VectorReader alignments_reader(alignments_rspecifier); + + int32 num_done = 0, num_err = 0; + for (; !alignments_reader.Done(); alignments_reader.Next()) { + const std::string &key = alignments_reader.Key(); + const std::vector &alignment = alignments_reader.Value(); + + for (size_t i = 0; i < alignment.size(); i++) { + int32 tid = alignment[i]; // transition identifier. + model.Accumulate(1.0, tid, &transition_accs); + } + + num_done++; + } + KALDI_LOG << "Done " << num_done << " files, " << num_err + << " with errors."; + + { + Output ko(accs_wxfilename, binary); + transition_accs.Write(ko.Stream(), binary); + } + KALDI_LOG << "Written accs."; + if (num_done != 0) + return 0; + else + return 1; + } catch(const std::exception &e) { + std::cerr << e.what(); + return -1; + } +} + + diff --git a/src/simplehmmbin/simple-hmm-align-compiled.cc b/src/simplehmmbin/simple-hmm-align-compiled.cc new file mode 100644 index 00000000000..4a2bc286b24 --- /dev/null +++ b/src/simplehmmbin/simple-hmm-align-compiled.cc @@ -0,0 +1,131 @@ +// simplehmmbin/simple-hmm-align-compiled.cc + +// Copyright 2009-2013 Microsoft Corporation +// Johns Hopkins University (author: Daniel Povey) +// 2016 Vimal Manohar + + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "base/kaldi-common.h" +#include "util/common-utils.h" +#include "simplehmm/simple-hmm.h" +#include "simplehmm/simple-hmm-utils.h" +#include "fstext/fstext-lib.h" +#include "decoder/decoder-wrappers.h" +#include "decoder/decodable-matrix.h" + +int main(int argc, char *argv[]) { + try { + using namespace kaldi; + typedef kaldi::int32 int32; + using fst::SymbolTable; + using fst::VectorFst; + using fst::StdArc; + + const char *usage = + "Align matrix of log-likelihoods given simple HMM model.\n" + "Usage: simple-hmm-align-compiled [options] " + " []\n" + "e.g.: \n" + " simple-hmm-align-compiled 1.mdl ark:graphs.fsts ark:log_likes.1.ark ark:1.ali\n"; + + ParseOptions po(usage); + AlignConfig align_config; + BaseFloat acoustic_scale = 1.0; + BaseFloat transition_scale = 1.0; + BaseFloat self_loop_scale = 1.0; + + align_config.Register(&po); + po.Register("transition-scale", &transition_scale, + "Transition-probability scale [relative to acoustics]"); + po.Register("acoustic-scale", &acoustic_scale, + "Scaling factor for acoustic likelihoods"); + po.Register("self-loop-scale", &self_loop_scale, + "Scale of self-loop versus non-self-loop log probs [relative to acoustics]"); + po.Read(argc, argv); + + if (po.NumArgs() < 4 || po.NumArgs() > 5) { + po.PrintUsage(); + exit(1); + } + + std::string model_in_filename = po.GetArg(1), + fst_rspecifier = po.GetArg(2), + loglikes_rspecifier = po.GetArg(3), + alignment_wspecifier = po.GetArg(4), + scores_wspecifier = po.GetOptArg(5); + + SimpleHmm model; + ReadKaldiObject(model_in_filename, &model); + + SequentialTableReader fst_reader(fst_rspecifier); + RandomAccessBaseFloatMatrixReader loglikes_reader(loglikes_rspecifier); + Int32VectorWriter alignment_writer(alignment_wspecifier); + BaseFloatWriter scores_writer(scores_wspecifier); + + int32 num_done = 0, num_err = 0, num_retry = 0; + double tot_like = 0.0; + kaldi::int64 frame_count = 0; + + for (; !fst_reader.Done(); fst_reader.Next()) { + const std::string &utt = fst_reader.Key(); + if (!loglikes_reader.HasKey(utt)) { + num_err++; + KALDI_WARN << "No loglikes for utterance " << utt; + } else { + const Matrix &loglikes = loglikes_reader.Value(utt); + VectorFst decode_fst(fst_reader.Value()); + fst_reader.FreeCurrent(); // this stops copy-on-write of the fst + // by deleting the fst inside the reader, since we're about to mutate + // the fst by adding transition probs. + + if (loglikes.NumRows() == 0) { + KALDI_WARN << "Zero-length utterance: " << utt; + num_err++; + continue; + } + + { // Add transition-probs to the FST. + std::vector disambig_syms; // empty + AddTransitionProbs(model, disambig_syms, transition_scale, + self_loop_scale, &decode_fst); + } + + DecodableMatrixScaledMapped decodable(model, loglikes, acoustic_scale); + + AlignUtteranceWrapper(align_config, utt, + acoustic_scale, &decode_fst, + &decodable, + &alignment_writer, &scores_writer, + &num_done, &num_err, &num_retry, + &tot_like, &frame_count); + } + } + KALDI_LOG << "Overall log-likelihood per frame is " + << (tot_like/frame_count) + << " over " << frame_count<< " frames."; + KALDI_LOG << "Retried " << num_retry << " out of " + << (num_done + num_err) << " utterances."; + KALDI_LOG << "Done " << num_done << ", errors on " << num_err; + return (num_done != 0 ? 0 : 1); + } catch(const std::exception &e) { + std::cerr << e.what(); + return -1; + } +} + + diff --git a/src/simplehmmbin/simple-hmm-est.cc b/src/simplehmmbin/simple-hmm-est.cc new file mode 100644 index 00000000000..b121bad44b0 --- /dev/null +++ b/src/simplehmmbin/simple-hmm-est.cc @@ -0,0 +1,86 @@ +// simplehmmbin/simple-hmm-est.cc + +// Copyright 2009-2011 Microsoft Corporation +// 2016 Vimal Manohar + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "base/kaldi-common.h" +#include "util/common-utils.h" +#include "simplehmm/simple-hmm.h" + +int main(int argc, char *argv[]) { + try { + using namespace kaldi; + typedef kaldi::int32 int32; + + const char *usage = + "Do Maximum Likelihood re-estimation of simple HMM " + "transition parameters\n" + "Usage: simple-hmm-est [options] \n" + "e.g.: simple-hmm-est 1.mdl 1.acc 2.mdl\n"; + + bool binary_write = true; + MleTransitionUpdateConfig tcfg; + std::string occs_out_filename; + + ParseOptions po(usage); + po.Register("binary", &binary_write, "Write output in binary mode"); + tcfg.Register(&po); + + po.Read(argc, argv); + + if (po.NumArgs() != 3) { + po.PrintUsage(); + exit(1); + } + + std::string model_in_filename = po.GetArg(1), + stats_filename = po.GetArg(2), + model_out_filename = po.GetArg(3); + + SimpleHmm model; + ReadKaldiObject(model_in_filename, &model); + + Vector transition_accs; + ReadKaldiObject(stats_filename, &transition_accs); + + { + BaseFloat objf_impr, count; + model.MleUpdate(transition_accs, tcfg, &objf_impr, &count); + KALDI_LOG << "Transition model update: Overall " << (objf_impr/count) + << " log-like improvement per frame over " << (count) + << " frames."; + } + + WriteKaldiObject(model, model_out_filename, binary_write); + + if (GetVerboseLevel() >= 2) { + std::vector phone_names; + phone_names.push_back("0"); + phone_names.push_back("1"); + model.Print(KALDI_LOG, phone_names); + } + + KALDI_LOG << "Written model to " << model_out_filename; + return 0; + } catch(const std::exception &e) { + std::cerr << e.what() << '\n'; + return -1; + } +} + + diff --git a/src/simplehmmbin/simple-hmm-init.cc b/src/simplehmmbin/simple-hmm-init.cc new file mode 100644 index 00000000000..ddee0893b7c --- /dev/null +++ b/src/simplehmmbin/simple-hmm-init.cc @@ -0,0 +1,70 @@ +// bin/simple-hmm-init.cc + +// Copyright 2016 Vimal Manohar (Johns Hopkins University) + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "base/kaldi-common.h" +#include "util/common-utils.h" +#include "hmm/hmm-topology.h" +#include "simplehmm/simple-hmm.h" + +int main(int argc, char *argv[]) { + try { + using namespace kaldi; + using kaldi::int32; + + const char *usage = + "Initialize simple HMM from topology.\n" + "Usage: simple-hmm-init \n" + "e.g.: \n" + " simple-hmm-init topo init.mdl\n"; + + bool binary = true; + + ParseOptions po(usage); + po.Register("binary", &binary, "Write output in binary mode"); + + po.Read(argc, argv); + + if (po.NumArgs() != 2) { + po.PrintUsage(); + exit(1); + } + + std::string topo_filename = po.GetArg(1); + std::string model_filename = po.GetArg(2); + + HmmTopology topo; + { + bool binary_in; + Input ki(topo_filename, &binary_in); + topo.Read(ki.Stream(), binary_in); + } + + SimpleHmm model(topo); + { + Output ko(model_filename, binary); + model.Write(ko.Stream(), binary); + } + return 0; + } catch(const std::exception &e) { + std::cerr << e.what(); + return -1; + } +} + + diff --git a/tools/config/common_path.sh b/tools/config/common_path.sh index 36b5350dd8e..9c2e32d0cb1 100644 --- a/tools/config/common_path.sh +++ b/tools/config/common_path.sh @@ -1,4 +1,4 @@ -# we assume KALDI_ROOT is already defined +# we assume KALDI_ROOT is already defined [ -z "$KALDI_ROOT" ] && echo "The variable KALDI_ROOT must be already defined" && exit 1 # The formatting of the path export command is intentionally weird, because # this allows for easy diff'ing @@ -21,4 +21,5 @@ ${KALDI_ROOT}/src/onlinebin:\ ${KALDI_ROOT}/src/sgmm2bin:\ ${KALDI_ROOT}/src/sgmmbin:\ ${KALDI_ROOT}/src/segmenterbin:\ +${KALDI_ROOT}/src/simplehmmbin:\ $PATH From 7a678fdb323f67b8fc0f0622cff15156111c7043 Mon Sep 17 00:00:00 2001 From: Vimal Manohar Date: Wed, 25 Jan 2017 15:14:28 -0500 Subject: [PATCH 165/213] asr_diarization: Add IB clustering --- .../information-bottleneck-cluster-utils.cc | 192 +++++++++++++++ .../information-bottleneck-cluster-utils.h | 65 +++++ ...information-bottleneck-clusterable-test.cc | 94 ++++++++ .../information-bottleneck-clusterable.cc | 226 ++++++++++++++++++ .../information-bottleneck-clusterable.h | 163 +++++++++++++ src/tree/cluster-utils.cc | 73 +----- src/tree/cluster-utils.h | 89 ++++++- 7 files changed, 838 insertions(+), 64 deletions(-) create mode 100644 src/segmenter/information-bottleneck-cluster-utils.cc create mode 100644 src/segmenter/information-bottleneck-cluster-utils.h create mode 100644 src/segmenter/information-bottleneck-clusterable-test.cc create mode 100644 src/segmenter/information-bottleneck-clusterable.cc create mode 100644 src/segmenter/information-bottleneck-clusterable.h diff --git a/src/segmenter/information-bottleneck-cluster-utils.cc b/src/segmenter/information-bottleneck-cluster-utils.cc new file mode 100644 index 00000000000..75fda8c59fe --- /dev/null +++ b/src/segmenter/information-bottleneck-cluster-utils.cc @@ -0,0 +1,192 @@ +// segmenter/information-bottleneck-cluster-utils.cc + +// Copyright 2017 Vimal Manohar (Johns Hopkins University) + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "base/kaldi-common.h" +#include "tree/cluster-utils.h" +#include "segmenter/information-bottleneck-cluster-utils.h" + +namespace kaldi { + +typedef uint16 uint_smaller; +typedef int16 int_smaller; + +class InformationBottleneckBottomUpClusterer : public BottomUpClusterer { + public: + InformationBottleneckBottomUpClusterer( + const std::vector &points, + const InformationBottleneckClustererOptions &opts, + BaseFloat max_merge_thresh, + int32 min_clusters, + std::vector *clusters_out, + std::vector *assignments_out); + + private: + virtual BaseFloat ComputeDistance(int32 i, int32 j); + virtual bool StoppingCriterion() const; + virtual void UpdateClustererStats(int32 i, int32 j); + + BaseFloat NormalizedMutualInformation() const { + return ((merged_entropy_ - current_entropy_) + / (merged_entropy_ - initial_entropy_)); + } + + /// Stop merging when the stopping criterion, e.g. NMI, reaches this + /// threshold. + BaseFloat stopping_threshold_; + + /// Weight of the relevant variables entropy towards the objective. + BaseFloat relevance_factor_; + + /// Weight of the input variables entropy towards the objective. + BaseFloat input_factor_; + + /// Running entropy of the clusters. + BaseFloat current_entropy_; + + /// Some stats computed by the constructor that will be useful for + /// adding stopping criterion. + BaseFloat initial_entropy_; + BaseFloat merged_entropy_; +}; + + +InformationBottleneckBottomUpClusterer::InformationBottleneckBottomUpClusterer( + const std::vector &points, + const InformationBottleneckClustererOptions &opts, + BaseFloat max_merge_thresh, + int32 min_clusters, + std::vector *clusters_out, + std::vector *assignments_out) : + BottomUpClusterer(points, max_merge_thresh, min_clusters, + clusters_out, assignments_out), + stopping_threshold_(opts.stopping_threshold), + relevance_factor_(opts.relevance_factor), + input_factor_(opts.input_factor), + current_entropy_(0.0), initial_entropy_(0.0), merged_entropy_(0.0) { + if (points.size() == 0) return; + + InformationBottleneckClusterable* ibc = + static_cast(points[0]->Copy()); + initial_entropy_ -= ibc->Objf(1.0, 0.0); + + for (size_t i = 1; i < points.size(); i++) { + InformationBottleneckClusterable *c = + static_cast(points[i]); + ibc->Add(*points[i]); + initial_entropy_ -= c->Objf(1.0, 0.0); + } + + merged_entropy_ = -ibc->Objf(1.0, 0.0); + current_entropy_ = initial_entropy_; +} + +BaseFloat InformationBottleneckBottomUpClusterer::ComputeDistance( + int32 i, int32 j) { + const InformationBottleneckClusterable* cluster_i + = static_cast(GetCluster(i)); + const InformationBottleneckClusterable* cluster_j + = static_cast(GetCluster(j)); + + BaseFloat dist = (cluster_i->Distance(*cluster_j, relevance_factor_, + input_factor_)); + // / (cluster_i->Normalizer() + cluster_j->Normalizer())); + Distance(i, j) = dist; // set the distance in the array. + return dist; +} + +bool InformationBottleneckBottomUpClusterer::StoppingCriterion() const { + bool flag = (NumClusters() <= MinClusters() || IsQueueEmpty() || + NormalizedMutualInformation() < stopping_threshold_); + if (GetVerboseLevel() < 2 || !flag) return flag; + + if (NormalizedMutualInformation() < stopping_threshold_) { + KALDI_VLOG(2) << "Stopping at " << NumClusters() << " clusters " + << "because NMI = " << NormalizedMutualInformation() + << " < stopping_threshold (" << stopping_threshold_ << ")"; + } else if (NumClusters() < MinClusters()) { + KALDI_VLOG(2) << "Stopping at " << NumClusters() << " clusters " + << "<= min-clusters (" << MinClusters() << ")"; + } else if (IsQueueEmpty()) { + KALDI_VLOG(2) << "Stopping at " << NumClusters() << " clusters " + << "because queue is empty."; + } + + return flag; +} + +void InformationBottleneckBottomUpClusterer::UpdateClustererStats( + int32 i, int32 j) { + const InformationBottleneckClusterable* cluster_i + = static_cast(GetCluster(i)); + current_entropy_ += cluster_i->Distance(*GetCluster(j), 1.0, 0.0); + + if (GetVerboseLevel() > 2) { + const InformationBottleneckClusterable* cluster_j + = static_cast(GetCluster(j)); + std::vector cluster_i_points; + { + std::map::const_iterator it + = cluster_i->Counts().begin(); + for (; it != cluster_i->Counts().end(); ++it) + cluster_i_points.push_back(it->first); + } + + std::vector cluster_j_points; + { + std::map::const_iterator it + = cluster_j->Counts().begin(); + for (; it != cluster_j->Counts().end(); ++it) + cluster_j_points.push_back(it->first); + } + KALDI_VLOG(3) << "Merging clusters " + << "(" << cluster_i_points + << ", " << cluster_j_points + << ").. distance=" << Distance(i, j) + << ", num-clusters-after-merge= " << NumClusters() - 1 + << ", NMI= " << NormalizedMutualInformation(); + } +} + +BaseFloat IBClusterBottomUp( + const std::vector &points, + const InformationBottleneckClustererOptions &opts, + BaseFloat max_merge_thresh, + int32 min_clust, + std::vector *clusters_out, + std::vector *assignments_out) { + KALDI_ASSERT(max_merge_thresh >= 0.0 && min_clust >= 0); + KALDI_ASSERT(opts.stopping_threshold >= 0.0); + KALDI_ASSERT(opts.relevance_factor >= 0.0 && opts.input_factor >= 0.0); + + KALDI_ASSERT(!ContainsNullPointers(points)); + int32 npoints = points.size(); + // make sure fits in uint_smaller and does not hit the -1 which is reserved. + KALDI_ASSERT(sizeof(uint_smaller)==sizeof(uint32) || + npoints < static_cast(static_cast(-1))); + + KALDI_VLOG(2) << "Initializing clustering object."; + InformationBottleneckBottomUpClusterer bc( + points, opts, max_merge_thresh, min_clust, + clusters_out, assignments_out); + BaseFloat ans = bc.Cluster(); + if (clusters_out) KALDI_ASSERT(!ContainsNullPointers(*clusters_out)); + return ans; +} + +} // end namespace kaldi diff --git a/src/segmenter/information-bottleneck-cluster-utils.h b/src/segmenter/information-bottleneck-cluster-utils.h new file mode 100644 index 00000000000..58f1e4f380a --- /dev/null +++ b/src/segmenter/information-bottleneck-cluster-utils.h @@ -0,0 +1,65 @@ +// segmenter/information-bottleneck-cluster-utils.h + +// Copyright 2017 Vimal Manohar (Johns Hopkins University) + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#ifndef KALDI_SEGMENTER_INFORMATION_BOTTLENECK_CLUSTER_UTILS_H_ +#define KALDI_SEGMENTER_INFORMATION_BOTTLENECK_CLUSTER_UTILS_H_ + +#include "base/kaldi-common.h" +#include "tree/cluster-utils.h" +#include "segmenter/information-bottleneck-clusterable.h" +#include "util/common-utils.h" + +namespace kaldi { + +struct InformationBottleneckClustererOptions { + BaseFloat distance_threshold; + int32 num_clusters; + BaseFloat stopping_threshold; + BaseFloat relevance_factor; + BaseFloat input_factor; + + InformationBottleneckClustererOptions() : + distance_threshold(std::numeric_limits::max()), num_clusters(1), + stopping_threshold(0.3), relevance_factor(1.0), input_factor(0.1) { } + + + void Register(OptionsItf *opts) { + opts->Register("stopping-threshold", &stopping_threshold, + "Stopping merging/splitting when an objective such as " + "NMI reaches this value."); + opts->Register("relevance-factor", &relevance_factor, + "Weight factor of the entropy of relevant variables " + "in the objective function"); + opts->Register("input-factor", &input_factor, + "Weight factor of the entropy of input variables " + "in the objective function"); + } +}; + +BaseFloat IBClusterBottomUp( + const std::vector &points, + const InformationBottleneckClustererOptions &opts, + BaseFloat max_merge_thresh, + int32 min_clusters, + std::vector *clusters_out, + std::vector *assignments_out); + +} // end namespace kaldi + +#endif // KALDI_SEGMENTER_INFORMATION_BOTTLENECK_CLUSTER_UTILS_H_ diff --git a/src/segmenter/information-bottleneck-clusterable-test.cc b/src/segmenter/information-bottleneck-clusterable-test.cc new file mode 100644 index 00000000000..ee0358c8f05 --- /dev/null +++ b/src/segmenter/information-bottleneck-clusterable-test.cc @@ -0,0 +1,94 @@ + +#include "base/kaldi-common.h" +#include "segmenter/information-bottleneck-clusterable.h" + +namespace kaldi { + +static void TestClusterable() { + { + Vector a_vec(3); + a_vec(0) = 0.5; + a_vec(1) = 0.5; + int32 a_count = 100; + KALDI_ASSERT(ApproxEqual(a_vec.Sum(), 1.0)); + + Vector b_vec(3); + b_vec(1) = 0.333; + b_vec(2) = 0.667; + int32 b_count = 100; + KALDI_ASSERT(ApproxEqual(b_vec.Sum(), 1.0)); + + InformationBottleneckClusterable a(1, a_count, a_vec); + InformationBottleneckClusterable b(2, b_count, b_vec); + + Vector sum_vec(a_vec.Dim()); + sum_vec.AddVec(a_count, a_vec); + sum_vec.AddVec(b_count, b_vec); + sum_vec.Scale(1.0 / (a_count + b_count)); + KALDI_ASSERT(ApproxEqual(sum_vec.Sum(), 1.0)); + + InformationBottleneckClusterable sum(3); + InformationBottleneckClusterable c(3); + + sum.Add(a); + sum.Add(b); + + c.AddStats(1, a_count, a_vec); + c.AddStats(2, b_count, b_vec); + + KALDI_ASSERT(c.Counts() == sum.Counts()); + KALDI_ASSERT(ApproxEqual(c.Objf(), sum.Objf())); + KALDI_ASSERT(ApproxEqual(-c.Objf() + a.Objf() + b.Objf(), a.Distance(b))); + KALDI_ASSERT(sum_vec.ApproxEqual(c.RelevanceDist())); + KALDI_ASSERT(sum_vec.ApproxEqual(sum.RelevanceDist())); + } + + for (int32 i = 0; i < 100; i++) { + int32 dim = RandInt(2, 10); + + Vector a_vec(dim); + a_vec.SetRandn(); + a_vec.ApplyPowAbs(1.0); + a_vec.Scale(1 / a_vec.Sum()); + KALDI_ASSERT(ApproxEqual(a_vec.Sum(), 1.0)); + int32 a_count = RandInt(1, 100); + InformationBottleneckClusterable a(1, a_count, a_vec); + + Vector b_vec(dim); + b_vec.SetRandn(); + b_vec.ApplyPowAbs(1.0); + b_vec.Scale(1 / b_vec.Sum()); + KALDI_ASSERT(ApproxEqual(b_vec.Sum(), 1.0)); + int32 b_count = RandInt(1, 100); + InformationBottleneckClusterable b(2, b_count, b_vec); + + Vector sum_vec(a_vec.Dim()); + sum_vec.AddVec(a_count, a_vec); + sum_vec.AddVec(b_count, b_vec); + sum_vec.Scale(1.0 / (a_count + b_count)); + KALDI_ASSERT(ApproxEqual(sum_vec.Sum(), 1.0)); + + InformationBottleneckClusterable sum(dim); + InformationBottleneckClusterable c(dim); + + sum.Add(a); + sum.Add(b); + + c.AddStats(1, a_count, a_vec); + c.AddStats(2, b_count, b_vec); + + KALDI_ASSERT(c.Counts() == sum.Counts()); + KALDI_ASSERT(ApproxEqual(c.Objf(), sum.Objf())); + KALDI_ASSERT(ApproxEqual(-c.Objf() + a.Objf() + b.Objf(), a.Distance(b))); + KALDI_ASSERT(sum_vec.ApproxEqual(c.RelevanceDist())); + KALDI_ASSERT(sum_vec.ApproxEqual(sum.RelevanceDist())); + } +} + +} // end namespace kaldi + +int main() { + using namespace kaldi; + + TestClusterable(); +} diff --git a/src/segmenter/information-bottleneck-clusterable.cc b/src/segmenter/information-bottleneck-clusterable.cc new file mode 100644 index 00000000000..7817f7cfdc6 --- /dev/null +++ b/src/segmenter/information-bottleneck-clusterable.cc @@ -0,0 +1,226 @@ +// segmenter/information-bottleneck-clusterable.cc + +// Copyright 2017 Vimal Manohar (Johns Hopkins University) + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "segmenter/information-bottleneck-clusterable.h" + +namespace kaldi { + +void InformationBottleneckClusterable::AddStats( + int32 id, BaseFloat count, + const VectorBase &relevance_dist) { + std::map::iterator it = counts_.find(id); + KALDI_ASSERT(it == counts_.end() || it->first != id); + counts_.insert(it, std::make_pair(id, count)); + + double sum = relevance_dist.Sum(); + KALDI_ASSERT (sum != 0.0); + + p_yp_c_.Scale(total_count_); + p_yp_c_.AddVec(count / sum, relevance_dist); + total_count_ += count; + p_yp_c_.Scale(1.0 / total_count_); +} + +BaseFloat InformationBottleneckClusterable::Objf( + BaseFloat relevance_factor, BaseFloat input_factor) const { + double relevance_entropy = 0.0, count = 0.0; + for (int32 i = 0; i < p_yp_c_.Dim(); i++) { + if (p_yp_c_(i) > 1e-20) { + relevance_entropy -= p_yp_c_(i) * Log(p_yp_c_(i)); + count += p_yp_c_(i); + } + } + relevance_entropy = total_count_ * (relevance_entropy / count - Log(count)); + + double input_entropy = total_count_ * Log(total_count_); + for (std::map::const_iterator it = counts_.begin(); + it != counts_.end(); ++it) { + input_entropy -= it->second * Log(it->second); + } + + BaseFloat objf = -relevance_factor * relevance_entropy + + input_factor * input_entropy; + return objf; +} + +void InformationBottleneckClusterable::Add(const Clusterable &other_in) { + KALDI_ASSERT(other_in.Type() == "information-bottleneck"); + const InformationBottleneckClusterable *other = + static_cast (&other_in); + + for (std::map::const_iterator it = other->counts_.begin(); + it != other->counts_.end(); ++it) { + std::map::iterator hint_it = counts_.lower_bound( + it->first); + KALDI_ASSERT (hint_it == counts_.end() || hint_it->first != it->first); + counts_.insert(hint_it, *it); + } + + p_yp_c_.Scale(total_count_); + p_yp_c_.AddVec(other->total_count_, other->p_yp_c_); + total_count_ += other->total_count_; + p_yp_c_.Scale(1.0 / total_count_); +} + +void InformationBottleneckClusterable::Sub(const Clusterable &other_in) { + KALDI_ASSERT(other_in.Type() == "information-bottleneck"); + const InformationBottleneckClusterable *other = + static_cast (&other_in); + + for (std::map::const_iterator it = other->counts_.begin(); + it != other->counts_.end(); ++it) { + std::map::iterator hint_it = counts_.lower_bound( + it->first); + KALDI_ASSERT (hint_it->first == it->first); + counts_.erase(hint_it); + } + + p_yp_c_.Scale(total_count_); + p_yp_c_.AddVec(-other->total_count_, other->p_yp_c_); + total_count_ -= other->total_count_; + p_yp_c_.Scale(1.0 / total_count_); +} + +Clusterable* InformationBottleneckClusterable::Copy() const { + InformationBottleneckClusterable *ans = + new InformationBottleneckClusterable(RelevanceDim()); + ans->Add(*this); + return ans; +} + +void InformationBottleneckClusterable::Scale(BaseFloat f) { + KALDI_ASSERT(f >= 0.0); + for (std::map::iterator it = counts_.begin(); + it != counts_.end(); ++it) { + it->second *= f; + } + total_count_ *= f; +} + +void InformationBottleneckClusterable::Write( + std::ostream &os, bool binary) const { + WriteToken(os, binary, "IBCL"); // magic string. + WriteBasicType(os, binary, counts_.size()); + BaseFloat total_count = 0.0; + for (std::map::const_iterator it = counts_.begin(); + it != counts_.end(); ++it) { + WriteBasicType(os, binary, it->first); + WriteBasicType(os, binary, it->second); + total_count += it->second; + } + KALDI_ASSERT(ApproxEqual(total_count_, total_count)); + WriteToken(os, binary, ""); + p_yp_c_.Write(os, binary); +} + +Clusterable* InformationBottleneckClusterable::ReadNew( + std::istream &is, bool binary) const { + InformationBottleneckClusterable *ibc = + new InformationBottleneckClusterable(); + ibc->Read(is, binary); + return ibc; +} + +void InformationBottleneckClusterable::Read(std::istream &is, bool binary) { + ExpectToken(is, binary, "IBCL"); // magic string. + int32 size; + ReadBasicType(is, binary, &size); + + for (int32 i = 0; i < 2 * size; i++) { + int32 id; + BaseFloat count; + ReadBasicType(is, binary, &id); + ReadBasicType(is, binary, &count); + std::pair::iterator, bool> ret; + ret = counts_.insert(std::make_pair(id, count)); + if (!ret.second) { + KALDI_ERR << "Duplicate element " << id << " when reading counts"; + } + total_count_ += count; + } + + ExpectToken(is, binary, ""); + p_yp_c_.Read(is, binary); +} + +BaseFloat InformationBottleneckClusterable::ObjfPlus( + const Clusterable &other, BaseFloat relevance_factor, + BaseFloat input_factor) const { + InformationBottleneckClusterable *copy = static_cast(Copy()); + copy->Add(other); + BaseFloat ans = copy->Objf(relevance_factor, input_factor); + delete copy; + return ans; +} + +BaseFloat InformationBottleneckClusterable::ObjfMinus( + const Clusterable &other, BaseFloat relevance_factor, + BaseFloat input_factor) const { + InformationBottleneckClusterable *copy = static_cast(Copy()); + copy->Add(other); + BaseFloat ans = copy->Objf(relevance_factor, input_factor); + delete copy; + return ans; +} + +BaseFloat InformationBottleneckClusterable::Distance( + const Clusterable &other_in, BaseFloat relevance_factor, + BaseFloat input_factor) const { + KALDI_ASSERT(other_in.Type() == "information-bottleneck"); + const InformationBottleneckClusterable *other = + static_cast (&other_in); + + BaseFloat normalizer = this->Normalizer() + other->Normalizer(); + BaseFloat pi_i = this->Normalizer() / normalizer; + BaseFloat pi_j = other->Normalizer() / normalizer; + + // Compute the distribution q_Y(y) = p(y|{c_i} + {c_j}) + Vector relevance_dist(this->RelevanceDim()); + relevance_dist.AddVec(pi_i, this->RelevanceDist()); + relevance_dist.AddVec(pi_j, other->RelevanceDist()); + + BaseFloat relevance_divergence + = pi_i * KLDivergence(this->RelevanceDist(), relevance_dist) + + pi_j * KLDivergence(other->RelevanceDist(), relevance_dist); + + BaseFloat input_divergence + = Log(normalizer) - pi_i * Log(this->Normalizer()) + - pi_j * Log(other->Normalizer()); + + KALDI_ASSERT(relevance_divergence > -1e-4); + KALDI_ASSERT(input_divergence > -1e-4); + return (normalizer * (relevance_factor * relevance_divergence + - input_factor * input_divergence)); +} + +BaseFloat KLDivergence(const VectorBase &p1, + const VectorBase &p2) { + KALDI_ASSERT(p1.Dim() == p2.Dim()); + + double ans = 0.0, sum = 0.0; + for (int32 i = 0; i < p1.Dim(); i++) { + if (p1(i) > 1e-20) { + ans += p1(i) * Log(p1(i) / p2(i)); + sum += p1(i); + } + } + return ans / sum - Log(sum); +} + +} // end namespace kaldi diff --git a/src/segmenter/information-bottleneck-clusterable.h b/src/segmenter/information-bottleneck-clusterable.h new file mode 100644 index 00000000000..cb88d1221f7 --- /dev/null +++ b/src/segmenter/information-bottleneck-clusterable.h @@ -0,0 +1,163 @@ +// segmenter/information-bottleneck-clusterable.h + +// Copyright 2017 Vimal Manohar (Johns Hopkins University) + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#ifndef KALDI_SEGMENTER_INFORMATION_BOTTLENECK_CLUSTERABLE_H_ +#define KALDI_SEGMENTER_INFORMATION_BOTTLENECK_CLUSTERABLE_H_ + +#include "base/kaldi-common.h" +#include "matrix/kaldi-matrix.h" +#include "itf/clusterable-itf.h" + +namespace kaldi { + +class InformationBottleneckClusterable: public Clusterable { + public: + /// Constructor used for creating empty object e.g. when reading from file. + InformationBottleneckClusterable(): total_count_(0.0) { } + + /// Constructor initializing the relevant variable dimension. + /// Used for making Copy() of object. + InformationBottleneckClusterable(int32 relevance_dim) : + total_count_(0.0), p_yp_c_(relevance_dim) { } + + /// Constructor initializing from input stats corresponding to a + /// segment. + InformationBottleneckClusterable(int32 id, BaseFloat count, + const VectorBase &relevance_dist): + total_count_(0.0), p_yp_c_(relevance_dist.Dim()) { + AddStats(id, count, relevance_dist); + } + + /// Return a copy of this object. + virtual Clusterable* Copy() const; + + /// Return the objective function, which is + /// N(c) * (-r * H(Y|c) + ibeta * H(X|c)) + /// where N(c) is the total count in the cluster + /// H(Y|c) is the conditional entropy of the relevance + /// variable distribution + /// H(X|c) is the conditional entropy of the input variable + /// distribution + /// r is the weight on the relevant variables + /// ibeta is the weight on the input variables + virtual BaseFloat Objf(BaseFloat relevance_factor, + BaseFloat input_factor) const; + + /// Return the objective function with the default values + /// for relevant_factor (1.0) and input_factor (0.1) + virtual BaseFloat Objf() const { return Objf(1.0, 0.1); } + + /// Return the count in this cluster. + virtual BaseFloat Normalizer() const { return total_count_; } + + /// Set stats to empty. + virtual void SetZero() { + counts_.clear(); + p_yp_c_.Resize(0); + total_count_ = 0.0; + } + + /// Add stats to this object + virtual void AddStats(int32 id, BaseFloat count, + const VectorBase &relevance_dist); + + /// Add other stats. + virtual void Add(const Clusterable &other); + /// Subtract other stats. + virtual void Sub(const Clusterable &other); + /// Scale the stats by a positive number f. + virtual void Scale(BaseFloat f); + + /// Return a string that describes the clusterable type. + virtual std::string Type() const { return "information-bottleneck"; } + + /// Write data to stream. + virtual void Write(std::ostream &os, bool binary) const; + + /// Read data from a stream and return the corresponding object (const + /// function; it's a class member because we need access to the vtable + /// so generic code can read derived types). + virtual Clusterable* ReadNew(std::istream &is, bool binary) const; + + /// Read data from stream + virtual void Read(std::istream &is, bool binary); + + /// Return the objective function of the combined object this + other. + virtual BaseFloat ObjfPlus(const Clusterable &other, + BaseFloat relevance_factor, + BaseFloat input_factor) const; + + /// Same as the above function, but using default values for + /// relevance_factor (1.0) and input_factor (0.1) + virtual BaseFloat ObjfPlus(const Clusterable &other) const { + return ObjfPlus(other, 1.0, 0.1); + } + + /// Return the objective function of the combined object this + other. + virtual BaseFloat ObjfMinus(const Clusterable &other, + BaseFloat relevance_factor, + BaseFloat input_factor) const; + + /// Same as the above function, but using default values for + /// relevance_factor (1.0) and input_factor (0.1) + virtual BaseFloat ObjfMinus(const Clusterable &other) const { + return ObjfMinus(other, 1.0, 0.1); + } + + /// Return the objective function decrease from merging the two + /// clusters. + /// Always a non-negative number. + virtual BaseFloat Distance(const Clusterable &other, + BaseFloat relevance_factor, + BaseFloat input_factor) const; + + /// Same as the above function, but using default values for + /// relevance_factor (1.0) and input_factor (0.1) + virtual BaseFloat Distance(const Clusterable &other) const { + return Distance(other, 1.0, 0.1); + } + + virtual ~InformationBottleneckClusterable() {} + + /// Public accessors + virtual const Vector& RelevanceDist() const { return p_yp_c_; } + virtual int32 RelevanceDim() const { return p_yp_c_.Dim(); } + + virtual const std::map& Counts() const { return counts_; } + + private: + /// A list of the original segments this cluster contains along with + /// their corresponding counts. + std::map counts_; + + /// Total count in this cluster. + BaseFloat total_count_; + + /// Relevant variable distribution. + /// TODO: Make sure that this is a valid probability distribution. + Vector p_yp_c_; +}; + +/// Returns the KL Divergence between two probability distributions. +BaseFloat KLDivergence(const VectorBase &p1, + const VectorBase &p2); + +} // end namespace kaldi + +#endif // KALDI_SEGMENTER_INFORMATION_BOTTLENECK_CLUSTERABLE_H_ diff --git a/src/tree/cluster-utils.cc b/src/tree/cluster-utils.cc index 53de0825e08..965eb104d9e 100644 --- a/src/tree/cluster-utils.cc +++ b/src/tree/cluster-utils.cc @@ -190,62 +190,6 @@ void AddToClustersOptimized(const std::vector &stats, // Bottom-up clustering routines // ============================================================================ -class BottomUpClusterer { - public: - BottomUpClusterer(const std::vector &points, - BaseFloat max_merge_thresh, - int32 min_clust, - std::vector *clusters_out, - std::vector *assignments_out) - : ans_(0.0), points_(points), max_merge_thresh_(max_merge_thresh), - min_clust_(min_clust), clusters_(clusters_out != NULL? clusters_out - : &tmp_clusters_), assignments_(assignments_out != NULL ? - assignments_out : &tmp_assignments_) { - nclusters_ = npoints_ = points.size(); - dist_vec_.resize((npoints_ * (npoints_ - 1)) / 2); - } - - BaseFloat Cluster(); - ~BottomUpClusterer() { DeletePointers(&tmp_clusters_); } - - private: - void Renumber(); - void InitializeAssignments(); - void SetInitialDistances(); ///< Sets up distances and queue. - /// CanMerge returns true if i and j are existing clusters, and the distance - /// (negated objf-change) "dist" is accurate (i.e. not outdated). - bool CanMerge(int32 i, int32 j, BaseFloat dist); - /// Merge j into i and delete j. - void MergeClusters(int32 i, int32 j); - /// Reconstructs the priority queue from the distances. - void ReconstructQueue(); - - void SetDistance(int32 i, int32 j); - BaseFloat& Distance(int32 i, int32 j) { - KALDI_ASSERT(i < npoints_ && j < i); - return dist_vec_[(i * (i - 1)) / 2 + j]; - } - - BaseFloat ans_; - const std::vector &points_; - BaseFloat max_merge_thresh_; - int32 min_clust_; - std::vector *clusters_; - std::vector *assignments_; - - std::vector tmp_clusters_; - std::vector tmp_assignments_; - - std::vector dist_vec_; - int32 nclusters_; - int32 npoints_; - typedef std::pair > QueueElement; - // Priority queue using greater (lowest distances are highest priority). - typedef std::priority_queue, - std::greater > QueueType; - QueueType queue_; -}; - BaseFloat BottomUpClusterer::Cluster() { KALDI_VLOG(2) << "Initializing cluster assignments."; InitializeAssignments(); @@ -253,12 +197,15 @@ BaseFloat BottomUpClusterer::Cluster() { SetInitialDistances(); KALDI_VLOG(2) << "Clustering..."; - while (nclusters_ > min_clust_ && !queue_.empty()) { + while (!StoppingCriterion()) { std::pair > pr = queue_.top(); BaseFloat dist = pr.first; int32 i = (int32) pr.second.first, j = (int32) pr.second.second; queue_.pop(); - if (CanMerge(i, j, dist)) MergeClusters(i, j); + if (CanMerge(i, j, dist)) { + UpdateClustererStats(i, j); + MergeClusters(i, j); + } } KALDI_VLOG(2) << "Renumbering clusters to contiguous numbers."; Renumber(); @@ -325,11 +272,12 @@ void BottomUpClusterer::InitializeAssignments() { void BottomUpClusterer::SetInitialDistances() { for (int32 i = 0; i < npoints_; i++) { for (int32 j = 0; j < i; j++) { - BaseFloat dist = (*clusters_)[i]->Distance(*((*clusters_)[j])); - dist_vec_[(i * (i - 1)) / 2 + j] = dist; + BaseFloat dist = ComputeDistance(i, j); if (dist <= max_merge_thresh_) queue_.push(std::make_pair(dist, std::make_pair(static_cast(i), static_cast(j)))); + if (j == i - 1) + KALDI_VLOG(2) << "Distance(" << i << ", " << j << ") = " << dist; } } } @@ -344,6 +292,7 @@ bool BottomUpClusterer::CanMerge(int32 i, int32 j, BaseFloat dist) { void BottomUpClusterer::MergeClusters(int32 i, int32 j) { KALDI_ASSERT(i != j && i < npoints_ && j < npoints_); + (*clusters_)[i]->Add(*((*clusters_)[j])); delete (*clusters_)[j]; (*clusters_)[j] = NULL; @@ -389,8 +338,7 @@ void BottomUpClusterer::ReconstructQueue() { void BottomUpClusterer::SetDistance(int32 i, int32 j) { KALDI_ASSERT(i < npoints_ && j < i && (*clusters_)[i] != NULL && (*clusters_)[j] != NULL); - BaseFloat dist = (*clusters_)[i]->Distance(*((*clusters_)[j])); - dist_vec_[(i * (i - 1)) / 2 + j] = dist; // set the distance in the array. + BaseFloat dist = ComputeDistance(i, j); if (dist < max_merge_thresh_) { queue_.push(std::make_pair(dist, std::make_pair(static_cast(i), static_cast(j)))); @@ -403,7 +351,6 @@ void BottomUpClusterer::SetDistance(int32 i, int32 j) { } - BaseFloat ClusterBottomUp(const std::vector &points, BaseFloat max_merge_thresh, int32 min_clust, diff --git a/src/tree/cluster-utils.h b/src/tree/cluster-utils.h index 55583a237bf..b11dfe1c031 100644 --- a/src/tree/cluster-utils.h +++ b/src/tree/cluster-utils.h @@ -21,10 +21,14 @@ #ifndef KALDI_TREE_CLUSTER_UTILS_H_ #define KALDI_TREE_CLUSTER_UTILS_H_ +#include #include +using std::vector; #include "matrix/matrix-lib.h" +#include "util/stl-utils.h" #include "itf/clusterable-itf.h" + namespace kaldi { /// \addtogroup clustering_group_simple @@ -103,9 +107,92 @@ void AddToClustersOptimized(const std::vector &stats, * @param assignments_out [out] If non-NULL, will be resized to the number of * points, and each element is the index of the cluster that point * was assigned to. + */ + +class BottomUpClusterer { + public: + typedef uint16 uint_smaller; + typedef int16 int_smaller; + + BottomUpClusterer(const std::vector &points, + BaseFloat max_merge_thresh, + int32 min_clust, + std::vector *clusters_out, + std::vector *assignments_out) + : ans_(0.0), points_(points), max_merge_thresh_(max_merge_thresh), + min_clust_(min_clust), clusters_(clusters_out != NULL? clusters_out + : &tmp_clusters_), assignments_(assignments_out != NULL ? + assignments_out : &tmp_assignments_) { + nclusters_ = npoints_ = points.size(); + dist_vec_.resize((npoints_ * (npoints_ - 1)) / 2); + } + + BaseFloat Cluster(); + ~BottomUpClusterer() { DeletePointers(&tmp_clusters_); } + + /// Public accessors + const Clusterable* GetCluster(int32 i) const { return (*clusters_)[i]; } + BaseFloat& Distance(int32 i, int32 j) { + KALDI_ASSERT(i < npoints_ && j < i); + return dist_vec_[(i * (i - 1)) / 2 + j]; + } + /// CanMerge returns true if i and j are existing clusters, and the distance + /// (negated objf-change) "dist" is accurate (i.e. not outdated). + virtual bool CanMerge(int32 i, int32 j, BaseFloat dist); + + /// Merge j into i and delete j. + virtual void MergeClusters(int32 i, int32 j); + + int32 NumClusters() const { return nclusters_; } + int32 NumPoints() const { return npoints_; } + int32 MinClusters() const { return min_clust_; } + bool IsQueueEmpty() const { return queue_.empty(); } + + private: + void Renumber(); + void InitializeAssignments(); + void SetInitialDistances(); ///< Sets up distances and queue. + /// Reconstructs the priority queue from the distances. + void ReconstructQueue(); + + /// Update some stats to reflect merging clusters i and j + virtual void UpdateClustererStats(int32, int32 j) { }; + + virtual bool StoppingCriterion() const { + return nclusters_ <= min_clust_ || queue_.empty(); + } + + void SetDistance(int32 i, int32 j); + virtual BaseFloat ComputeDistance(int32 i, int32 j) { + BaseFloat dist = (*clusters_)[i]->Distance(*((*clusters_)[j])); + dist_vec_[(i * (i - 1)) / 2 + j] = dist; // set the distance in the array. + return dist; + } + + BaseFloat ans_; + const std::vector &points_; + BaseFloat max_merge_thresh_; + int32 min_clust_; + std::vector *clusters_; + std::vector *assignments_; + + std::vector tmp_clusters_; + std::vector tmp_assignments_; + + std::vector dist_vec_; + int32 nclusters_; + int32 npoints_; + typedef std::pair > QueueElement; + // Priority queue using greater (lowest distances are highest priority). + typedef std::priority_queue, + std::greater > QueueType; + QueueType queue_; +}; + +/** This is a wrapper function to the BottomUpClusterer class. * @return Returns the total objf change relative to all clusters being separate, which is * a negative. Note that this is not the same as what the other clustering algorithms return. - */ + **/ BaseFloat ClusterBottomUp(const std::vector &points, BaseFloat thresh, int32 min_clust, From 403bde7b3470c34967b6a86e24d2c2cd4a98623e Mon Sep 17 00:00:00 2001 From: Vimal Manohar Date: Wed, 25 Jan 2017 15:14:53 -0500 Subject: [PATCH 166/213] asr_diarization: Add intersect int vectors --- src/segmenterbin/intersect-int-vectors.cc | 158 ++++++++++++++++++++++ 1 file changed, 158 insertions(+) create mode 100644 src/segmenterbin/intersect-int-vectors.cc diff --git a/src/segmenterbin/intersect-int-vectors.cc b/src/segmenterbin/intersect-int-vectors.cc new file mode 100644 index 00000000000..53731bf9046 --- /dev/null +++ b/src/segmenterbin/intersect-int-vectors.cc @@ -0,0 +1,158 @@ +// segmenterbin/intersect-int-vectors.cc + +// Copyright 2017 Vimal Manohar (Johns Hopkins University) + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "base/kaldi-common.h" +#include "util/common-utils.h" + +int main(int argc, char *argv[]) { + try { + using namespace kaldi; + + const char *usage = + "Intersect two integer vectors and create a new integer vectors " + "whole ids are defined as the cross-products of the integer " + "ids from the two vectors.\n" + "\n" + "Usage: intersect-int-vectors [options] " + " \n" + " e.g.: intersect-int-vectors ark:1.ali ark:2.ali ark:-\n" + "See also: segmentation-init-from-segments, " + "segmentation-combine-segments\n"; + + ParseOptions po(usage); + + std::string mapping_rxfilename, mapping_wxfilename; + int32 length_tolerance = 0; + + po.Register("mapping-in", &mapping_rxfilename, + "A file with three columns that define the mapping from " + "a pair of integers to a third one."); + po.Register("mapping-out", &mapping_wxfilename, + "Write a mapping in the same format as --mapping-in, " + "but let the program decide the mapping to unique integer " + "ids."); + po.Register("length-tolerance", &length_tolerance, + "Tolerance this number of frames of mismatch between the " + "two integer vector pairs."); + + po.Read(argc, argv); + + if (po.NumArgs() != 3) { + po.PrintUsage(); + exit(1); + } + + std::string ali_rspecifier1 = po.GetArg(1), + ali_rspecifier2 = po.GetArg(2), + ali_wspecifier = po.GetArg(3); + + std::map, int32> mapping; + if (!mapping_rxfilename.empty()) { + Input ki(mapping_rxfilename); + std::string line; + while (std::getline(ki.Stream(), line)) { + std::vector parts; + SplitStringToVector(line, " ", true, &parts); + KALDI_ASSERT(parts.size() == 3); + + std::pair id_pair = std::make_pair( + std::atoi(parts[0].c_str()), std::atoi(parts[1].c_str())); + int32 id_new = std::atoi(parts[2].c_str()); + KALDI_ASSERT(id_new >= 0); + + std::map, int32>::iterator it = + mapping.lower_bound(id_pair); + KALDI_ASSERT(it == mapping.end() || it->first != id_pair); + + mapping.insert(it, std::make_pair(id_pair, id_new)); + } + } + + SequentialInt32VectorReader ali_reader1(ali_rspecifier1); + RandomAccessInt32VectorReader ali_reader2(ali_rspecifier2); + + Int32VectorWriter ali_writer(ali_wspecifier); + + int32 num_ids = 0, num_err = 0, num_done = 0; + + for (; !ali_reader1.Done(); ali_reader1.Next()) { + const std::string &key = ali_reader1.Key(); + + if (!ali_reader2.HasKey(key)) { + KALDI_WARN << "Could not find second alignment for key " << key + << "in " << ali_rspecifier2; + num_err++; + continue; + } + + const std::vector &alignment1 = ali_reader1.Value(); + const std::vector &alignment2 = ali_reader2.Value(key); + + if (static_cast(alignment1.size()) + - static_cast(alignment2.size()) > length_tolerance) { + KALDI_WARN << "Mismatch in length of alignments in " + << ali_rspecifier1 << " and " << ali_rspecifier2 + << "; " << alignment1.size() << " vs " + << alignment2.size(); + num_err++; + } + + std::vector alignment_out(alignment1.size()); + + for (size_t i = 0; i < alignment1.size(); i++) { + std::pair id_pair = std::make_pair( + alignment1[i], alignment2[i]); + + std::map, int32>::iterator it = + mapping.lower_bound(id_pair); + + int32 id_new = -1; + if (!mapping_rxfilename.empty()) { + if (it == mapping.end() || it->first != id_pair) { + KALDI_ERR << "Could not find id-pair (" << id_pair.first + << ", " << id_pair.second + << ") in mapping " << mapping_rxfilename; + } + id_new = it->second; + } else { + if (it == mapping.end() || it->first != id_pair) { + id_new = ++num_ids; + mapping.insert(it, std::make_pair(id_pair, id_new)); + } else { + id_new = it->second; + } + } + + alignment_out[i] = id_new; + } + + ali_writer.Write(key, alignment_out); + num_done++; + } + + KALDI_LOG << "Intersected " << num_done << " int vector pairs; " + << "failed with " << num_err; + + return ((num_done > 0 && num_err < num_done) ? 0 : 1); + } catch(const std::exception &e) { + std::cerr << e.what(); + return -1; + } +} + From 71f0de65b482c419d18707ca3b5bc6ebdd7978e2 Mon Sep 17 00:00:00 2001 From: Vimal Manohar Date: Wed, 25 Jan 2017 15:15:09 -0500 Subject: [PATCH 167/213] asr_diarization: Clustering using IB --- .../segmentation-cluster-adjacent-segments.cc | 290 ++++++++++++++++++ 1 file changed, 290 insertions(+) create mode 100644 src/segmenterbin/segmentation-cluster-adjacent-segments.cc diff --git a/src/segmenterbin/segmentation-cluster-adjacent-segments.cc b/src/segmenterbin/segmentation-cluster-adjacent-segments.cc new file mode 100644 index 00000000000..812785ac5e6 --- /dev/null +++ b/src/segmenterbin/segmentation-cluster-adjacent-segments.cc @@ -0,0 +1,290 @@ +// segmenterbin/segmentation-merge.cc + +// Copyright 2017 Vimal Manohar (Johns Hopkins University) + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "base/kaldi-common.h" +#include "util/common-utils.h" +#include "segmenter/segmentation-utils.h" +#include "tree/clusterable-classes.h" + +namespace kaldi { +namespace segmenter { + +BaseFloat Distance(const Segment &seg1, const Segment &seg2, + const MatrixBase &feats, + BaseFloat var_floor, + int32 length_tolerance = 2) { + int32 start1 = seg1.start_frame; + int32 end1 = seg1.end_frame; + + int32 start2 = seg2.start_frame; + int32 end2 = seg2.end_frame; + + if (end1 > feats.NumRows() + length_tolerance) { + KALDI_ERR << "Segment end > feature length; " << end1 + << " vs " << feats.NumRows(); + } + + GaussClusterable stats1(feats.NumCols(), var_floor); + for (int32 i = start1; i < std::min(end1, feats.NumRows()); i++) { + stats1.AddStats(feats.Row(i)); + } + Vector means1(stats1.x_stats()); + means1.Scale(1.0 / stats1.count()); + Vector vars1(stats1.x2_stats()); + vars1.Scale(1.0 / stats1.count()); + vars1.AddVec2(-1.0, means1); + vars1.ApplyFloor(var_floor); + + GaussClusterable stats2(feats.NumCols(), var_floor); + for (int32 i = start2; i < std::min(end2, feats.NumRows()); i++) { + stats2.AddStats(feats.Row(i)); + } + Vector means2(stats2.x_stats()); + means2.Scale(1.0 / stats2.count()); + Vector vars2(stats2.x2_stats()); + vars2.Scale(1.0 / stats2.count()); + vars2.AddVec2(-1.0, means2); + vars2.ApplyFloor(var_floor); + + double ans = 0.0; + for (int32 i = 0; i < feats.NumCols(); i++) { + ans += (vars1(i) / vars2(i) + vars2(i) / vars1(i) + + (means2(i) - means1(i)) * (means2(i) - means1(i)) + * (1.0 / vars1(i) + 1.0 / vars2(i))); + } + + return ans; +} + +int32 ClusterAdjacentSegments(const MatrixBase &feats, + BaseFloat absolute_distance_threshold, + BaseFloat delta_distance_threshold, + BaseFloat var_floor, + int32 length_tolerance, + Segmentation *segmentation) { + if (segmentation->Dim() == 1) { + segmentation->Begin()->SetLabel(1); + return 1; + } + + SegmentList::iterator it = segmentation->Begin(), + next_it = segmentation->Begin(); + ++next_it; + + BaseFloat prev_dist = Distance(*it, *next_it, feats, + var_floor, length_tolerance); + + if (segmentation->Dim() == 2) { + it->SetLabel(1); + if (prev_dist < absolute_distance_threshold * feats.NumCols()) { + // Similar segments merged. + next_it->SetLabel(it->Label()); + } else { + // Segments not merged. + next_it->SetLabel(it->Label() + 1); + } + + return next_it->Label();; + } + + ++it; + ++next_it; + bool next_segment_is_new_cluster = false; + + for (; next_it != segmentation->End(); ++it, ++next_it) { + SegmentList::iterator prev_it(it); + --prev_it; + + // Compute distance between this and next segment. + BaseFloat dist = Distance(*it, *next_it, feats, var_floor, + length_tolerance); + + // Possibly merge current segment if previous. + if (next_segment_is_new_cluster || + (prev_it->end_frame + 1 >= it->start_frame && + prev_dist < absolute_distance_threshold * feats.NumCols())) { + // Previous and current segment are next to each other. + // Merge current segment with previous. + it->SetLabel(prev_it->Label()); + + KALDI_VLOG(3) << "Merging clusters " << *prev_it << " and " << *it + << " ; dist = " << prev_dist; + } else { + it->SetLabel(prev_it->Label() + 1); + KALDI_VLOG(3) << "Not merging merging cluster " << *prev_it + << " and " << *it << " ; dist = " << prev_dist; + } + + // Decide if the current segment must be merged with next. + if (prev_it->end_frame + 1 >= it->start_frame && + it->end_frame + 1 >= next_it->start_frame) { + // All 3 segments are adjacent. + if (dist - prev_dist > delta_distance_threshold * feats.NumCols()) { + // Next segment is very different from the current and previous segment. + // So create a new cluster for the next segment. + next_segment_is_new_cluster = true; + } else { + next_segment_is_new_cluster = false; + } + } + + prev_dist = dist; + } + + SegmentList::iterator prev_it(it); + --prev_it; + if (next_segment_is_new_cluster || + (prev_it->end_frame + 1 >= it->start_frame && + prev_dist < absolute_distance_threshold * feats.NumCols())) { + // Merge current segment with previous. + it->SetLabel(prev_it->Label()); + + KALDI_VLOG(3) << "Merging clusters " << *prev_it << " and " << *it + << " ; dist = " << prev_dist; + } else { + it->SetLabel(prev_it->Label() + 1); + } + + return it->Label(); +} + +} // end segmenter +} // end kaldi + + +int main(int argc, char *argv[]) { + try { + using namespace kaldi; + using namespace segmenter; + + const char *usage = + "Merge adjacent segments that are similar to each other.\n" + "\n" + "Usage: segmentation-cluster-adjacent-segments [options] " + " \n" + " e.g.: segmentation-cluster-adjacent-segments ark:foo.seg ark:feats.ark ark,t:-\n" + "See also: segmentation-merge, segmentation-merge-recordings, " + "segmentation-post-process --merge-labels\n"; + + bool binary = true; + int32 length_tolerance = 2; + BaseFloat var_floor = 0.01; + BaseFloat absolute_distance_threshold = 3.0; + BaseFloat delta_distance_threshold = 0.2; + + ParseOptions po(usage); + + po.Register("binary", &binary, + "Write in binary mode " + "(only relevant if output is a wxfilename)"); + po.Register("length-tolerance", &length_tolerance, + "Tolerate length difference between segmentation and " + "features if its less than this many frames."); + po.Register("variance-floor", &var_floor, + "Variance floor of Gaussians used in computing distances " + "for clustering."); + po.Register("absolute-distance-threshold", &absolute_distance_threshold, + "Maximum per-dim distance below which segments will not be " + "be merged."); + po.Register("delta-distance-threshold", &delta_distance_threshold, + "If the delta-distance is below this value, then the " + "adjacent segments will not be merged."); + + po.Read(argc, argv); + + if (po.NumArgs() != 3) { + po.PrintUsage(); + exit(1); + } + + std::string segmentation_in_fn = po.GetArg(1), + feats_in_fn = po.GetArg(2), + segmentation_out_fn = po.GetArg(3); + + // all these "fn"'s are either rspecifiers or filenames. + bool in_is_rspecifier = + (ClassifyRspecifier(segmentation_in_fn, NULL, NULL) + != kNoRspecifier), + out_is_wspecifier = + (ClassifyWspecifier(segmentation_out_fn, NULL, NULL, NULL) + != kNoWspecifier); + + if (in_is_rspecifier != out_is_wspecifier) + KALDI_ERR << "Cannot mix regular files and archives"; + + if (!in_is_rspecifier) { + Segmentation segmentation; + ReadKaldiObject(segmentation_in_fn, &segmentation); + + Matrix feats; + ReadKaldiObject(feats_in_fn, &feats); + + Sort(&segmentation); + int32 num_clusters = ClusterAdjacentSegments( + feats, absolute_distance_threshold, delta_distance_threshold, + var_floor, length_tolerance, + &segmentation); + + KALDI_LOG << "Clustered segments; got " << num_clusters << " clusters."; + WriteKaldiObject(segmentation, segmentation_out_fn, binary); + + return 0; + } else { + int32 num_done = 0, num_err = 0; + + SequentialSegmentationReader segmentation_reader(segmentation_in_fn); + RandomAccessBaseFloatMatrixReader feats_reader(feats_in_fn); + SegmentationWriter segmentation_writer(segmentation_out_fn); + + for (; !segmentation_reader.Done(); segmentation_reader.Next()) { + Segmentation segmentation(segmentation_reader.Value()); + const std::string &key = segmentation_reader.Key(); + + if (!feats_reader.HasKey(key)) { + KALDI_WARN << "Could not find key " << key << " in " + << "feats-rspecifier " << feats_in_fn; + num_err++; + continue; + } + + const MatrixBase &feats = feats_reader.Value(key); + + Sort(&segmentation); + int32 num_clusters = ClusterAdjacentSegments( + feats, absolute_distance_threshold, delta_distance_threshold, + var_floor, length_tolerance, + &segmentation); + KALDI_VLOG(2) << "For key " << key << ", got " << num_clusters + << " clusters."; + + segmentation_writer.Write(key, segmentation); + num_done++; + } + + KALDI_LOG << "Clustered segments from " << num_done << " recordings " + << "failed with " << num_err; + return (num_done != 0 ? 0 : 1); + } + } catch(const std::exception &e) { + std::cerr << e.what(); + return -1; + } +} + + From 120ac02034b2e748f050279309bbd40601a1d348 Mon Sep 17 00:00:00 2001 From: Vimal Manohar Date: Wed, 25 Jan 2017 15:15:22 -0500 Subject: [PATCH 168/213] asr_diarization: aib cluster --- src/segmenterbin/agglomerative-cluster-ib.cc | 160 +++++++++++++++++++ 1 file changed, 160 insertions(+) create mode 100644 src/segmenterbin/agglomerative-cluster-ib.cc diff --git a/src/segmenterbin/agglomerative-cluster-ib.cc b/src/segmenterbin/agglomerative-cluster-ib.cc new file mode 100644 index 00000000000..489b24c24bc --- /dev/null +++ b/src/segmenterbin/agglomerative-cluster-ib.cc @@ -0,0 +1,160 @@ +// segmenterbin/agglomerative-cluster-ib.cc + +// Copyright 2017 Vimal Manohar (Johns Hopkins University) + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "base/kaldi-common.h" +#include "util/common-utils.h" +#include "tree/cluster-utils.h" +#include "segmenter/information-bottleneck-cluster-utils.h" + +int main(int argc, char *argv[]) { + try { + using namespace kaldi; + + const char *usage = + "Cluster per-utterance probability distributions of " + "relevance variables using Information Bottleneck principle.\n" + "Usage: agglomerative-cluster-ib [options] " + " \n" + " e.g.: agglomerative-cluster-ib ark:avg_post.1.ark " + "ark,t:data/dev/reco2utt ark,t:labels.txt"; + + ParseOptions po(usage); + + InformationBottleneckClustererOptions opts; + + std::string reco2num_clusters_rspecifier; + std::string counts_rspecifier; + int32 junk_label = -2; + BaseFloat max_merge_thresh = std::numeric_limits::max(); + int32 min_clusters = 1; + + po.Register("reco2num-clusters-rspecifier", &reco2num_clusters_rspecifier, + "If supplied, clustering creates exactly this many clusters " + "for the corresponding recording."); + po.Register("counts-rspecifier", &counts_rspecifier, + "The counts for each of the initial segments. If not specified " + "the count is taken to be 1 for each segment."); + po.Register("junk-label", &junk_label, + "Assign this label to utterances that could not be clustered"); + po.Register("max-merge-thresh", &max_merge_thresh, + "Threshold on cost change from merging clusters; clusters " + "won't be merged if the cost is more than this."); + po.Register("min-clusters", &min_clusters, + "Mininum number of clusters desired; we'll stop merging " + "after reaching this number."); + + opts.Register(&po); + po.Read(argc, argv); + + if (po.NumArgs() != 3) { + po.PrintUsage(); + exit(1); + } + + std::string relevance_prob_rspecifier = po.GetArg(1), + reco2utt_rspecifier = po.GetArg(2), + label_wspecifier = po.GetArg(3); + + RandomAccessBaseFloatVectorReader relevance_prob_reader( + relevance_prob_rspecifier); + SequentialTokenVectorReader reco2utt_reader(reco2utt_rspecifier); + RandomAccessInt32Reader reco2num_clusters_reader( + reco2num_clusters_rspecifier); + Int32Writer label_writer(label_wspecifier); + RandomAccessBaseFloatReader counts_reader(counts_rspecifier); + + int32 count = 1, num_utt_err = 0, num_reco_err = 0, num_done = 0, + num_reco = 0; + + for (; !reco2utt_reader.Done(); reco2utt_reader.Next()) { + const std::vector &uttlist = reco2utt_reader.Value(); + const std::string &reco = reco2utt_reader.Key(); + + std::vector points; + points.reserve(uttlist.size()); + + int32 id = 0; + for (std::vector::const_iterator it = uttlist.begin(); + it != uttlist.end(); ++it, id++) { + if (!relevance_prob_reader.HasKey(*it)) { + KALDI_WARN << "Could not find relevance probability distribution " + << "for utterance " << *it << " in archive " + << relevance_prob_rspecifier; + num_utt_err++; + continue; + } + + if (!counts_rspecifier.empty()) { + if (!counts_reader.HasKey(*it)) { + KALDI_WARN << "Could not find counts for utterance " << *it; + num_utt_err++; + continue; + } + count = counts_reader.Value(*it); + } + + const Vector& relevance_prob = + relevance_prob_reader.Value(*it); + + points.push_back( + new InformationBottleneckClusterable(id, count, relevance_prob)); + num_done++; + } + + std::vector clusters_out; + std::vector assignments_out; + + int32 this_num_clusters = min_clusters; + + if (!reco2num_clusters_rspecifier.empty()) { + if (!reco2num_clusters_reader.HasKey(reco)) { + KALDI_WARN << "Could not find num-clusters for recording " + << reco; + num_reco_err++; + } else { + this_num_clusters = reco2num_clusters_reader.Value(reco); + } + } + + IBClusterBottomUp(points, opts, max_merge_thresh, this_num_clusters, + NULL, &assignments_out); + + for (int32 i = 0; i < points.size(); i++) { + InformationBottleneckClusterable* point + = static_cast (points[i]); + int32 id = point->Counts().begin()->first; + const std::string &utt = uttlist[id]; + label_writer.Write(utt, assignments_out[i] + 1); + } + + DeletePointers(&points); + num_reco++; + } + + KALDI_LOG << "Clustered " << num_done << " segments from " + << num_reco << " recordings; failed with " + << num_utt_err << " segments and " + << num_reco_err << " recordings."; + + return (num_done > 0 ? 0 : 1); + } catch(const std::exception &e) { + std::cerr << e.what(); + return -1; + } +} From 932073b2a537e5e79d330b6d52b3bc3af93a0b7e Mon Sep 17 00:00:00 2001 From: Vimal Manohar Date: Wed, 25 Jan 2017 15:16:02 -0500 Subject: [PATCH 169/213] asr_diarization: LSTM SAD music --- .../tuning/train_lstm_sad_music_1a.sh | 267 ++++++++++++++++ .../tuning/train_lstm_sad_music_1b.sh | 265 ++++++++++++++++ .../tuning/train_lstm_sad_music_1c.sh | 265 ++++++++++++++++ .../tuning/train_lstm_sad_music_1e.sh | 269 ++++++++++++++++ .../tuning/train_lstm_sad_music_1f.sh | 291 ++++++++++++++++++ 5 files changed, 1357 insertions(+) create mode 100644 egs/aspire/s5/local/segmentation/tuning/train_lstm_sad_music_1a.sh create mode 100644 egs/aspire/s5/local/segmentation/tuning/train_lstm_sad_music_1b.sh create mode 100644 egs/aspire/s5/local/segmentation/tuning/train_lstm_sad_music_1c.sh create mode 100644 egs/aspire/s5/local/segmentation/tuning/train_lstm_sad_music_1e.sh create mode 100644 egs/aspire/s5/local/segmentation/tuning/train_lstm_sad_music_1f.sh diff --git a/egs/aspire/s5/local/segmentation/tuning/train_lstm_sad_music_1a.sh b/egs/aspire/s5/local/segmentation/tuning/train_lstm_sad_music_1a.sh new file mode 100644 index 00000000000..4f0754d8355 --- /dev/null +++ b/egs/aspire/s5/local/segmentation/tuning/train_lstm_sad_music_1a.sh @@ -0,0 +1,267 @@ +#!/bin/bash + +# This is a script to train a time-delay neural network for speech activity detection (SAD) and +# music-id using statistic pooling component for long-context information. +# This script is same as 1e, but removes the stats component in the 3rd layer. + +set -o pipefail +set -e +set -u + +. cmd.sh + +# At this script level we don't support not running on GPU, as it would be painfully slow. +# If you want to run without GPU you'd have to call train_tdnn.sh with --gpu false, +# --num-threads 16 and --minibatch-size 128. + +stage=0 +train_stage=-10 +get_egs_stage=-10 +egs_opts= # Directly passed to get_egs_multiple_targets.py + +chunk_width=40 +num_chunk_per_minibatch=64 + +extra_left_context=80 +extra_right_context=0 + +relu_dim=256 +cell_dim=256 +projection_dim=64 + +# training options +num_epochs=2 +initial_effective_lrate=0.0003 +final_effective_lrate=0.00003 +num_jobs_initial=3 +num_jobs_final=8 +remove_egs=false +max_param_change=0.2 # Small max-param change for small network +extra_egs_copy_cmd= # Used if you want to do some weird stuff to egs + # such as removing one of the targets + +sad_data_dir=data/train_aztec_unsad_whole_corrupted_sp_hires_bp_2400 +music_data_dir=data/train_aztec_unsad_whole_music_corrupted_sp_hires_bp + +extra_egs_copy_cmd="nnet3-copy-egs --keep-outputs=output-speech,output-music ark:- ark:- |" + +egs_dir= +nj=40 +feat_type=raw +config_dir= + +dir= +affix=1a + +. cmd.sh +. ./path.sh +. ./utils/parse_options.sh + +num_utts=`cat $sad_data_dir/utt2spk $music_data_dir/utt2spk | wc -l` +num_utts_subset_valid=`perl -e '$n=int($ARGV[0] * 0.005); print ($n > 4000 ? 4000 : $n)' $num_utts` +num_utts_subset_train=`perl -e '$n=int($ARGV[0] * 0.005); print ($n > 4000 ? 4000 : $n)' $num_utts` + +if [ -z "$dir" ]; then + dir=exp/nnet3_lstm_sad_music_snr/nnet_lstm +fi + +dir=$dir${affix:+_$affix} + +if ! cuda-compiled; then + cat < $dir/configs/network.xconfig + input dim=`feat-to-dim scp:$sad_data_dir/feats.scp -` name=input + + relu-renorm-layer name=tdnn1 input=Append(-2,-1,0,1,2) dim=$relu_dim add-log-stddev=true + relu-renorm-layer name=tdnn2 input=Append(-1,0,1,2) dim=$relu_dim add-log-stddev=true + relu-renorm-layer name=tdnn3 input=Append(-3,0,3,6) dim=$relu_dim add-log-stddev=true + fast-lstmp-layer name=lstm1 cell-dim=$cell_dim recurrent-projection-dim=$projection_dim non-recurrent-projection-dim=$projection_dim delay=-3 + relu-renorm-layer name=tdnn4 input=Append(-6,0,6,12) dim=$relu_dim + fast-lstmp-layer name=lstm2 cell-dim=$cell_dim recurrent-projection-dim=$projection_dim non-recurrent-projection-dim=$projection_dim delay=-6 + relu-renorm-layer name=tdnn5 input=Append(-12,0,12,24) dim=$relu_dim + + output-layer name=output-snr include-log-softmax=false dim=$num_snr_bins objective-type=quadratic objective-scale=`perl -e "print (($num_frames_music / $num_frames_sad) ** 0.25) / $num_snr_bins"` input=tdnn5 + output-layer name=output-speech include-log-softmax=true dim=2 objective-scale=`perl -e "print (($num_frames_music / $num_frames_sad) ** 0.25)"` input=tdnn5 + output-layer name=output-music include-log-softmax=true dim=2 input=tdnn5 + + output name=output-temp input=Append(input@-2,input@-1,input,input@1,input@2) +EOF + steps/nnet3/xconfig_to_configs.py --xconfig-file $dir/configs/network.xconfig \ + --config-dir $dir/configs/ \ + --nnet-edits="rename-node old-name=output-speech new-name=output" + + cat <> $dir/configs/vars +add_lda=false +EOF +fi + +samples_per_iter=`perl -e "print int(400000 / $chunk_width)"` + +if [ -z "$egs_dir" ]; then + egs_dir=$dir/egs_multi + if [ $stage -le 2 ]; then + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $dir/egs_speech/storage ]; then + utils/create_split_dir.pl \ + /export/b{03,04,05,06}/$USER/kaldi-data/egs/aspire-$(date +'%m_%d_%H_%M')/s5/$dir/egs_speech/storage $dir/egs_speech/storage + fi + + . $dir/configs/vars + + steps/nnet3/get_egs_multiple_targets.py --cmd="$decode_cmd" \ + $egs_opts \ + --feat.dir="$sad_data_dir" \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --frames-per-eg=$chunk_width \ + --left-context=$[model_left_context + extra_left_context] \ + --right-context=$[model_right_context + extra_right_context] \ + --num-utts-subset-train=$num_utts_subset_train \ + --num-utts-subset-valid=$num_utts_subset_valid \ + --samples-per-iter=$samples_per_iter \ + --stage=$get_egs_stage \ + --targets-parameters="--output-name=output-snr --target-type=dense --targets-scp=$sad_data_dir/irm_targets.scp --deriv-weights-scp=$sad_data_dir/deriv_weights_manual_seg.scp" \ + --targets-parameters="--output-name=output-speech --target-type=sparse --dim=2 --targets-scp=$sad_data_dir/speech_feat.scp --deriv-weights-scp=$sad_data_dir/deriv_weights.scp --scp2ark-cmd=\"extract-column --column-index=0 scp:- ark,t:- | steps/segmentation/quantize_vector.pl | ali-to-post ark,t:- ark:- |\" --compress=true" \ + --generate-egs-scp=true \ + --dir=$dir/egs_speech + fi + + if [ $stage -le 3 ]; then + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $dir/egs_music/storage ]; then + utils/create_split_dir.pl \ + /export/b{03,04,05,06}/$USER/kaldi-data/egs/aspire-$(date +'%m_%d_%H_%M')/s5/$dir/egs_music/storage $dir/egs_music/storage + fi + + . $dir/configs/vars + + steps/nnet3/get_egs_multiple_targets.py --cmd="$decode_cmd" \ + $egs_opts \ + --feat.dir="$music_data_dir" \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --frames-per-eg=$chunk_width \ + --left-context=$[model_left_context + extra_left_context] \ + --right-context=$[model_right_context + extra_right_context] \ + --num-utts-subset-train=$num_utts_subset_train \ + --num-utts-subset-valid=$num_utts_subset_valid \ + --samples-per-iter=$samples_per_iter \ + --stage=$get_egs_stage \ + --targets-parameters="--output-name=output-music --target-type=sparse --dim=2 --targets-scp=$music_data_dir/music_labels.scp --scp2ark-cmd=\"ali-to-post scp:- ark:- |\" --compress=true" \ + --generate-egs-scp=true \ + --dir=$dir/egs_music + fi + + if [ $stage -le 4 ]; then + # num_chunk_per_minibatch is multiplied by 4 to allow a buffer to use + # the same egs with a different num_chunk_per_minibatch + steps/nnet3/multilingual/get_egs.sh \ + --cmd "$train_cmd" \ + --minibatch-size $[num_chunk_per_minibatch * 4] \ + --samples-per-iter $samples_per_iter \ + 2 $dir/egs_speech $dir/egs_music $dir/egs_multi + fi +fi + +if [ $stage -le 5 ]; then + steps/nnet3/train_raw_rnn.py --stage=$train_stage \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --egs.chunk-width=$chunk_width \ + --egs.dir="$egs_dir" --egs.stage=$get_egs_stage \ + --egs.chunk-left-context=$extra_left_context \ + --egs.chunk-right-context=$extra_right_context \ + --egs.use-multitask-egs=true --egs.rename-multitask-outputs=false \ + ${extra_egs_copy_cmd:+--egs.extra-copy-cmd="$extra_egs_copy_cmd"} \ + --trainer.num-epochs=$num_epochs \ + --trainer.samples-per-iter=20000 \ + --trainer.optimization.num-jobs-initial=$num_jobs_initial \ + --trainer.optimization.num-jobs-final=$num_jobs_final \ + --trainer.optimization.initial-effective-lrate=$initial_effective_lrate \ + --trainer.optimization.final-effective-lrate=$final_effective_lrate \ + --trainer.optimization.shrink-value=1.0 \ + --trainer.rnn.num-chunk-per-minibatch=$num_chunk_per_minibatch \ + --trainer.deriv-truncate-margin=8 \ + --trainer.max-param-change=$max_param_change \ + --trainer.compute-per-dim-accuracy=true \ + --cmd="$decode_cmd" --nj 40 \ + --cleanup=true \ + --cleanup.remove-egs=$remove_egs \ + --cleanup.preserve-model-interval=10 \ + --use-gpu=true \ + --use-dense-targets=false \ + --feat-dir=$sad_data_dir \ + --targets-scp="$sad_data_dir/speech_feat.scp" \ + --dir=$dir || exit 1 +fi diff --git a/egs/aspire/s5/local/segmentation/tuning/train_lstm_sad_music_1b.sh b/egs/aspire/s5/local/segmentation/tuning/train_lstm_sad_music_1b.sh new file mode 100644 index 00000000000..cbbb016607a --- /dev/null +++ b/egs/aspire/s5/local/segmentation/tuning/train_lstm_sad_music_1b.sh @@ -0,0 +1,265 @@ +#!/bin/bash + +# This is a script to train a time-delay neural network for speech activity detection (SAD) and +# music-id using statistic pooling component for long-context information. +# This script is same as 1e, but removes the stats component in the 3rd layer. + +set -o pipefail +set -e +set -u + +. cmd.sh + +# At this script level we don't support not running on GPU, as it would be painfully slow. +# If you want to run without GPU you'd have to call train_tdnn.sh with --gpu false, +# --num-threads 16 and --minibatch-size 128. + +stage=0 +train_stage=-10 +get_egs_stage=-10 +egs_opts= # Directly passed to get_egs_multiple_targets.py + +chunk_width=40 +num_chunk_per_minibatch=64 + +extra_left_context=80 +extra_right_context=0 + +relu_dim=256 +cell_dim=256 +projection_dim=64 + +# training options +num_epochs=2 +initial_effective_lrate=0.0003 +final_effective_lrate=0.00003 +num_jobs_initial=3 +num_jobs_final=8 +remove_egs=false +max_param_change=0.2 # Small max-param change for small network +extra_egs_copy_cmd= # Used if you want to do some weird stuff to egs + # such as removing one of the targets + +sad_data_dir=data/train_aztec_unsad_whole_corrupted_sp_hires_bp_2400 +music_data_dir=data/train_aztec_unsad_whole_music_corrupted_sp_hires_bp + +extra_egs_copy_cmd="nnet3-copy-egs --keep-outputs=output-speech,output-music ark:- ark:- |" + +egs_dir= +nj=40 +feat_type=raw +config_dir= + +dir= +affix=1b + +. cmd.sh +. ./path.sh +. ./utils/parse_options.sh + +num_utts=`cat $sad_data_dir/utt2spk $music_data_dir/utt2spk | wc -l` +num_utts_subset_valid=`perl -e '$n=int($ARGV[0] * 0.005); print ($n > 4000 ? 4000 : $n)' $num_utts` +num_utts_subset_train=`perl -e '$n=int($ARGV[0] * 0.005); print ($n > 4000 ? 4000 : $n)' $num_utts` + +if [ -z "$dir" ]; then + dir=exp/nnet3_lstm_sad_music_snr/nnet_lstm +fi + +dir=$dir${affix:+_$affix} + +if ! cuda-compiled; then + cat < $dir/configs/network.xconfig + input dim=`feat-to-dim scp:$sad_data_dir/feats.scp -` name=input + + relu-renorm-layer name=tdnn1 input=Append(-3,-2,-1,0,1,2,3) dim=$relu_dim add-log-stddev=true + relu-renorm-layer name=tdnn2 input=Append(-6,0,6) dim=$relu_dim + fast-lstmp-layer name=lstm1 cell-dim=$cell_dim recurrent-projection-dim=$projection_dim non-recurrent-projection-dim=$projection_dim delay=-6 + relu-renorm-layer name=tdnn3 input=Append(-12,0,12) dim=$relu_dim + + output-layer name=output-snr include-log-softmax=false dim=$num_snr_bins objective-type=quadratic objective-scale=`perl -e "print (($num_frames_music / $num_frames_sad) ** 0.25) / $num_snr_bins"` input=tdnn3 + output-layer name=output-speech include-log-softmax=true dim=2 objective-scale=`perl -e "print (($num_frames_music / $num_frames_sad) ** 0.25)"` input=tdnn3 + output-layer name=output-music include-log-softmax=true dim=2 input=tdnn3 + + output name=output-temp input=Append(input@-3,input@-2,input@-1,input,input@1,input@2, input@3) +EOF + steps/nnet3/xconfig_to_configs.py --xconfig-file $dir/configs/network.xconfig \ + --config-dir $dir/configs/ \ + --nnet-edits="rename-node old-name=output-speech new-name=output" + + cat <> $dir/configs/vars +add_lda=false +EOF +fi + +samples_per_iter=`perl -e "print int(400000 / $chunk_width)"` + +if [ -z "$egs_dir" ]; then + egs_dir=$dir/egs_multi + if [ $stage -le 2 ]; then + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $dir/egs_speech/storage ]; then + utils/create_split_dir.pl \ + /export/b{03,04,05,06}/$USER/kaldi-data/egs/aspire-$(date +'%m_%d_%H_%M')/s5/$dir/egs_speech/storage $dir/egs_speech/storage + fi + + . $dir/configs/vars + + steps/nnet3/get_egs_multiple_targets.py --cmd="$decode_cmd" \ + $egs_opts \ + --feat.dir="$sad_data_dir" \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --frames-per-eg=$chunk_width \ + --left-context=$[model_left_context + extra_left_context] \ + --right-context=$[model_right_context + extra_right_context] \ + --num-utts-subset-train=$num_utts_subset_train \ + --num-utts-subset-valid=$num_utts_subset_valid \ + --samples-per-iter=$samples_per_iter \ + --stage=$get_egs_stage \ + --targets-parameters="--output-name=output-snr --target-type=dense --targets-scp=$sad_data_dir/irm_targets.scp --deriv-weights-scp=$sad_data_dir/deriv_weights_manual_seg.scp" \ + --targets-parameters="--output-name=output-speech --target-type=sparse --dim=2 --targets-scp=$sad_data_dir/speech_feat.scp --deriv-weights-scp=$sad_data_dir/deriv_weights.scp --scp2ark-cmd=\"extract-column --column-index=0 scp:- ark,t:- | steps/segmentation/quantize_vector.pl | ali-to-post ark,t:- ark:- |\" --compress=true" \ + --generate-egs-scp=true \ + --dir=$dir/egs_speech + fi + + if [ $stage -le 3 ]; then + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $dir/egs_music/storage ]; then + utils/create_split_dir.pl \ + /export/b{03,04,05,06}/$USER/kaldi-data/egs/aspire-$(date +'%m_%d_%H_%M')/s5/$dir/egs_music/storage $dir/egs_music/storage + fi + + . $dir/configs/vars + + steps/nnet3/get_egs_multiple_targets.py --cmd="$decode_cmd" \ + $egs_opts \ + --feat.dir="$music_data_dir" \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --frames-per-eg=$chunk_width \ + --left-context=$[model_left_context + extra_left_context] \ + --right-context=$[model_right_context + extra_right_context] \ + --num-utts-subset-train=$num_utts_subset_train \ + --num-utts-subset-valid=$num_utts_subset_valid \ + --samples-per-iter=$samples_per_iter \ + --stage=$get_egs_stage \ + --targets-parameters="--output-name=output-music --target-type=sparse --dim=2 --targets-scp=$music_data_dir/music_labels.scp --scp2ark-cmd=\"ali-to-post scp:- ark:- |\" --compress=true" \ + --generate-egs-scp=true \ + --dir=$dir/egs_music + fi + + if [ $stage -le 4 ]; then + # num_chunk_per_minibatch is multiplied by 4 to allow a buffer to use + # the same egs with a different num_chunk_per_minibatch + steps/nnet3/multilingual/get_egs.sh \ + --cmd "$train_cmd" \ + --minibatch-size $[num_chunk_per_minibatch * 4] \ + --samples-per-iter $samples_per_iter \ + 2 $dir/egs_speech $dir/egs_music $dir/egs_multi + fi +fi + +if [ $stage -le 5 ]; then + steps/nnet3/train_raw_rnn.py --stage=$train_stage \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --egs.chunk-width=$chunk_width \ + --egs.dir="$egs_dir" --egs.stage=$get_egs_stage \ + --egs.chunk-left-context=$extra_left_context \ + --egs.chunk-right-context=$extra_right_context \ + --egs.use-multitask-egs=true --egs.rename-multitask-outputs=false \ + ${extra_egs_copy_cmd:+--egs.extra-copy-cmd="$extra_egs_copy_cmd"} \ + --trainer.num-epochs=$num_epochs \ + --trainer.samples-per-iter=20000 \ + --trainer.optimization.num-jobs-initial=$num_jobs_initial \ + --trainer.optimization.num-jobs-final=$num_jobs_final \ + --trainer.optimization.initial-effective-lrate=$initial_effective_lrate \ + --trainer.optimization.final-effective-lrate=$final_effective_lrate \ + --trainer.optimization.shrink-value=1.0 \ + --trainer.rnn.num-chunk-per-minibatch=$num_chunk_per_minibatch \ + --trainer.deriv-truncate-margin=8 \ + --trainer.max-param-change=$max_param_change \ + --trainer.compute-per-dim-accuracy=true \ + --cmd="$decode_cmd" --nj 40 \ + --cleanup=true \ + --cleanup.remove-egs=$remove_egs \ + --cleanup.preserve-model-interval=10 \ + --use-gpu=true \ + --use-dense-targets=false \ + --feat-dir=$sad_data_dir \ + --targets-scp="$sad_data_dir/speech_feat.scp" \ + --dir=$dir || exit 1 +fi + diff --git a/egs/aspire/s5/local/segmentation/tuning/train_lstm_sad_music_1c.sh b/egs/aspire/s5/local/segmentation/tuning/train_lstm_sad_music_1c.sh new file mode 100644 index 00000000000..53c2a7a47ac --- /dev/null +++ b/egs/aspire/s5/local/segmentation/tuning/train_lstm_sad_music_1c.sh @@ -0,0 +1,265 @@ +#!/bin/bash + +# This is a script to train a time-delay neural network for speech activity detection (SAD) and +# music-id using statistic pooling component for long-context information. +# This script is same as 1e, but removes the stats component in the 3rd layer. + +set -o pipefail +set -e +set -u + +. cmd.sh + +# At this script level we don't support not running on GPU, as it would be painfully slow. +# If you want to run without GPU you'd have to call train_tdnn.sh with --gpu false, +# --num-threads 16 and --minibatch-size 128. + +stage=0 +train_stage=-10 +get_egs_stage=-10 +egs_opts= # Directly passed to get_egs_multiple_targets.py + +chunk_width=40 +num_chunk_per_minibatch=64 + +extra_left_context=80 +extra_right_context=0 + +relu_dim=256 +cell_dim=256 +projection_dim=64 + +# training options +num_epochs=2 +initial_effective_lrate=0.0003 +final_effective_lrate=0.00003 +num_jobs_initial=3 +num_jobs_final=8 +remove_egs=false +max_param_change=0.2 # Small max-param change for small network +extra_egs_copy_cmd= # Used if you want to do some weird stuff to egs + # such as removing one of the targets + +sad_data_dir=data/train_aztec_unsad_whole_corrupted_sp_hires_bp_2400 +music_data_dir=data/train_aztec_unsad_whole_music_corrupted_sp_hires_bp + +extra_egs_copy_cmd="nnet3-copy-egs --keep-outputs=output-speech,output-music ark:- ark:- |" + +egs_dir= +nj=40 +feat_type=raw +config_dir= + +dir= +affix=1b + +. cmd.sh +. ./path.sh +. ./utils/parse_options.sh + +num_utts=`cat $sad_data_dir/utt2spk $music_data_dir/utt2spk | wc -l` +num_utts_subset_valid=`perl -e '$n=int($ARGV[0] * 0.005); print ($n > 4000 ? 4000 : $n)' $num_utts` +num_utts_subset_train=`perl -e '$n=int($ARGV[0] * 0.005); print ($n > 4000 ? 4000 : $n)' $num_utts` + +if [ -z "$dir" ]; then + dir=exp/nnet3_lstm_sad_music_snr/nnet_lstm +fi + +dir=$dir${affix:+_$affix} + +if ! cuda-compiled; then + cat < $dir/configs/network.xconfig + input dim=`feat-to-dim scp:$sad_data_dir/feats.scp -` name=input + + relu-renorm-layer name=tdnn1 input=Append(-3,-2,-1,0,1,2,3) dim=$relu_dim add-log-stddev=true + relu-renorm-layer name=tdnn2 input=Append(-6,0,6) dim=$relu_dim + fast-lstmp-layer name=lstm1 cell-dim=$cell_dim recurrent-projection-dim=$projection_dim non-recurrent-projection-dim=$projection_dim delay=-6 + relu-renorm-layer name=tdnn3 input=Append(-12,0,12) dim=$relu_dim + + output-layer name=output-snr include-log-softmax=false dim=$num_snr_bins objective-type=quadratic learning-rate-factor=0.1 objective-scale=`perl -e "print $speech_scale / $num_snr_bins"` input=tdnn3 + output-layer name=output-speech include-log-softmax=true dim=2 presoftmax-scale-file=$dir/presoftmax_prior_scale_output-speech.txt learning-rate-factor=0.1 objective-scale=$speech_scale input=tdnn3 + output-layer name=output-music include-log-softmax=true dim=2 presoftmax-scale-file=$dir/presoftmax_prior_scale_output-music.txt learning-rate-factor=0.1 objective-scale=$music_scale input=tdnn3 + + output name=output-temp input=Append(input@-3,input@-2,input@-1,input,input@1,input@2, input@3) +EOF + steps/nnet3/xconfig_to_configs.py --xconfig-file $dir/configs/network.xconfig \ + --config-dir $dir/configs/ \ + --nnet-edits="rename-node old-name=output-speech new-name=output" + + cat <> $dir/configs/vars +add_lda=false +EOF +fi + +samples_per_iter=`perl -e "print int(400000 / $chunk_width)"` + +if [ -z "$egs_dir" ]; then + egs_dir=$dir/egs_multi + if [ $stage -le 2 ]; then + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $dir/egs_speech/storage ]; then + utils/create_split_dir.pl \ + /export/b{03,04,05,06}/$USER/kaldi-data/egs/aspire-$(date +'%m_%d_%H_%M')/s5/$dir/egs_speech/storage $dir/egs_speech/storage + fi + + . $dir/configs/vars + + steps/nnet3/get_egs_multiple_targets.py --cmd="$decode_cmd" \ + $egs_opts \ + --feat.dir="$sad_data_dir" \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --frames-per-eg=$chunk_width \ + --left-context=$[model_left_context + extra_left_context] \ + --right-context=$[model_right_context + extra_right_context] \ + --num-utts-subset-train=$num_utts_subset_train \ + --num-utts-subset-valid=$num_utts_subset_valid \ + --samples-per-iter=$samples_per_iter \ + --stage=$get_egs_stage \ + --targets-parameters="--output-name=output-snr --target-type=dense --targets-scp=$sad_data_dir/irm_targets.scp --deriv-weights-scp=$sad_data_dir/deriv_weights_manual_seg.scp" \ + --targets-parameters="--output-name=output-speech --target-type=sparse --dim=2 --targets-scp=$sad_data_dir/speech_feat.scp --deriv-weights-scp=$sad_data_dir/deriv_weights.scp --scp2ark-cmd=\"extract-column --column-index=0 scp:- ark,t:- | steps/segmentation/quantize_vector.pl | ali-to-post ark,t:- ark:- |\" --compress=true" \ + --generate-egs-scp=true \ + --dir=$dir/egs_speech + fi + + if [ $stage -le 3 ]; then + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $dir/egs_music/storage ]; then + utils/create_split_dir.pl \ + /export/b{03,04,05,06}/$USER/kaldi-data/egs/aspire-$(date +'%m_%d_%H_%M')/s5/$dir/egs_music/storage $dir/egs_music/storage + fi + + . $dir/configs/vars + + steps/nnet3/get_egs_multiple_targets.py --cmd="$decode_cmd" \ + $egs_opts \ + --feat.dir="$music_data_dir" \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --frames-per-eg=$chunk_width \ + --left-context=$[model_left_context + extra_left_context] \ + --right-context=$[model_right_context + extra_right_context] \ + --num-utts-subset-train=$num_utts_subset_train \ + --num-utts-subset-valid=$num_utts_subset_valid \ + --samples-per-iter=$samples_per_iter \ + --stage=$get_egs_stage \ + --targets-parameters="--output-name=output-music --target-type=sparse --dim=2 --targets-scp=$music_data_dir/music_labels.scp --scp2ark-cmd=\"ali-to-post scp:- ark:- |\" --compress=true" \ + --generate-egs-scp=true \ + --dir=$dir/egs_music + fi + + if [ $stage -le 4 ]; then + # num_chunk_per_minibatch is multiplied by 4 to allow a buffer to use + # the same egs with a different num_chunk_per_minibatch + steps/nnet3/multilingual/get_egs.sh \ + --cmd "$train_cmd" \ + --minibatch-size $[num_chunk_per_minibatch * 4] \ + --samples-per-iter $samples_per_iter \ + 2 $dir/egs_speech $dir/egs_music $dir/egs_multi + fi +fi + +if [ $stage -le 5 ]; then + steps/nnet3/train_raw_rnn.py --stage=$train_stage \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --egs.chunk-width=$chunk_width \ + --egs.dir="$egs_dir" --egs.stage=$get_egs_stage \ + --egs.chunk-left-context=$extra_left_context \ + --egs.chunk-right-context=$extra_right_context \ + --egs.use-multitask-egs=true --egs.rename-multitask-outputs=false \ + ${extra_egs_copy_cmd:+--egs.extra-copy-cmd="$extra_egs_copy_cmd"} \ + --trainer.num-epochs=$num_epochs \ + --trainer.samples-per-iter=20000 \ + --trainer.optimization.num-jobs-initial=$num_jobs_initial \ + --trainer.optimization.num-jobs-final=$num_jobs_final \ + --trainer.optimization.initial-effective-lrate=$initial_effective_lrate \ + --trainer.optimization.final-effective-lrate=$final_effective_lrate \ + --trainer.optimization.shrink-value=1.0 \ + --trainer.rnn.num-chunk-per-minibatch=$num_chunk_per_minibatch \ + --trainer.deriv-truncate-margin=8 \ + --trainer.max-param-change=$max_param_change \ + --trainer.compute-per-dim-accuracy=true \ + --cmd="$decode_cmd" --nj 40 \ + --cleanup=true \ + --cleanup.remove-egs=$remove_egs \ + --cleanup.preserve-model-interval=10 \ + --use-gpu=true \ + --use-dense-targets=false \ + --feat-dir=$sad_data_dir \ + --targets-scp="$sad_data_dir/speech_feat.scp" \ + --dir=$dir || exit 1 +fi + diff --git a/egs/aspire/s5/local/segmentation/tuning/train_lstm_sad_music_1e.sh b/egs/aspire/s5/local/segmentation/tuning/train_lstm_sad_music_1e.sh new file mode 100644 index 00000000000..dfb1297c895 --- /dev/null +++ b/egs/aspire/s5/local/segmentation/tuning/train_lstm_sad_music_1e.sh @@ -0,0 +1,269 @@ +#!/bin/bash + +# This is a script to train a time-delay neural network for speech activity detection (SAD) and +# music-id using statistic pooling component for long-context information. +# This script is same as 1c, but uses larger amount of data. + +set -o pipefail +set -e +set -u + +. cmd.sh + +# At this script level we don't support not running on GPU, as it would be painfully slow. +# If you want to run without GPU you'd have to call train_tdnn.sh with --gpu false, +# --num-threads 16 and --minibatch-size 128. + +stage=0 +train_stage=-10 +get_egs_stage=-10 +egs_opts= # Directly passed to get_egs_multiple_targets.py + +chunk_width=40 +num_chunk_per_minibatch=64 + +extra_left_context=40 +extra_right_context=0 + +relu_dim=256 +cell_dim=256 +projection_dim=64 + +# training options +num_epochs=2 +initial_effective_lrate=0.0003 +final_effective_lrate=0.00003 +num_jobs_initial=3 +num_jobs_final=8 +remove_egs=false +max_param_change=0.2 # Small max-param change for small network +extra_egs_copy_cmd= # Used if you want to do some weird stuff to egs + # such as removing one of the targets + +extra_egs_copy_cmd="nnet3-copy-egs --keep-outputs=output-speech,output-music ark:- ark:- |" + +egs_dir= +nj=40 +feat_type=raw +config_dir= + +dir= +affix=1b + +. cmd.sh +. ./path.sh +. ./utils/parse_options.sh + +utils/combine_data.sh --extra-files "deriv_weights.scp speech_labels.scp" \ + data/train_tztec_whole_corrupted_spr_hires_bp data/fisher_train_100k_whole_900_corrupted_spr_hires_bp/ \ + data/babel_{turkish,zulu,cantonese,tamil}_train_whole_corrupted_spr_hires_bp/ + +utils/combine_data.sh --extra-files "deriv_weights.scp speech_labels.scp music_labels.scp" \ + data/train_tztec_whole_music_corrupted_spr_hires_bp data/fisher_train_100k_whole_900_music_corrupted_spr_hires_bp/ \ + data/babel_{turkish,zulu,cantonese,tamil}_train_whole_music_corrupted_spr_hires_bp/ + +sad_data_dir=data/train_tztec_whole_corrupted_spr_hires_bp +music_data_dir=data/train_tztec_whole_music_corrupted_spr_hires_bp + +num_utts=`cat $sad_data_dir/utt2spk $music_data_dir/utt2spk | wc -l` +num_utts_subset_valid=`perl -e '$n=int($ARGV[0] * 0.005); print ($n > 4000 ? 4000 : $n)' $num_utts` +num_utts_subset_train=`perl -e '$n=int($ARGV[0] * 0.005); print ($n > 4000 ? 4000 : $n)' $num_utts` + +if [ -z "$dir" ]; then + dir=exp/nnet3_lstm_sad_music/nnet_lstm +fi + +dir=$dir${affix:+_$affix} + +if ! cuda-compiled; then + cat < $dir/configs/network.xconfig + input dim=`feat-to-dim scp:$sad_data_dir/feats.scp -` name=input + + relu-renorm-layer name=tdnn1 input=Append(-3,-2,-1,0,1,2,3) dim=$relu_dim add-log-stddev=true + relu-renorm-layer name=tdnn2 input=Append(-6,0,6) dim=$relu_dim + fast-lstmp-layer name=lstm1 cell-dim=$cell_dim recurrent-projection-dim=$projection_dim non-recurrent-projection-dim=$projection_dim delay=-6 + relu-renorm-layer name=tdnn3 input=Append(-12,0,12) dim=$relu_dim + + output-layer name=output-speech include-log-softmax=true dim=2 presoftmax-scale-file=$dir/presoftmax_prior_scale_output-speech.txt learning-rate-factor=0.1 objective-scale=$speech_scale input=tdnn3 + output-layer name=output-music include-log-softmax=true dim=2 presoftmax-scale-file=$dir/presoftmax_prior_scale_output-music.txt learning-rate-factor=0.1 objective-scale=$music_scale input=tdnn3 + + output name=output-temp input=Append(input@-3,input@-2,input@-1,input,input@1,input@2, input@3) +EOF + steps/nnet3/xconfig_to_configs.py --xconfig-file $dir/configs/network.xconfig \ + --config-dir $dir/configs/ \ + --nnet-edits="rename-node old-name=output-speech new-name=output" + + cat <> $dir/configs/vars +add_lda=false +EOF +fi + +samples_per_iter=`perl -e "print int(400000 / $chunk_width)"` + +if [ -z "$egs_dir" ]; then + egs_dir=$dir/egs_multi + if [ $stage -le 2 ]; then + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $dir/egs_speech/storage ]; then + utils/create_split_dir.pl \ + /export/b{03,04,05,06}/$USER/kaldi-data/egs/aspire-$(date +'%m_%d_%H_%M')/s5/$dir/egs_speech/storage $dir/egs_speech/storage + fi + + . $dir/configs/vars + + steps/nnet3/get_egs_multiple_targets.py --cmd="$decode_cmd" \ + $egs_opts \ + --feat.dir="$sad_data_dir" \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --frames-per-eg=$chunk_width \ + --left-context=$[model_left_context + extra_left_context] \ + --right-context=$[model_right_context + extra_right_context] \ + --num-utts-subset-train=$num_utts_subset_train \ + --num-utts-subset-valid=$num_utts_subset_valid \ + --samples-per-iter=$samples_per_iter \ + --stage=$get_egs_stage \ + --targets-parameters="--output-name=output-speech --target-type=sparse --dim=2 --targets-scp=$sad_data_dir/speech_labels.scp --deriv-weights-scp=$sad_data_dir/deriv_weights.scp --scp2ark-cmd=\"ali-to-post scp:- ark:- |\" --compress=true" \ + --generate-egs-scp=true \ + --dir=$dir/egs_speech + #--targets-parameters="--output-name=output-snr --target-type=dense --targets-scp=$sad_data_dir/irm_targets.scp --deriv-weights-scp=$sad_data_dir/deriv_weights_manual_seg.scp" \ + fi + + if [ $stage -le 3 ]; then + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $dir/egs_music/storage ]; then + utils/create_split_dir.pl \ + /export/b{03,04,05,06}/$USER/kaldi-data/egs/aspire-$(date +'%m_%d_%H_%M')/s5/$dir/egs_music/storage $dir/egs_music/storage + fi + + . $dir/configs/vars + + steps/nnet3/get_egs_multiple_targets.py --cmd="$decode_cmd" \ + $egs_opts \ + --feat.dir="$music_data_dir" \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --frames-per-eg=$chunk_width \ + --left-context=$[model_left_context + extra_left_context] \ + --right-context=$[model_right_context + extra_right_context] \ + --num-utts-subset-train=$num_utts_subset_train \ + --num-utts-subset-valid=$num_utts_subset_valid \ + --samples-per-iter=$samples_per_iter \ + --stage=$get_egs_stage \ + --targets-parameters="--output-name=output-music --target-type=sparse --dim=2 --targets-scp=$music_data_dir/music_labels.scp --scp2ark-cmd=\"ali-to-post scp:- ark:- |\" --compress=true" \ + --targets-parameters="--output-name=output-speech --target-type=sparse --dim=2 --targets-scp=$music_data_dir/speech_labels.scp --deriv-weights-scp=$music_data_dir/deriv_weights.scp --scp2ark-cmd=\"ali-to-post scp:- ark:- |\" --compress=true" \ + --generate-egs-scp=true \ + --dir=$dir/egs_music + fi + + if [ $stage -le 4 ]; then + # num_chunk_per_minibatch is multiplied by 4 to allow a buffer to use + # the same egs with a different num_chunk_per_minibatch + steps/nnet3/multilingual/get_egs.sh \ + --cmd "$train_cmd" \ + --minibatch-size $[num_chunk_per_minibatch * 4] \ + --samples-per-iter $samples_per_iter \ + 2 $dir/egs_speech $dir/egs_music $dir/egs_multi + fi +fi + +if [ $stage -le 5 ]; then + steps/nnet3/train_raw_rnn.py --stage=$train_stage \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --egs.chunk-width=$chunk_width \ + --egs.dir="$egs_dir" --egs.stage=$get_egs_stage \ + --egs.chunk-left-context=$extra_left_context \ + --egs.chunk-right-context=$extra_right_context \ + --egs.use-multitask-egs=true --egs.rename-multitask-outputs=false \ + ${extra_egs_copy_cmd:+--egs.extra-copy-cmd="$extra_egs_copy_cmd"} \ + --trainer.num-epochs=$num_epochs \ + --trainer.samples-per-iter=20000 \ + --trainer.optimization.num-jobs-initial=$num_jobs_initial \ + --trainer.optimization.num-jobs-final=$num_jobs_final \ + --trainer.optimization.initial-effective-lrate=$initial_effective_lrate \ + --trainer.optimization.final-effective-lrate=$final_effective_lrate \ + --trainer.optimization.shrink-value=1.0 \ + --trainer.rnn.num-chunk-per-minibatch=$num_chunk_per_minibatch \ + --trainer.deriv-truncate-margin=8 \ + --trainer.max-param-change=$max_param_change \ + --trainer.compute-per-dim-accuracy=true \ + --cmd="$decode_cmd" --nj 40 \ + --cleanup=true \ + --cleanup.remove-egs=$remove_egs \ + --cleanup.preserve-model-interval=10 \ + --use-gpu=true \ + --use-dense-targets=false \ + --feat-dir=$sad_data_dir \ + --targets-scp="$sad_data_dir/speech_labels.scp" \ + --dir=$dir || exit 1 +fi + diff --git a/egs/aspire/s5/local/segmentation/tuning/train_lstm_sad_music_1f.sh b/egs/aspire/s5/local/segmentation/tuning/train_lstm_sad_music_1f.sh new file mode 100644 index 00000000000..782a31132c6 --- /dev/null +++ b/egs/aspire/s5/local/segmentation/tuning/train_lstm_sad_music_1f.sh @@ -0,0 +1,291 @@ +#!/bin/bash + +# This is a script to train a time-delay neural network for speech activity detection (SAD) and +# music-id using statistic pooling component for long-context information. +# This script is same as 1c, but uses larger amount of data. + +set -o pipefail +set -e +set -u + +. cmd.sh + +# At this script level we don't support not running on GPU, as it would be painfully slow. +# If you want to run without GPU you'd have to call train_tdnn.sh with --gpu false, +# --num-threads 16 and --minibatch-size 128. + +stage=0 +train_stage=-10 +get_egs_stage=-10 +egs_opts= # Directly passed to get_egs_multiple_targets.py + +chunk_width=40 +num_chunk_per_minibatch=64 + +extra_left_context=40 +extra_right_context=0 + +relu_dim=256 +cell_dim=256 +projection_dim=64 + +# training options +num_epochs=2 +initial_effective_lrate=0.0003 +final_effective_lrate=0.00003 +num_jobs_initial=3 +num_jobs_final=8 +remove_egs=false +max_param_change=0.2 # Small max-param change for small network +extra_egs_copy_cmd= # Used if you want to do some weird stuff to egs + # such as removing one of the targets + +extra_egs_copy_cmd="nnet3-copy-egs --keep-outputs=output-speech,output-music,output-speech_music ark:- ark:- |" + +egs_dir= +nj=40 +feat_type=raw +config_dir= + +dir= +affix=1b + +. cmd.sh +. ./path.sh +. ./utils/parse_options.sh + +utils/combine_data.sh --extra-files "deriv_weights.scp speech_labels.scp" \ + data/train_tztec_whole_corrupted_spr_hires_bp data/fisher_train_100k_whole_900_corrupted_spr_hires_bp/ \ + data/babel_{turkish,zulu,cantonese,tamil}_train_whole_corrupted_spr_hires_bp/ + +cp data/train_tztec_whole_corrupted_spr_hires_bp/{speech_labels.scp,speech_music_labels.scp} + +utils/combine_data.sh --extra-files "deriv_weights.scp speech_labels.scp music_labels.scp speech_music_labels.scp" \ + data/train_tztec_whole_music_corrupted_spr_hires_bp data/fisher_train_100k_whole_900_music_corrupted_spr_hires_bp/ \ + data/babel_{turkish,zulu,cantonese,tamil}_train_whole_music_corrupted_spr_hires_bp/ + +sad_data_dir=data/train_tztec_whole_corrupted_spr_hires_bp +music_data_dir=data/train_tztec_whole_music_corrupted_spr_hires_bp + +num_utts=`cat $sad_data_dir/utt2spk $music_data_dir/utt2spk | wc -l` +num_utts_subset_valid=`perl -e '$n=int($ARGV[0] * 0.005); print ($n > 4000 ? 4000 : $n)' $num_utts` +num_utts_subset_train=`perl -e '$n=int($ARGV[0] * 0.005); print ($n > 4000 ? 4000 : $n)' $num_utts` + +if [ -z "$dir" ]; then + dir=exp/nnet3_lstm_sad_music/nnet_lstm +fi + +dir=$dir${affix:+_$affix} + +if ! cuda-compiled; then + cat < $dir/configs/network.xconfig + input dim=`feat-to-dim scp:$sad_data_dir/feats.scp -` name=input + + relu-renorm-layer name=tdnn1 input=Append(-3,-2,-1,0,1,2,3) dim=$relu_dim add-log-stddev=true + relu-renorm-layer name=tdnn2 input=Append(-6,0,6) dim=$relu_dim + fast-lstmp-layer name=lstm1 cell-dim=$cell_dim recurrent-projection-dim=$projection_dim non-recurrent-projection-dim=$projection_dim delay=-6 + relu-renorm-layer name=tdnn3 input=Append(-12,0,12) dim=$relu_dim + + output-layer name=output-speech include-log-softmax=true dim=2 presoftmax-scale-file=$dir/presoftmax_prior_scale_output-speech.txt learning-rate-factor=0.1 objective-scale=$speech_scale input=tdnn3 + output-layer name=output-music include-log-softmax=true dim=2 presoftmax-scale-file=$dir/presoftmax_prior_scale_output-music.txt learning-rate-factor=0.1 objective-scale=$music_scale input=tdnn3 + output-layer name=output-speech_music include-log-softmax=true dim=4 presoftmax-scale-file=$dir/presoftmax_prior_scale_output-speech_music.txt learning-rate-factor=0.1 objective-scale=$speech_music_scale input=tdnn3 + + output name=output-temp input=Append(input@-3,input@-2,input@-1,input,input@1,input@2, input@3) +EOF + steps/nnet3/xconfig_to_configs.py --xconfig-file $dir/configs/network.xconfig \ + --config-dir $dir/configs/ \ + --nnet-edits="rename-node old-name=output-speech new-name=output" + + cat <> $dir/configs/vars +add_lda=false +EOF +fi + +samples_per_iter=`perl -e "print int(400000 / $chunk_width)"` + +if [ -z "$egs_dir" ]; then + egs_dir=$dir/egs_multi + if [ $stage -le 2 ]; then + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $dir/egs_speech/storage ]; then + utils/create_split_dir.pl \ + /export/b{03,04,05,06}/$USER/kaldi-data/egs/aspire-$(date +'%m_%d_%H_%M')/s5/$dir/egs_speech/storage $dir/egs_speech/storage + fi + + . $dir/configs/vars + + steps/nnet3/get_egs_multiple_targets.py --cmd="$decode_cmd" \ + $egs_opts \ + --feat.dir="$sad_data_dir" \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --frames-per-eg=$chunk_width \ + --left-context=$[model_left_context + extra_left_context] \ + --right-context=$[model_right_context + extra_right_context] \ + --num-utts-subset-train=$num_utts_subset_train \ + --num-utts-subset-valid=$num_utts_subset_valid \ + --samples-per-iter=$samples_per_iter \ + --stage=$get_egs_stage \ + --targets-parameters="--output-name=output-speech --target-type=sparse --dim=2 --targets-scp=$sad_data_dir/speech_labels.scp --deriv-weights-scp=$sad_data_dir/deriv_weights.scp --scp2ark-cmd=\"ali-to-post scp:- ark:- |\" --compress=true" \ + --targets-parameters="--output-name=output-speech_music --target-type=sparse --dim=4 --targets-scp=$sad_data_dir/speech_music_labels.scp --deriv-weights-scp=$sad_data_dir/deriv_weights.scp --scp2ark-cmd=\"ali-to-post scp:- ark:- |\" --compress=true" \ + --generate-egs-scp=true \ + --dir=$dir/egs_speech + #--targets-parameters="--output-name=output-snr --target-type=dense --targets-scp=$sad_data_dir/irm_targets.scp --deriv-weights-scp=$sad_data_dir/deriv_weights_manual_seg.scp" \ + fi + + if [ $stage -le 3 ]; then + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $dir/egs_music/storage ]; then + utils/create_split_dir.pl \ + /export/b{03,04,05,06}/$USER/kaldi-data/egs/aspire-$(date +'%m_%d_%H_%M')/s5/$dir/egs_music/storage $dir/egs_music/storage + fi + + . $dir/configs/vars + + steps/nnet3/get_egs_multiple_targets.py --cmd="$decode_cmd" \ + $egs_opts \ + --feat.dir="$music_data_dir" \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --frames-per-eg=$chunk_width \ + --left-context=$[model_left_context + extra_left_context] \ + --right-context=$[model_right_context + extra_right_context] \ + --num-utts-subset-train=$num_utts_subset_train \ + --num-utts-subset-valid=$num_utts_subset_valid \ + --samples-per-iter=$samples_per_iter \ + --stage=$get_egs_stage \ + --targets-parameters="--output-name=output-music --target-type=sparse --dim=2 --targets-scp=$music_data_dir/music_labels.scp --scp2ark-cmd=\"ali-to-post scp:- ark:- |\" --compress=true" \ + --targets-parameters="--output-name=output-speech_music --target-type=sparse --dim=4 --targets-scp=$music_data_dir/speech_music_labels.scp --deriv-weights-scp=$music_data_dir/deriv_weights.scp --scp2ark-cmd=\"ali-to-post scp:- ark:- |\" --compress=true" \ + --targets-parameters="--output-name=output-speech --target-type=sparse --dim=2 --targets-scp=$music_data_dir/speech_labels.scp --deriv-weights-scp=$music_data_dir/deriv_weights.scp --scp2ark-cmd=\"ali-to-post scp:- ark:- |\" --compress=true" \ + --generate-egs-scp=true \ + --dir=$dir/egs_music + fi + + if [ $stage -le 4 ]; then + # num_chunk_per_minibatch is multiplied by 4 to allow a buffer to use + # the same egs with a different num_chunk_per_minibatch + steps/nnet3/multilingual/get_egs.sh \ + --cmd "$train_cmd" \ + --minibatch-size $[num_chunk_per_minibatch * 4] \ + --samples-per-iter $samples_per_iter \ + 2 $dir/egs_speech $dir/egs_music $dir/egs_multi + fi +fi + +if [ $stage -le 5 ]; then + steps/nnet3/train_raw_rnn.py --stage=$train_stage \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --egs.chunk-width=$chunk_width \ + --egs.dir="$egs_dir" --egs.stage=$get_egs_stage \ + --egs.chunk-left-context=$extra_left_context \ + --egs.chunk-right-context=$extra_right_context \ + --egs.use-multitask-egs=true --egs.rename-multitask-outputs=false \ + ${extra_egs_copy_cmd:+--egs.extra-copy-cmd="$extra_egs_copy_cmd"} \ + --trainer.num-epochs=$num_epochs \ + --trainer.samples-per-iter=20000 \ + --trainer.optimization.num-jobs-initial=$num_jobs_initial \ + --trainer.optimization.num-jobs-final=$num_jobs_final \ + --trainer.optimization.initial-effective-lrate=$initial_effective_lrate \ + --trainer.optimization.final-effective-lrate=$final_effective_lrate \ + --trainer.optimization.shrink-value=1.0 \ + --trainer.rnn.num-chunk-per-minibatch=$num_chunk_per_minibatch \ + --trainer.deriv-truncate-margin=8 \ + --trainer.max-param-change=$max_param_change \ + --trainer.compute-per-dim-accuracy=true \ + --cmd="$decode_cmd" --nj 40 \ + --cleanup=true \ + --cleanup.remove-egs=$remove_egs \ + --cleanup.preserve-model-interval=10 \ + --use-gpu=true \ + --use-dense-targets=false \ + --feat-dir=$sad_data_dir \ + --targets-scp="$sad_data_dir/speech_labels.scp" \ + --dir=$dir || exit 1 +fi + From 53b7649c38555d17ccee7fada57f91f9e1fac94a Mon Sep 17 00:00:00 2001 From: Vimal Manohar Date: Wed, 25 Jan 2017 15:17:13 -0500 Subject: [PATCH 170/213] asr_diarization: segmentation configs --- egs/ami/s5b/conf/segmentation_speech.conf | 14 ++++++++++++++ egs/aspire/s5/conf/segmentation_speech_simple.conf | 14 ++++++++++++++ 2 files changed, 28 insertions(+) create mode 100644 egs/ami/s5b/conf/segmentation_speech.conf create mode 100644 egs/aspire/s5/conf/segmentation_speech_simple.conf diff --git a/egs/ami/s5b/conf/segmentation_speech.conf b/egs/ami/s5b/conf/segmentation_speech.conf new file mode 100644 index 00000000000..c4c75b212fc --- /dev/null +++ b/egs/ami/s5b/conf/segmentation_speech.conf @@ -0,0 +1,14 @@ +# General segmentation options +pad_length=20 # Pad speech segments by this many frames on either side +max_relabel_length=10 # Maximum duration of speech that will be removed as part + # of smoothing process. This is only if there are no other + # speech segments nearby. +max_intersegment_length=30 # Merge nearby speech segments if the silence + # between them is less than this many frames. +post_pad_length=10 # Pad speech segments by this many frames on either side + # after the merging process using max_intersegment_length +max_segment_length=1000 # Segments that are longer than this are split into + # overlapping frames. +overlap_length=250 # Overlapping frames when segments are split. + # See the above option. +min_silence_length=20 # Min silence length at which to split very long segments diff --git a/egs/aspire/s5/conf/segmentation_speech_simple.conf b/egs/aspire/s5/conf/segmentation_speech_simple.conf new file mode 100644 index 00000000000..56c178c8115 --- /dev/null +++ b/egs/aspire/s5/conf/segmentation_speech_simple.conf @@ -0,0 +1,14 @@ +# General segmentation options +pad_length=20 # Pad speech segments by this many frames on either side +max_relabel_length=-1 # Maximum duration of speech that will be removed as part + # of smoothing process. This is only if there are no other + # speech segments nearby. +max_intersegment_length=30 # Merge nearby speech segments if the silence + # between them is less than this many frames. +post_pad_length=-1 # Pad speech segments by this many frames on either side + # after the merging process using max_intersegment_length +max_segment_length=1000 # Segments that are longer than this are split into + # overlapping frames. +overlap_length=250 # Overlapping frames when segments are split. + # See the above option. +min_silence_length=20 # Min silence length at which to split very long segments From cb8c7187f62506079a09f1eb5e744fcf21f715c3 Mon Sep 17 00:00:00 2001 From: Vimal Manohar Date: Wed, 25 Jan 2017 15:17:59 -0500 Subject: [PATCH 171/213] An old version of resolve_ctm_overlaps --- egs/wsj/s5/steps/resolve_ctm_overlaps.py.old | 149 +++++++++++++++++++ 1 file changed, 149 insertions(+) create mode 100755 egs/wsj/s5/steps/resolve_ctm_overlaps.py.old diff --git a/egs/wsj/s5/steps/resolve_ctm_overlaps.py.old b/egs/wsj/s5/steps/resolve_ctm_overlaps.py.old new file mode 100755 index 00000000000..aaee767e7e4 --- /dev/null +++ b/egs/wsj/s5/steps/resolve_ctm_overlaps.py.old @@ -0,0 +1,149 @@ +#!/usr/bin/env python +# Copyright 2014 Johns Hopkins University (Authors: Daniel Povey, Vijayaditya Peddinti). +# 2016 Vimal Manohar +# Apache 2.0. + +# Script to combine ctms with overlapping segments + +import sys, math, numpy as np, argparse +break_threshold = 0.01 + +def ReadSegments(segments_file): + segments = {} + for line in open(segments_file).readlines(): + parts = line.strip().split() + segments[parts[0]] = (parts[1], float(parts[2]), float(parts[3])) + return segments + +#def get_breaks(ctm, prev_end): +# breaks = [] +# for i in xrange(0, len(ctm)): +# if ctm[i][2] - prev_end > break_threshold: +# breaks.append([i, ctm[i][2]]) +# prev_end = ctm[i][2] + ctm[i][3] +# return np.array(breaks) + +# Resolve overlaps within segments of the same recording +def ResolveOverlaps(ctms, segments): + total_ctm = [] + if len(ctms) == 0: + raise Exception('Something wrong with the input ctms') + + next_utt = ctms[0][0][0] + for ctm_index in range(len(ctms) - 1): + # Assumption here is that the segments are written in consecutive order? + cur_ctm = ctms[ctm_index] + next_ctm = ctms[ctm_index + 1] + + cur_utt = next_utt + next_utt = next_ctm[0][0] + if (next_utt not in segments): + raise Exception('Could not find utterance %s in segments' % next_utt) + + if len(cur_ctm) > 0: + assert(cur_utt == cur_ctm[0][0]) + + assert(next_utt > cur_utt) + if (cur_utt not in segments): + raise Exception('Could not find utterance %s in segments' % cur_utt) + + # length of this segment + window_length = segments[cur_utt][2] - segments[cur_utt][1] + + # overlap of this segment with the next segment + # Note: It is possible for this to be negative when there is actually + # no overlap between consecutive segments. + overlap = segments[cur_utt][2] - segments[next_utt][1] + + # find the breaks after overlap starts + index = len(cur_ctm) + + for i in xrange(len(cur_ctm)): + if (cur_ctm[i][2] + cur_ctm[i][3]/2.0 > (window_length - overlap/2.0)): + # if midpoint of a hypothesis word is beyond the midpoint of the + # overlap region + index = i + break + + # Ignore the hypotheses beyond this midpoint. They will be considered as + # part of the next segment. + total_ctm += cur_ctm[:index] + + # Ignore the hypotheses of the next utterance that overlaps with the + # current utterance + index = -1 + for i in xrange(len(next_ctm)): + if (next_ctm[i][2] + next_ctm[i][3]/2.0 > (overlap/2.0)): + index = i + break + + if index >= 0: + ctms[ctm_index + 1] = next_ctm[index:] + else: + ctms[ctm_index + 1] = [] + + # merge the last ctm entirely + total_ctm += ctms[-1] + + return total_ctm + +def ReadCtm(ctm_file_lines, segments): + ctms = {} + for key in [ x[0] for x in segments.values() ]: + ctms[key] = [] + + ctm = [] + prev_utt = ctm_file_lines[0].split()[0] + for line in ctm_file_lines: + parts = line.split() + if (prev_utt == parts[0]): + ctm.append([parts[0], parts[1], float(parts[2]), + float(parts[3])] + parts[4:]) + else: + # New utterance. Append the previous utterance's CTM + # into the list for the utterance's recording + ctms[segments[ctm[0][0]][0]].append(ctm) + + assert(parts[0] > prev_utt) + + prev_utt = parts[0] + ctm = [] + ctm.append([parts[0], parts[1], float(parts[2]), + float(parts[3])] + parts[4:]) + + # append the last ctm + ctms[segments[ctm[0][0]][0]].append(ctm) + return ctms + +def WriteCtm(ctm_lines, out_file): + for line in ctm_lines: + out_file.write("{0} {1} {2} {3} {4}\n".format(line[0], line[1], line[2], line[3], " ".join(line[4:]))) + +if __name__ == "__main__": + usage = """ Python script to resolve overlaps in ctms """ + parser = argparse.ArgumentParser(usage) + parser.add_argument('segments', type=str, help = 'use segments to resolve overlaps') + parser.add_argument('ctm_in', type=str, help='input_ctm_file') + parser.add_argument('ctm_out', type=str, help='output_ctm_file') + params = parser.parse_args() + + if params.ctm_in == "-": + params.ctm_in = sys.stdin + else: + params.ctm_in = open(params.ctm_in) + if params.ctm_out == "-": + params.ctm_out = sys.stdout + else: + params.ctm_out = open(params.ctm_out, 'w') + + segments = ReadSegments(params.segments) + + # Read CTMs into a dictionary indexed by the recording + ctms = ReadCtm(params.ctm_in.readlines(), segments) + + for key in sorted(ctms.keys()): + # Process CTMs in the sorted order of recordings + ctm_reco = ctms[key] + ctm_reco = ResolveOverlaps(ctm_reco, segments) + WriteCtm(ctm_reco, params.ctm_out) + params.ctm_out.close() From 58dc6a6955241c5a1841ab52fc5335c578224e44 Mon Sep 17 00:00:00 2001 From: Vimal Manohar Date: Wed, 25 Jan 2017 15:18:43 -0500 Subject: [PATCH 172/213] asr_diarization: Add steps/data/make_corrupted_data_dir.py --- .../s5/steps/data/make_corrupted_data_dir.py | 613 ++++++++++++++++++ 1 file changed, 613 insertions(+) create mode 100644 egs/wsj/s5/steps/data/make_corrupted_data_dir.py diff --git a/egs/wsj/s5/steps/data/make_corrupted_data_dir.py b/egs/wsj/s5/steps/data/make_corrupted_data_dir.py new file mode 100644 index 00000000000..c0fa94c2a42 --- /dev/null +++ b/egs/wsj/s5/steps/data/make_corrupted_data_dir.py @@ -0,0 +1,613 @@ +#!/usr/bin/env python +# Copyright 2016 Tom Ko +# Apache 2.0 +# script to generate reverberated data + +# we're using python 3.x style print but want it to work in python 2.x, +from __future__ import print_function +import argparse, shlex, glob, math, os, random, sys, warnings, copy, imp, ast + +import data_dir_manipulation_lib as data_lib + +sys.path.insert(0, 'steps') +import libs.common as common_lib + +def GetArgs(): + # we add required arguments as named arguments for readability + parser = argparse.ArgumentParser(description="Reverberate the data directory with an option " + "to add isotropic and point source noises. " + "Usage: reverberate_data_dir.py [options...] " + "E.g. reverberate_data_dir.py --rir-set-parameters rir_list " + "--foreground-snrs 20:10:15:5:0 --background-snrs 20:10:15:5:0 " + "--noise-list-file noise_list --speech-rvb-probability 1 --num-replications 2 " + "--random-seed 1 data/train data/train_rvb", + formatter_class=argparse.ArgumentDefaultsHelpFormatter) + + parser.add_argument("--rir-set-parameters", type=str, action='append', required = True, dest = "rir_set_para_array", + help="Specifies the parameters of an RIR set. " + "Supports the specification of mixture_weight and rir_list_file_name. The mixture weight is optional. " + "The default mixture weight is the probability mass remaining after adding the mixture weights " + "of all the RIR lists, uniformly divided among the RIR lists without mixture weights. " + "E.g. --rir-set-parameters '0.3, rir_list' or 'rir_list' " + "the format of the RIR list file is " + "--rir-id --room-id " + "--receiver-position-id --source-position-id " + "--rt-60 --drr location " + "E.g. --rir-id 00001 --room-id 001 --receiver-position-id 001 --source-position-id 00001 " + "--rt60 0.58 --drr -4.885 data/impulses/Room001-00001.wav") + parser.add_argument("--noise-set-parameters", type=str, action='append', + default = None, dest = "noise_set_para_array", + help="Specifies the parameters of an noise set. " + "Supports the specification of mixture_weight and noise_list_file_name. The mixture weight is optional. " + "The default mixture weight is the probability mass remaining after adding the mixture weights " + "of all the noise lists, uniformly divided among the noise lists without mixture weights. " + "E.g. --noise-set-parameters '0.3, noise_list' or 'noise_list' " + "the format of the noise list file is " + "--noise-id --noise-type " + "--bg-fg-type " + "--room-linkage " + "location " + "E.g. --noise-id 001 --noise-type isotropic --rir-id 00019 iso_noise.wav") + parser.add_argument("--speech-segments-set-parameters", type=str, action='append', + default = None, dest = "speech_segments_set_para_array", + help="Specifies the speech segments for overlapped speech generation.\n" + "Format: [], wav_scp, segments_list\n"); + parser.add_argument("--num-replications", type=int, dest = "num_replicas", default = 1, + help="Number of replicate to generated for the data") + parser.add_argument('--foreground-snrs', type=str, dest = "foreground_snr_string", + default = '20:10:0', + help='When foreground noises are being added the script will iterate through these SNRs.') + parser.add_argument('--background-snrs', type=str, dest = "background_snr_string", + default = '20:10:0', + help='When background noises are being added the script will iterate through these SNRs.') + parser.add_argument('--overlap-snrs', type=str, dest = "overlap_snr_string", + default = "20:10:0", + help='When overlapping speech segments are being added the script will iterate through these SNRs.') + parser.add_argument('--prefix', type=str, default = None, + help='This prefix will modified for each reverberated copy, by adding additional affixes.') + parser.add_argument("--speech-rvb-probability", type=float, default = 1.0, + help="Probability of reverberating a speech signal, e.g. 0 <= p <= 1") + parser.add_argument("--pointsource-noise-addition-probability", type=float, default = 1.0, + help="Probability of adding point-source noises, e.g. 0 <= p <= 1") + parser.add_argument("--isotropic-noise-addition-probability", type=float, default = 1.0, + help="Probability of adding isotropic noises, e.g. 0 <= p <= 1") + parser.add_argument("--overlapping-speech-addition-probability", type=float, default = 1.0, + help="Probability of adding overlapping speech, e.g. 0 <= p <= 1") + parser.add_argument("--rir-smoothing-weight", type=float, default = 0.3, + help="Smoothing weight for the RIR probabilties, e.g. 0 <= p <= 1. If p = 0, no smoothing will be done. " + "The RIR distribution will be mixed with a uniform distribution according to the smoothing weight") + parser.add_argument("--noise-smoothing-weight", type=float, default = 0.3, + help="Smoothing weight for the noise probabilties, e.g. 0 <= p <= 1. If p = 0, no smoothing will be done. " + "The noise distribution will be mixed with a uniform distribution according to the smoothing weight") + parser.add_argument("--overlapping-speech-smoothing-weight", type=float, default = 0.3, + help="The overlapping speech distribution will be mixed with a uniform distribution according to the smoothing weight") + parser.add_argument("--max-noises-per-minute", type=int, default = 2, + help="This controls the maximum number of point-source noises that could be added to a recording according to its duration") + parser.add_argument("--min-overlapping-segments-per-minute", type=int, default = 1, + help="This controls the minimum number of overlapping segments of speech that could be added to a recording per minute") + parser.add_argument("--max-overlapping-segments-per-minute", type=int, default = 5, + help="This controls the maximum number of overlapping segments of speech that could be added to a recording per minute") + parser.add_argument('--random-seed', type=int, default=0, + help='seed to be used in the randomization of impulses and noises') + parser.add_argument("--shift-output", type=str, + help="If true, the reverberated waveform will be shifted by the amount of the peak position of the RIR", + choices=['true', 'false'], default = "true") + parser.add_argument('--source-sampling-rate', type=int, default=None, + help="Sampling rate of the source data. If a positive integer is specified with this option, " + "the RIRs/noises will be resampled to the rate of the source data.") + parser.add_argument("--include-original-data", type=str, help="If true, the output data includes one copy of the original data", + choices=['true', 'false'], default = "false") + parser.add_argument("--output-additive-noise-dir", type=str, + action = common_lib.NullstrToNoneAction, default = None, + help="Output directory corresponding to the additive noise part of the data corruption") + parser.add_argument("--output-reverb-dir", type=str, + action = common_lib.NullstrToNoneAction, default = None, + help="Output directory corresponding to the reverberated signal part of the data corruption") + + parser.add_argument("input_dir", + help="Input data directory") + parser.add_argument("output_dir", + help="Output data directory") + + print(' '.join(sys.argv)) + + args = parser.parse_args() + args = CheckArgs(args) + + return args + +def CheckArgs(args): + if not os.path.exists(args.output_dir): + os.makedirs(args.output_dir) + + ## Check arguments. + + if args.prefix is None: + if args.num_replicas > 1 or args.include_original_data == "true": + args.prefix = "rvb" + warnings.warn("--prefix is set to 'rvb' as more than one copy of data is generated") + + if args.output_reverb_dir is not None: + if not os.path.exists(args.output_reverb_dir): + os.makedirs(args.output_reverb_dir) + + if args.output_additive_noise_dir is not None: + if not os.path.exists(args.output_additive_noise_dir): + os.makedirs(args.output_additive_noise_dir) + + ## Check arguments. + + if args.num_replicas > 1 and args.prefix is None: + args.prefix = "rvb" + warnings.warn("--prefix is set to 'rvb' as --num-replications is larger than 1.") + + if not args.num_replicas > 0: + raise Exception("--num-replications cannot be non-positive") + + if args.speech_rvb_probability < 0 or args.speech_rvb_probability > 1: + raise Exception("--speech-rvb-probability must be between 0 and 1") + + if args.pointsource_noise_addition_probability < 0 or args.pointsource_noise_addition_probability > 1: + raise Exception("--pointsource-noise-addition-probability must be between 0 and 1") + + if args.isotropic_noise_addition_probability < 0 or args.isotropic_noise_addition_probability > 1: + raise Exception("--isotropic-noise-addition-probability must be between 0 and 1") + + if args.overlapping_speech_addition_probability < 0 or args.overlapping_speech_addition_probability > 1: + raise Exception("--overlapping-speech-addition-probability must be between 0 and 1") + + if args.rir_smoothing_weight < 0 or args.rir_smoothing_weight > 1: + raise Exception("--rir-smoothing-weight must be between 0 and 1") + + if args.noise_smoothing_weight < 0 or args.noise_smoothing_weight > 1: + raise Exception("--noise-smoothing-weight must be between 0 and 1") + + if args.overlapping_speech_smoothing_weight < 0 or args.overlapping_speech_smoothing_weight > 1: + raise Exception("--overlapping-speech-smoothing-weight must be between 0 and 1") + + if args.max_noises_per_minute < 0: + raise Exception("--max-noises-per-minute cannot be negative") + + if args.min_overlapping_segments_per_minute < 0: + raise Exception("--min-overlapping-segments-per-minute cannot be negative") + + if args.max_overlapping_segments_per_minute < 0: + raise Exception("--max-overlapping-segments-per-minute cannot be negative") + + return args + +def ParseSpeechSegmentsList(speech_segments_set_para_array, smoothing_weight): + set_list = [] + for set_para in speech_segments_set_para_array: + set = lambda: None + setattr(set, "wav_scp", None) + setattr(set, "segments", None) + setattr(set, "probability", None) + parts = set_para.split(',') + if len(parts) == 3: + set.probability = float(parts[0]) + set.wav_scp = parts[1].strip() + set.segments = parts[2].strip() + else: + set.wav_scp = parts[0].strip() + set.segments = parts[1].strip() + if not os.path.isfile(set.wav_scp): + raise Exception(set.wav_scp + " not found") + if not os.path.isfile(set.segments): + raise Exception(set.segments + " not found") + set_list.append(set) + + data_lib.SmoothProbabilityDistribution(set_list) + + segments_list = [] + for segments_set in set_list: + current_segments_list = [] + + wav_dict = {} + for s in open(segments_set.wav_scp): + parts = s.strip().split() + wav_dict[parts[0]] = ' '.join(parts[1:]) + + for s in open(segments_set.segments): + parts = s.strip().split() + current_segment = argparse.Namespace() + current_segment.utt_id = parts[0] + current_segment.probability = None + + start_time = float(parts[2]) + end_time = float(parts[3]) + + current_segment.duration = (end_time - start_time) + + wav_rxfilename = wav_dict[parts[1]] + if wav_rxfilename.split()[-1] == '|': + current_segment.wav_rxfilename = "{0} sox -t wav - -t wav - trim {1} {2} |".format(wav_rxfilename, start_time, end_time - start_time) + else: + current_segment.wav_rxfilename = "sox {0} -t wav - trim {1} {2} |".format(wav_rxfilename, start_time, end_time - start_time) + + current_segments_list.append(current_segment) + + segments_list += data_lib.SmoothProbabilityDistribution(current_segments_list, smoothing_weight, segments_set.probability) + + return segments_list + +def AddOverlappingSpeech(room, # the room selected + speech_segments_list, # the speech list + overlapping_speech_addition_probability, # Probability of another speech waveform + snrs, # the SNR for adding the foreground speech + speech_dur, # duration of the recording + min_overlapping_speech_segments, # Minimum number of speech signals that can be added + max_overlapping_speech_segments, # Maximum number of speech signals that can be added + overlapping_speech_descriptor # descriptor to store the information of the overlapping speech + ): + if (len(speech_segments_list) > 0 and random.random() < overlapping_speech_addition_probability + and max_overlapping_speech_segments >= 1): + for k in range(1, random.randint(min_overlapping_speech_segments, max_overlapping_speech_segments) + 1): + # pick the overlapping_speech speech signal and the RIR to + # reverberate the overlapping_speech speech signal + speech_segment = data_lib.PickItemWithProbability(speech_segments_list) + rir = data_lib.PickItemWithProbability(room.rir_list) + + speech_rvb_command = """wav-reverberate --impulse-response="{0}" --shift-output=true """.format(rir.rir_rspecifier) + overlapping_speech_descriptor['start_times'].append( + round(random.random() + * max(speech_dur - speech_segment.duration, 0), 2)) + overlapping_speech_descriptor['snrs'].append(snrs.next()) + overlapping_speech_descriptor['utt_ids'].append(speech_segment.utt_id) + overlapping_speech_descriptor['durations'].append(speech_segment.duration) + + if len(speech_segment.wav_rxfilename.split()) == 1: + overlapping_speech_descriptor['speech_segments'].append("{1} {0} - |".format(speech_segment.wav_rxfilename, speech_rvb_command)) + else: + overlapping_speech_descriptor['speech_segments'].append("{0} {1} - - |".format(speech_segment.wav_rxfilename, speech_rvb_command)) + +# This function randomly decides whether to reverberate, and sample a RIR if it does +# It also decides whether to add the appropriate noises +# This function return the string of options to the binary wav-reverberate +def GenerateReverberationAndOverlappedSpeechOpts( + room_dict, # the room dictionary, please refer to MakeRoomDict() for the format + pointsource_noise_list, # the point source noise list + iso_noise_dict, # the isotropic noise dictionary + foreground_snrs, # the SNR for adding the foreground noises + background_snrs, # the SNR for adding the background noises + speech_segments_list, + overlap_snrs, + speech_rvb_probability, # Probability of reverberating a speech signal + isotropic_noise_addition_probability, # Probability of adding isotropic noises + pointsource_noise_addition_probability, # Probability of adding point-source noises + overlapping_speech_addition_probability, # Probability of adding overlapping speech segments + speech_dur, # duration of the recording + max_noises_recording, # Maximum number of point-source noises that can be added + min_overlapping_segments_recording, # Minimum number of overlapping segments that can be added + max_overlapping_segments_recording # Maximum number of overlapping segments that can be added + ): + impulse_response_opts = "" + + noise_addition_descriptor = {'noise_io': [], + 'start_times': [], + 'snrs': [], + 'noise_ids': [], + 'durations': [] + } + + # Randomly select the room + # Here the room probability is a sum of the probabilities of the RIRs recorded in the room. + room = data_lib.PickItemWithProbability(room_dict) + # Randomly select the RIR in the room + speech_rir = data_lib.PickItemWithProbability(room.rir_list) + if random.random() < speech_rvb_probability: + # pick the RIR to reverberate the speech + impulse_response_opts = """--impulse-response="{0}" """.format(speech_rir.rir_rspecifier) + + rir_iso_noise_list = [] + if speech_rir.room_id in iso_noise_dict: + rir_iso_noise_list = iso_noise_dict[speech_rir.room_id] + # Add the corresponding isotropic noise associated with the selected RIR + if len(rir_iso_noise_list) > 0 and random.random() < isotropic_noise_addition_probability: + isotropic_noise = data_lib.PickItemWithProbability(rir_iso_noise_list) + # extend the isotropic noise to the length of the speech waveform + # check if it is really a pipe + if len(isotropic_noise.noise_rspecifier.split()) == 1: + noise_addition_descriptor['noise_io'].append("wav-reverberate --duration={1} {0} - |".format(isotropic_noise.noise_rspecifier, speech_dur)) + else: + noise_addition_descriptor['noise_io'].append("{0} wav-reverberate --duration={1} - - |".format(isotropic_noise.noise_rspecifier, speech_dur)) + noise_addition_descriptor['start_times'].append(0) + noise_addition_descriptor['snrs'].append(background_snrs.next()) + noise_addition_descriptor['noise_ids'].append(isotropic_noise.noise_id) + noise_addition_descriptor['durations'].append(speech_dur) + + data_lib.AddPointSourceNoise(room, # the room selected + pointsource_noise_list, # the point source noise list + pointsource_noise_addition_probability, # Probability of adding point-source noises + foreground_snrs, # the SNR for adding the foreground noises + background_snrs, # the SNR for adding the background noises + speech_dur, # duration of the recording + max_noises_recording, # Maximum number of point-source noises that can be added + noise_addition_descriptor # descriptor to store the information of the noise added + ) + + assert len(noise_addition_descriptor['noise_io']) == len(noise_addition_descriptor['start_times']) + assert len(noise_addition_descriptor['noise_io']) == len(noise_addition_descriptor['snrs']) + assert len(noise_addition_descriptor['noise_io']) == len(noise_addition_descriptor['noise_ids']) + assert len(noise_addition_descriptor['noise_io']) == len(noise_addition_descriptor['durations']) + + overlapping_speech_descriptor = {'speech_segments': [], + 'start_times': [], + 'snrs': [], + 'utt_ids': [], + 'durations': [] + } + + AddOverlappingSpeech(room, + speech_segments_list, # speech segments list + overlapping_speech_addition_probability, + overlap_snrs, + speech_dur, + min_overlapping_segments_recording, + max_overlapping_segments_recording, + overlapping_speech_descriptor + ) + + return [impulse_response_opts, noise_addition_descriptor, + overlapping_speech_descriptor] + +# This is the main function to generate pipeline command for the corruption +# The generic command of wav-reverberate will be like: +# wav-reverberate --duration=t --impulse-response=rir.wav +# --additive-signals='noise1.wav,noise2.wav' --snrs='snr1,snr2' --start-times='s1,s2' input.wav output.wav +def GenerateReverberatedWavScpWithOverlappedSpeech( + wav_scp, # a dictionary whose values are the Kaldi-IO strings of the speech recordings + durations, # a dictionary whose values are the duration (in sec) of the speech recordings + output_dir, # output directory to write the corrupted wav.scp + room_dict, # the room dictionary, please refer to MakeRoomDict() for the format + pointsource_noise_list, # the point source noise list + iso_noise_dict, # the isotropic noise dictionary + foreground_snr_array, # the SNR for adding the foreground noises + background_snr_array, # the SNR for adding the background noises + speech_segments_list, # list of speech segments to create overlapped speech + overlap_snr_array, # the SNR for adding overlapping speech + num_replicas, # Number of replicate to generated for the data + include_original, # include a copy of the original data + prefix, # prefix for the id of the corrupted utterances + speech_rvb_probability, # Probability of reverberating a speech signal + shift_output, # option whether to shift the output waveform + isotropic_noise_addition_probability, # Probability of adding isotropic noises + pointsource_noise_addition_probability, # Probability of adding point-source noises + max_noises_per_minute, # maximum number of point-source noises that can be added to a recording according to its duration + overlapping_speech_addition_probability, + min_overlapping_segments_per_minute, + max_overlapping_segments_per_minute, + output_reverb_dir = None, + output_additive_noise_dir = None + ): + foreground_snrs = data_lib.list_cyclic_iterator(foreground_snr_array) + background_snrs = data_lib.list_cyclic_iterator(background_snr_array) + overlap_snrs = data_lib.list_cyclic_iterator(overlap_snr_array) + + corrupted_wav_scp = {} + reverb_wav_scp = {} + additive_noise_wav_scp = {} + overlapping_segments_info = {} + + keys = wav_scp.keys() + keys.sort() + + if include_original: + start_index = 0 + else: + start_index = 1 + + for i in range(start_index, num_replicas+1): + for recording_id in keys: + wav_original_pipe = wav_scp[recording_id] + # check if it is really a pipe + if len(wav_original_pipe.split()) == 1: + wav_original_pipe = "cat {0} |".format(wav_original_pipe) + speech_dur = durations[recording_id] + max_noises_recording = math.floor(max_noises_per_minute * speech_dur / 60) + min_overlapping_segments_recording = max(math.floor(min_overlapping_segments_per_minute * speech_dur / 60), 1) + max_overlapping_segments_recording = math.ceil(max_overlapping_segments_per_minute * speech_dur / 60) + + [impulse_response_opts, noise_addition_descriptor, + overlapping_speech_descriptor] = GenerateReverberationAndOverlappedSpeechOpts( + room_dict = room_dict, # the room dictionary, please refer to MakeRoomDict() for the format + pointsource_noise_list = pointsource_noise_list, # the point source noise list + iso_noise_dict = iso_noise_dict, # the isotropic noise dictionary + foreground_snrs = foreground_snrs, # the SNR for adding the foreground noises + background_snrs = background_snrs, # the SNR for adding the background noises + speech_segments_list = speech_segments_list, # Speech segments for creating overlapped speech + overlap_snrs = overlap_snrs, # the SNR for adding overlapping speech + speech_rvb_probability = speech_rvb_probability, # Probability of reverberating a speech signal + isotropic_noise_addition_probability = isotropic_noise_addition_probability, # Probability of adding isotropic noises + pointsource_noise_addition_probability = pointsource_noise_addition_probability, # Probability of adding point-source noises + overlapping_speech_addition_probability = overlapping_speech_addition_probability, + speech_dur = speech_dur, # duration of the recording + max_noises_recording = max_noises_recording, # Maximum number of point-source noises that can be added + min_overlapping_segments_recording = min_overlapping_segments_recording, + max_overlapping_segments_recording = max_overlapping_segments_recording + ) + + additive_noise_opts = "" + + if (len(noise_addition_descriptor['noise_io']) > 0 or + len(overlapping_speech_descriptor['speech_segments']) > 0): + additive_noise_opts += ("--additive-signals='{0}' " + .format(',' + .join(noise_addition_descriptor['noise_io'] + + overlapping_speech_descriptor['speech_segments'])) + ) + additive_noise_opts += ("--start-times='{0}' " + .format(',' + .join(map(lambda x:str(x), noise_addition_descriptor['start_times'] + + overlapping_speech_descriptor['start_times']))) + ) + additive_noise_opts += ("--snrs='{0}' " + .format(',' + .join(map(lambda x:str(x), noise_addition_descriptor['snrs'] + + overlapping_speech_descriptor['snrs']))) + ) + + reverberate_opts = impulse_response_opts + additive_noise_opts + + new_recording_id = data_lib.GetNewId(recording_id, prefix, i) + + # prefix using index 0 is reserved for original data e.g. rvb0_swb0035 corresponds to the swb0035 recording in original data + if reverberate_opts == "" or i == 0: + wav_corrupted_pipe = "{0}".format(wav_original_pipe) + else: + wav_corrupted_pipe = "{0} wav-reverberate --shift-output={1} {2} - - |".format(wav_original_pipe, shift_output, reverberate_opts) + + corrupted_wav_scp[new_recording_id] = wav_corrupted_pipe + + if output_reverb_dir is not None: + if impulse_response_opts == "": + wav_reverb_pipe = "{0}".format(wav_original_pipe) + else: + wav_reverb_pipe = "{0} wav-reverberate --shift-output={1} --reverb-out-wxfilename=- {2} - /dev/null |".format(wav_original_pipe, shift_output, reverberate_opts) + reverb_wav_scp[new_recording_id] = wav_reverb_pipe + + if output_additive_noise_dir is not None: + if additive_noise_opts != "": + wav_additive_noise_pipe = "{0} wav-reverberate --shift-output={1} --additive-noise-out-wxfilename=- {2} - /dev/null |".format(wav_original_pipe, shift_output, reverberate_opts) + additive_noise_wav_scp[new_recording_id] = wav_additive_noise_pipe + else: + assert False + + if len(overlapping_speech_descriptor['speech_segments']) > 0: + overlapping_segments_info[new_recording_id] = [ + ':'.join(x) + for x in zip(overlapping_speech_descriptor['utt_ids'], + [ str(x) for x in overlapping_speech_descriptor['start_times'] ], + [ str(x) for x in overlapping_speech_descriptor['durations'] ]) + ] + + data_lib.WriteDictToFile(corrupted_wav_scp, output_dir + "/wav.scp") + + # Write for each new recording, the id, start time and durations + # of the overlapping segments + data_lib.WriteDictToFile(overlapping_segments_info, output_dir + "/overlapped_segments_info.txt") + + if output_reverb_dir is not None: + data_lib.WriteDictToFile(reverb_wav_scp, output_reverb_dir + "/wav.scp") + + if output_additive_noise_dir is not None: + data_lib.WriteDictToFile(additive_noise_wav_scp, output_additive_noise_dir + "/wav.scp") + + +# This function creates multiple copies of the necessary files, e.g. utt2spk, wav.scp ... +def CreateReverberatedCopy(input_dir, + output_dir, + room_dict, # the room dictionary, please refer to MakeRoomDict() for the format + pointsource_noise_list, # the point source noise list + iso_noise_dict, # the isotropic noise dictionary + speech_segments_list, + foreground_snr_string, # the SNR for adding the foreground noises + background_snr_string, # the SNR for adding the background noises + overlap_snr_string, # the SNR for overlapping speech + num_replicas, # Number of replicate to generated for the data + include_original, # include a copy of the original data + prefix, # prefix for the id of the corrupted utterances + speech_rvb_probability, # Probability of reverberating a speech signal + shift_output, # option whether to shift the output waveform + isotropic_noise_addition_probability, # Probability of adding isotropic noises + pointsource_noise_addition_probability, # Probability of adding point-source noises + max_noises_per_minute, # maximum number of point-source noises that can be added to a recording according to its duration + overlapping_speech_addition_probability, + min_overlapping_segments_per_minute, + max_overlapping_segments_per_minute, + output_reverb_dir = None, + output_additive_noise_dir = None + ): + + wav_scp = data_lib.ParseFileToDict(input_dir + "/wav.scp", value_processor = lambda x: " ".join(x)) + if not os.path.isfile(input_dir + "/reco2dur"): + print("Getting the duration of the recordings..."); + read_entire_file="false" + for value in wav_scp.values(): + # we will add more checks for sox commands which modify the header as we come across these cases in our data + if "sox" in value and "speed" in value: + read_entire_file="true" + break + data_lib.RunKaldiCommand("wav-to-duration --read-entire-file={1} scp:{0}/wav.scp ark,t:{0}/reco2dur".format(input_dir, read_entire_file)) + durations = data_lib.ParseFileToDict(input_dir + "/reco2dur", value_processor = lambda x: float(x[0])) + foreground_snr_array = map(lambda x: float(x), foreground_snr_string.split(':')) + background_snr_array = map(lambda x: float(x), background_snr_string.split(':')) + overlap_snr_array = map(lambda x: float(x), overlap_snr_string.split(':')) + + GenerateReverberatedWavScpWithOverlappedSpeech( + wav_scp = wav_scp, + durations = durations, + output_dir = output_dir, + room_dict = room_dict, + pointsource_noise_list = pointsource_noise_list, + iso_noise_dict = iso_noise_dict, + foreground_snr_array = foreground_snr_array, + background_snr_array = background_snr_array, + speech_segments_list = speech_segments_list, + overlap_snr_array = overlap_snr_array, + num_replicas = num_replicas, include_original=include_original, prefix = prefix, + speech_rvb_probability = speech_rvb_probability, + shift_output = shift_output, + isotropic_noise_addition_probability = isotropic_noise_addition_probability, + pointsource_noise_addition_probability = pointsource_noise_addition_probability, + max_noises_per_minute = max_noises_per_minute, + overlapping_speech_addition_probability = overlapping_speech_addition_probability, + min_overlapping_segments_per_minute = min_overlapping_segments_per_minute, + max_overlapping_segments_per_minute = max_overlapping_segments_per_minute, + output_reverb_dir = output_reverb_dir, + output_additive_noise_dir = output_additive_noise_dir) + + data_lib.CopyDataDirFiles(input_dir, output_dir, num_replicas, include_original=include_original, prefix=prefix) + + if output_reverb_dir is not None: + data_lib.CopyDataDirFiles(input_dir, output_reverb_dir, num_replicas, include_original=include_original, prefix=prefix) + + if output_additive_noise_dir is not None: + data_lib.CopyDataDirFiles(input_dir, output_additive_noise_dir, num_replicas, include_original=include_original, prefix=prefix) + + +def Main(): + args = GetArgs() + random.seed(args.random_seed) + rir_list = data_lib.ParseRirList(args.rir_set_para_array, args.rir_smoothing_weight, args.source_sampling_rate) + print("Number of RIRs is {0}".format(len(rir_list))) + pointsource_noise_list = [] + iso_noise_dict = {} + if args.noise_set_para_array is not None: + pointsource_noise_list, iso_noise_dict = data_lib.ParseNoiseList(args.noise_set_para_array, args.noise_smoothing_weight, args.source_sampling_rate) + print("Number of point-source noises is {0}".format(len(pointsource_noise_list))) + print("Number of isotropic noises is {0}".format(sum(len(iso_noise_dict[key]) for key in iso_noise_dict.keys()))) + room_dict = data_lib.MakeRoomDict(rir_list) + + if args.include_original_data == "true": + include_original = True + else: + include_original = False + + speech_segments_list = ParseSpeechSegmentsList(args.speech_segments_set_para_array, args.overlapping_speech_smoothing_weight) + + CreateReverberatedCopy(input_dir = args.input_dir, + output_dir = args.output_dir, + room_dict = room_dict, + pointsource_noise_list = pointsource_noise_list, + iso_noise_dict = iso_noise_dict, + speech_segments_list = speech_segments_list, + foreground_snr_string = args.foreground_snr_string, + background_snr_string = args.background_snr_string, + overlap_snr_string = args.overlap_snr_string, + num_replicas = args.num_replicas, + include_original = include_original, + prefix = args.prefix, + speech_rvb_probability = args.speech_rvb_probability, + shift_output = args.shift_output, + isotropic_noise_addition_probability = args.isotropic_noise_addition_probability, + pointsource_noise_addition_probability = args.pointsource_noise_addition_probability, + max_noises_per_minute = args.max_noises_per_minute, + overlapping_speech_addition_probability = args.overlapping_speech_addition_probability, + min_overlapping_segments_per_minute = args.min_overlapping_segments_per_minute, + max_overlapping_segments_per_minute = args.max_overlapping_segments_per_minute, + output_reverb_dir = args.output_reverb_dir, + output_additive_noise_dir = args.output_additive_noise_dir) + +if __name__ == "__main__": + Main() From 613f0aa8921baf05c8315f829f40e7a96b28b88e Mon Sep 17 00:00:00 2001 From: Vimal Manohar Date: Wed, 25 Jan 2017 15:19:37 -0500 Subject: [PATCH 173/213] asr_diarization: Add deprecated sad run scripts --- .../local/segmentation/run_train_sad_music.sh | 161 ++++++++++++++++ .../run_train_sad_ovlp_logprob.sh | 148 +++++++++++++++ .../segmentation/run_train_stats_sad_music.sh | 172 ++++++++++++++++++ .../v1/local/run_dnn_music_id.sh | 130 +++++++++++++ 4 files changed, 611 insertions(+) create mode 100644 egs/aspire/s5/local/segmentation/run_train_sad_music.sh create mode 100644 egs/aspire/s5/local/segmentation/run_train_sad_ovlp_logprob.sh create mode 100644 egs/aspire/s5/local/segmentation/run_train_stats_sad_music.sh create mode 100755 egs/bn_music_speech/v1/local/run_dnn_music_id.sh diff --git a/egs/aspire/s5/local/segmentation/run_train_sad_music.sh b/egs/aspire/s5/local/segmentation/run_train_sad_music.sh new file mode 100644 index 00000000000..5acb4bf4306 --- /dev/null +++ b/egs/aspire/s5/local/segmentation/run_train_sad_music.sh @@ -0,0 +1,161 @@ +#!/bin/bash + +# this is the standard "tdnn" system, built in nnet3; it's what we use to +# call multi-splice. + +set -o pipefail +set -e +set -u + +. cmd.sh + +# At this script level we don't support not running on GPU, as it would be painfully slow. +# If you want to run without GPU you'd have to call train_tdnn.sh with --gpu false, +# --num-threads 16 and --minibatch-size 128. + +stage=0 +train_stage=-10 +get_egs_stage=-10 +egs_opts= + +splice_indexes="-3,-2,-1,0,1,2,3 -6,0 -9,0,3 0" +relu_dim=256 + +# training options +num_epochs=2 +initial_effective_lrate=0.0003 +final_effective_lrate=0.00003 +num_jobs_initial=3 +num_jobs_final=8 +remove_egs=false +max_param_change=1 +extra_egs_copy_cmd= + +num_utts_subset_valid=40 +num_utts_subset_train=40 +add_idct=true + +# target options +train_data_dir=data/train_azteec_whole_sp_corrupted_hires + +snr_scp= +speech_feat_scp= +music_labels_scp= + +deriv_weights_scp= +deriv_weights_for_irm_scp= + +egs_dir= +nj=40 +feat_type=raw +config_dir= + +dir= +affix=a + +. cmd.sh +. ./path.sh +. ./utils/parse_options.sh + +num_hidden_layers=`echo $splice_indexes | perl -ane 'print scalar @F'` || exit 1 +if [ -z "$dir" ]; then + dir=exp/nnet3_sad_snr/nnet_tdnn +fi + +dir=$dir${affix:+_$affix}_n${num_hidden_layers} + +if ! cuda-compiled; then + cat < " + echo " e.g.: $0 data/bn exp/nnet3_sad_snr/tdnn_b_n4/sad_bn_whole exp/nnet3_sad_snr/tdnn_b_n4/music_bn_whole exp/nnet3_sad_snr/tdnn_b_n4/segmentation_bn_whole exp/nnet3_sad_snr/tdnn_b_n4/segmentation_music_bn_whole exp/dnn_music_id" + exit 1 +fi + +data=$1 +sad_likes_dir=$2 +music_likes_dir=$3 +dir=$4 + +min_silence_duration=`perl -e "print (int($min_silence_duration / $frame_subsampling_factor))"` +min_speech_duration=`perl -e "print (int($min_speech_duration / $frame_subsampling_factor))"` +min_music_duration=`perl -e "print (int($min_music_duration / $frame_subsampling_factor))"` + +lang=$dir/lang + +if [ $stage -le 1 ]; then + mkdir -p $lang + + # Create a lang directory with phones.txt and topo with + # silence, music and speech phones. + steps/segmentation/internal/prepare_sad_lang.py \ + --phone-transition-parameters="--phone-list=1 --min-duration=$min_silence_duration --end-transition-probability=$sil_transition_probability" \ + --phone-transition-parameters="--phone-list=2 --min-duration=$min_speech_duration --end-transition-probability=$speech_transition_probability" \ + --phone-transition-parameters="--phone-list=3 --min-duration=$min_music_duration --end-transition-probability=$music_transition_probability" \ + $lang + + cp $lang/phones.txt $lang/words.txt +fi + +feat_dim=2 # dummy. We don't need this. +if [ $stage -le 2 ]; then + $cmd $dir/log/create_transition_model.log gmm-init-mono \ + $lang/topo $feat_dim - $dir/tree \| \ + copy-transition-model --binary=false - $dir/trans.mdl || exit 1 +fi + +# Make unigram G.fst +if [ $stage -le 3 ]; then + cat > $lang/word2prior < $lang/G.fst +fi + +graph_dir=$dir/graph_test + +if [ $stage -le 4 ]; then + $cmd $dir/log/make_vad_graph.log \ + steps/segmentation/internal/make_sad_graph.sh --iter trans \ + $lang $dir $dir/graph_test || exit 1 +fi + +if [ $stage -le 5 ]; then + utils/split_data.sh $data $nj + sdata=$data/split$nj + + nj_sad=`cat $sad_likes_dir/num_jobs` + sad_likes= + for n in `seq $nj_sad`; do + sad_likes="$sad_likes $sad_likes_dir/log_likes.$n.gz" + done + + nj_music=`cat $music_likes_dir/num_jobs` + music_likes= + for n in `seq $nj_music`; do + music_likes="$music_likes $music_likes_dir/log_likes.$n.gz" + done + + decoder_opts+=(--acoustic-scale=$acwt --beam=$beam --max-active=$max_active) + $cmd JOB=1:$nj $dir/log/decode.JOB.log \ + paste-feats "ark:gunzip -c $sad_likes | extract-feature-segments ark,s,cs:- $sdata/JOB/segments ark:- |" \ + "ark,s,cs:gunzip -c $music_likes | extract-feature-segments ark,s,cs:- $sdata/JOB/segments ark:- | select-feats 1 ark:- ark:- |" \ + ark:- \| decode-faster-mapped ${decoder_opts[@]} \ + $dir/trans.mdl $graph_dir/HCLG.fst ark:- \ + ark:/dev/null ark:- \| \ + ali-to-phones --per-frame $dir/trans.mdl ark:- \ + "ark:|gzip -c > $dir/ali.JOB.gz" +fi + +include_silence=true +if [ $stage -le 6 ]; then + $cmd JOB=1:$nj $dir/log/get_class_id.JOB.log \ + ali-to-post "ark:gunzip -c $dir/ali.JOB.gz |" ark:- \| \ + post-to-feats --post-dim=4 ark:- ark:- \| \ + matrix-sum-rows --do-average ark:- ark,t:- \| \ + sid/vector_to_music_labels.pl ${include_silence:+--include-silence-in-music} '>' $dir/ratio.JOB +fi + +for n in `seq $nj`; do + cat $dir/ratio.$n +done > $dir/ratio + +cat $dir/ratio | local/print_scores.py /dev/stdin | compute-eer - From 840bee25d288d61a3364318daa34ac9ec9e9e816 Mon Sep 17 00:00:00 2001 From: Vimal Manohar Date: Wed, 25 Jan 2017 15:21:45 -0500 Subject: [PATCH 174/213] asr_diarization: Add deprecated do_corruption_whole_data_dir_overlapped_speech.sh --- ...uption_whole_data_dir_overlapped_speech.sh | 284 ++++++++++++++++++ 1 file changed, 284 insertions(+) create mode 100755 egs/aspire/s5/local/segmentation/do_corruption_whole_data_dir_overlapped_speech.sh diff --git a/egs/aspire/s5/local/segmentation/do_corruption_whole_data_dir_overlapped_speech.sh b/egs/aspire/s5/local/segmentation/do_corruption_whole_data_dir_overlapped_speech.sh new file mode 100755 index 00000000000..75dbce578b2 --- /dev/null +++ b/egs/aspire/s5/local/segmentation/do_corruption_whole_data_dir_overlapped_speech.sh @@ -0,0 +1,284 @@ +#! /bin/bash + +# Copyright 2016 Vimal Manohar +# Apache 2.0 + +set -e +set -u +set -o pipefail + +. path.sh + +stage=0 +corruption_stage=-10 +corrupt_only=false + +# Data options +data_dir=data/train_si284 # Excpecting non-whole data directory +speed_perturb=true +num_data_reps=5 # Number of corrupted versions +snrs="20:10:15:5:0:-5" +foreground_snrs="20:10:15:5:0:-5" +background_snrs="20:10:15:5:0:-5" +overlap_snrs="5:2:1:0:-1:-2" +# Whole-data directory corresponding to data_dir +whole_data_dir=data/train_si284_whole +overlap_labels_dir=overlap_labels + +# Parallel options +reco_nj=40 +nj=40 +cmd=queue.pl + +# Options for feature extraction +mfcc_config=conf/mfcc_hires_bp.conf +feat_suffix=hires_bp +energy_config=conf/log_energy.conf + +reco_vad_dir= # Output of prepare_unsad_data.sh. + # If provided, the speech labels and deriv weights will be + # copied into the output data directory. +utt_vad_dir= + +. utils/parse_options.sh + +if [ $# -ne 0 ]; then + echo "Usage: $0" + exit 1 +fi + +rvb_opts=() +# This is the config for the system using simulated RIRs and point-source noises +rvb_opts+=(--rir-set-parameters "0.5, RIRS_NOISES/simulated_rirs/smallroom/rir_list") +rvb_opts+=(--rir-set-parameters "0.5, RIRS_NOISES/simulated_rirs/mediumroom/rir_list") +rvb_opts+=(--speech-segments-set-parameters="$data_dir/wav.scp,$data_dir/segments") + +whole_data_id=`basename ${whole_data_dir}` + +corrupted_data_id=${whole_data_id}_ovlp_corrupted +clean_data_id=${whole_data_id}_ovlp_clean +noise_data_id=${whole_data_id}_ovlp_noise + +if [ $stage -le 1 ]; then + python steps/data/make_corrupted_data_dir.py \ + "${rvb_opts[@]}" \ + --prefix="ovlp" \ + --overlap-snrs=$overlap_snrs \ + --speech-rvb-probability=1 \ + --overlapping-speech-addition-probability=1 \ + --num-replications=$num_data_reps \ + --min-overlapping-segments-per-minute=5 \ + --max-overlapping-segments-per-minute=20 \ + --output-additive-noise-dir=data/${noise_data_id} \ + --output-reverb-dir=data/${clean_data_id} \ + data/${whole_data_id} data/${corrupted_data_id} +fi + +if $dry_run; then + exit 0 +fi + +clean_data_dir=data/${clean_data_id} +corrupted_data_dir=data/${corrupted_data_id} +noise_data_dir=data/${noise_data_id} +orig_corrupted_data_dir=$corrupted_data_dir + +if $speed_perturb; then + if [ $stage -le 2 ]; then + ## Assuming whole data directories + for x in $clean_data_dir $corrupted_data_dir $noise_data_dir; do + cp $x/reco2dur $x/utt2dur + utils/data/perturb_data_dir_speed_3way.sh $x ${x}_sp + done + fi + + corrupted_data_dir=${corrupted_data_dir}_sp + clean_data_dir=${clean_data_dir}_sp + noise_data_dir=${noise_data_dir}_sp + + corrupted_data_id=${corrupted_data_id}_sp + clean_data_id=${clean_data_id}_sp + noise_data_id=${noise_data_id}_sp + + if [ $stage -le 3 ]; then + utils/data/perturb_data_dir_volume.sh --scale-low 0.03125 --scale-high 2 --force true ${corrupted_data_dir} + utils/data/perturb_data_dir_volume.sh --force true --reco2vol ${corrupted_data_dir}/reco2vol ${clean_data_dir} + utils/data/perturb_data_dir_volume.sh --force true --reco2vol ${corrupted_data_dir}/reco2vol ${noise_data_dir} + fi +fi + +if $corrupt_only; then + echo "$0: Got corrupted data directory in ${corrupted_data_dir}" + exit 0 +fi + +mfccdir=`basename $mfcc_config` +mfccdir=${mfccdir%%.conf} + +if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $mfccdir/storage ]; then + utils/create_split_dir.pl \ + /export/b0{3,4,5,6}/$USER/kaldi-data/egs/aspire-$(date +'%m_%d_%H_%M')/s5/$mfccdir/storage $mfccdir/storage +fi + +if [ $stage -le 4 ]; then + utils/copy_data_dir.sh $corrupted_data_dir ${corrupted_data_dir}_$feat_suffix + corrupted_data_dir=${corrupted_data_dir}_$feat_suffix + steps/make_mfcc.sh --mfcc-config $mfcc_config \ + --cmd "$train_cmd" --nj $reco_nj \ + $corrupted_data_dir exp/make_${feat_suffix}/${corrupted_data_id} $mfccdir +fi + +if [ $stage -le 5 ]; then + steps/make_mfcc.sh --mfcc-config $energy_config \ + --cmd "$train_cmd" --nj $reco_nj \ + $clean_data_dir exp/make_log_energy/${clean_data_id} log_energy_feats +fi + +if [ $stage -le 6 ]; then + steps/make_mfcc.sh --mfcc-config $energy_config \ + --cmd "$train_cmd" --nj $reco_nj \ + $noise_data_dir exp/make_log_energy/${noise_data_id} log_energy_feats +fi + +if [ -z "$reco_vad_dir" ]; then + echo "reco-vad-dir must be provided" + exit 1 +fi + +targets_dir=irm_targets +if [ $stage -le 8 ]; then + mkdir -p exp/make_irm_targets/${corrupted_data_id} + + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $targets_dir/storage ]; then + utils/create_split_dir.pl \ + /export/b0{3,4,5,6}/$USER/kaldi-data/egs/aspire-$(date +'%m_%d_%H_%M')/s5/$targets_dir/storage $targets_dir/storage + fi + + steps/segmentation/make_snr_targets.sh \ + --nj $nj --cmd "$train_cmd --max-jobs-run $max_jobs_run" \ + --target-type Irm --compress true --apply-exp false \ + ${clean_data_dir} ${noise_data_dir} ${corrupted_data_dir} \ + exp/make_irm_targets/${corrupted_data_id} $targets_dir +fi + +# Combine the VAD from the base recording and the VAD from the overlapping segments +# to create per-frame labels of the number of overlapping speech segments +# Unreliable segments are regions where no VAD labels were available for the +# overlapping segments. These can be later removed by setting deriv weights to 0. + +# Data dirs without speed perturbation +overlap_dir=exp/make_overlap_labels/${corrupted_data_id} +unreliable_dir=exp/make_overlap_labels/unreliable_${corrupted_data_id} +overlap_data_dir=$overlap_dir/overlap_data +unreliable_data_dir=$overlap_dir/unreliable_data + +mkdir -p $unreliable_dir + +if [ $stage -le 8 ]; then + cat $reco_vad_dir/sad_seg.scp | \ + steps/segmentation/get_reverb_scp.pl -f 1 $num_data_reps "ovlp" \ + | sort -k1,1 > ${corrupted_data_dir}/sad_seg.scp + utils/data/get_utt2num_frames.sh $corrupted_data_dir + utils/split_data.sh --per-reco ${orig_corrupted_data_dir} $reco_nj + + $train_cmd JOB=1:$reco_nj $overlap_dir/log/get_overlap_seg.JOB.log \ + segmentation-init-from-overlap-info --lengths-rspecifier=ark,t:$corrupted_data_dir/utt2num_frames \ + "scp:utils/filter_scp.pl ${orig_corrupted_data_dir}/split${reco_nj}reco/JOB/utt2spk $corrupted_data_dir/sad_seg.scp |" \ + ark,t:$orig_corrupted_data_dir/overlapped_segments_info.txt \ + scp:$utt_vad_dir/sad_seg.scp ark:- ark:$unreliable_dir/unreliable_seg_speed_unperturbed.JOB.ark \| \ + segmentation-copy --keep-label=1 ark:- ark:- \| \ + segmentation-get-stats --lengths-rspecifier=ark,t:$corrupted_data_dir/utt2num_frames \ + ark:- ark:- ark:/dev/null \| \ + segmentation-init-from-ali ark:- ark:$overlap_dir/overlap_seg_speed_unperturbed.JOB.ark +fi + +if [ $stage -le 9 ]; then + mkdir -p $overlap_data_dir $unreliable_data_dir + cp $orig_corrupted_data_dir/wav.scp $overlap_data_dir + cp $orig_corrupted_data_dir/wav.scp $unreliable_data_dir + + # Create segments where there is definitely an overlap. + # Assume no more than 10 speakers overlap. + $train_cmd JOB=1:$reco_nj $overlap_dir/log/process_to_segments.JOB.log \ + segmentation-post-process --remove-labels=0:1 \ + ark:$overlap_dir/overlap_seg_speed_unperturbed.JOB.ark ark:- \| \ + segmentation-post-process --merge-labels=2:3:4:5:6:7:8:9:10 --merge-dst-label=1 ark:- ark:- \| \ + segmentation-to-segments ark:- ark:$overlap_data_dir/utt2spk.JOB $overlap_data_dir/segments.JOB + + $train_cmd JOB=1:$reco_nj $overlap_dir/log/get_unreliable_segments.JOB.log \ + segmentation-to-segments --single-speaker \ + ark:$unreliable_dir/unreliable_seg_speed_unperturbed.JOB.ark \ + ark:$unreliable_data_dir/utt2spk.JOB $unreliable_data_dir/segments.JOB + + for n in `seq $reco_nj`; do cat $overlap_data_dir/utt2spk.$n; done > $overlap_data_dir/utt2spk + for n in `seq $reco_nj`; do cat $overlap_data_dir/segments.$n; done > $overlap_data_dir/segments + for n in `seq $reco_nj`; do cat $unreliable_data_dir/utt2spk.$n; done > $unreliable_data_dir/utt2spk + for n in `seq $reco_nj`; do cat $unreliable_data_dir/segments.$n; done > $unreliable_data_dir/segments + + utils/fix_data_dir.sh $overlap_data_dir + utils/fix_data_dir.sh $unreliable_data_dir + + if $speed_perturb; then + utils/data/perturb_data_dir_speed_3way.sh $overlap_data_dir ${overlap_data_dir}_sp + utils/data/perturb_data_dir_speed_3way.sh $unreliable_data_dir ${unreliable_data_dir}_sp + fi +fi + +if $speed_perturb; then + overlap_data_dir=${overlap_data_dir}_sp + unreliable_data_dir=${unreliable_data_dir}_sp +fi + +# make $overlap_labels_dir an absolute pathname. +overlap_labels_dir=`perl -e '($dir,$pwd)= @ARGV; if($dir!~m:^/:) { $dir = "$pwd/$dir"; } print $dir; ' $overlap_labels_dir ${PWD}` + +if [ $stage -le 10 ]; then + utils/split_data.sh --per-reco ${overlap_data_dir} $reco_nj + + $train_cmd JOB=1:$reco_nj $overlap_dir/log/get_overlap_speech_labels.JOB.log \ + utils/data/get_reco2utt.sh ${overlap_data_dir}/split${reco_nj}reco/JOB '&&' \ + segmentation-init-from-segments --shift-to-zero=false \ + ${overlap_data_dir}/split${reco_nj}reco/JOB/segments ark:- \| \ + segmentation-combine-segments-to-recordings ark:- ark,t:${overlap_data_dir}/split${reco_nj}reco/JOB/reco2utt \ + ark:- \| \ + segmentation-to-ali --lengths-rspecifier=ark,t:${corrupted_data_dir}/utt2num_frames ark:- \ + ark,scp:$overlap_labels_dir/overlapped_speech_${corrupted_data_id}.JOB.ark,$overlap_labels_dir/overlapped_speech_${corrupted_data_id}.JOB.scp +fi + +for n in `seq $reco_nj`; do + cat $overlap_labels_dir/overlapped_speech_${corrupted_data_id}.$n.scp +done > ${corrupted_data_dir}/overlapped_speech_labels.scp + +if [ $stage -le 11 ]; then + utils/data/get_reco2utt.sh ${unreliable_data_dir} + + # First convert the unreliable segments into a recording-level segmentation. + # Initialize a segmentation from utt2num_frames and set to 0, the regions + # of unreliable segments. At this stage deriv weights is 1 for all but the + # unreliable segment regions. + # Initialize a segmentation from the VAD labels and retain only the speech segments. + # Intersect this with the deriv weights segmentation from above. At this stage + # deriv weights is 1 for only the regions where base VAD label is 1 and + # the overlapping segment is not unreliable. Convert this to deriv weights. + $train_cmd JOB=1:$reco_nj $unreliable_dir/log/get_deriv_weights.JOB.log\ + segmentation-init-from-segments --shift-to-zero=false \ + "utils/filter_scp.pl -f 2 ${overlap_data_dir}/split${reco_nj}reco/JOB/reco2utt ${unreliable_data_dir}/segments |" ark:- \| \ + segmentation-combine-segments-to-recordings ark:- "ark,t:utils/filter_scp.pl ${overlap_data_dir}/split${reco_nj}reco/JOB/reco2utt ${unreliable_data_dir}/reco2utt |" \ + ark:- \| \ + segmentation-create-subsegments --filter-label=1 --subsegment-label=0 --ignore-missing \ + "ark:utils/filter_scp.pl ${overlap_data_dir}/split${reco_nj}reco/JOB/reco2utt $corrupted_data_dir/utt2num_frames | segmentation-init-from-lengths ark,t:- ark:- |" \ + ark:- ark:- \| \ + segmentation-intersect-segments --mismatch-label=0 \ + "ark:utils/filter_scp.pl ${overlap_data_dir}/split${reco_nj}reco/JOB/reco2utt $corrupted_data_dir/sad_seg.scp | segmentation-post-process --remove-labels=0:2:3 scp:- ark:- |" \ + ark:- ark:- \| \ + segmentation-post-process --remove-labels=0 ark:- ark:- \| \ + segmentation-to-ali --lengths-rspecifier=ark,t:${corrupted_data_dir}/utt2num_frames ark:- ark,t:- \| \ + steps/segmentation/convert_ali_to_vec.pl \| copy-vector ark,t:- \ + ark,scp:$overlap_labels_dir/deriv_weights_for_overlapped_speech.JOB.ark,$overlap_labels_dir/deriv_weights_for_overlapped_speech.JOB.scp + + for n in `seq $reco_nj`; do + cat $overlap_labels_dir/deriv_weights_for_overlapped_speech.${n}.scp + done > $corrupted_data_dir/deriv_weights_for_overlapped_speech.scp +fi + +exit 0 From 2725cd1e75b743513e8ba93dbde1ee8750dc0d5c Mon Sep 17 00:00:00 2001 From: Vimal Manohar Date: Wed, 25 Jan 2017 15:23:53 -0500 Subject: [PATCH 175/213] asr_diarization: steps/data/wav_scp2noise_list.py --- egs/wsj/s5/steps/data/wav_scp2noise_list.py | 39 +++++++++++++++++++++ 1 file changed, 39 insertions(+) create mode 100755 egs/wsj/s5/steps/data/wav_scp2noise_list.py diff --git a/egs/wsj/s5/steps/data/wav_scp2noise_list.py b/egs/wsj/s5/steps/data/wav_scp2noise_list.py new file mode 100755 index 00000000000..960bce33c7d --- /dev/null +++ b/egs/wsj/s5/steps/data/wav_scp2noise_list.py @@ -0,0 +1,39 @@ +#! /usr/bin/env python + +# Copyright 2016 Vimal Manohar +# Apache 2.0 + +from __future__ import print_function +import argparse, random + +def GetArgs(): + parser = argparse.ArgumentParser(description="""This script converts a wav.scp +into noise-set-paramters that can be passed to steps/data/reverberate_data_dir.py.""") + + parser.add_argument("wav_scp", type=str, + help = "The input wav.scp") + parser.add_argument("noise_list", type=str, + help = "File to write the output noise-set-parameters") + + args = parser.parse_args() + + return args + +def Main(): + args = GetArgs() + + noise_list = open(args.noise_list, 'w') + + for line in open(args.wav_scp): + parts = line.strip().split() + + print ('''--noise-id {reco} --noise-type point-source \ +--bg-fg-type foreground "{wav}"'''.format( + reco = parts[0], + wav = " ".join(parts[1:])), file = noise_list) + + noise_list.close() + +if __name__ == '__main__': + Main() + From 4a35cec70673db92e5d5e559fe237ee0e050abbc Mon Sep 17 00:00:00 2001 From: Vimal Manohar Date: Wed, 25 Jan 2017 15:24:16 -0500 Subject: [PATCH 176/213] asr_diarization: Cluster segments AIB --- .../segmentation/cluster_segments_aIB.sh | 138 ++++++++++++++++++ .../cluster_segments_aIB_change_point.sh | 138 ++++++++++++++++++ 2 files changed, 276 insertions(+) create mode 100755 egs/wsj/s5/steps/segmentation/cluster_segments_aIB.sh create mode 100755 egs/wsj/s5/steps/segmentation/cluster_segments_aIB_change_point.sh diff --git a/egs/wsj/s5/steps/segmentation/cluster_segments_aIB.sh b/egs/wsj/s5/steps/segmentation/cluster_segments_aIB.sh new file mode 100755 index 00000000000..a1f187fab31 --- /dev/null +++ b/egs/wsj/s5/steps/segmentation/cluster_segments_aIB.sh @@ -0,0 +1,138 @@ +#! /bin/bash + +window=2.5 +overlap=0.0 +stage=-1 +cmd=queue.pl +reco_nj=4 +frame_shift=0.01 +utt_nj=18 +min_clusters=10 +stopping_threshold=0.5 + +. path.sh +. utils/parse_options.sh + +set -o pipefail +set -e +set -u + +if [ $# -ne 3 ]; then + echo "Usage: $0 " + exit 1 +fi + +data=$1 +dir=$2 +out_data=$3 + +num_frames=`perl -e "print int($window / $frame_shift + 0.5)"` +num_frames_overlap=`perl -e "print int($overlap/ $frame_shift + 0.5)"` + +data_uniform_seg=${data}_uniform_seg_window${window}_ovlp${overlap} + +mkdir -p ${data_uniform_seg} + +mkdir -p $dir + +#segmentation-cluster-adjacent-segments --verbose=0 'ark:segmentation-copy --keep-label=1 "ark:gunzip -c exp/nnet3_lstm_sad_music/nnet_lstm_1e//segmentation_bn_eval97_whole_bp/orig_segmentation.1.gz |" ark:- | segmentation-split-segments --max-segment-length=250 --overlap-length=0 ark:- ark:- |' scp:data/bn_eval97_bp_hires/feats.scp "ark:| segmentation-post-process --merge-adjacent-segments ark:- ark:- | segmentation-to-segments ark:- ark,t:- /dev/null" 2>&1 | less + +if [ $stage -le 0 ]; then + $cmd $dir/log/get_subsegments.log \ + segmentation-init-from-segments --frame-overlap=0.015 $data/segments ark:- \| \ + segmentation-split-segments --max-segment-length=$num_frames --overlap-length=$num_frames_overlap ark:- ark:- \| \ + segmentation-cluster-adjacent-segments --verbose=3 ark:- "scp:$data/feats.scp" ark:- \| \ + segmentation-post-process --merge-adjacent-segments ark:- ark:- \| \ + segmentation-to-segments --frame-overlap=0.0 ark:- ark:/dev/null \ + ${data_uniform_seg}/sub_segments + + utils/data/subsegment_data_dir.sh ${data} ${data_uniform_seg}{/sub_segments,} +fi + +gmm_dir=$dir/gmms +mkdir -p $gmm_dir + +utils/split_data.sh --per-reco ${data_uniform_seg} $reco_nj + +if [ $stage -le 1 ]; then + echo $reco_nj > $gmm_dir/num_jobs + $cmd JOB=1:$reco_nj $gmm_dir/log/train_gmm.JOB.log \ + gmm-global-init-models-from-feats --share-covars=true \ + --spk2utt-rspecifier=ark,t:${data_uniform_seg}/split${reco_nj}reco/JOB/reco2utt \ + --num-gauss-init=64 --num-gauss=64 --num-gauss-fraction=0.001 --max-gauss=512 --min-gauss=64 \ + --num-iters=20 --num-frames=500000 \ + scp:${data_uniform_seg}/split${reco_nj}reco/JOB/feats.scp \ + ark,scp:$gmm_dir/gmm.JOB.ark,$gmm_dir/gmm.JOB.scp + + for n in `seq $reco_nj`; do + cat $gmm_dir/gmm.$n.scp + done > $gmm_dir/gmm.scp + +fi + +post_dir=$gmm_dir/post_`basename $data_uniform_seg` +mkdir -p $post_dir + +if [ $stage -le 2 ]; then + echo $reco_nj > $post_dir/num_jobs + + $cmd JOB=1:$reco_nj $gmm_dir/log/compute_post.JOB.log \ + gmm-global-get-post \ + --utt2spk="ark,t:cut -d ' ' -f 1,2 ${data_uniform_seg}/split${reco_nj}reco/JOB/segments |" \ + scp:$gmm_dir/gmm.scp \ + scp:${data_uniform_seg}/split${reco_nj}reco/JOB/feats.scp \ + "ark:| gzip -c > $post_dir/post.JOB.gz" \ + "ark:| gzip -c > $post_dir/frame_loglikes.JOB.gz" +fi + +if [ $stage -le 3 ]; then + utils/data/get_utt2num_frames.sh --nj $utt_nj --cmd "$cmd" ${data_uniform_seg} + + $cmd JOB=1:$reco_nj $post_dir/log/compute_average_post.JOB.log \ + gmm-global-post-to-feats \ + --utt2spk="ark,t:cut -d ' ' -f 1,2 ${data_uniform_seg}/split${reco_nj}reco/JOB/segments |" \ + scp:$gmm_dir/gmm.scp "ark:gunzip -c $post_dir/post.JOB.gz |" ark:- \| \ + matrix-sum-rows --do-average ark:- "ark:| gzip -c > $post_dir/avg_post.JOB.gz" +fi + +seg_dir=$dir/segmentation_`basename $data_uniform_seg` + +if [ $stage -le 4 ]; then + $cmd JOB=1:$reco_nj $seg_dir/log/cluster_segments.JOB.log \ + agglomerative-cluster-ib --min-clusters=$min_clusters \ + --verbose=3 --stopping-threshold=$stopping_threshold --input-factor=0 \ + --counts-rspecifier="ark,t:utils/filter_scp.pl $data_uniform_seg/split${reco_nj}reco/JOB/utt2spk $data_uniform_seg/utt2num_frames |" \ + "ark:gunzip -c $post_dir/avg_post.JOB.gz |" \ + "ark,t:${data_uniform_seg}/split${reco_nj}reco/JOB/reco2utt" \ + ark,t:$seg_dir/utt2cluster_id.JOB +fi + +if [ $stage -le 5 ]; then + $cmd JOB=1:$reco_nj $seg_dir/log/init_segmentation.JOB.log \ + segmentation-init-from-segments --frame-overlap=0.0 --shift-to-zero=false \ + --utt2label-rspecifier=ark,t:${seg_dir}/utt2cluster_id.JOB \ + ${data_uniform_seg}/split${reco_nj}reco/JOB/segments ark:- \| \ + segmentation-combine-segments-to-recordings ark:- \ + ark,t:${data_uniform_seg}/split${reco_nj}reco/JOB/reco2utt \ + ark:- \| \ + segmentation-post-process --merge-adjacent-segments ark:- ark:- \| \ + segmentation-post-process --max-segment-length=1000 --overlap-length=250 ark:- ark:- \| \ + segmentation-to-segments ark:- ark,t:$seg_dir/utt2spk.JOB $seg_dir/segments.JOB +fi + +if [ $stage -le 6 ]; then + rm -r $out_data || true + utils/data/convert_data_dir_to_whole.sh $data $out_data + rm $out_data/{text,cmvn.scp} || true + + for n in `seq $reco_nj`; do + cat $seg_dir/utt2spk.$n + done > $out_data/utt2spk + + for n in `seq $reco_nj`; do + cat $seg_dir/segments.$n + done > $out_data/segments + + utils/utt2spk_to_spk2utt.pl $out_data/utt2spk > $out_data/spk2utt + utils/fix_data_dir.sh $out_data +fi diff --git a/egs/wsj/s5/steps/segmentation/cluster_segments_aIB_change_point.sh b/egs/wsj/s5/steps/segmentation/cluster_segments_aIB_change_point.sh new file mode 100755 index 00000000000..a1f187fab31 --- /dev/null +++ b/egs/wsj/s5/steps/segmentation/cluster_segments_aIB_change_point.sh @@ -0,0 +1,138 @@ +#! /bin/bash + +window=2.5 +overlap=0.0 +stage=-1 +cmd=queue.pl +reco_nj=4 +frame_shift=0.01 +utt_nj=18 +min_clusters=10 +stopping_threshold=0.5 + +. path.sh +. utils/parse_options.sh + +set -o pipefail +set -e +set -u + +if [ $# -ne 3 ]; then + echo "Usage: $0 " + exit 1 +fi + +data=$1 +dir=$2 +out_data=$3 + +num_frames=`perl -e "print int($window / $frame_shift + 0.5)"` +num_frames_overlap=`perl -e "print int($overlap/ $frame_shift + 0.5)"` + +data_uniform_seg=${data}_uniform_seg_window${window}_ovlp${overlap} + +mkdir -p ${data_uniform_seg} + +mkdir -p $dir + +#segmentation-cluster-adjacent-segments --verbose=0 'ark:segmentation-copy --keep-label=1 "ark:gunzip -c exp/nnet3_lstm_sad_music/nnet_lstm_1e//segmentation_bn_eval97_whole_bp/orig_segmentation.1.gz |" ark:- | segmentation-split-segments --max-segment-length=250 --overlap-length=0 ark:- ark:- |' scp:data/bn_eval97_bp_hires/feats.scp "ark:| segmentation-post-process --merge-adjacent-segments ark:- ark:- | segmentation-to-segments ark:- ark,t:- /dev/null" 2>&1 | less + +if [ $stage -le 0 ]; then + $cmd $dir/log/get_subsegments.log \ + segmentation-init-from-segments --frame-overlap=0.015 $data/segments ark:- \| \ + segmentation-split-segments --max-segment-length=$num_frames --overlap-length=$num_frames_overlap ark:- ark:- \| \ + segmentation-cluster-adjacent-segments --verbose=3 ark:- "scp:$data/feats.scp" ark:- \| \ + segmentation-post-process --merge-adjacent-segments ark:- ark:- \| \ + segmentation-to-segments --frame-overlap=0.0 ark:- ark:/dev/null \ + ${data_uniform_seg}/sub_segments + + utils/data/subsegment_data_dir.sh ${data} ${data_uniform_seg}{/sub_segments,} +fi + +gmm_dir=$dir/gmms +mkdir -p $gmm_dir + +utils/split_data.sh --per-reco ${data_uniform_seg} $reco_nj + +if [ $stage -le 1 ]; then + echo $reco_nj > $gmm_dir/num_jobs + $cmd JOB=1:$reco_nj $gmm_dir/log/train_gmm.JOB.log \ + gmm-global-init-models-from-feats --share-covars=true \ + --spk2utt-rspecifier=ark,t:${data_uniform_seg}/split${reco_nj}reco/JOB/reco2utt \ + --num-gauss-init=64 --num-gauss=64 --num-gauss-fraction=0.001 --max-gauss=512 --min-gauss=64 \ + --num-iters=20 --num-frames=500000 \ + scp:${data_uniform_seg}/split${reco_nj}reco/JOB/feats.scp \ + ark,scp:$gmm_dir/gmm.JOB.ark,$gmm_dir/gmm.JOB.scp + + for n in `seq $reco_nj`; do + cat $gmm_dir/gmm.$n.scp + done > $gmm_dir/gmm.scp + +fi + +post_dir=$gmm_dir/post_`basename $data_uniform_seg` +mkdir -p $post_dir + +if [ $stage -le 2 ]; then + echo $reco_nj > $post_dir/num_jobs + + $cmd JOB=1:$reco_nj $gmm_dir/log/compute_post.JOB.log \ + gmm-global-get-post \ + --utt2spk="ark,t:cut -d ' ' -f 1,2 ${data_uniform_seg}/split${reco_nj}reco/JOB/segments |" \ + scp:$gmm_dir/gmm.scp \ + scp:${data_uniform_seg}/split${reco_nj}reco/JOB/feats.scp \ + "ark:| gzip -c > $post_dir/post.JOB.gz" \ + "ark:| gzip -c > $post_dir/frame_loglikes.JOB.gz" +fi + +if [ $stage -le 3 ]; then + utils/data/get_utt2num_frames.sh --nj $utt_nj --cmd "$cmd" ${data_uniform_seg} + + $cmd JOB=1:$reco_nj $post_dir/log/compute_average_post.JOB.log \ + gmm-global-post-to-feats \ + --utt2spk="ark,t:cut -d ' ' -f 1,2 ${data_uniform_seg}/split${reco_nj}reco/JOB/segments |" \ + scp:$gmm_dir/gmm.scp "ark:gunzip -c $post_dir/post.JOB.gz |" ark:- \| \ + matrix-sum-rows --do-average ark:- "ark:| gzip -c > $post_dir/avg_post.JOB.gz" +fi + +seg_dir=$dir/segmentation_`basename $data_uniform_seg` + +if [ $stage -le 4 ]; then + $cmd JOB=1:$reco_nj $seg_dir/log/cluster_segments.JOB.log \ + agglomerative-cluster-ib --min-clusters=$min_clusters \ + --verbose=3 --stopping-threshold=$stopping_threshold --input-factor=0 \ + --counts-rspecifier="ark,t:utils/filter_scp.pl $data_uniform_seg/split${reco_nj}reco/JOB/utt2spk $data_uniform_seg/utt2num_frames |" \ + "ark:gunzip -c $post_dir/avg_post.JOB.gz |" \ + "ark,t:${data_uniform_seg}/split${reco_nj}reco/JOB/reco2utt" \ + ark,t:$seg_dir/utt2cluster_id.JOB +fi + +if [ $stage -le 5 ]; then + $cmd JOB=1:$reco_nj $seg_dir/log/init_segmentation.JOB.log \ + segmentation-init-from-segments --frame-overlap=0.0 --shift-to-zero=false \ + --utt2label-rspecifier=ark,t:${seg_dir}/utt2cluster_id.JOB \ + ${data_uniform_seg}/split${reco_nj}reco/JOB/segments ark:- \| \ + segmentation-combine-segments-to-recordings ark:- \ + ark,t:${data_uniform_seg}/split${reco_nj}reco/JOB/reco2utt \ + ark:- \| \ + segmentation-post-process --merge-adjacent-segments ark:- ark:- \| \ + segmentation-post-process --max-segment-length=1000 --overlap-length=250 ark:- ark:- \| \ + segmentation-to-segments ark:- ark,t:$seg_dir/utt2spk.JOB $seg_dir/segments.JOB +fi + +if [ $stage -le 6 ]; then + rm -r $out_data || true + utils/data/convert_data_dir_to_whole.sh $data $out_data + rm $out_data/{text,cmvn.scp} || true + + for n in `seq $reco_nj`; do + cat $seg_dir/utt2spk.$n + done > $out_data/utt2spk + + for n in `seq $reco_nj`; do + cat $seg_dir/segments.$n + done > $out_data/segments + + utils/utt2spk_to_spk2utt.pl $out_data/utt2spk > $out_data/spk2utt + utils/fix_data_dir.sh $out_data +fi From 268e0175fd376ec9a90b7193d1d33873c7b7d478 Mon Sep 17 00:00:00 2001 From: Vimal Manohar Date: Wed, 25 Jan 2017 15:24:43 -0500 Subject: [PATCH 177/213] asr_diarization: Train simple HMM --- .../s5/steps/segmentation/train_simple_hmm.py | 194 ++++++++++++++++++ 1 file changed, 194 insertions(+) create mode 100755 egs/wsj/s5/steps/segmentation/train_simple_hmm.py diff --git a/egs/wsj/s5/steps/segmentation/train_simple_hmm.py b/egs/wsj/s5/steps/segmentation/train_simple_hmm.py new file mode 100755 index 00000000000..9f581b0a520 --- /dev/null +++ b/egs/wsj/s5/steps/segmentation/train_simple_hmm.py @@ -0,0 +1,194 @@ +#! /usr/bin/env python + +# Copyright 2016 Vimal Manohar +# Apache 2.0. + +import argparse +import logging +import os +import sys + +sys.path.insert(0, 'steps') +import libs.common as common_lib + + +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) +handler = logging.StreamHandler() +handler.setLevel(logging.INFO) +formatter = logging.Formatter("%(asctime)s [%(pathname)s:%(lineno)s - " + "%(funcName)s - %(levelname)s ] %(message)s") +handler.setFormatter(formatter) +logger.addHandler(handler) + + +def get_args(): + """Parse command-line arguments""" + + parser = argparse.ArgumentParser( + """Train a simple HMM model starting from HMM topology.""") + + # Alignment options + parser.add_argument("--align.transition-scale", dest='transition_scale', + type=float, default=10.0, + help="""Transition-probability scale [relative to + acoustics]""") + parser.add_argument("--align.self-loop-scale", dest='self_loop_scale', + type=float, default=1.0, + help="""Scale on self-loop versus non-self-loop log + probs [relative to acoustics]""") + parser.add_argument("--align.beam", dest='beam', + type=float, default=6, + help="""Decoding beam used in alignment""") + + # Training options + parser.add_argument("--training.num-iters", dest='num_iters', + type=int, default=30, + help="""Number of iterations of training""") + parser.add_argument("--training.use-soft-counts", dest='use_soft_counts', + type=str, action=common_lib.StrToBoolAction, + choices=["true", "false"], default=False, + help="""Use soft counts (posteriors) instead of + alignments""") + + # General options + parser.add_argument("--scp2ark-cmd", type=str, + default="copy-int-vector scp:- ark:- |", + help="The command used to convert scp from stdin to " + "write archive to stdout") + parser.add_argument("--cmd", dest='command', type=str, + default="run.pl", + help="Command used to run jobs") + parser.add_argument("--stage", type=int, default=-10, + help="""Stage to run training from""") + + parser.add_argument("--data", type=str, required=True, + help="Data directory; primarily used for splitting") + + labels_group = parser.add_mutually_exclusive_group(required=True) + labels_group.add_argument("--labels-scp", type=str, + help="Input labels that must be convert to alignment " + "of class-ids using --scp2ark-cmd") + labels_group.add_argument("--labels-rspecifier", type=str, + help="Input labels rspecifier") + + parser.add_argument("--lang", type=str, required=True, + help="The language directory containing the " + "HMM Topology file topo") + parser.add_argument("--loglikes-dir", type=str, required=True, + help="Directory containing the log-likelihoods") + parser.add_argument("--dir", type=str, required=True, + help="Directory where the intermediate and final " + "models will be written") + + args = parser.parse_args() + + if args.use_soft_counts: + raise NotImplementedError("--use-soft-counts not supported yet!") + + return args + + +def check_files(args): + """Check files required for this script""" + + files = ("{lang}/topo {data}/utt2spk " + "{loglikes_dir}/log_likes.1.gz {loglikes_dir}/num_jobs " + "".format(lang=args.lang, data=args.data, + loglikes_dir=args.loglikes_dir).split()) + + if args.labels_scp is not None: + files.append(args.labels_scp) + + for f in files: + if not os.path.exists(f): + logger.error("Could not find file %s", f) + raise RuntimeError + + +def run(args): + """The function that does it all""" + + check_files(args) + + if args.stage <= -2: + logger.info("Initializing simple HMM model") + common_lib.run_kaldi_command( + """{cmd} {dir}/log/init.log simple-hmm-init {lang}/topo """ + """ {dir}/0.mdl""".format(cmd=args.command, dir=args.dir, + lang=args.lang)) + + num_jobs = common_lib.get_number_of_jobs(args.loglikes_dir) + split_data = common_lib.split_data(args.data, num_jobs) + + if args.labels_rspecifier is not None: + labels_rspecifier = args.labels_rspecifier + else: + labels_rspecifier = ("ark:utils/filter_scp.pl {sdata}/JOB/utt2spk " + "{labels_scp} | {scp2ark_cmd}".format( + sdata=split_data, labels_scp=args.labels_scp, + scp2ark_cmd=args.scp2ark_cmd)) + + if args.stage <= -1: + logger.info("Compiling training graphs") + common_lib.run_kaldi_command( + """{cmd} JOB=1:{nj} {dir}/log/compile_graphs.JOB.log """ + """ compile-train-simple-hmm-graphs {dir}/0.mdl """ + """ "{labels_rspecifier}" """ + """ "ark:| gzip -c > {dir}/fsts.JOB.gz" """.format( + cmd=args.command, nj=num_jobs, + dir=args.dir, lang=args.lang, + labels_rspecifier=labels_rspecifier)) + + scale_opts = ("--transition-scale={tscale} --self-loop-scale={loop_scale}" + "".format(tscale=args.transition_scale, + loop_scale=args.self_loop_scale)) + + for iter_ in range(0, args.num_iters): + if args.stage > iter_: + continue + + logger.info("Training iteration %d", iter_) + + common_lib.run_kaldi_command( + """{cmd} JOB=1:{nj} {dir}/log/align.{iter}.JOB.log """ + """ simple-hmm-align-compiled {scale_opts} """ + """ --beam={beam} --retry-beam={retry_beam} {dir}/{iter}.mdl """ + """ "ark:gunzip -c {dir}/fsts.JOB.gz |" """ + """ "ark:gunzip -c {loglikes_dir}/log_likes.JOB.gz |" """ + """ ark:- \| """ + """ simple-hmm-acc-stats-ali {dir}/{iter}.mdl ark:- """ + """ {dir}/{iter}.JOB.acc""".format( + cmd=args.command, nj=num_jobs, dir=args.dir, iter=iter_, + scale_opts=scale_opts, beam=args.beam, + retry_beam=args.beam * 4, loglikes_dir=args.loglikes_dir)) + + common_lib.run_kaldi_command( + """{cmd} {dir}/log/update.{iter}.log """ + """ simple-hmm-est {dir}/{iter}.mdl """ + """ "vector-sum {dir}/{iter}.*.acc - |" """ + """ {dir}/{new_iter}.mdl""".format( + cmd=args.command, dir=args.dir, iter=iter_, + new_iter=iter_ + 1)) + + common_lib.run_kaldi_command( + "rm {dir}/{iter}.*.acc".format(dir=args.dir, iter=iter_)) + # end train loop + + common_lib.force_symlink("{0}.mdl".format(args.num_iters), + "{0}/final.mdl".format(args.dir)) + + logger.info("Done training simple HMM in %s/final.mdl", args.dir) + + +def main(): + try: + args = get_args() + run(args) + except Exception: + logger.error("Failed training models") + raise + + +if __name__ == '__main__': + main() From 7f10cd555746dc3034f990a92a6d96c5178f6cce Mon Sep 17 00:00:00 2001 From: Vimal Manohar Date: Wed, 25 Jan 2017 15:25:39 -0500 Subject: [PATCH 178/213] asr_diarization: Add deprecated data_lib.py --- egs/wsj/s5/utils/data/data_lib.py | 57 +++++++++++++++++++++++++++++++ 1 file changed, 57 insertions(+) create mode 100644 egs/wsj/s5/utils/data/data_lib.py diff --git a/egs/wsj/s5/utils/data/data_lib.py b/egs/wsj/s5/utils/data/data_lib.py new file mode 100644 index 00000000000..5e58fcac3d5 --- /dev/null +++ b/egs/wsj/s5/utils/data/data_lib.py @@ -0,0 +1,57 @@ +import os + +import libs.common as common_lib + +def get_frame_shift(data_dir): + frame_shift = common_lib.run_kaldi_command("utils/data/get_frame_shift.sh {0}".format(data_dir))[0] + return float(frame_shift.strip()) + +def generate_utt2dur(data_dir): + common_lib.run_kaldi_command("utils/data/get_utt2dur.sh {0}".format(data_dir)) + +def get_utt2dur(data_dir): + GenerateUtt2Dur(data_dir) + utt2dur = {} + for line in open('{0}/utt2dur'.format(data_dir), 'r').readlines(): + parts = line.split() + utt2dur[parts[0]] = float(parts[1]) + return utt2dur + +def get_utt2uniq(data_dir): + utt2uniq_file = '{0}/utt2uniq'.format(data_dir) + if not os.path.exists(utt2uniq_file): + return None, None + utt2uniq = {} + uniq2utt = {} + for line in open(utt2uniq_file, 'r').readlines(): + parts = line.split() + utt2uniq[parts[0]] = parts[1] + if uniq2utt.has_key(parts[1]): + uniq2utt[parts[1]].append(parts[0]) + else: + uniq2utt[parts[1]] = [parts[0]] + return utt2uniq, uniq2utt + +def get_num_frames(data_dir, utts = None): + GenerateUtt2Dur(data_dir) + frame_shift = GetFrameShift(data_dir) + total_duration = 0 + utt2dur = GetUtt2Dur(data_dir) + if utts is None: + utts = utt2dur.keys() + for utt in utts: + total_duration = total_duration + utt2dur[utt] + return int(float(total_duration)/frame_shift) + +def create_data_links(file_names): + # if file_names already exist create_data_link.pl returns with code 1 + # so we just delete them before calling create_data_link.pl + for file_name in file_names: + TryToDelete(file_name) + common_lib.run_kaldi_command(" utils/create_data_link.pl {0}".format(" ".join(file_names))) + +def try_to_delete(file_name): + try: + os.remove(file_name) + except OSError: + pass From 311d31f75768ecb339502203bc777ac4b2fad9ab Mon Sep 17 00:00:00 2001 From: Vimal Manohar Date: Wed, 25 Jan 2017 15:26:12 -0500 Subject: [PATCH 179/213] asr_diarization: Add nnet3-copy-egs-overlapped --- src/nnet3bin/Makefile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/nnet3bin/Makefile b/src/nnet3bin/Makefile index 2a660da232c..39aa6d477a2 100644 --- a/src/nnet3bin/Makefile +++ b/src/nnet3bin/Makefile @@ -18,7 +18,7 @@ BINFILES = nnet3-init nnet3-info nnet3-get-egs nnet3-copy-egs nnet3-subset-egs \ nnet3-discriminative-compute-objf nnet3-discriminative-train \ discriminative-get-supervision nnet3-discriminative-subset-egs \ nnet3-discriminative-compute-from-egs nnet3-get-egs-multiple-targets \ - nnet3-copy-egs-overlap-detection + nnet3-am-compute nnet3-copy-egs-overlap-detection OBJFILES = From d54b41220f9ad991445ba002c2b7539015570100 Mon Sep 17 00:00:00 2001 From: Vimal Manohar Date: Wed, 25 Jan 2017 15:26:36 -0500 Subject: [PATCH 180/213] segmenterbin/Makefile --- src/segmenterbin/Makefile | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/segmenterbin/Makefile b/src/segmenterbin/Makefile index 6e2fd226019..a424f192d3e 100644 --- a/src/segmenterbin/Makefile +++ b/src/segmenterbin/Makefile @@ -19,7 +19,9 @@ BINFILES = segmentation-copy segmentation-get-stats \ segmentation-init-from-additive-signals-info \ class-counts-per-frame-to-labels \ agglomerative-cluster-ib \ - intersect-int-vectors #\ + intersect-int-vectors \ + gmm-global-init-models-from-feats \ + segmentation-cluster-adjacent-segments #\ gmm-acc-pdf-stats-segmentation \ gmm-est-segmentation gmm-update-segmentation \ segmentation-init-from-diarization \ From e27267f8d7ae5cf3302e63b9b84110af06fb2fb2 Mon Sep 17 00:00:00 2001 From: Vimal Manohar Date: Wed, 25 Jan 2017 15:27:57 -0500 Subject: [PATCH 181/213] asr_diarization: Overlapping speech detection tuning scripts --- .../local/segmentation/tuning/train_lstm_overlapping_sad_1b.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/aspire/s5/local/segmentation/tuning/train_lstm_overlapping_sad_1b.sh b/egs/aspire/s5/local/segmentation/tuning/train_lstm_overlapping_sad_1b.sh index a634060b317..361f7c27bc0 100755 --- a/egs/aspire/s5/local/segmentation/tuning/train_lstm_overlapping_sad_1b.sh +++ b/egs/aspire/s5/local/segmentation/tuning/train_lstm_overlapping_sad_1b.sh @@ -56,7 +56,7 @@ num_utts_subset_valid=`perl -e '$n=int($ARGV[0] * 0.005); print ($n > 4000 ? 400 num_utts_subset_train=`perl -e '$n=int($ARGV[0] * 0.005); print ($n > 4000 ? 4000 : $n)' $num_utts` if [ -z "$dir" ]; then - dir=exp/nnet3_stats_sad_ovlp_snr/nnet_lstm + dir=exp/nnet3_lstm_sad_ovlp_snr/nnet_lstm fi dir=$dir${affix:+_$affix} From 27ab5b2752d22a735add9ada1ea265433028d69c Mon Sep 17 00:00:00 2001 From: Vimal Manohar Date: Wed, 25 Jan 2017 15:28:50 -0500 Subject: [PATCH 182/213] asr_diarization: add nnet3-am-compute --- src/nnet3bin/nnet3-am-compute.cc | 186 +++++++++++++++++++++++++++++++ 1 file changed, 186 insertions(+) create mode 100644 src/nnet3bin/nnet3-am-compute.cc diff --git a/src/nnet3bin/nnet3-am-compute.cc b/src/nnet3bin/nnet3-am-compute.cc new file mode 100644 index 00000000000..c91417c0aee --- /dev/null +++ b/src/nnet3bin/nnet3-am-compute.cc @@ -0,0 +1,186 @@ +// nnet3bin/nnet3-am-compute.cc + +// Copyright 2012-2015 Johns Hopkins University (author: Daniel Povey) +// 2015 Vimal Manohar + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + + +#include "base/kaldi-common.h" +#include "util/common-utils.h" +#include "nnet3/nnet-am-decodable-simple.h" +#include "base/timer.h" +#include "nnet3/nnet-utils.h" + + +int main(int argc, char *argv[]) { + try { + using namespace kaldi; + using namespace kaldi::nnet3; + typedef kaldi::int32 int32; + typedef kaldi::int64 int64; + + const char *usage = + "Propagate the features through neural network model " + "and write the pseudo log-likelihoods (after dividing by priors).\n" + "If --apply-exp=true, apply the Exp() function to the output " + "before writing it out.\n" + "\n" + "Usage: nnet3-am-compute [options] " + "\n" + " e.g.: nnet3-am-compute final.mdl scp:feats.scp ark:log_likes.ark\n" + "See also: nnet3-compute-from-egs, nnet3-compute\n"; + + ParseOptions po(usage); + Timer timer; + + NnetSimpleComputationOptions opts; + opts.acoustic_scale = 1.0; // by default do no scaling in this recipe. + + bool apply_exp = false; + std::string use_gpu = "yes"; + + std::string word_syms_filename; + std::string ivector_rspecifier, + online_ivector_rspecifier, + utt2spk_rspecifier; + int32 online_ivector_period = 0; + + opts.Register(&po); + + po.Register("ivectors", &ivector_rspecifier, "Rspecifier for " + "iVectors as vectors (i.e. not estimated online); per utterance " + "by default, or per speaker if you provide the --utt2spk option."); + po.Register("utt2spk", &utt2spk_rspecifier, "Rspecifier for " + "utt2spk option used to get ivectors per speaker"); + po.Register("online-ivectors", &online_ivector_rspecifier, "Rspecifier for " + "iVectors estimated online, as matrices. If you supply this," + " you must set the --online-ivector-period option."); + po.Register("online-ivector-period", &online_ivector_period, "Number of frames " + "between iVectors in matrices supplied to the --online-ivectors " + "option"); + po.Register("apply-exp", &apply_exp, "If true, apply exp function to " + "output"); + po.Register("use-gpu", &use_gpu, + "yes|no|optional|wait, only has effect if compiled with CUDA"); + + po.Read(argc, argv); + + if (po.NumArgs() != 3) { + po.PrintUsage(); + exit(1); + } + +#if HAVE_CUDA==1 + CuDevice::Instantiate().SelectGpuId(use_gpu); +#endif + + std::string nnet_rxfilename = po.GetArg(1), + feature_rspecifier = po.GetArg(2), + matrix_wspecifier = po.GetArg(3); + + TransitionModel trans_model; + AmNnetSimple am_nnet; + { + bool binary_read; + Input ki(nnet_rxfilename, &binary_read); + trans_model.Read(ki.Stream(), binary_read); + am_nnet.Read(ki.Stream(), binary_read); + } + + RandomAccessBaseFloatMatrixReader online_ivector_reader( + online_ivector_rspecifier); + RandomAccessBaseFloatVectorReaderMapped ivector_reader( + ivector_rspecifier, utt2spk_rspecifier); + + CachingOptimizingCompiler compiler(am_nnet.GetNnet(), opts.optimize_config); + + BaseFloatMatrixWriter matrix_writer(matrix_wspecifier); + + int32 num_success = 0, num_fail = 0; + int64 frame_count = 0; + + SequentialBaseFloatMatrixReader feature_reader(feature_rspecifier); + + for (; !feature_reader.Done(); feature_reader.Next()) { + std::string utt = feature_reader.Key(); + const Matrix &features (feature_reader.Value()); + if (features.NumRows() == 0) { + KALDI_WARN << "Zero-length utterance: " << utt; + num_fail++; + continue; + } + const Matrix *online_ivectors = NULL; + const Vector *ivector = NULL; + if (!ivector_rspecifier.empty()) { + if (!ivector_reader.HasKey(utt)) { + KALDI_WARN << "No iVector available for utterance " << utt; + num_fail++; + continue; + } else { + ivector = &ivector_reader.Value(utt); + } + } + if (!online_ivector_rspecifier.empty()) { + if (!online_ivector_reader.HasKey(utt)) { + KALDI_WARN << "No online iVector available for utterance " << utt; + num_fail++; + continue; + } else { + online_ivectors = &online_ivector_reader.Value(utt); + } + } + + DecodableNnetSimple nnet_computer( + opts, am_nnet.GetNnet(), am_nnet.Priors(), + features, &compiler, + ivector, online_ivectors, + online_ivector_period); + + Matrix matrix(nnet_computer.NumFrames(), + nnet_computer.OutputDim()); + for (int32 t = 0; t < nnet_computer.NumFrames(); t++) { + SubVector row(matrix, t); + nnet_computer.GetOutputForFrame(t, &row); + } + + if (apply_exp) + matrix.ApplyExp(); + + matrix_writer.Write(utt, matrix); + + frame_count += features.NumRows(); + num_success++; + } + +#if HAVE_CUDA==1 + CuDevice::Instantiate().PrintProfile(); +#endif + double elapsed = timer.Elapsed(); + KALDI_LOG << "Time taken "<< elapsed + << "s: real-time factor assuming 100 frames/sec is " + << (elapsed*100.0/frame_count); + KALDI_LOG << "Done " << num_success << " utterances, failed for " + << num_fail; + + if (num_success != 0) return 0; + else return 1; + } catch(const std::exception &e) { + std::cerr << e.what(); + return -1; + } +} + From 9156d29387c370384ef0b4f92e78be90942b5bfb Mon Sep 17 00:00:00 2001 From: Vimal Manohar Date: Mon, 6 Feb 2017 16:02:35 -0500 Subject: [PATCH 183/213] asr_diarization: Update simple hmm --- src/simplehmm/Makefile | 2 +- src/simplehmm/simple-hmm-utils.cc | 2 +- src/simplehmm/simple-hmm-utils.h | 2 +- src/simplehmm/simple-hmm.cc | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/simplehmm/Makefile b/src/simplehmm/Makefile index 89c9f70a8c3..d83fba05900 100644 --- a/src/simplehmm/Makefile +++ b/src/simplehmm/Makefile @@ -5,7 +5,7 @@ include ../kaldi.mk TESTFILES = simple-hmm-test -OBJFILES = simple-hmm.o simple-hmm-utils.o simple-hmm-graph-compiler.o +OBJFILES = simple-hmm.o simple-hmm-utils.o LIBNAME = kaldi-simplehmm ADDLIBS = ../hmm/kaldi-hmm.a ../decoder/kaldi-decoder.a \ diff --git a/src/simplehmm/simple-hmm-utils.cc b/src/simplehmm/simple-hmm-utils.cc index 3406b7b56f8..fc0c7e4ca3c 100644 --- a/src/simplehmm/simple-hmm-utils.cc +++ b/src/simplehmm/simple-hmm-utils.cc @@ -20,7 +20,7 @@ #include -#include "hmm/simple-hmm-utils.h" +#include "simplehmm/simple-hmm-utils.h" #include "fst/fstlib.h" #include "fstext/fstext-lib.h" diff --git a/src/simplehmm/simple-hmm-utils.h b/src/simplehmm/simple-hmm-utils.h index bd0a3a15702..5bdf185214a 100644 --- a/src/simplehmm/simple-hmm-utils.h +++ b/src/simplehmm/simple-hmm-utils.h @@ -22,7 +22,7 @@ #define KALDI_HMM_SIMPLE_HMM_UTILS_H_ #include "hmm/hmm-utils.h" -#include "hmm/simple-hmm.h" +#include "simplehmm/simple-hmm.h" #include "fst/fstlib.h" namespace kaldi { diff --git a/src/simplehmm/simple-hmm.cc b/src/simplehmm/simple-hmm.cc index 2db6bfbf297..e0e7442ead3 100644 --- a/src/simplehmm/simple-hmm.cc +++ b/src/simplehmm/simple-hmm.cc @@ -17,7 +17,7 @@ // See the Apache 2 License for the specific language governing permissions and // limitations under the License. -#include "hmm/simple-hmm.h" +#include "simplehmm/simple-hmm.h" namespace kaldi { From 2d13d907f35412dc0f2844fdad74a45bcbbd37e0 Mon Sep 17 00:00:00 2001 From: Vimal Manohar Date: Mon, 6 Feb 2017 16:02:55 -0500 Subject: [PATCH 184/213] asr_diarization: Update cluster-utils --- src/tree/cluster-utils.cc | 6 +++--- src/tree/cluster-utils.h | 40 +++++++++++++++++++++++---------------- 2 files changed, 27 insertions(+), 19 deletions(-) diff --git a/src/tree/cluster-utils.cc b/src/tree/cluster-utils.cc index 965eb104d9e..aa9ae46bc01 100644 --- a/src/tree/cluster-utils.cc +++ b/src/tree/cluster-utils.cc @@ -273,7 +273,7 @@ void BottomUpClusterer::SetInitialDistances() { for (int32 i = 0; i < npoints_; i++) { for (int32 j = 0; j < i; j++) { BaseFloat dist = ComputeDistance(i, j); - if (dist <= max_merge_thresh_) + if (dist <= MergeThreshold(i, j)) queue_.push(std::make_pair(dist, std::make_pair(static_cast(i), static_cast(j)))); if (j == i - 1) @@ -325,7 +325,7 @@ void BottomUpClusterer::ReconstructQueue() { for (int32 j = 0; j < i; j++) { if ((*clusters_)[j] != NULL) { BaseFloat dist = dist_vec_[(i * (i - 1)) / 2 + j]; - if (dist <= max_merge_thresh_) { + if (dist <= MergeThreshold(i, j)) { queue_.push(std::make_pair(dist, std::make_pair( static_cast(i), static_cast(j)))); } @@ -339,7 +339,7 @@ void BottomUpClusterer::SetDistance(int32 i, int32 j) { KALDI_ASSERT(i < npoints_ && j < i && (*clusters_)[i] != NULL && (*clusters_)[j] != NULL); BaseFloat dist = ComputeDistance(i, j); - if (dist < max_merge_thresh_) { + if (dist < MergeThreshold(i, j)) { queue_.push(std::make_pair(dist, std::make_pair(static_cast(i), static_cast(j)))); } diff --git a/src/tree/cluster-utils.h b/src/tree/cluster-utils.h index b11dfe1c031..2658cb8dfd0 100644 --- a/src/tree/cluster-utils.h +++ b/src/tree/cluster-utils.h @@ -119,9 +119,10 @@ class BottomUpClusterer { int32 min_clust, std::vector *clusters_out, std::vector *assignments_out) - : ans_(0.0), points_(points), max_merge_thresh_(max_merge_thresh), + : points_(points), max_merge_thresh_(max_merge_thresh), min_clust_(min_clust), clusters_(clusters_out != NULL? clusters_out - : &tmp_clusters_), assignments_(assignments_out != NULL ? + : &tmp_clusters_), ans_(0.0), + assignments_(assignments_out != NULL ? assignments_out : &tmp_assignments_) { nclusters_ = npoints_ = points.size(); dist_vec_.resize((npoints_ * (npoints_ - 1)) / 2); @@ -131,7 +132,6 @@ class BottomUpClusterer { ~BottomUpClusterer() { DeletePointers(&tmp_clusters_); } /// Public accessors - const Clusterable* GetCluster(int32 i) const { return (*clusters_)[i]; } BaseFloat& Distance(int32 i, int32 j) { KALDI_ASSERT(i < npoints_ && j < i); return dist_vec_[(i * (i - 1)) / 2 + j]; @@ -143,10 +143,27 @@ class BottomUpClusterer { /// Merge j into i and delete j. virtual void MergeClusters(int32 i, int32 j); + typedef std::pair > + QueueElement; + // Priority queue using greater (lowest distances are highest priority). + typedef std::priority_queue, + std::greater > QueueType; + int32 NumClusters() const { return nclusters_; } int32 NumPoints() const { return npoints_; } int32 MinClusters() const { return min_clust_; } bool IsQueueEmpty() const { return queue_.empty(); } + + protected: + const std::vector &points_; + BaseFloat max_merge_thresh_; + int32 min_clust_; + std::vector *clusters_; + + std::vector dist_vec_; + int32 nclusters_; + int32 npoints_; + QueueType queue_; private: void Renumber(); @@ -162,6 +179,10 @@ class BottomUpClusterer { return nclusters_ <= min_clust_ || queue_.empty(); } + virtual BaseFloat MergeThreshold(int32 i, int32 j) { + return max_merge_thresh_; + } + void SetDistance(int32 i, int32 j); virtual BaseFloat ComputeDistance(int32 i, int32 j) { BaseFloat dist = (*clusters_)[i]->Distance(*((*clusters_)[j])); @@ -170,23 +191,10 @@ class BottomUpClusterer { } BaseFloat ans_; - const std::vector &points_; - BaseFloat max_merge_thresh_; - int32 min_clust_; - std::vector *clusters_; std::vector *assignments_; std::vector tmp_clusters_; std::vector tmp_assignments_; - - std::vector dist_vec_; - int32 nclusters_; - int32 npoints_; - typedef std::pair > QueueElement; - // Priority queue using greater (lowest distances are highest priority). - typedef std::priority_queue, - std::greater > QueueType; - QueueType queue_; }; /** This is a wrapper function to the BottomUpClusterer class. From e5988b726d2a840079d3399ba161f68e53a08128 Mon Sep 17 00:00:00 2001 From: Vimal Manohar Date: Mon, 6 Feb 2017 16:03:16 -0500 Subject: [PATCH 185/213] asr_diarization: ib clusterable --- .../information-bottleneck-cluster-utils.cc | 77 +++++++++++-------- .../information-bottleneck-cluster-utils.h | 11 ++- .../information-bottleneck-clusterable.cc | 11 ++- .../segmentation-cluster-adjacent-segments.cc | 55 +++++++++++-- 4 files changed, 114 insertions(+), 40 deletions(-) diff --git a/src/segmenter/information-bottleneck-cluster-utils.cc b/src/segmenter/information-bottleneck-cluster-utils.cc index 75fda8c59fe..5ed283da564 100644 --- a/src/segmenter/information-bottleneck-cluster-utils.cc +++ b/src/segmenter/information-bottleneck-cluster-utils.cc @@ -37,24 +37,27 @@ class InformationBottleneckBottomUpClusterer : public BottomUpClusterer { std::vector *assignments_out); private: + virtual void SetInitialDistances(); virtual BaseFloat ComputeDistance(int32 i, int32 j); virtual bool StoppingCriterion() const; virtual void UpdateClustererStats(int32 i, int32 j); + virtual BaseFloat MergeThreshold(int32 i, int32 j) { + if (opts_.normalize_by_count) + return max_merge_thresh_ + * ((*clusters_)[i]->Normalizer() + (*clusters_)[j]->Normalizer()); + else if (opts_.normalize_by_entropy) + return -max_merge_thresh_ * (*clusters_)[i]->ObjfPlus(*(*clusters_)[j]); + else + return max_merge_thresh_; + } + BaseFloat NormalizedMutualInformation() const { return ((merged_entropy_ - current_entropy_) / (merged_entropy_ - initial_entropy_)); } - /// Stop merging when the stopping criterion, e.g. NMI, reaches this - /// threshold. - BaseFloat stopping_threshold_; - - /// Weight of the relevant variables entropy towards the objective. - BaseFloat relevance_factor_; - - /// Weight of the input variables entropy towards the objective. - BaseFloat input_factor_; + const InformationBottleneckClustererOptions &opts_; /// Running entropy of the clusters. BaseFloat current_entropy_; @@ -75,9 +78,7 @@ InformationBottleneckBottomUpClusterer::InformationBottleneckBottomUpClusterer( std::vector *assignments_out) : BottomUpClusterer(points, max_merge_thresh, min_clusters, clusters_out, assignments_out), - stopping_threshold_(opts.stopping_threshold), - relevance_factor_(opts.relevance_factor), - input_factor_(opts.input_factor), + opts_(opts), current_entropy_(0.0), initial_entropy_(0.0), merged_entropy_(0.0) { if (points.size() == 0) return; @@ -96,34 +97,50 @@ InformationBottleneckBottomUpClusterer::InformationBottleneckBottomUpClusterer( current_entropy_ = initial_entropy_; } +void InformationBottleneckBottomUpClusterer::SetInitialDistances() { + for (int32 i = 0; i < npoints_; i++) { + for (int32 j = 0; j < i; j++) { + BaseFloat dist = ComputeDistance(i, j); + if (dist <= MergeThreshold(i, j)) { + queue_.push(std::make_pair( + dist, std::make_pair(static_cast(i), + static_cast(j)))); + } + if (j == i - 1) + KALDI_VLOG(2) << "Distance(" << i << ", " << j << ") = " << dist; + } + } +} + BaseFloat InformationBottleneckBottomUpClusterer::ComputeDistance( int32 i, int32 j) { const InformationBottleneckClusterable* cluster_i - = static_cast(GetCluster(i)); + = static_cast((*clusters_)[i]); const InformationBottleneckClusterable* cluster_j - = static_cast(GetCluster(j)); + = static_cast((*clusters_)[j]); - BaseFloat dist = (cluster_i->Distance(*cluster_j, relevance_factor_, - input_factor_)); + BaseFloat dist = (cluster_i->Distance(*cluster_j, opts_.relevance_factor, + opts_.input_factor)); // / (cluster_i->Normalizer() + cluster_j->Normalizer())); Distance(i, j) = dist; // set the distance in the array. return dist; } bool InformationBottleneckBottomUpClusterer::StoppingCriterion() const { - bool flag = (NumClusters() <= MinClusters() || IsQueueEmpty() || - NormalizedMutualInformation() < stopping_threshold_); + bool flag = (nclusters_ <= min_clust_ || queue_.empty() || + NormalizedMutualInformation() < opts_.stopping_threshold); if (GetVerboseLevel() < 2 || !flag) return flag; - if (NormalizedMutualInformation() < stopping_threshold_) { - KALDI_VLOG(2) << "Stopping at " << NumClusters() << " clusters " + if (NormalizedMutualInformation() < opts_.stopping_threshold) { + KALDI_VLOG(2) << "Stopping at " << nclusters_ << " clusters " << "because NMI = " << NormalizedMutualInformation() - << " < stopping_threshold (" << stopping_threshold_ << ")"; - } else if (NumClusters() < MinClusters()) { - KALDI_VLOG(2) << "Stopping at " << NumClusters() << " clusters " - << "<= min-clusters (" << MinClusters() << ")"; - } else if (IsQueueEmpty()) { - KALDI_VLOG(2) << "Stopping at " << NumClusters() << " clusters " + << " < stopping_threshold (" + << opts_.stopping_threshold << ")"; + } else if (nclusters_ < min_clust_) { + KALDI_VLOG(2) << "Stopping at " << nclusters_ << " clusters " + << "<= min-clusters (" << min_clust_ << ")"; + } else if (queue_.empty()) { + KALDI_VLOG(2) << "Stopping at " << nclusters_ << " clusters " << "because queue is empty."; } @@ -133,12 +150,12 @@ bool InformationBottleneckBottomUpClusterer::StoppingCriterion() const { void InformationBottleneckBottomUpClusterer::UpdateClustererStats( int32 i, int32 j) { const InformationBottleneckClusterable* cluster_i - = static_cast(GetCluster(i)); - current_entropy_ += cluster_i->Distance(*GetCluster(j), 1.0, 0.0); + = static_cast((*clusters_)[i]); + current_entropy_ += cluster_i->Distance(*(*clusters_)[j], 1.0, 0.0); if (GetVerboseLevel() > 2) { const InformationBottleneckClusterable* cluster_j - = static_cast(GetCluster(j)); + = static_cast((*clusters_)[j]); std::vector cluster_i_points; { std::map::const_iterator it @@ -158,7 +175,7 @@ void InformationBottleneckBottomUpClusterer::UpdateClustererStats( << "(" << cluster_i_points << ", " << cluster_j_points << ").. distance=" << Distance(i, j) - << ", num-clusters-after-merge= " << NumClusters() - 1 + << ", num-clusters-after-merge= " << nclusters_ - 1 << ", NMI= " << NormalizedMutualInformation(); } } diff --git a/src/segmenter/information-bottleneck-cluster-utils.h b/src/segmenter/information-bottleneck-cluster-utils.h index 58f1e4f380a..82b5c285c65 100644 --- a/src/segmenter/information-bottleneck-cluster-utils.h +++ b/src/segmenter/information-bottleneck-cluster-utils.h @@ -33,10 +33,13 @@ struct InformationBottleneckClustererOptions { BaseFloat stopping_threshold; BaseFloat relevance_factor; BaseFloat input_factor; + bool normalize_by_count; + bool normalize_by_entropy; InformationBottleneckClustererOptions() : distance_threshold(std::numeric_limits::max()), num_clusters(1), - stopping_threshold(0.3), relevance_factor(1.0), input_factor(0.1) { } + stopping_threshold(0.3), relevance_factor(1.0), input_factor(0.1), + normalize_by_count(false), normalize_by_entropy(false) { } void Register(OptionsItf *opts) { @@ -49,6 +52,12 @@ struct InformationBottleneckClustererOptions { opts->Register("input-factor", &input_factor, "Weight factor of the entropy of input variables " "in the objective function"); + opts->Register("normalize-by-count", &normalize_by_count, + "If provided, normalizes the score (distance) by " + "the count post-merge."); + opts->Register("normalize-by-entropy", &normalize_by_entropy, + "If provided, normalizes the score (distance) by " + "the entropy post-merge."); } }; diff --git a/src/segmenter/information-bottleneck-clusterable.cc b/src/segmenter/information-bottleneck-clusterable.cc index 7817f7cfdc6..05850c1eebc 100644 --- a/src/segmenter/information-bottleneck-clusterable.cc +++ b/src/segmenter/information-bottleneck-clusterable.cc @@ -68,7 +68,9 @@ void InformationBottleneckClusterable::Add(const Clusterable &other_in) { it != other->counts_.end(); ++it) { std::map::iterator hint_it = counts_.lower_bound( it->first); - KALDI_ASSERT (hint_it == counts_.end() || hint_it->first != it->first); + if (hint_it != counts_.end() && hint_it->first == it->first) { + KALDI_ERR << "Duplicate segment id " << it->first; + } counts_.insert(hint_it, *it); } @@ -205,8 +207,11 @@ BaseFloat InformationBottleneckClusterable::Distance( KALDI_ASSERT(relevance_divergence > -1e-4); KALDI_ASSERT(input_divergence > -1e-4); - return (normalizer * (relevance_factor * relevance_divergence - - input_factor * input_divergence)); + + double ans = (normalizer * (relevance_factor * relevance_divergence + - input_factor * input_divergence)); + KALDI_ASSERT(input_factor != 0.0 || ans > -1e-4); + return ans; } BaseFloat KLDivergence(const VectorBase &p1, diff --git a/src/segmenterbin/segmentation-cluster-adjacent-segments.cc b/src/segmenterbin/segmentation-cluster-adjacent-segments.cc index 812785ac5e6..fde13cd7ead 100644 --- a/src/segmenterbin/segmentation-cluster-adjacent-segments.cc +++ b/src/segmenterbin/segmentation-cluster-adjacent-segments.cc @@ -78,21 +78,59 @@ int32 ClusterAdjacentSegments(const MatrixBase &feats, BaseFloat var_floor, int32 length_tolerance, Segmentation *segmentation) { - if (segmentation->Dim() == 1) { - segmentation->Begin()->SetLabel(1); + if (segmentation->Dim() <= 3) { + // Very unusual case. + // TODO: Do something more reasonable. return 1; } + SegmentList::iterator it = segmentation->Begin(), next_it = segmentation->Begin(); ++next_it; + + // Vector storing for each segment, whether there is a change point at the + // beginning of the segment. + std::vector is_change_point(segmentation->Dim(), false); + is_change_point[0] = true; + + Vector distances(segmentation->Dim() - 1); + int32 i = 0; + + for (; next_it != segmentation->End(); ++it, ++next_it, i++) { + // Distance between segment i and i + 1 + distances(i) = Distance(*it, *next_it, feats, + var_floor, length_tolerance); + + if (i > 2) { + if (distances(i-1) - distances(i-2) > delta_distance_threshold && + distances(i) - distances(i-1) < -delta_distance_threshold) { + is_change_point[i-1] = true; + } + } else { + if (distances(i) - distances(i-1) > absolute_distance_threshold) + is_change_point[i] = true; + } + } + + int32 num_classes = 0; + for (i = 0, it = segmentation->Begin(); + it != segmentation->End(); ++it, i++) { + if (is_change_point[i]) { + num_classes++; + } + it->SetLabel(num_classes); + } + return num_classes; + /* BaseFloat prev_dist = Distance(*it, *next_it, feats, var_floor, length_tolerance); if (segmentation->Dim() == 2) { it->SetLabel(1); - if (prev_dist < absolute_distance_threshold * feats.NumCols()) { + if (prev_dist < absolute_distance_threshold * feats.NumCols() + && next_it->start_frame <= it->end_frame) { // Similar segments merged. next_it->SetLabel(it->Label()); } else { @@ -103,6 +141,10 @@ int32 ClusterAdjacentSegments(const MatrixBase &feats, return next_it->Label();; } + // The algorithm is a simple peak detection. + // Consider three segments that are pointed by the iterators + // prev_it, it, next_it. + // If Distance(prev_it, it) > Consider ++it; ++next_it; bool next_segment_is_new_cluster = false; @@ -162,6 +204,7 @@ int32 ClusterAdjacentSegments(const MatrixBase &feats, } return it->Label(); + */ } } // end segmenter @@ -186,7 +229,7 @@ int main(int argc, char *argv[]) { int32 length_tolerance = 2; BaseFloat var_floor = 0.01; BaseFloat absolute_distance_threshold = 3.0; - BaseFloat delta_distance_threshold = 0.2; + BaseFloat delta_distance_threshold = 0.0002; ParseOptions po(usage); @@ -203,8 +246,8 @@ int main(int argc, char *argv[]) { "Maximum per-dim distance below which segments will not be " "be merged."); po.Register("delta-distance-threshold", &delta_distance_threshold, - "If the delta-distance is below this value, then the " - "adjacent segments will not be merged."); + "If the delta-distance is below this value, then it will " + "be treated as 0."); po.Read(argc, argv); From 9a86fc0c5ad9a947df6929ce5af6792587d4cdc2 Mon Sep 17 00:00:00 2001 From: Vimal Manohar Date: Mon, 6 Feb 2017 16:03:42 -0500 Subject: [PATCH 186/213] asr_diarization: init-models-from-feats --- src/segmenterbin/Makefile | 3 ++- .../gmm-global-init-models-from-feats.cc | 13 +++++++++---- 2 files changed, 11 insertions(+), 5 deletions(-) diff --git a/src/segmenterbin/Makefile b/src/segmenterbin/Makefile index a424f192d3e..6e0036c6fb7 100644 --- a/src/segmenterbin/Makefile +++ b/src/segmenterbin/Makefile @@ -21,7 +21,8 @@ BINFILES = segmentation-copy segmentation-get-stats \ agglomerative-cluster-ib \ intersect-int-vectors \ gmm-global-init-models-from-feats \ - segmentation-cluster-adjacent-segments #\ + segmentation-cluster-adjacent-segments \ + ib-scoring-dense #\ gmm-acc-pdf-stats-segmentation \ gmm-est-segmentation gmm-update-segmentation \ segmentation-init-from-diarization \ diff --git a/src/segmenterbin/gmm-global-init-models-from-feats.cc b/src/segmenterbin/gmm-global-init-models-from-feats.cc index a472b48624c..c323306df83 100644 --- a/src/segmenterbin/gmm-global-init-models-from-feats.cc +++ b/src/segmenterbin/gmm-global-init-models-from-feats.cc @@ -65,9 +65,9 @@ void MleDiagGmmSharedVarsUpdate(const MleDiagGmmOptions &config, DiagGmm *gmm, BaseFloat *obj_change_out, BaseFloat *count_out, - int32 *floored_elements_out, - int32 *floored_gaussians_out, - int32 *removed_gaussians_out) { + int32 *floored_elements_out = NULL, + int32 *floored_gaussians_out = NULL, + int32 *removed_gaussians_out = NULL) { KALDI_ASSERT(gmm != NULL); if (flags & ~diag_gmm_acc.Flags()) @@ -213,7 +213,12 @@ void TrainOneIter(const MatrixBase &feats, << feats.NumRows() << " frames."; BaseFloat objf_change, count; - MleDiagGmmUpdate(gmm_opts, gmm_acc, kGmmAll, gmm, &objf_change, &count); + if (share_covars) { + MleDiagGmmSharedVarsUpdate(gmm_opts, gmm_acc, kGmmAll, gmm, + &objf_change, &count); + } else { + MleDiagGmmUpdate(gmm_opts, gmm_acc, kGmmAll, gmm, &objf_change, &count); + } KALDI_LOG << "Objective-function change on iteration " << iter << " was " << (objf_change / count) << " over " << count << " frames."; From 4646f14d59a0685526ee4897a806f6c75cc29177 Mon Sep 17 00:00:00 2001 From: Vimal Manohar Date: Mon, 6 Feb 2017 16:04:06 -0500 Subject: [PATCH 187/213] asr_diarization: Clustering script --- .../segmentation/cluster_segments_aIB.sh | 32 +++++++++---- .../cluster_segments_aIB_change_point.sh | 45 ++++++++++++++----- 2 files changed, 58 insertions(+), 19 deletions(-) diff --git a/egs/wsj/s5/steps/segmentation/cluster_segments_aIB.sh b/egs/wsj/s5/steps/segmentation/cluster_segments_aIB.sh index a1f187fab31..7cf151f1ad0 100755 --- a/egs/wsj/s5/steps/segmentation/cluster_segments_aIB.sh +++ b/egs/wsj/s5/steps/segmentation/cluster_segments_aIB.sh @@ -8,7 +8,7 @@ reco_nj=4 frame_shift=0.01 utt_nj=18 min_clusters=10 -stopping_threshold=0.5 +clustering_opts="--stopping-threshold=0.5 --max-merge-thresh=0.25 --normalize-by-entropy" . path.sh . utils/parse_options.sh @@ -29,7 +29,7 @@ out_data=$3 num_frames=`perl -e "print int($window / $frame_shift + 0.5)"` num_frames_overlap=`perl -e "print int($overlap/ $frame_shift + 0.5)"` -data_uniform_seg=${data}_uniform_seg_window${window}_ovlp${overlap} +data_uniform_seg=$dir/`basename ${data}`_uniform_seg_window${window}_ovlp${overlap} mkdir -p ${data_uniform_seg} @@ -41,8 +41,6 @@ if [ $stage -le 0 ]; then $cmd $dir/log/get_subsegments.log \ segmentation-init-from-segments --frame-overlap=0.015 $data/segments ark:- \| \ segmentation-split-segments --max-segment-length=$num_frames --overlap-length=$num_frames_overlap ark:- ark:- \| \ - segmentation-cluster-adjacent-segments --verbose=3 ark:- "scp:$data/feats.scp" ark:- \| \ - segmentation-post-process --merge-adjacent-segments ark:- ark:- \| \ segmentation-to-segments --frame-overlap=0.0 ark:- ark:/dev/null \ ${data_uniform_seg}/sub_segments @@ -98,16 +96,34 @@ fi seg_dir=$dir/segmentation_`basename $data_uniform_seg` if [ $stage -le 4 ]; then + $cmd JOB=1:$reco_nj $seg_dir/log/compute_scores.JOB.log \ + ib-scoring-dense --input-factor=0.0 $clustering_opts \ + --counts-rspecifier="ark,t:utils/filter_scp.pl $data_uniform_seg/split${reco_nj}reco/JOB/utt2spk $data_uniform_seg/utt2num_frames |" \ + "ark,t:${data_uniform_seg}/split${reco_nj}reco/JOB/reco2utt" \ + "ark:gunzip -c $post_dir/avg_post.JOB.gz |" \ + ark,t:$seg_dir/scores.JOB.txt ark:/dev/null +fi + +if [ $stage -le 5 ]; then + threshold=$(for n in `seq $reco_nj`; do + /export/a12/vmanoha1/kaldi-diarization-v2/src/ivectorbin/compute-calibration \ + ark,t:$seg_dir/scores.$n.txt -; done | \ + awk '{i += $1; j++;} END{print i / j}') + echo $threshold > $seg_dir/threshold +fi + +threshold=$(cat $seg_dir/threshold) +if [ $stage -le 6 ]; then $cmd JOB=1:$reco_nj $seg_dir/log/cluster_segments.JOB.log \ - agglomerative-cluster-ib --min-clusters=$min_clusters \ - --verbose=3 --stopping-threshold=$stopping_threshold --input-factor=0 \ + agglomerative-cluster-ib --input-factor=0.0 --min-clusters=$min_clusters $clustering_opts \ + --max-merge-thresh=$threshold --verbose=3 \ --counts-rspecifier="ark,t:utils/filter_scp.pl $data_uniform_seg/split${reco_nj}reco/JOB/utt2spk $data_uniform_seg/utt2num_frames |" \ "ark:gunzip -c $post_dir/avg_post.JOB.gz |" \ "ark,t:${data_uniform_seg}/split${reco_nj}reco/JOB/reco2utt" \ ark,t:$seg_dir/utt2cluster_id.JOB fi -if [ $stage -le 5 ]; then +if [ $stage -le 7 ]; then $cmd JOB=1:$reco_nj $seg_dir/log/init_segmentation.JOB.log \ segmentation-init-from-segments --frame-overlap=0.0 --shift-to-zero=false \ --utt2label-rspecifier=ark,t:${seg_dir}/utt2cluster_id.JOB \ @@ -120,7 +136,7 @@ if [ $stage -le 5 ]; then segmentation-to-segments ark:- ark,t:$seg_dir/utt2spk.JOB $seg_dir/segments.JOB fi -if [ $stage -le 6 ]; then +if [ $stage -le 8 ]; then rm -r $out_data || true utils/data/convert_data_dir_to_whole.sh $data $out_data rm $out_data/{text,cmvn.scp} || true diff --git a/egs/wsj/s5/steps/segmentation/cluster_segments_aIB_change_point.sh b/egs/wsj/s5/steps/segmentation/cluster_segments_aIB_change_point.sh index a1f187fab31..9ca3efb7b9a 100755 --- a/egs/wsj/s5/steps/segmentation/cluster_segments_aIB_change_point.sh +++ b/egs/wsj/s5/steps/segmentation/cluster_segments_aIB_change_point.sh @@ -6,9 +6,10 @@ stage=-1 cmd=queue.pl reco_nj=4 frame_shift=0.01 +frame_overlap=0.0 utt_nj=18 min_clusters=10 -stopping_threshold=0.5 +clustering_opts="--stopping-threshold=0.5 --max-merge-thresh=0.25 --normalize-by-entropy" . path.sh . utils/parse_options.sh @@ -29,17 +30,19 @@ out_data=$3 num_frames=`perl -e "print int($window / $frame_shift + 0.5)"` num_frames_overlap=`perl -e "print int($overlap/ $frame_shift + 0.5)"` -data_uniform_seg=${data}_uniform_seg_window${window}_ovlp${overlap} - -mkdir -p ${data_uniform_seg} +data_id=`basename $data` +data_uniform_seg=$dir/${data_id}_uniform_seg_window${window}_ovlp${overlap} mkdir -p $dir #segmentation-cluster-adjacent-segments --verbose=0 'ark:segmentation-copy --keep-label=1 "ark:gunzip -c exp/nnet3_lstm_sad_music/nnet_lstm_1e//segmentation_bn_eval97_whole_bp/orig_segmentation.1.gz |" ark:- | segmentation-split-segments --max-segment-length=250 --overlap-length=0 ark:- ark:- |' scp:data/bn_eval97_bp_hires/feats.scp "ark:| segmentation-post-process --merge-adjacent-segments ark:- ark:- | segmentation-to-segments ark:- ark,t:- /dev/null" 2>&1 | less if [ $stage -le 0 ]; then + rm -r ${data_uniform_seg} || true + mkdir -p ${data_uniform_seg} + $cmd $dir/log/get_subsegments.log \ - segmentation-init-from-segments --frame-overlap=0.015 $data/segments ark:- \| \ + segmentation-init-from-segments --frame-overlap=$frame_overlap $data/segments ark:- \| \ segmentation-split-segments --max-segment-length=$num_frames --overlap-length=$num_frames_overlap ark:- ark:- \| \ segmentation-cluster-adjacent-segments --verbose=3 ark:- "scp:$data/feats.scp" ark:- \| \ segmentation-post-process --merge-adjacent-segments ark:- ark:- \| \ @@ -86,8 +89,6 @@ if [ $stage -le 2 ]; then fi if [ $stage -le 3 ]; then - utils/data/get_utt2num_frames.sh --nj $utt_nj --cmd "$cmd" ${data_uniform_seg} - $cmd JOB=1:$reco_nj $post_dir/log/compute_average_post.JOB.log \ gmm-global-post-to-feats \ --utt2spk="ark,t:cut -d ' ' -f 1,2 ${data_uniform_seg}/split${reco_nj}reco/JOB/segments |" \ @@ -98,16 +99,38 @@ fi seg_dir=$dir/segmentation_`basename $data_uniform_seg` if [ $stage -le 4 ]; then + utils/data/get_utt2num_frames.sh --nj $utt_nj --cmd "$cmd" ${data_uniform_seg} + + $cmd JOB=1:$reco_nj $seg_dir/log/compute_scores.JOB.log \ + ib-scoring-dense --input-factor=0 $clustering_opts \ + --counts-rspecifier="ark,t:utils/filter_scp.pl $data_uniform_seg/split${reco_nj}reco/JOB/utt2spk $data_uniform_seg/utt2num_frames |" \ + "ark,t:${data_uniform_seg}/split${reco_nj}reco/JOB/reco2utt" \ + "ark:gunzip -c $post_dir/avg_post.JOB.gz |" \ + ark,t:$seg_dir/scores.JOB.txt ark:/dev/null +fi + +if [ $stage -le 5 ]; then + $cmd JOB=1:$reco_nj $seg_dir/log/calibrate.JOB.log \ + /export/a12/vmanoha1/kaldi-diarization-v2/src/ivectorbin/compute-calibration \ + ark,t:$seg_dir/scores.JOB.txt $seg_dir/threshold.JOB.txt + + threshold=$(for n in `seq $reco_nj`; do cat $seg_dir/threshold.$n.txt; done | \ + awk '{i += $1; j++;} END{print i / j}') + echo $threshold > $seg_dir/threshold +fi + +threshold=$(cat $seg_dir/threshold) +if [ $stage -le 6 ]; then $cmd JOB=1:$reco_nj $seg_dir/log/cluster_segments.JOB.log \ - agglomerative-cluster-ib --min-clusters=$min_clusters \ - --verbose=3 --stopping-threshold=$stopping_threshold --input-factor=0 \ + agglomerative-cluster-ib --input-factor=0.0 $clustering_opts \ + --max-merge-thresh=$threshold --verbose=3 \ --counts-rspecifier="ark,t:utils/filter_scp.pl $data_uniform_seg/split${reco_nj}reco/JOB/utt2spk $data_uniform_seg/utt2num_frames |" \ "ark:gunzip -c $post_dir/avg_post.JOB.gz |" \ "ark,t:${data_uniform_seg}/split${reco_nj}reco/JOB/reco2utt" \ ark,t:$seg_dir/utt2cluster_id.JOB fi -if [ $stage -le 5 ]; then +if [ $stage -le 7 ]; then $cmd JOB=1:$reco_nj $seg_dir/log/init_segmentation.JOB.log \ segmentation-init-from-segments --frame-overlap=0.0 --shift-to-zero=false \ --utt2label-rspecifier=ark,t:${seg_dir}/utt2cluster_id.JOB \ @@ -120,7 +143,7 @@ if [ $stage -le 5 ]; then segmentation-to-segments ark:- ark,t:$seg_dir/utt2spk.JOB $seg_dir/segments.JOB fi -if [ $stage -le 6 ]; then +if [ $stage -le 8 ]; then rm -r $out_data || true utils/data/convert_data_dir_to_whole.sh $data $out_data rm $out_data/{text,cmvn.scp} || true From 53e167d4e3500f9b01518209ca269598683d81ed Mon Sep 17 00:00:00 2001 From: Vimal Manohar Date: Mon, 6 Feb 2017 16:48:58 -0500 Subject: [PATCH 188/213] asr_diarization: Added virtual destructor --- src/hmm/transition-model.h | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/hmm/transition-model.h b/src/hmm/transition-model.h index c059e319dd5..51802b37f41 100644 --- a/src/hmm/transition-model.h +++ b/src/hmm/transition-model.h @@ -130,12 +130,15 @@ class TransitionModel { /// Constructor that takes no arguments: typically used prior to calling Read. TransitionModel() { } + + virtual ~TransitionModel() { } /// Does the same things as the constructor. void Init(const ContextDependencyInterface &ctx_dep, const HmmTopology &hmm_topo); - void Read(std::istream &is, bool binary); // note, no symbol table: topo object always read/written w/o symbols. + // note, no symbol table: topo object always read/written w/o symbols. + virtual void Read(std::istream &is, bool binary); void Write(std::ostream &os, bool binary) const; @@ -319,7 +322,6 @@ class TransitionModel { /// of pdfs). int32 num_pdfs_; - DISALLOW_COPY_AND_ASSIGN(TransitionModel); }; From 53dec62b21da6f2ec693f96c464fd04ad9ea1b04 Mon Sep 17 00:00:00 2001 From: Vimal Manohar Date: Mon, 13 Feb 2017 12:26:09 -0500 Subject: [PATCH 189/213] error_msg: Simplifying err_msg --- egs/wsj/s5/steps/libs/common.py | 32 ++++++++++++++++++++--------- egs/wsj/s5/steps/nnet3/train_dnn.py | 6 +++--- 2 files changed, 25 insertions(+), 13 deletions(-) diff --git a/egs/wsj/s5/steps/libs/common.py b/egs/wsj/s5/steps/libs/common.py index f2a336cd640..393fef7d4f6 100644 --- a/egs/wsj/s5/steps/libs/common.py +++ b/egs/wsj/s5/steps/libs/common.py @@ -79,10 +79,13 @@ class KaldiCommandException(Exception): kaldi command that caused the error and the error string captured. """ def __init__(self, command, err=None): + import re Exception.__init__(self, "There was an error while running the command " - "{0}\n{1}\n{2}".format(command, "-"*10, - "" if err is None else err)) + "{0}\n{1}\n{2}".format( + re.sub('\s+', ' ', command).strip(), + "-"*10, + "" if err is None else err)) class BackgroundProcessHandler(): @@ -165,17 +168,20 @@ def add_process(self, t): self.start() def is_process_done(self, t): - p, command = t + p, command, exit_on_failure = t if p.poll() is None: return False return True def ensure_process_is_done(self, t): - p, command = t + p, command, exit_on_failure = t logger.debug("Waiting for process '{0}' to end".format(command)) [stdout, stderr] = p.communicate() if p.returncode is not 0: - raise KaldiCommandException(command, stderr) + print("There was an error while running the command " + "{0}\n{1}\n{2}".format(command, "-"*10, stderr)) + if exit_on_failure: + os._exit(1) def ensure_processes_are_done(self): self.__process_queue.reverse() @@ -192,7 +198,8 @@ def debug(self): logger.info("Process '{0}' is running".format(command)) -def run_job(command, wait=True, background_process_handler=None): +def run_job(command, wait=True, background_process_handler=None, + exit_on_failure=False): """ Runs a kaldi job, usually using a script such as queue.pl and run.pl, and redirects the stdout and stderr to the parent process's streams. @@ -206,12 +213,14 @@ class that is instantiated by the top-level script. If this is wait: If True, wait until the process is completed. However, if the background_process_handler is provided, this option will be ignored and the process will be run in the background. + exit_on_failure: If True, will exit from the script on failure. + Only applicable when background_process_handler is specified. """ p = subprocess.Popen(command, shell=True) if background_process_handler is not None: wait = False - background_process_handler.add_process((p, command)) + background_process_handler.add_process((p, command, exit_on_failure)) if wait: p.communicate() @@ -222,7 +231,8 @@ class that is instantiated by the top-level script. If this is return p -def run_kaldi_command(command, wait=True, background_process_handler=None): +def run_kaldi_command(command, wait=True, background_process_handler=None, + exit_on_failure=False): """ Runs commands frequently seen in Kaldi scripts and captures the stdout and stderr. These are usually a sequence of commands connected by pipes, so we use @@ -235,6 +245,8 @@ class that is instantiated by the top-level script. If this is wait: If True, wait until the process is completed. However, if the background_process_handler is provided, this option will be ignored and the process will be run in the background. + exit_on_failure: If True, will exit from the script on failure. + Only applicable when background_process_handler is specified. """ p = subprocess.Popen(command, shell=True, stdout=subprocess.PIPE, @@ -242,7 +254,7 @@ class that is instantiated by the top-level script. If this is if background_process_handler is not None: wait = False - background_process_handler.add_process((p, command)) + background_process_handler.add_process((p, command, exit_on_failure)) if wait: [stdout, stderr] = p.communicate() @@ -281,7 +293,7 @@ def get_number_of_jobs(alidir): num_jobs = int(open('{0}/num_jobs'.format(alidir)).readline().strip()) except (IOError, ValueError) as e: raise Exception("Exception while reading the " - "number of alignment jobs: {0}".format(e.errstr)) + "number of alignment jobs: {0}".format(e)) return num_jobs diff --git a/egs/wsj/s5/steps/nnet3/train_dnn.py b/egs/wsj/s5/steps/nnet3/train_dnn.py index 2813f719606..f5ac42fd52f 100755 --- a/egs/wsj/s5/steps/nnet3/train_dnn.py +++ b/egs/wsj/s5/steps/nnet3/train_dnn.py @@ -412,14 +412,14 @@ def main(): polling_time=args.background_polling_time) train(args, run_opts, background_process_handler) background_process_handler.ensure_processes_are_done() - except Exception as e: + except Exception: if args.email is not None: message = ("Training session for experiment {dir} " "died due to an error.".format(dir=args.dir)) common_lib.send_mail(message, message, args.email) - traceback.print_exc() background_process_handler.stop() - raise e + logger.error("Training session failed; traceback = ", exc_info=True) + raise SystemExit(1) if __name__ == "__main__": From 94a419f673d214039591b57145047deffdb226c4 Mon Sep 17 00:00:00 2001 From: Vimal Manohar Date: Thu, 23 Feb 2017 15:57:30 -0500 Subject: [PATCH 190/213] Modify the way some of the segmentation scripts work --- egs/wsj/s5/steps/segmentation/decode_sad.sh | 14 ++++++++--- .../do_segmentation_data_dir_simple.sh | 10 +++++--- .../internal/convert_ali_to_vad.sh | 25 +++++++++++-------- 3 files changed, 31 insertions(+), 18 deletions(-) diff --git a/egs/wsj/s5/steps/segmentation/decode_sad.sh b/egs/wsj/s5/steps/segmentation/decode_sad.sh index 2f2e5ae2586..a39e93dd83f 100755 --- a/egs/wsj/s5/steps/segmentation/decode_sad.sh +++ b/egs/wsj/s5/steps/segmentation/decode_sad.sh @@ -28,7 +28,13 @@ mkdir -p $dir nj=`cat $log_likes_dir/num_jobs` echo $nj > $dir/num_jobs -for f in $graph_dir/$iter.mdl $log_likes_dir/log_likes.1.gz $graph_dir/HCLG.fst; do +if [ -f $dir/$iter.mdl ]; then + srcdir=$dir +else + srcdir=`dirname $dir` +fi + +for f in $srcdir/$iter.mdl $log_likes_dir/log_likes.1.gz $graph_dir/HCLG.fst; do if [ ! -f $f ]; then echo "$0: Could not find file $f" exit 1 @@ -37,14 +43,14 @@ done decoder_opts+=(--acoustic-scale=$acwt --beam=$beam --max-active=$max_active) -ali="ark:| ali-to-phones --per-frame $graph_dir/$iter.mdl ark:- ark:- | gzip -c > $dir/ali.JOB.gz" +ali="ark:| ali-to-phones --per-frame $srcdir/$iter.mdl ark:- ark:- | gzip -c > $dir/ali.JOB.gz" if $get_pdfs; then - ali="ark:| ali-to-pdf $graph_dir/$iter.mdl ark:- ark:- | gzip -c > $dir/ali.JOB.gz" + ali="ark:| ali-to-pdf $srcdir/$iter.mdl ark:- ark:- | gzip -c > $dir/ali.JOB.gz" fi $cmd JOB=1:$nj $dir/log/decode.JOB.log \ decode-faster-mapped ${decoder_opts[@]} \ - $graph_dir/$iter.mdl \ + $srcdir/$iter.mdl \ $graph_dir/HCLG.fst "ark:gunzip -c $log_likes_dir/log_likes.JOB.gz |" \ ark:/dev/null "$ali" diff --git a/egs/wsj/s5/steps/segmentation/do_segmentation_data_dir_simple.sh b/egs/wsj/s5/steps/segmentation/do_segmentation_data_dir_simple.sh index 0da130ee3ab..cd4f36ded6b 100755 --- a/egs/wsj/s5/steps/segmentation/do_segmentation_data_dir_simple.sh +++ b/egs/wsj/s5/steps/segmentation/do_segmentation_data_dir_simple.sh @@ -32,8 +32,9 @@ extra_right_context=0 frame_subsampling_factor=1 # Subsampling at the output -transition_scale=10.0 -loopscale=1.0 +transition_scale=1.0 +loopscale=0.1 +acwt=1.0 # Set to true if the test data has > 8kHz sampling frequency. do_downsampling=false @@ -95,6 +96,7 @@ else fi if [ $stage -le 1 ]; then + utils/fix_data_dir.sh $test_data_dir steps/make_mfcc.sh --mfcc-config $mfcc_config --nj $nj --cmd "$train_cmd" \ ${test_data_dir} exp/make_hires/${data_id}${feat_affix} $mfcc_dir steps/compute_cmvn_stats.sh ${test_data_dir} exp/make_hires/${data_id}${feat_affix} $mfcc_dir @@ -163,9 +165,9 @@ if [ $stage -le 5 ]; then fi if [ $stage -le 6 ]; then - # 'final' here refers to $lang/final.mdl steps/segmentation/decode_sad.sh --acwt 1.0 --cmd "$decode_cmd" \ - --iter final --get-pdfs true $graph_dir $sad_dir $seg_dir + --iter ${iter} \ + --get-pdfs true $graph_dir $sad_dir $seg_dir fi if [ $stage -le 7 ]; then diff --git a/egs/wsj/s5/steps/segmentation/internal/convert_ali_to_vad.sh b/egs/wsj/s5/steps/segmentation/internal/convert_ali_to_vad.sh index 234b5020797..0d8939a9b80 100755 --- a/egs/wsj/s5/steps/segmentation/internal/convert_ali_to_vad.sh +++ b/egs/wsj/s5/steps/segmentation/internal/convert_ali_to_vad.sh @@ -8,15 +8,20 @@ set -u cmd=run.pl -frame_shift=0.01 -frame_subsampling_factor=1 - . parse_options.sh if [ $# -ne 3 ]; then echo "This script converts the alignment in the alignment directory " echo "to speech activity segments based on the provided phone-map." - echo "Usage: $0 exp/tri3_ali data/lang/phones/sad.map exp/tri3_ali_vad" + echo "The output is stored in sad_seg.*.ark along with an scp-file " + echo "sad_seg.scp in Segmentation format.\n" + echo "If alignment directory has frame_subsampling_factor, the segments " + echo "are applied that frame-subsampling-factor.\n" + echo "The phone-map file must have two columns: " + echo " \n" + echo "\n" + echo "Usage: $0 " + echo "e.g. : $0 exp/tri3_ali data/lang/phones/sad.map exp/tri3_ali_vad" exit 1 fi @@ -33,21 +38,21 @@ mkdir -p $dir nj=`cat $ali_dir/num_jobs` || exit 1 echo $nj > $dir/num_jobs +frame_subsampling_factor=1 if [ -f $ali_dir/frame_subsampling_factor ]; then frame_subsampling_factor=`cat $ali_dir/frame_subsampling_factor` fi -ali_frame_shift=`perl -e "print ($frame_shift * $frame_subsampling_factor);"` -ali_frame_overlap=`perl -e "print ($ali_frame_shift * 1.5);"` - dir=`perl -e '($dir,$pwd)= @ARGV; if($dir!~m:^/:) { $dir = "$pwd/$dir"; } print $dir; ' $dir ${PWD}` $cmd JOB=1:$nj $dir/log/get_sad.JOB.log \ segmentation-init-from-ali \ - "ark:gunzip -c ${ali_dir}/ali.JOB.gz | ali-to-phones --per-frame ${ali_dir}/final.mdl ark:- ark:- |" \ - ark:- \| segmentation-copy --label-map=$phone_map ark:- ark:- \| \ + "ark:gunzip -c ${ali_dir}/ali.JOB.gz | ali-to-phones --per-frame ${ali_dir}/final.mdl ark:- ark:- |" \ + ark:- \| \ + segmentation-copy --label-map=$phone_map \ + --frame-subsampling-factor=$frame_subsampling_factor ark:- ark:- \| \ segmentation-post-process --merge-adjacent-segments ark:- \ - ark,scp:$dir/sad_seg.JOB.ark,$dir/sad_seg.JOB.scp + ark,scp:$dir/sad_seg.JOB.ark,$dir/sad_seg.JOB.scp for n in `seq $nj`; do cat $dir/sad_seg.$n.scp From 0465262edf57e03509d2c5b4b6da877a17baa0e2 Mon Sep 17 00:00:00 2001 From: Vimal Manohar Date: Thu, 23 Feb 2017 15:58:37 -0500 Subject: [PATCH 191/213] asr_diarization: add more checks and messages to segmentation binaries --- src/segmenterbin/segmentation-combine-segments.cc | 3 ++- src/segmenterbin/segmentation-init-from-segments.cc | 2 +- src/segmenterbin/segmentation-merge-recordings.cc | 4 ++-- 3 files changed, 5 insertions(+), 4 deletions(-) diff --git a/src/segmenterbin/segmentation-combine-segments.cc b/src/segmenterbin/segmentation-combine-segments.cc index 09b789a0921..1d745ca91f9 100644 --- a/src/segmenterbin/segmentation-combine-segments.cc +++ b/src/segmenterbin/segmentation-combine-segments.cc @@ -103,7 +103,8 @@ int main(int argc, char *argv[]) { if (!utt_segmentation_reader.HasKey(*it)) { KALDI_WARN << "Could not find utterance " << *it << " in " - << "segmentation " << utt_segmentation_rspecifier; + << "segmentation " << utt_segmentation_rspecifier + << (include_missing ? "; using default segmentation": ""); if (!include_missing) { num_err++; } else { diff --git a/src/segmenterbin/segmentation-init-from-segments.cc b/src/segmenterbin/segmentation-init-from-segments.cc index 469b4ef2965..980ec697602 100644 --- a/src/segmenterbin/segmentation-init-from-segments.cc +++ b/src/segmenterbin/segmentation-init-from-segments.cc @@ -70,7 +70,7 @@ int main(int argc, char *argv[]) { ParseOptions po(usage); - po.Register("segment-label", &segment_label, + po.Register("label", &segment_label, "Label for all the segments in the segmentations"); po.Register("utt2label-rspecifier", &utt2label_rspecifier, "Mapping for each utterance to an integer label. " diff --git a/src/segmenterbin/segmentation-merge-recordings.cc b/src/segmenterbin/segmentation-merge-recordings.cc index dccd82b0595..69f6758c90d 100644 --- a/src/segmenterbin/segmentation-merge-recordings.cc +++ b/src/segmenterbin/segmentation-merge-recordings.cc @@ -92,8 +92,8 @@ int main(int argc, char *argv[]) { << "created overall " << num_segments << " segments; " << "failed to merge " << num_err << " old segmentations"; - return (num_new_segmentations > 0 && num_err < num_old_segmentations / 2 ? - 0 : 1); + return (num_segments > 0 && num_new_segmentations > 0 && + num_err < num_old_segmentations / 2 ? 0 : 1); } catch(const std::exception &e) { std::cerr << e.what(); return -1; From ff438b98f2101ef9f490d4b8f8ebb0d3026af5ad Mon Sep 17 00:00:00 2001 From: Vimal Manohar Date: Wed, 1 Mar 2017 17:01:05 -0500 Subject: [PATCH 192/213] asr_diarization: Add more control over speed in the SAD scripts --- .../s5/local/segmentation/do_corruption_data_dir.sh | 10 ++++++---- .../local/segmentation/do_corruption_data_dir_music.sh | 7 ++++--- 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/egs/aspire/s5/local/segmentation/do_corruption_data_dir.sh b/egs/aspire/s5/local/segmentation/do_corruption_data_dir.sh index 5d38be87d70..45fdf6c1c5c 100755 --- a/egs/aspire/s5/local/segmentation/do_corruption_data_dir.sh +++ b/egs/aspire/s5/local/segmentation/do_corruption_data_dir.sh @@ -19,8 +19,9 @@ speed_perturb=true num_data_reps=5 # Number of corrupted versions snrs="20:10:15:5:0:-5" foreground_snrs="20:10:15:5:0:-5" -background_snrs="20:10:15:5:0:-5" +background_snrs="20:10:15:5:2:0:-2:-5" base_rirs=simulated +speeds="0.9 1.0 1.1" # Parallel options reco_nj=40 @@ -48,7 +49,8 @@ if [ "$base_rirs" == "simulated" ]; then # This is the config for the system using simulated RIRs and point-source noises rvb_opts+=(--rir-set-parameters "0.5, RIRS_NOISES/simulated_rirs/smallroom/rir_list") rvb_opts+=(--rir-set-parameters "0.5, RIRS_NOISES/simulated_rirs/mediumroom/rir_list") - rvb_opts+=(--noise-set-parameters RIRS_NOISES/pointsource_noises/noise_list) + rvb_opts+=(--noise-set-parameters "0.1, RIRS_NOISES/pointsource_noises/background_noise_list") + rvb_opts+=(--noise-set-parameters "0.9, RIRS_NOISES/pointsource_noises/foreground_noise_list") else # This is the config for the JHU ASpIRE submission system rvb_opts+=(--rir-set-parameters "1.0, RIRS_NOISES/real_rirs_isotropic_noises/rir_list") @@ -67,7 +69,7 @@ if [ $stage -le 1 ]; then --pointsource-noise-addition-probability=1 \ --isotropic-noise-addition-probability=1 \ --num-replications=$num_data_reps \ - --max-noises-per-minute=1 \ + --max-noises-per-minute=2 \ data/${data_id} data/${corrupted_data_id} fi @@ -78,7 +80,7 @@ if $speed_perturb; then ## Assuming whole data directories for x in $corrupted_data_dir; do cp $x/reco2dur $x/utt2dur - utils/data/perturb_data_dir_speed_random.sh $x ${x}_spr + utils/data/perturb_data_dir_speed_random.sh --speeds "$speeds" $x ${x}_spr done fi diff --git a/egs/aspire/s5/local/segmentation/do_corruption_data_dir_music.sh b/egs/aspire/s5/local/segmentation/do_corruption_data_dir_music.sh index 4fc369234ea..8865e640674 100755 --- a/egs/aspire/s5/local/segmentation/do_corruption_data_dir_music.sh +++ b/egs/aspire/s5/local/segmentation/do_corruption_data_dir_music.sh @@ -24,6 +24,7 @@ mfcc_irm_config=conf/mfcc_hires_bp.conf dry_run=false corrupt_only=false speed_perturb=true +speeds="0.9 1.0 1.1" reco_vad_dir= @@ -79,7 +80,7 @@ if $speed_perturb; then ## Assuming whole data directories for x in $corrupted_data_dir; do cp $x/reco2dur $x/utt2dur - utils/data/perturb_data_dir_speed_random.sh $x ${x}_spr + utils/data/perturb_data_dir_speed_random.sh --speeds "$speeds" $x ${x}_spr done fi @@ -169,7 +170,7 @@ if [ $stage -le 10 ]; then utils/fix_data_dir.sh $music_data_dir if $speed_perturb; then - utils/data/perturb_data_dir_speed_3way.sh $music_data_dir ${music_data_dir}_spr + utils/data/perturb_data_dir_speed_4way.sh $music_data_dir ${music_data_dir}_spr mv ${music_data_dir}_spr/segments{,.temp} cat ${music_data_dir}_spr/segments.temp | \ utils/filter_scp.pl -f 2 ${corrupted_data_dir}/reco2utt > ${music_data_dir}_spr/segments @@ -222,7 +223,7 @@ if [ $stage -le 12 ]; then EOF $train_cmd JOB=1:$reco_nj $music_dir/log/get_speech_music_labels.JOB.log \ - intersect-int-vectors --mapping-in=$music_dir/speech_music_map \ + intersect-int-vectors --mapping-in=$music_dir/speech_music_map --length-tolerance=2 \ "scp:utils/filter_scp.pl ${corrupted_data_dir}/split${reco_nj}reco/JOB/reco2utt ${corrupted_data_dir}/speech_labels.scp |" \ "scp:utils/filter_scp.pl ${corrupted_data_dir}/split${reco_nj}reco/JOB/reco2utt ${corrupted_data_dir}/music_labels.scp |" \ ark,scp:$label_dir/speech_music_labels_${corrupted_data_id}.JOB.ark,$label_dir/speech_music_labels_${corrupted_data_id}.JOB.scp From 6ef5b5885a6f72441dc3f4d9876a7f620dc2716a Mon Sep 17 00:00:00 2001 From: Vimal Manohar Date: Wed, 1 Mar 2017 17:05:08 -0500 Subject: [PATCH 193/213] asr_diarization: prepare unsad data --- .../local/segmentation/prepare_babel_data.sh | 14 +++++---- .../local/segmentation/prepare_fisher_data.sh | 10 +++---- .../local/segmentation/prepare_unsad_data.sh | 29 ++++++++++--------- 3 files changed, 29 insertions(+), 24 deletions(-) diff --git a/egs/aspire/s5/local/segmentation/prepare_babel_data.sh b/egs/aspire/s5/local/segmentation/prepare_babel_data.sh index 24a61eca772..927c530663d 100644 --- a/egs/aspire/s5/local/segmentation/prepare_babel_data.sh +++ b/egs/aspire/s5/local/segmentation/prepare_babel_data.sh @@ -68,18 +68,20 @@ EOF # The original data directory which will be converted to a whole (recording-level) directory. utils/copy_data_dir.sh $ROOT_DIR/data/train data/babel_${lang_id}_train train_data_dir=data/babel_${lang_id}_train +speeds="0.9 1.0 1.1" +num_speeds=$(echo $speeds | awk '{print NF}') # Expecting the user to have done run.sh to have $model_dir, # $sat_model_dir, $lang, $lang_test, $train_data_dir -local/segmentation/prepare_unsad_data.sh --stage 14 \ - --sad-map $dir/babel_sad.map \ +local/segmentation/prepare_unsad_data.sh \ + --sad-map $dir/babel_sad.map --speeds "$speeds" \ --config-dir $ROOT_DIR/conf --feat-type plp --add-pitch true \ --reco-nj 40 --nj 100 --cmd "$train_cmd" \ --sat-model-dir $sat_model_dir \ --lang-test $lang_test \ $train_data_dir $lang $model_dir $dir -orig_data_dir=${train_data_dir}_sp +orig_data_dir=${train_data_dir}_sp${num_speeds} data_dir=${train_data_dir}_whole @@ -90,16 +92,16 @@ if [ ! -z $subset ]; then data_dir=${data_dir}_$subset fi -reco_vad_dir=$dir/`basename $model_dir`_reco_vad_`basename $train_data_dir`_sp +reco_vad_dir=$dir/`basename $model_dir`_reco_vad_`basename $train_data_dir`_sp4 # Add noise from MUSAN corpus to data directory and create a new data directory local/segmentation/do_corruption_data_dir.sh \ - --data-dir $data_dir \ + --data-dir $data_dir --speeds "$speeds" \ --reco-vad-dir $reco_vad_dir \ --feat-suffix hires_bp --mfcc-config conf/mfcc_hires_bp.conf # Add music from MUSAN corpus to data directory and create a new data directory local/segmentation/do_corruption_data_dir_music.sh \ - --data-dir $data_dir \ + --data-dir $data_dir --speeds "$speeds" \ --reco-vad-dir $reco_vad_dir \ --feat-suffix hires_bp --mfcc-config conf/mfcc_hires_bp.conf diff --git a/egs/aspire/s5/local/segmentation/prepare_fisher_data.sh b/egs/aspire/s5/local/segmentation/prepare_fisher_data.sh index 4749ff7da8a..d90dd05f472 100644 --- a/egs/aspire/s5/local/segmentation/prepare_fisher_data.sh +++ b/egs/aspire/s5/local/segmentation/prepare_fisher_data.sh @@ -56,7 +56,7 @@ oov_I 3 oov_S 3 EOF -false && { +true && { # Expecting the user to have done run.sh to have $model_dir, # $sat_model_dir, $lang, $lang_test, $train_data_dir local/segmentation/prepare_unsad_data.sh \ @@ -72,21 +72,21 @@ data_dir=${train_data_dir}_whole if [ ! -z $subset ]; then # Work on a subset - false && utils/subset_data_dir.sh ${data_dir} $subset \ + true && utils/subset_data_dir.sh ${data_dir} $subset \ ${data_dir}_$subset data_dir=${data_dir}_$subset fi -reco_vad_dir=$dir/`basename $model_dir`_reco_vad_`basename $train_data_dir`_sp +reco_vad_dir=$dir/`basename $model_dir`_reco_vad_`basename $train_data_dir`_sp4 # Add noise from MUSAN corpus to data directory and create a new data directory -false && local/segmentation/do_corruption_data_dir.sh \ +true && local/segmentation/do_corruption_data_dir.sh \ --data-dir $data_dir \ --reco-vad-dir $reco_vad_dir \ --feat-suffix hires_bp --mfcc-config conf/mfcc_hires_bp.conf # Add music from MUSAN corpus to data directory and create a new data directory -local/segmentation/do_corruption_data_dir_music.sh --stage 10 \ +local/segmentation/do_corruption_data_dir_music.sh \ --data-dir $data_dir \ --reco-vad-dir $reco_vad_dir \ --feat-suffix hires_bp --mfcc-config conf/mfcc_hires_bp.conf diff --git a/egs/aspire/s5/local/segmentation/prepare_unsad_data.sh b/egs/aspire/s5/local/segmentation/prepare_unsad_data.sh index 7385e309f5f..dc4cbf58994 100755 --- a/egs/aspire/s5/local/segmentation/prepare_unsad_data.sh +++ b/egs/aspire/s5/local/segmentation/prepare_unsad_data.sh @@ -38,6 +38,8 @@ sat_model_dir= # Model directory used for getting alignments lang_test= # Language directory used to build graph. # If its not provided, $lang will be used instead. +speeds="0.9 1.0 1.1" + . utils/parse_options.sh if [ $# -ne 4 ]; then @@ -188,14 +190,15 @@ if [ $stage -le -2 ]; then utils/data/get_utt2dur.sh ${whole_data_dir} fi +num_speeds=`echo $speeds | awk '{print NF}'` if $speed_perturb; then - plpdir=${plpdir}_sp - mfccdir=${mfccdir}_sp + plpdir=${plpdir}_sp$num_speeds + mfccdir=${mfccdir}_sp$num_speeds if [ $stage -le -1 ]; then - utils/data/perturb_data_dir_speed_3way.sh ${whole_data_dir} ${whole_data_dir}_sp - utils/data/perturb_data_dir_speed_3way.sh ${data_dir} ${data_dir}_sp + utils/data/perturb_data_dir_speed_${num_speeds}way.sh ${whole_data_dir} ${whole_data_dir}_sp${num_speeds} + utils/data/perturb_data_dir_speed_${num_speeds}way.sh ${data_dir} ${data_dir}_sp${num_speeds} if [ $feat_type == "mfcc" ]; then if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $mfccdir/storage ]; then @@ -205,9 +208,9 @@ if $speed_perturb; then make_mfcc --cmd "$cmd --max-jobs-run 40" --nj $nj \ --mfcc-config $feat_config \ --add-pitch $add_pitch --pitch-config $pitch_config \ - ${whole_data_dir}_sp exp/make_mfcc $mfccdir || exit 1 + ${whole_data_dir}_sp${num_speeds} exp/make_mfcc $mfccdir || exit 1 steps/compute_cmvn_stats.sh \ - ${whole_data_dir}_sp exp/make_mfcc $mfccdir || exit 1 + ${whole_data_dir}_sp${num_speeds} exp/make_mfcc $mfccdir || exit 1 elif [ $feat_type == "plp" ]; then if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $plpdir/storage ]; then utils/create_split_dir.pl \ @@ -217,20 +220,20 @@ if $speed_perturb; then make_plp --cmd "$cmd --max-jobs-run 40" --nj $nj \ --plp-config $feat_config \ --add-pitch $add_pitch --pitch-config $pitch_config \ - ${whole_data_dir}_sp exp/make_plp $plpdir || exit 1 + ${whole_data_dir}_sp${num_speeds} exp/make_plp $plpdir || exit 1 steps/compute_cmvn_stats.sh \ - ${whole_data_dir}_sp exp/make_plp $plpdir || exit 1 + ${whole_data_dir}_sp${num_speeds} exp/make_plp $plpdir || exit 1 else echo "$0: Unknown feat-type $feat_type. Must be mfcc or plp." exit 1 fi - utils/fix_data_dir.sh ${whole_data_dir}_sp + utils/fix_data_dir.sh ${whole_data_dir}_sp${num_speeds} fi - data_dir=${data_dir}_sp - whole_data_dir=${whole_data_dir}_sp - data_id=${data_id}_sp + data_dir=${data_dir}_sp${num_speeds} + whole_data_dir=${whole_data_dir}_sp${num_speeds} + data_id=${data_id}_sp${num_speeds} fi @@ -440,7 +443,7 @@ vad_dir=`perl -e '($dir,$pwd)= @ARGV; if($dir!~m:^/:) { $dir = "$pwd/$dir"; } pr if [ $stage -le 10 ]; then segmentation-init-from-segments --frame-shift=$frame_shift \ - --frame-overlap=$frame_overlap --segment-label=0 \ + --frame-overlap=$frame_overlap --label=0 \ $outside_data_dir/segments \ ark,scp:$vad_dir/outside_sad_seg.ark,$vad_dir/outside_sad_seg.scp fi From 1a1712350d7a7f1d6d7faebac9e50e3c826abd33 Mon Sep 17 00:00:00 2001 From: Vimal Manohar Date: Wed, 1 Mar 2017 17:05:45 -0500 Subject: [PATCH 194/213] asr_diarization: Better logging in compute_cmvn_stats --- egs/wsj/s5/steps/compute_cmvn_stats.sh | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/egs/wsj/s5/steps/compute_cmvn_stats.sh b/egs/wsj/s5/steps/compute_cmvn_stats.sh index 9056d88691c..6e7531394a2 100755 --- a/egs/wsj/s5/steps/compute_cmvn_stats.sh +++ b/egs/wsj/s5/steps/compute_cmvn_stats.sh @@ -91,18 +91,18 @@ if $fake; then ! cat $data/spk2utt | awk -v dim=$dim '{print $1, "["; for (n=0; n < dim; n++) { printf("0 "); } print "1"; for (n=0; n < dim; n++) { printf("1 "); } print "0 ]";}' | \ copy-matrix ark:- ark,scp:$cmvndir/cmvn_$name.ark,$cmvndir/cmvn_$name.scp && \ - echo "Error creating fake CMVN stats" && exit 1; + echo "Error creating fake CMVN stats. See $logdir/cmvn_$name.log." && exit 1; elif $two_channel; then ! compute-cmvn-stats-two-channel $data/reco2file_and_channel scp:$data/feats.scp \ ark,scp:$cmvndir/cmvn_$name.ark,$cmvndir/cmvn_$name.scp \ - 2> $logdir/cmvn_$name.log && echo "Error computing CMVN stats (using two-channel method)" && exit 1; + 2> $logdir/cmvn_$name.log && echo "Error computing CMVN stats (using two-channel method). See $logdir/cmvn_$name.log." && exit 1; elif [ ! -z "$fake_dims" ]; then ! compute-cmvn-stats --spk2utt=ark:$data/spk2utt scp:$data/feats.scp ark:- | \ modify-cmvn-stats "$fake_dims" ark:- ark,scp:$cmvndir/cmvn_$name.ark,$cmvndir/cmvn_$name.scp && \ - echo "Error computing (partially fake) CMVN stats" && exit 1; + echo "Error computing (partially fake) CMVN stats. See $logdir/cmvn_$name.log" && exit 1; else ! compute-cmvn-stats --spk2utt=ark:$data/spk2utt scp:$data/feats.scp ark,scp:$cmvndir/cmvn_$name.ark,$cmvndir/cmvn_$name.scp \ - 2> $logdir/cmvn_$name.log && echo "Error computing CMVN stats" && exit 1; + 2> $logdir/cmvn_$name.log && echo "Error computing CMVN stats. See $logdir/cmvn_$name.log" && exit 1; fi cp $cmvndir/cmvn_$name.scp $data/cmvn.scp || exit 1; From d02ef223a7cb80b4e930c5df0414fcdfdea86a8d Mon Sep 17 00:00:00 2001 From: Vimal Manohar Date: Wed, 1 Mar 2017 17:06:22 -0500 Subject: [PATCH 195/213] asr_diarization: Add perturb_data_dir_speed_random.sh --- .../data/perturb_data_dir_speed_random.sh | 24 ++++++++++++++----- 1 file changed, 18 insertions(+), 6 deletions(-) diff --git a/egs/wsj/s5/utils/data/perturb_data_dir_speed_random.sh b/egs/wsj/s5/utils/data/perturb_data_dir_speed_random.sh index d9d027b77a3..1eb7ebb874c 100755 --- a/egs/wsj/s5/utils/data/perturb_data_dir_speed_random.sh +++ b/egs/wsj/s5/utils/data/perturb_data_dir_speed_random.sh @@ -4,6 +4,8 @@ # Apache 2.0 +speeds="0.9 1.0 1.1" + . utils/parse_options.sh if [ $# != 2 ]; then @@ -33,15 +35,25 @@ echo "... obtaining it after speed-perturbing would be very slow, and" echo "... you might need it." utils/data/get_utt2dur.sh ${srcdir} -utils/split_data.sh --per-reco $srcdir 3 +num_speeds=`echo $speeds | awk '{print NF}'` +utils/split_data.sh --per-reco $srcdir $num_speeds + +speed_dirs= +i=1 +for speed in $speeds; do + if [ $speed != 1.0 ]; then + utils/data/perturb_data_dir_speed.sh $speed ${srcdir}/split${num_speeds}reco/$i ${destdir}_speed$speed || exit 1 + speed_dirs="${speed_dirs} ${destdir}_speed$speed" + else + speed_dirs="$speed_dirs ${srcdir}/split${num_speeds}reco/$i" + fi +done -utils/data/perturb_data_dir_speed.sh 0.9 ${srcdir}/split3reco/1 ${destdir}_speed0.9 || exit 1 -utils/data/perturb_data_dir_speed.sh 1.1 ${srcdir}/split3reco/3 ${destdir}_speed1.1 || exit 1 -utils/data/combine_data.sh $destdir ${srcdir}/split3reco/2 ${destdir}_speed0.9 ${destdir}_speed1.1 || exit 1 +utils/data/combine_data.sh $destdir ${speed_dirs} || exit 1 -rm -r ${destdir}_speed0.9 ${destdir}_speed1.1 +rm -r $speed_dirs ${srcdir}/split${num_speeds}reco -echo "$0: generated 3-way speed-perturbed version of random subsets of data in $srcdir, in $destdir" +echo "$0: generated $num_speeds-way speed-perturbed version of random subsets of data in $srcdir, in $destdir" if [ -f $srcdir/text ]; then utils/validate_data_dir.sh --no-feats $destdir else From 997d17de332c3e1e21f8aac1a38461475d94a05d Mon Sep 17 00:00:00 2001 From: Vimal Manohar Date: Wed, 1 Mar 2017 17:08:19 -0500 Subject: [PATCH 196/213] asr_diarization: AMI Segmentation run script --- .../segmentation/run_segmentation_ami.sh | 180 +++++++-- .../segmentation/run_segmentation_amitrain.sh | 373 ++++++++++++++++++ 2 files changed, 513 insertions(+), 40 deletions(-) create mode 100755 egs/aspire/s5/local/segmentation/run_segmentation_amitrain.sh diff --git a/egs/aspire/s5/local/segmentation/run_segmentation_ami.sh b/egs/aspire/s5/local/segmentation/run_segmentation_ami.sh index 733c6aa53fe..48677598728 100755 --- a/egs/aspire/s5/local/segmentation/run_segmentation_ami.sh +++ b/egs/aspire/s5/local/segmentation/run_segmentation_ami.sh @@ -23,6 +23,7 @@ speech_prior=0.3 min_silence_duration=30 min_speech_duration=10 frame_subsampling_factor=3 +ali_dir=/export/a09/vmanoha1/workspace_asr_diarization/egs/ami/s5b/exp/ihm/nnet3_cleaned/tdnn_sp_ali_dev_ihmdata_oraclespk . utils/parse_options.sh @@ -56,65 +57,110 @@ fi if [ $stage -le 2 ]; then ( cd $src_dir - utils/data/get_reco2utt.sh $src_dir/data/sdm1/dev + + utils/copy_data_dir.sh $src_dir/data/sdm1/dev_ihmdata \ + $src_dir/data/sdm1/dev_ihmdata_oraclespk + + cut -d ' ' -f 1,2 $src_dir/data/ihm/dev/segments | \ + utils/apply_map.pl -f 1 $src_dir/data/sdm1/dev_ihmdata/ihmutt2utt > \ + $src_dir/data/sdm1/dev_ihmdata_oraclespk/utt2spk.temp + + cat $src_dir/data/sdm1/dev_ihmdata_oraclespk/utt2spk.temp | \ + awk '{print $1" "$2"-"$1}' > \ + $src_dir/data/sdm1/dev_ihmdata_oraclespk/utt2newutt + + utils/apply_map.pl -f 1 $src_dir/data/sdm1/dev_ihmdata_oraclespk/utt2newutt \ + < $src_dir/data/sdm1/dev_ihmdata_oraclespk/utt2spk.temp > \ + $src_dir/data/sdm1/dev_ihmdata_oraclespk/utt2spk + + for f in feats.scp segments text; do + utils/apply_map.pl -f 1 $src_dir/data/sdm1/dev_ihmdata_oraclespk/utt2newutt \ + < $src_dir/data/sdm1/dev_ihmdata/$f > \ + $src_dir/data/sdm1/dev_ihmdata_oraclespk/$f + done + + rm $src_dir/data/sdm1/dev_ihmdata_oraclespk/{spk2utt,cmvn.scp} + utils/fix_data_dir.sh \ + $src_dir/data/sdm1/dev_ihmdata_oraclespk + + utils/data/get_reco2utt.sh $src_dir/data/sdm1/dev_ihmdata_oraclespk ) +fi - phone_map=$dir/phone_map +phone_map=$dir/phone_map +if [ $stage -le 2 ]; then steps/segmentation/get_sad_map.py \ $src_dir/data/lang | utils/sym2int.pl -f 1 $src_dir/data/lang/phones.txt > \ $phone_map fi -if [ $stage -le 3 ]; then - # Expecting user to have run local/run_cleanup_segmentation.sh in $src_dir - ( - cd $src_dir - steps/align_fmllr.sh --nj 18 --cmd "$train_cmd" \ - data/sdm1/dev_ihmdata data/lang \ - exp/ihm/tri3_cleaned \ - exp/sdm1/tri3_cleaned_dev_ihmdata - ) +if [ -z $ali_dir ]; then + if [ $stage -le 3 ]; then + # Expecting user to have run local/run_cleanup_segmentation.sh in $src_dir + ( + cd $src_dir + steps/align_fmllr.sh --nj 18 --cmd "$train_cmd" \ + data/sdm1/dev_ihmdata_oraclespk data/lang \ + exp/ihm/tri3_cleaned \ + exp/sdm1/tri3_cleaned_dev_ihmdata_oraclespk + ) + fi + ali_dir=exp/sdm1/tri3_cleaned_ali_dev_ihmdata_oraclespk fi if [ $stage -le 4 ]; then steps/segmentation/internal/convert_ali_to_vad.sh --cmd "$train_cmd" \ - $src_dir/exp/sdm1/tri3_cleaned_dev_ihmdata $phone_map $dir + $ali_dir $phone_map $dir fi echo "A 1" > $dir/channel_map cat $src_dir/data/sdm1/dev/reco2file_and_channel | \ utils/apply_map.pl -f 3 $dir/channel_map > $dir/reco2file_and_channel -cat $src_dir/data/sdm1/dev_ihmdata/reco2utt | \ - awk 'BEGIN{i=1} {print $1" "i; i++;}' > \ - $src_dir/data/sdm1/dev_ihmdata/reco.txt +# Map each IHM recording to a unique integer id. +# This will be the "speaker label" as each recording is assumed to have a +# single speaker. +cat $src_dir/data/sdm1/dev_ihmdata_oraclespk/reco2utt | \ + awk 'BEGIN{i=1} {print $1" "1":"i" 100000:100000"; i++;}' > \ + $src_dir/data/sdm1/dev_ihmdata_oraclespk/reco.txt if [ $stage -le 5 ]; then utils/data/get_reco2num_frames.sh --frame-shift 0.01 --frame-overlap 0.015 \ - --cmd queue.pl --nj 18 \ + --cmd "$train_cmd" --nj 18 \ $src_dir/data/sdm1/dev - # Get a filter that selects only regions within the manual segments. - $train_cmd $dir/log/get_manual_segments_regions.log \ - segmentation-init-from-segments --shift-to-zero=false $src_dir/data/sdm1/dev/segments ark:- \| \ + # Get a filter that changes the first and the last segment region outside + # the manual segmentation (usually some preparation lines) that are not + # transcribed. + $train_cmd $dir/log/interior_regions.log \ + segmentation-init-from-segments --shift-to-zero=false --frame-overlap=0.0 $src_dir/data/sdm1/dev/segments ark:- \| \ segmentation-combine-segments-to-recordings ark:- ark,t:$src_dir/data/sdm1/dev/reco2utt ark:- \| \ segmentation-create-subsegments --filter-label=1 --subsegment-label=1 \ "ark:segmentation-init-from-lengths --label=0 ark,t:$src_dir/data/sdm1/dev/reco2num_frames ark:- |" ark:- ark,t:- \| \ - perl -ane '$F[3] = 10000; $F[$#F-1] = 10000; print join(" ", @F) . "\n";' \| \ + perl -ane '$F[3] = 100000; $F[$#F-1] = 100000; print join(" ", @F) . "\n";' \| \ segmentation-post-process --merge-labels=0:1 --merge-dst-label=1 ark:- ark:- \| \ - segmentation-post-process --merge-labels=10000 --merge-dst-label=0 --merge-adjacent-segments \ - --max-intersegment-length=10000 ark,t:- \ + segmentation-post-process --merge-labels=100000 --merge-dst-label=0 --merge-adjacent-segments \ + --max-intersegment-length=1000000 ark,t:- \ + "ark:| gzip -c > $dir/interior_regions.seg.gz" + + $train_cmd $dir/log/get_manual_segments_regions.log \ + segmentation-init-from-segments --shift-to-zero=false --frame-overlap=0.0 $src_dir/data/sdm1/dev/segments ark:- \| \ + segmentation-combine-segments-to-recordings ark:- ark,t:$src_dir/data/sdm1/dev/reco2utt ark:- \| \ + segmentation-create-subsegments --filter-label=1 --subsegment-label=1 \ + "ark:segmentation-init-from-lengths --label=100000 ark,t:$src_dir/data/sdm1/dev/reco2num_frames ark:- |" ark:- ark:- \| \ + segmentation-post-process --merge-labels=100000 --merge-dst-label=0 --merge-adjacent-segments \ + --max-intersegment-length=1000000 ark,t:- \ "ark:| gzip -c > $dir/manual_segments_regions.seg.gz" fi if [ $stage -le 6 ]; then # Reference RTTM where SPEECH frames are obtainted by combining IHM VAD alignments $train_cmd $dir/log/get_ref_spk_seg.log \ - segmentation-combine-segments scp:$dir/sad_seg.scp \ - "ark:segmentation-init-from-segments --shift-to-zero=false $src_dir/data/sdm1/dev_ihmdata/segments ark:- |" \ - ark,t:$src_dir/data/sdm1/dev_ihmdata/reco2utt ark:- \| \ - segmentation-copy --keep-label=1 ark:- ark:- \| \ - segmentation-copy --utt2label-rspecifier=ark,t:$src_dir/data/sdm1/dev_ihmdata/reco.txt \ + segmentation-combine-segments --include-missing-utt-level-segmentations scp:$dir/sad_seg.scp \ + "ark:segmentation-init-from-segments --shift-to-zero=false --frame-overlap=0.0 --label=100000 $src_dir/data/sdm1/dev_ihmdata_oraclespk/segments ark:- |" \ + ark,t:$src_dir/data/sdm1/dev_ihmdata_oraclespk/reco2utt ark:- \| \ + segmentation-post-process --remove-labels=0 ark:- ark:- \| \ + segmentation-copy --utt2label-map-rspecifier=ark,t:$src_dir/data/sdm1/dev_ihmdata_oraclespk/reco.txt \ ark:- ark:- \| \ segmentation-merge-recordings \ "ark,t:utils/utt2spk_to_spk2utt.pl $src_dir/data/sdm1/dev_ihmdata/ihm2sdm_reco |" \ @@ -123,41 +169,95 @@ fi if [ $stage -le 7 ]; then # To get the actual RTTM, we need to add no-score - $train_cmd $dir/log/get_ref_rttm.log \ + $train_cmd $dir/log/get_ref_spk_rttm_manual_seg.log \ + export PATH=$KALDI_ROOT/tools/sctk/bin:$PATH '&&' \ + segmentation-copy --keep-label=0 "ark:gunzip -c $dir/manual_segments_regions.seg.gz |" ark:- \| \ + segmentation-post-process --merge-labels=0 --merge-dst-label=100000 ark:- ark:- \| \ + segmentation-merge "ark:gunzip -c $dir/ref_spk_seg.gz |" ark:- ark:- \| \ + segmentation-to-rttm --reco2file-and-channel=$dir/reco2file_and_channel \ + --map-to-speech-and-sil=false --no-score-label=100000 ark:- - \| \ + rttmSmooth.pl -s 0 \| rttmSort.pl '>' $dir/ref_spk_manual_seg.rttm + + $train_cmd $dir/log/get_ref_spk_rttm_interior.log \ + export PATH=$KALDI_ROOT/tools/sctk/bin:$PATH '&&' \ + segmentation-copy --keep-label=0 "ark:gunzip -c $dir/interior_regions.seg.gz |" ark:- \| \ + segmentation-post-process --merge-labels=0 --merge-dst-label=100000 ark:- ark:- \| \ + segmentation-merge "ark:gunzip -c $dir/ref_spk_seg.gz |" ark:- ark:- \| \ + segmentation-to-rttm --reco2file-and-channel=$dir/reco2file_and_channel \ + --map-to-speech-and-sil=false --no-score-label=100000 ark:- - \| \ + rttmSmooth.pl -s 0 \| rttmSort.pl '>' $dir/ref_spk_interior.rttm + + $train_cmd $dir/log/get_ref_rttm_manual_seg.log \ + export PATH=$KALDI_ROOT/tools/sctk/bin:$PATH '&&' \ segmentation-get-stats --lengths-rspecifier=ark,t:$src_dir/data/sdm1/dev/reco2num_frames \ "ark:gunzip -c $dir/ref_spk_seg.gz | segmentation-post-process --remove-labels=0 ark:- ark:- |" \ - ark:/dev/null ark:- \| \ + ark:/dev/null ark:- ark:/dev/null \| \ segmentation-init-from-ali ark:- ark:- \| \ segmentation-post-process --merge-labels=1:2:3:4:5:6:7:8:9:10 --merge-dst-label=1 \ --merge-adjacent-segments --max-intersegment-length=10000 ark:- ark:- \| \ - segmentation-create-subsegments --filter-label=0 --subsegment-label=10000 \ + segmentation-create-subsegments --filter-label=0 --subsegment-label=100000 \ ark:- "ark:gunzip -c $dir/manual_segments_regions.seg.gz |" ark:- \| \ segmentation-post-process --merge-adjacent-segments --max-intersegment-length=10000 ark:- ark:- \| \ segmentation-to-rttm --reco2file-and-channel=$dir/reco2file_and_channel \ - --no-score-label=10000 ark:- $dir/ref.rttm + --no-score-label=100000 ark:- - \| \ + rttmSmooth.pl -s 0 \| rttmSort.pl '>' $dir/ref_manual_seg.rttm + + $train_cmd $dir/log/get_ref_rttm_interior.log \ + export PATH=$KALDI_ROOT/tools/sctk/bin:$PATH '&&' \ + segmentation-get-stats --lengths-rspecifier=ark,t:$src_dir/data/sdm1/dev/reco2num_frames \ + "ark:gunzip -c $dir/ref_spk_seg.gz | segmentation-post-process --remove-labels=0 ark:- ark:- |" \ + ark:/dev/null ark:- ark:/dev/null \| \ + segmentation-init-from-ali ark:- ark:- \| \ + segmentation-post-process --merge-labels=1:2:3:4:5:6:7:8:9:10 --merge-dst-label=1 \ + --merge-adjacent-segments --max-intersegment-length=10000 ark:- ark:- \| \ + segmentation-create-subsegments --filter-label=0 --subsegment-label=100000 \ + ark:- "ark:gunzip -c $dir/interior_regions.seg.gz |" ark:- \| \ + segmentation-post-process --merge-adjacent-segments --max-intersegment-length=10000 ark:- ark:- \| \ + segmentation-to-rttm --reco2file-and-channel=$dir/reco2file_and_channel \ + --no-score-label=100000 ark:- - \| \ + rttmSmooth.pl -s 0 \| rttmSort.pl '>' $dir/ref_interior.rttm # Get RTTM for overlapped speech detection with 3 classes # 0 -> SILENCE, 1 -> SINGLE_SPEAKER, 2 -> OVERLAP - $train_cmd $dir/log/get_overlapping_rttm.log \ + $train_cmd $dir/log/get_overlapping_rttm_manual_seg.log \ + export PATH=$KALDI_ROOT/tools/sctk/bin:$PATH '&&' \ segmentation-get-stats --lengths-rspecifier=ark,t:$src_dir/data/sdm1/dev/reco2num_frames \ "ark:gunzip -c $dir/ref_spk_seg.gz | segmentation-post-process --remove-labels=0 ark:- ark:- |" \ - ark:/dev/null ark:- \| \ + ark:/dev/null ark:- ark:/dev/null \| \ segmentation-init-from-ali ark:- ark:- \| \ segmentation-post-process --merge-labels=2:3:4:5:6:7:8:9:10 --merge-dst-label=2 \ --merge-adjacent-segments --max-intersegment-length=10000 ark:- ark:- \| \ - segmentation-create-subsegments --filter-label=0 --subsegment-label=10000 \ + segmentation-create-subsegments --filter-label=0 --subsegment-label=100000 \ ark:- "ark:gunzip -c $dir/manual_segments_regions.seg.gz |" ark:- \| \ segmentation-post-process --merge-adjacent-segments --max-intersegment-length=10000 ark:- ark:- \| \ segmentation-to-rttm --map-to-speech-and-sil=false --reco2file-and-channel=$dir/reco2file_and_channel \ - --no-score-label=10000 ark:- $dir/overlapping_speech_ref.rttm + --no-score-label=100000 ark:- - \| \ + rttmSmooth.pl -s 0 \| rttmSort.pl '>' $dir/overlapping_speech_ref_manual_seg.rttm + + $train_cmd $dir/log/get_overlapping_rttm_manual_seg.log \ + export PATH=$KALDI_ROOT/tools/sctk/bin:$PATH '&&' \ + segmentation-get-stats --lengths-rspecifier=ark,t:$src_dir/data/sdm1/dev/reco2num_frames \ + "ark:gunzip -c $dir/ref_spk_seg.gz | segmentation-post-process --remove-labels=0 ark:- ark:- |" \ + ark:/dev/null ark:- ark:/dev/null \| \ + segmentation-init-from-ali ark:- ark:- \| \ + segmentation-post-process --merge-labels=2:3:4:5:6:7:8:9:10 --merge-dst-label=2 \ + --merge-adjacent-segments --max-intersegment-length=10000 ark:- ark:- \| \ + segmentation-create-subsegments --filter-label=0 --subsegment-label=100000 \ + ark:- "ark:gunzip -c $dir/interior_regions.seg.gz |" ark:- \| \ + segmentation-post-process --merge-adjacent-segments --max-intersegment-length=10000 ark:- ark:- \| \ + segmentation-to-rttm --map-to-speech-and-sil=false --reco2file-and-channel=$dir/reco2file_and_channel \ + --no-score-label=100000 ark:- - \| \ + rttmSmooth.pl -s 0 \| rttmSort.pl'>' $dir/overlapping_speech_ref_interior.rttm fi +exit 0 + if [ $stage -le 8 ]; then # Get a filter that selects only regions of speech $train_cmd $dir/log/get_speech_filter.log \ segmentation-get-stats --lengths-rspecifier=ark,t:$src_dir/data/sdm1/dev/reco2num_frames \ "ark:gunzip -c $dir/ref_spk_seg.gz | segmentation-post-process --remove-labels=0 ark:- ark:- |" \ - ark:/dev/null ark:- \| \ + ark:/dev/null ark:- ark:/dev/null \| \ segmentation-init-from-ali ark:- ark:- \| \ segmentation-post-process --merge-labels=1:2:3:4:5:6:7:8:9:10 --merge-dst-label=1 ark:- ark:- \| \ segmentation-create-subsegments --filter-label=0 --subsegment-label=0 \ @@ -181,10 +281,10 @@ sad_dir=${nnet_dir}/sad_ami_sdm1_dev_whole_bp/ hyp_dir=${hyp_dir}_seg if [ $stage -le 10 ]; then - utils/data/get_reco2utt.sh $src_dir/data/sdm1/dev_ihmdata + utils/data/get_reco2utt.sh $src_dir/data/sdm1/dev_ihmdata_oraclespk utils/data/get_reco2utt.sh $hyp_dir - segmentation-init-from-segments --shift-to-zero=false $hyp_dir/segments ark:- | \ + segmentation-init-from-segments --shift-to-zero=false --frame-overlap=0.0 $hyp_dir/segments ark:- | \ segmentation-combine-segments-to-recordings ark:- ark,t:$hyp_dir/reco2utt ark:- | \ segmentation-to-ali --length-tolerance=48 --lengths-rspecifier=ark,t:$src_dir/data/sdm1/dev/reco2num_frames \ ark:- ark:- | \ @@ -301,7 +401,7 @@ fi if [ $stage -le 17 ]; then steps/segmentation/decode_sad.sh \ - --acwt 1 --beam 10 --max-active 7000 \ + --acwt 1 --beam 10 --max-active 7000 --iter trans \ $seg_dir/graph_test $likes_dir $seg_dir fi @@ -323,7 +423,7 @@ EOF $train_cmd $dir/log/get_overlapping_rttm.log \ segmentation-get-stats --lengths-rspecifier=ark,t:$src_dir/data/sdm1/dev/reco2num_frames \ "ark:gunzip -c $dir/ref_spk_seg.gz | segmentation-post-process --remove-labels=0 ark:- ark:- |" \ - ark:/dev/null ark:- \| \ + ark:/dev/null ark:- ark:/dev/null \| \ segmentation-init-from-ali ark:- ark:- \| \ segmentation-post-process --merge-labels=2:3:4:5:6:7:8:9:10 --merge-dst-label=2 ark:- ark:- \| \ segmentation-create-subsegments --filter-label=0 --subsegment-label=10000 \ diff --git a/egs/aspire/s5/local/segmentation/run_segmentation_amitrain.sh b/egs/aspire/s5/local/segmentation/run_segmentation_amitrain.sh new file mode 100755 index 00000000000..59f618b5e8e --- /dev/null +++ b/egs/aspire/s5/local/segmentation/run_segmentation_amitrain.sh @@ -0,0 +1,373 @@ +#! /bin/bash + +# Copyright 2016 Vimal Manohar +# Apache 2.0. + +. cmd.sh +. path.sh + +set -e +set -o pipefail +set -u + +stage=-1 +nnet_dir=exp/nnet3_sad_snr/nnet_tdnn_k_n4 +extra_left_context=100 +extra_right_context=20 +task=SAD +iter=final + +segmentation_stage=-1 +sil_prior=0.7 +speech_prior=0.3 +min_silence_duration=30 +min_speech_duration=10 +frame_subsampling_factor=3 +ali_dir=/export/a09/vmanoha1/workspace_asr_diarization/egs/ami/s5b/exp/ihm/nnet3_cleaned/tdnn_sp_ali_train_ihmdata_oraclespk + +. utils/parse_options.sh + +export PATH=$KALDI_ROOT/tools/sctk/bin:$PATH + +src_dir=/export/a09/vmanoha1/workspace_asr_diarization/egs/ami/s5b # AMI src_dir +dir=exp/sad_ami_sdm1_train/ref + +mkdir -p $dir + +# Expecting user to have done run.sh to run the AMI recipe in $src_dir for +# both sdm and ihm microphone conditions + +if [ $stage -le 1 ]; then + ( + cd $src_dir + # local/prepare_parallel_train_data.sh --train-set train sdm1 + + awk '{print $1" "$2}' $src_dir/data/ihm/train/segments > \ + $src_dir/data/ihm/train/utt2reco + awk '{print $1" "$2}' $src_dir/data/sdm1/train/segments > \ + $src_dir/data/sdm1/train/utt2reco + + cat $src_dir/data/sdm1/train_ihmdata/ihmutt2utt | \ + utils/apply_map.pl -f 1 $src_dir/data/ihm/train/utt2reco | \ + utils/apply_map.pl -f 2 $src_dir/data/sdm1/train/utt2reco | \ + sort -u > $src_dir/data/sdm1/train_ihmdata/ihm2sdm_reco + ) +fi + +if [ $stage -le 2 ]; then + ( + cd $src_dir + + utils/copy_data_dir.sh $src_dir/data/sdm1/train_ihmdata \ + $src_dir/data/sdm1/train_ihmdata_oraclespk + + cat $src_dir/data/ihm/train/utt2spk | \ + utils/apply_map.pl -f 1 $src_dir/data/sdm1/train_ihmdata/ihmutt2utt > \ + $src_dir/data/sdm1/train_ihmdata_oraclespk/utt2spk.temp + + cat $src_dir/data/sdm1/train_ihmdata_oraclespk/utt2spk.temp | \ + awk '{print $1" "$2"-"$1}' > \ + $src_dir/data/sdm1/train_ihmdata_oraclespk/utt2newutt + + utils/apply_map.pl -f 1 $src_dir/data/sdm1/train_ihmdata_oraclespk/utt2newutt \ + < $src_dir/data/sdm1/train_ihmdata_oraclespk/utt2spk.temp > \ + $src_dir/data/sdm1/train_ihmdata_oraclespk/utt2spk + + for f in feats.scp segments text; do + utils/apply_map.pl -f 1 $src_dir/data/sdm1/train_ihmdata_oraclespk/utt2newutt \ + < $src_dir/data/sdm1/train_ihmdata/$f > \ + $src_dir/data/sdm1/train_ihmdata_oraclespk/$f + done + + rm $src_dir/data/sdm1/train_ihmdata_oraclespk/{spk2utt,cmvn.scp} + utils/fix_data_dir.sh \ + $src_dir/data/sdm1/train_ihmdata_oraclespk + + utils/data/get_reco2utt.sh $src_dir/data/sdm1/train_ihmdata_oraclespk + ) +fi + +phone_map=$dir/phone_map +if [ $stage -le 2 ]; then + steps/segmentation/get_sad_map.py \ + $src_dir/data/lang | utils/sym2int.pl -f 1 $src_dir/data/lang/phones.txt > \ + $phone_map +fi + +if [ -z $ali_dir ]; then + if [ $stage -le 3 ]; then + # Expecting user to have run local/run_cleanup_segmentation.sh in $src_dir + ( + cd $src_dir + steps/align_fmllr.sh --nj 18 --cmd "$train_cmd" \ + data/sdm1/train_ihmdata_oraclespk data/lang \ + exp/ihm/tri3_cleaned \ + exp/sdm1/tri3_cleaned_train_ihmdata_oraclespk + ) + fi + ali_dir=exp/sdm1/tri3_cleaned_ali_train_ihmdata_oraclespk +fi + +if [ $stage -le 4 ]; then + steps/segmentation/internal/convert_ali_to_vad.sh --cmd "$train_cmd" \ + $ali_dir $phone_map $dir +fi + +echo "A 1" > $dir/channel_map +cat $src_dir/data/sdm1/train/reco2file_and_channel | \ + utils/apply_map.pl -f 3 $dir/channel_map > $dir/reco2file_and_channel + +# Map each IHM recording to a unique integer id. +# This will be the "speaker label" as each recording is assumed to have a +# single speaker. +cat $src_dir/data/sdm1/train_ihmdata_oraclespk/reco2utt | \ + awk 'BEGIN{i=1} {print $1" "1":"i; i++;}' > \ + $src_dir/data/sdm1/train_ihmdata_oraclespk/reco.txt +if [ $stage -le 5 ]; then + utils/data/get_reco2num_frames.sh --frame-shift 0.01 --frame-overlap 0.015 \ + --cmd "$train_cmd" --nj 18 \ + $src_dir/data/sdm1/train + + # Get a filter that changes the first and the last segment region outside + # the manual segmentation (usually some preparation lines) that are not + # transcribed. + $train_cmd $dir/log/interior_regions.log \ + segmentation-init-from-segments --shift-to-zero=false --frame-overlap=0.0 $src_dir/data/sdm1/train/segments ark:- \| \ + segmentation-combine-segments-to-recordings ark:- ark,t:$src_dir/data/sdm1/train/reco2utt ark:- \| \ + segmentation-create-subsegments --filter-label=1 --subsegment-label=1 \ + "ark:segmentation-init-from-lengths --label=0 ark,t:$src_dir/data/sdm1/train/reco2num_frames ark:- |" ark:- ark,t:- \| \ + perl -ane '$F[3] = 10000; $F[$#F-1] = 10000; print join(" ", @F) . "\n";' \| \ + segmentation-post-process --merge-labels=0:1 --merge-dst-label=1 ark:- ark:- \| \ + segmentation-post-process --merge-labels=10000 --merge-dst-label=0 --merge-adjacent-segments \ + --max-intersegment-length=1000000 ark,t:- \ + "ark:| gzip -c > $dir/interior_regions.seg.gz" + + $train_cmd $dir/log/get_manual_segments_regions.log \ + segmentation-init-from-segments --shift-to-zero=false --frame-overlap=0.0 $src_dir/data/sdm1/train/segments ark:- \| \ + segmentation-combine-segments-to-recordings ark:- ark,t:$src_dir/data/sdm1/train/reco2utt ark:- \| \ + segmentation-create-subsegments --filter-label=1 --subsegment-label=1 \ + "ark:segmentation-init-from-lengths --label=0 ark,t:$src_dir/data/sdm1/train/reco2num_frames ark:- |" ark:- ark,t:- \| \ + perl -ane '$F[3] = 10000; $F[$#F-1] = 10000; print join(" ", @F) . "\n";' \| \ + segmentation-post-process --merge-labels=0:1 --merge-dst-label=1 ark:- ark:- \| \ + segmentation-post-process --merge-labels=10000 --merge-dst-label=0 --merge-adjacent-segments \ + --max-intersegment-length=10000 ark,t:- \ + "ark:| gzip -c > $dir/manual_segments_regions.seg.gz" +fi + +if [ $stage -le 6 ]; then + # Reference RTTM where SPEECH frames are obtainted by combining IHM VAD alignments + $train_cmd $dir/log/get_ref_spk_seg.log \ + segmentation-combine-segments scp:$dir/sad_seg.scp \ + "ark:segmentation-init-from-segments --shift-to-zero=false --frame-overlap=0.0 $src_dir/data/sdm1/train_ihmdata_oraclespk/segments ark:- |" \ + ark,t:$src_dir/data/sdm1/train_ihmdata_oraclespk/reco2utt ark:- \| \ + segmentation-copy --keep-label=1 ark:- ark:- \| \ + segmentation-copy --utt2label-map-rspecifier=ark,t:$src_dir/data/sdm1/train_ihmdata/reco.txt \ + ark:- ark:- \| \ + segmentation-merge-recordings \ + "ark,t:utils/utt2spk_to_spk2utt.pl $src_dir/data/sdm1/train_ihmdata/ihm2sdm_reco |" \ + ark:- "ark:| gzip -c > $dir/ref_spk_seg.gz" +fi + +if [ $stage -le 7 ]; then + # To get the actual RTTM, we need to add no-score + $train_cmd $dir/log/get_ref_spk_rttm_manual_seg.log \ + export PATH=$KALDI_ROOT/tools/sctk/bin:$PATH '&&' \ + segmentation-copy --keep-label=0 "ark:gunzip -c $dir/manual_segments_regions.seg.gz |" ark:- \| \ + segmentation-post-process --merge-labels=0 --merge-dst-label=10000 ark:- ark:- \| \ + segmentation-merge "ark:gunzip -c $dir/ref_spk_seg.gz |" ark:- ark:- \| \ + segmentation-to-rttm --reco2file-and-channel=$dir/reco2file_and_channel \ + --map-to-speech-and-sil=false --no-score-label=10000 ark:- - \| \ + rttmSmooth.pl -s 0 \| rttmSort.pl '>' $dir/ref_spk_manual_seg.rttm + + $train_cmd $dir/log/get_ref_spk_rttm_interior.log \ + export PATH=$KALDI_ROOT/tools/sctk/bin:$PATH '&&' \ + segmentation-copy --keep-label=0 "ark:gunzip -c $dir/interior_regions.seg.gz |" ark:- \| \ + segmentation-post-process --merge-labels=0 --merge-dst-label=10000 ark:- ark:- \| \ + segmentation-merge "ark:gunzip -c $dir/ref_spk_seg.gz |" ark:- ark:- \| \ + segmentation-to-rttm --reco2file-and-channel=$dir/reco2file_and_channel \ + --map-to-speech-and-sil=false --no-score-label=10000 ark:- - \| \ + rttmSmooth.pl -s 0 \| rttmSort.pl '>' $dir/ref_spk_interior.rttm + + $train_cmd $dir/log/get_ref_rttm_manual_seg.log \ + export PATH=$KALDI_ROOT/tools/sctk/bin:$PATH '&&' \ + segmentation-get-stats --lengths-rspecifier=ark,t:$src_dir/data/sdm1/train/reco2num_frames \ + "ark:gunzip -c $dir/ref_spk_seg.gz | segmentation-post-process --remove-labels=0 ark:- ark:- |" \ + ark:/dev/null ark:- ark:/dev/null \| \ + segmentation-init-from-ali ark:- ark:- \| \ + segmentation-post-process --merge-labels=1:2:3:4:5:6:7:8:9:10 --merge-dst-label=1 \ + --merge-adjacent-segments --max-intersegment-length=10000 ark:- ark:- \| \ + segmentation-create-subsegments --filter-label=0 --subsegment-label=10000 \ + ark:- "ark:gunzip -c $dir/manual_segments_regions.seg.gz |" ark:- \| \ + segmentation-post-process --merge-adjacent-segments --max-intersegment-length=10000 ark:- ark:- \| \ + segmentation-to-rttm --reco2file-and-channel=$dir/reco2file_and_channel \ + --no-score-label=10000 ark:- - \| \ + rttmSmooth.pl -s 0 \| rttmSort.pl '>' $dir/ref_manual_seg.rttm + + $train_cmd $dir/log/get_ref_rttm_interior.log \ + export PATH=$KALDI_ROOT/tools/sctk/bin:$PATH '&&' \ + segmentation-get-stats --lengths-rspecifier=ark,t:$src_dir/data/sdm1/train/reco2num_frames \ + "ark:gunzip -c $dir/ref_spk_seg.gz | segmentation-post-process --remove-labels=0 ark:- ark:- |" \ + ark:/dev/null ark:- ark:/dev/null \| \ + segmentation-init-from-ali ark:- ark:- \| \ + segmentation-post-process --merge-labels=1:2:3:4:5:6:7:8:9:10 --merge-dst-label=1 \ + --merge-adjacent-segments --max-intersegment-length=10000 ark:- ark:- \| \ + segmentation-create-subsegments --filter-label=0 --subsegment-label=10000 \ + ark:- "ark:gunzip -c $dir/interior_regions.seg.gz |" ark:- \| \ + segmentation-post-process --merge-adjacent-segments --max-intersegment-length=10000 ark:- ark:- \| \ + segmentation-to-rttm --reco2file-and-channel=$dir/reco2file_and_channel \ + --no-score-label=10000 ark:- - \| \ + rttmSmooth.pl -s 0 \| rttmSort.pl '>' $dir/ref_interior.rttm + + # Get RTTM for overlapped speech detection with 3 classes + # 0 -> SILENCE, 1 -> SINGLE_SPEAKER, 2 -> OVERLAP + $train_cmd $dir/log/get_overlapping_rttm_manual_seg.log \ + export PATH=$KALDI_ROOT/tools/sctk/bin:$PATH '&&' \ + segmentation-get-stats --lengths-rspecifier=ark,t:$src_dir/data/sdm1/train/reco2num_frames \ + "ark:gunzip -c $dir/ref_spk_seg.gz | segmentation-post-process --remove-labels=0 ark:- ark:- |" \ + ark:/dev/null ark:- ark:/dev/null \| \ + segmentation-init-from-ali ark:- ark:- \| \ + segmentation-post-process --merge-labels=2:3:4:5:6:7:8:9:10 --merge-dst-label=2 \ + --merge-adjacent-segments --max-intersegment-length=10000 ark:- ark:- \| \ + segmentation-create-subsegments --filter-label=0 --subsegment-label=10000 \ + ark:- "ark:gunzip -c $dir/manual_segments_regions.seg.gz |" ark:- \| \ + segmentation-post-process --merge-adjacent-segments --max-intersegment-length=10000 ark:- ark:- \| \ + segmentation-to-rttm --map-to-speech-and-sil=false --reco2file-and-channel=$dir/reco2file_and_channel \ + --no-score-label=10000 ark:- - \| \ + rttmSmooth.pl -s 0 \| rttmSort.pl '>' $dir/overlapping_speech_ref_manual_seg.rttm + + $train_cmd $dir/log/get_overlapping_rttm_manual_seg.log \ + export PATH=$KALDI_ROOT/tools/sctk/bin:$PATH '&&' \ + segmentation-get-stats --lengths-rspecifier=ark,t:$src_dir/data/sdm1/train/reco2num_frames \ + "ark:gunzip -c $dir/ref_spk_seg.gz | segmentation-post-process --remove-labels=0 ark:- ark:- |" \ + ark:/dev/null ark:- ark:/dev/null \| \ + segmentation-init-from-ali ark:- ark:- \| \ + segmentation-post-process --merge-labels=2:3:4:5:6:7:8:9:10 --merge-dst-label=2 \ + --merge-adjacent-segments --max-intersegment-length=10000 ark:- ark:- \| \ + segmentation-create-subsegments --filter-label=0 --subsegment-label=10000 \ + ark:- "ark:gunzip -c $dir/interior_regions.seg.gz |" ark:- \| \ + segmentation-post-process --merge-adjacent-segments --max-intersegment-length=10000 ark:- ark:- \| \ + segmentation-to-rttm --map-to-speech-and-sil=false --reco2file-and-channel=$dir/reco2file_and_channel \ + --no-score-label=10000 ark:- - \| \ + rttmSmooth.pl -s 0 \| rttmSort.pl'>' $dir/overlapping_speech_ref_interior.rttm +fi + +exit 0 + +if [ $stage -le 8 ]; then + # Get a filter that selects only regions of speech + $train_cmd $dir/log/get_speech_filter.log \ + segmentation-get-stats --lengths-rspecifier=ark,t:$src_dir/data/sdm1/train/reco2num_frames \ + "ark:gunzip -c $dir/ref_spk_seg.gz | segmentation-post-process --remove-labels=0 ark:- ark:- |" \ + ark:/dev/null ark:- ark:/dev/null \| \ + segmentation-init-from-ali ark:- ark:- \| \ + segmentation-post-process --merge-labels=1:2:3:4:5:6:7:8:9:10 --merge-dst-label=1 ark:- ark:- \| \ + segmentation-create-subsegments --filter-label=0 --subsegment-label=0 \ + ark:- "ark:gunzip -c $dir/manual_segments_regions.seg.gz |" ark:- \| \ + segmentation-post-process --merge-adjacent-segments --max-intersegment-length=10000 \ + ark:- "ark:| gzip -c > $dir/manual_segments_speech_regions.seg.gz" +fi + +hyp_dir=${nnet_dir}/segmentation_ovlp_ami_sdm1_train_whole_bp/ami_sdm1_train + +if [ $stage -le 12 ]; then + steps/segmentation/do_segmentation_data_dir_generic.sh --reco-nj 18 \ + --mfcc-config conf/mfcc_hires_bp.conf --feat-affix bp --do-downsampling true \ + --extra-left-context $extra_left_context --extra-right-context $extra_right_context \ + --segmentation-config conf/segmentation_ovlp.conf \ + --output-name output-overlapping_sad \ + --min-durations 30:10:10 --priors 0.5:0.35:0.15 \ + --sad-name ovlp_sad --segmentation-name segmentation_ovlp_sad \ + --frame-subsampling-factor $frame_subsampling_factor --iter $iter \ + --stage $segmentation_stage \ + $src_dir/data/sdm1/train $nnet_dir mfcc_hires_bp $hyp_dir +fi + +likes_dir=${nnet_dir}/ovlp_sad_ami_sdm1_train_whole_bp/ + +hyp_dir=${hyp_dir}_seg +mkdir -p $hyp_dir + +seg_dir=${nnet_dir}/segmentation_ovlp_sad_ami_sdm1_train_whole_bp/ +lang=${seg_dir}/lang + +if [ $stage -le 14 ]; then +mkdir -p $lang +steps/segmentation/internal/prepare_sad_lang.py \ + --phone-transition-parameters="--phone-list=1 --min-duration=10 --end-transition-probability=0.1" \ + --phone-transition-parameters="--phone-list=2 --min-duration=3 --end-transition-probability=0.1" \ + --phone-transition-parameters="--phone-list=3 --min-duration=3 --end-transition-probability=0.1" $lang +cp $lang/phones.txt $lang/words.txt + +feat_dim=2 # dummy. We don't need this. +$train_cmd $seg_dir/log/create_transition_model.log gmm-init-mono \ + $lang/topo $feat_dim - $seg_dir/tree \| \ + copy-transition-model --binary=false - $seg_dir/trans.mdl || exit 1 +fi + +if [ $stage -le 15 ]; then + +cat > $lang/word2prior < $lang/G.fst +fi + +if [ $stage -le 16 ]; then + $train_cmd $seg_dir/log/make_vad_graph.log \ + steps/segmentation/internal/make_sad_graph.sh --iter trans \ + $lang $seg_dir $seg_dir/graph_test || exit 1 +fi + +if [ $stage -le 17 ]; then + steps/segmentation/decode_sad.sh \ + --acwt 1 --beam 10 --max-active 7000 \ + $seg_dir/graph_test $likes_dir $seg_dir +fi + +if [ $stage -le 18 ]; then + cat < $hyp_dir/labels_map +1 0 +2 1 +3 2 +EOF + gunzip -c $seg_dir/ali.*.gz | \ + segmentation-init-from-ali ark:- ark:- | \ + segmentation-copy --frame-subsampling-factor=$frame_subsampling_factor \ + --label-map=$hyp_dir/labels_map ark:- ark:- | \ + segmentation-to-rttm --map-to-speech-and-sil=false \ + --reco2file-and-channel=$dir/reco2file_and_channel ark:- $hyp_dir/sys.rttm +fi +# Get RTTM for overlapped speech detection with 3 classes +# 0 -> SILENCE, 1 -> SINGLE_SPEAKER, 2 -> OVERLAP +$train_cmd $dir/log/get_overlapping_rttm.log \ + segmentation-get-stats --lengths-rspecifier=ark,t:$src_dir/data/sdm1/train/reco2num_frames \ + "ark:gunzip -c $dir/ref_spk_seg.gz | segmentation-post-process --remove-labels=0 ark:- ark:- |" \ + ark:/dev/null ark:- \| \ + segmentation-init-from-ali ark:- ark:- \| \ + segmentation-post-process --merge-labels=2:3:4:5:6:7:8:9:10 --merge-dst-label=2 ark:- ark:- \| \ + segmentation-create-subsegments --filter-label=0 --subsegment-label=10000 \ + ark:- "ark:gunzip -c $dir/manual_segments_regions.seg.gz |" ark:- \| \ + segmentation-post-process --merge-adjacent-segments --max-intersegment-length=10000 ark:- ark:- \| \ + segmentation-to-rttm --map-to-speech-and-sil=false --reco2file-and-channel=$dir/reco2file_and_channel \ + --no-score-label=10000 ark:- $dir/overlapping_speech_ref.rttm + +if [ $stage -le 19 ]; then + cat < Date: Wed, 1 Mar 2017 17:11:00 -0500 Subject: [PATCH 197/213] asr_diarization: Trap PIPE failure in get_egs.sh --- egs/wsj/s5/steps/nnet3/get_egs.sh | 2 ++ 1 file changed, 2 insertions(+) diff --git a/egs/wsj/s5/steps/nnet3/get_egs.sh b/egs/wsj/s5/steps/nnet3/get_egs.sh index 79bfc25fff6..bfa800df4db 100755 --- a/egs/wsj/s5/steps/nnet3/get_egs.sh +++ b/egs/wsj/s5/steps/nnet3/get_egs.sh @@ -12,6 +12,8 @@ # right, and this ends up getting shared. This is at the expense of slightly # higher disk I/O while training. +set -o pipefail +trap "" PIPE # Begin configuration section. cmd=run.pl From 805e300ff96ca1bbd6bdcb670eab1474d3420094 Mon Sep 17 00:00:00 2001 From: Vimal Manohar Date: Fri, 3 Mar 2017 11:38:55 -0500 Subject: [PATCH 198/213] asr_diarization: Fix merging with kaldi 5.1 master --- src/Makefile | 10 +++++----- src/decoder/Makefile | 3 +-- src/lat/sausages.cc | 6 +++--- src/nnet3bin/nnet3-get-egs-dense-targets.cc | 16 +++++++++------- src/nnet3bin/nnet3-get-egs.cc | 5 +++-- src/nnet3bin/nnet3-info.cc | 17 ----------------- src/segmenter/Makefile | 6 +++--- src/segmenterbin/intersect-int-vectors.cc | 8 +++++--- src/simplehmm/Makefile | 2 +- .../simple-hmm-graph-compiler.cc | 2 +- .../simple-hmm-graph-compiler.h | 0 src/simplehmm/simple-hmm.h | 2 +- src/simplehmmbin/Makefile | 4 ++-- .../compile-train-simple-hmm-graphs.cc | 2 +- 14 files changed, 35 insertions(+), 48 deletions(-) rename src/{decoder => simplehmm}/simple-hmm-graph-compiler.cc (98%) rename src/{decoder => simplehmm}/simple-hmm-graph-compiler.h (100%) diff --git a/src/Makefile b/src/Makefile index fcfdc1412d8..b7ac6f60bd4 100644 --- a/src/Makefile +++ b/src/Makefile @@ -6,13 +6,13 @@ SHELL := /bin/bash SUBDIRS = base matrix util feat tree thread gmm transform \ - fstext hmm lm decoder lat kws cudamatrix nnet segmenter simplehmm \ + fstext hmm simplehmm lm decoder lat kws cudamatrix nnet segmenter \ bin fstbin gmmbin fgmmbin featbin \ nnetbin latbin sgmm2 sgmm2bin nnet2 nnet3 chain nnet3bin nnet2bin kwsbin \ ivector ivectorbin online2 online2bin lmbin chainbin segmenterbin simplehmmbin MEMTESTDIRS = base matrix util feat tree thread gmm transform \ - fstext hmm lm decoder lat nnet kws chain segmenter simplehmm \ + fstext hmm simplehmm lm decoder lat nnet kws chain segmenter \ bin fstbin gmmbin fgmmbin featbin \ nnetbin latbin sgmm2 nnet2 nnet3 nnet2bin nnet3bin sgmm2bin kwsbin \ ivector ivectorbin online2 online2bin lmbin segmenterbin simplehmmbin @@ -151,8 +151,8 @@ $(EXT_SUBDIRS) : mklibdir ext_depend #1)The tools depend on all the libraries bin fstbin gmmbin fgmmbin sgmm2bin featbin nnetbin nnet2bin nnet3bin chainbin latbin ivectorbin lmbin kwsbin online2bin segmenterbin simplehmmbin: \ - base matrix util feat tree thread gmm transform sgmm2 fstext hmm \ - lm decoder lat cudamatrix nnet nnet2 nnet3 ivector chain kws online2 segmenter simplehmm + base matrix util feat tree thread gmm transform sgmm2 fstext hmm simplehmm \ + lm decoder lat cudamatrix nnet nnet2 nnet3 ivector chain kws online2 segmenter #2)The libraries have inter-dependencies base: base/.depend.mk @@ -167,7 +167,7 @@ sgmm2: base util matrix gmm tree transform thread hmm fstext: base util thread matrix tree hmm: base tree matrix util thread lm: base util thread matrix fstext -decoder: base util thread matrix gmm hmm simplehmm tree transform lat +decoder: base util thread matrix gmm hmm tree transform lat lat: base util thread hmm tree matrix cudamatrix: base util thread matrix nnet: base util hmm tree thread matrix cudamatrix diff --git a/src/decoder/Makefile b/src/decoder/Makefile index 8ee63103ab9..93db701cb7a 100644 --- a/src/decoder/Makefile +++ b/src/decoder/Makefile @@ -7,12 +7,11 @@ TESTFILES = OBJFILES = training-graph-compiler.o lattice-simple-decoder.o lattice-faster-decoder.o \ lattice-faster-online-decoder.o simple-decoder.o faster-decoder.o \ - decoder-wrappers.o simple-hmm-graph-compiler.o + decoder-wrappers.o LIBNAME = kaldi-decoder ADDLIBS = ../lat/kaldi-lat.a ../hmm/kaldi-hmm.a \ - ../simplehmm/kaldi-simplehmm.a \ ../transform/kaldi-transform.a ../gmm/kaldi-gmm.a \ ../tree/kaldi-tree.a ../util/kaldi-util.a ../thread/kaldi-thread.a \ ../matrix/kaldi-matrix.a ../base/kaldi-base.a diff --git a/src/lat/sausages.cc b/src/lat/sausages.cc index ee5df716d0c..89734f76a04 100644 --- a/src/lat/sausages.cc +++ b/src/lat/sausages.cc @@ -60,7 +60,7 @@ void MinimumBayesRisk::MbrDecode() { KALDI_ASSERT(confidence > 0); KALDI_ASSERT(begin_times_[q].count(R_[q]) > 0); KALDI_ASSERT(end_times_[q].count(R_[q]) > 0); - one_best_times_.push_back(make_pair( + one_best_times_.push_back(std::make_pair( begin_times_[q][R_[q]] / confidence, end_times_[q][R_[q]] / confidence)); } @@ -278,13 +278,13 @@ void MinimumBayesRisk::AccStats() { for (map::iterator iter = tau_b[q].begin(); iter != tau_b[q].end(); ++iter) { times_[q-1].first += iter->second; - begin_times_[q-1].insert(make_pair(iter->first, iter->second)); + begin_times_[q-1].insert(std::make_pair(iter->first, iter->second)); } for (map::iterator iter = tau_e[q].begin(); iter != tau_e[q].end(); ++iter) { times_[q-1].second += iter->second; - end_times_[q-1].insert(make_pair(iter->first, iter->second)); + end_times_[q-1].insert(std::make_pair(iter->first, iter->second)); } if (times_[q-1].first > times_[q-1].second) // this is quite bad. diff --git a/src/nnet3bin/nnet3-get-egs-dense-targets.cc b/src/nnet3bin/nnet3-get-egs-dense-targets.cc index 4ec1cd09b6b..0e387e19fcb 100644 --- a/src/nnet3bin/nnet3-get-egs-dense-targets.cc +++ b/src/nnet3bin/nnet3-get-egs-dense-targets.cc @@ -31,7 +31,7 @@ namespace kaldi { namespace nnet3 { -static void ProcessFile(const MatrixBase &feats, +static bool ProcessFile(const MatrixBase &feats, const MatrixBase *ivector_feats, int32 ivector_period, const VectorBase *deriv_weights, @@ -47,7 +47,7 @@ static void ProcessFile(const MatrixBase &feats, if (!utt_splitter->LengthsMatch(utt_id, num_input_frames, targets.NumRows())) { if (targets.NumRows() == 0) - return; + return false; // normally we wouldn't process such an utterance but there may be // situations when a small disagreement is acceptable. KALDI_WARN << " .. processing this utterance anyway."; @@ -62,7 +62,7 @@ static void ProcessFile(const MatrixBase &feats, KALDI_WARN << "Not producing egs for utterance " << utt_id << " because it is too short: " << num_input_frames << " frames."; - return; + return false; } // 'frame_subsampling_factor' is not used in any recipes at the time of @@ -146,7 +146,8 @@ static void ProcessFile(const MatrixBase &feats, int32 t = i + start_frame_subsampled; if (t >= targets.NumRows()) t = targets.NumRows() - 1; - this_deriv_weights(i) = deriv_weights(t); + this_deriv_weights(i) = (*deriv_weights)(t); + } eg.io.push_back(NnetIo("output", this_deriv_weights, 0, targets_part)); } @@ -163,9 +164,9 @@ static void ProcessFile(const MatrixBase &feats, example_writer->Write(key, eg); } -} - + return true; +} } // namespace nnet2 } // namespace kaldi @@ -311,7 +312,8 @@ int main(int argc, char *argv[]) { } if (!ProcessFile(feats, online_ivector_feats, online_ivector_period, - deriv_weights, target_matrix, key, compress, num_targets, + deriv_weights, target_matrix, key, compress, + input_compress_format, feats_compress_format, num_targets, &utt_splitter, &example_writer)) num_err++; } diff --git a/src/nnet3bin/nnet3-get-egs.cc b/src/nnet3bin/nnet3-get-egs.cc index 02287f228c4..490124449ac 100644 --- a/src/nnet3bin/nnet3-get-egs.cc +++ b/src/nnet3bin/nnet3-get-egs.cc @@ -140,7 +140,7 @@ static bool ProcessFile(const MatrixBase &feats, Vector this_deriv_weights(num_frames_subsampled); for (int32 i = 0; i < num_frames_subsampled; i++) { int32 t = i + start_frame_subsampled; - this_deriv_weights(i) = deriv_weights(t); + this_deriv_weights(i) = (*deriv_weights)(t); } // Ignore frames that have frame weights 0 if (this_deriv_weights.Sum() == 0) continue; @@ -316,7 +316,8 @@ int main(int argc, char *argv[]) { } if (!ProcessFile(feats, online_ivector_feats, online_ivector_period, - deriv_weights, pdf_post, key, compress, num_pdfs, + deriv_weights, pdf_post, key, compress, + input_compress_format, feats_compress_format, num_pdfs, &utt_splitter, &example_writer)) num_err++; } diff --git a/src/nnet3bin/nnet3-info.cc b/src/nnet3bin/nnet3-info.cc index 7f8dc82b3ce..c722c3b0a85 100644 --- a/src/nnet3bin/nnet3-info.cc +++ b/src/nnet3bin/nnet3-info.cc @@ -38,13 +38,10 @@ int main(int argc, char *argv[]) { "See also: nnet3-am-info\n"; bool print_detailed_info = false; - bool print_learning_rates = false; ParseOptions po(usage); po.Register("print-detailed-info", &print_detailed_info, "Print more detailed info"); - po.Register("print-learning-rates", &print_learning_rates, - "Print learning rates of updatable components"); po.Read(argc, argv); @@ -58,20 +55,6 @@ int main(int argc, char *argv[]) { Nnet nnet; ReadKaldiObject(raw_nnet_rxfilename, &nnet); - if (print_learning_rates) { - Vector learning_rates; - GetLearningRates(nnet, &learning_rates); - std::cout << "learning-rates: " - << PrintVectorPerUpdatableComponent(nnet, learning_rates) - << "\n"; - - Vector learning_rate_factors; - GetLearningRateFactors(nnet, &learning_rate_factors); - std::cout << "learning-rate-factors: " - << PrintVectorPerUpdatableComponent(nnet, learning_rate_factors) - << "\n"; - } - if (print_detailed_info) std::cout << NnetInfo(nnet); else diff --git a/src/segmenter/Makefile b/src/segmenter/Makefile index 8a9b37cad75..8259de32c1f 100644 --- a/src/segmenter/Makefile +++ b/src/segmenter/Makefile @@ -5,9 +5,9 @@ include ../kaldi.mk TESTFILES = segmentation-io-test information-bottleneck-clusterable-test OBJFILES = segment.o segmentation.o segmentation-utils.o \ - segmentation-post-processor.o \ - information-bottleneck-clusterable.o \ - information-bottleneck-cluster-utils.o + segmentation-post-processor.o #\ + #information-bottleneck-clusterable.o \ + #information-bottleneck-cluster-utils.o LIBNAME = kaldi-segmenter diff --git a/src/segmenterbin/intersect-int-vectors.cc b/src/segmenterbin/intersect-int-vectors.cc index 53731bf9046..0611dd513e1 100644 --- a/src/segmenterbin/intersect-int-vectors.cc +++ b/src/segmenterbin/intersect-int-vectors.cc @@ -67,7 +67,7 @@ int main(int argc, char *argv[]) { Input ki(mapping_rxfilename); std::string line; while (std::getline(ki.Stream(), line)) { - std::vector parts; + std::vector parts; SplitStringToVector(line, " ", true, &parts); KALDI_ASSERT(parts.size() == 3); @@ -113,9 +113,11 @@ int main(int argc, char *argv[]) { num_err++; } - std::vector alignment_out(alignment1.size()); + int32 min_length = std::min(static_cast(alignment1.size()), + static_cast(alignment2.size())); + std::vector alignment_out(min_length); - for (size_t i = 0; i < alignment1.size(); i++) { + for (size_t i = 0; i < min_length; i++) { std::pair id_pair = std::make_pair( alignment1[i], alignment2[i]); diff --git a/src/simplehmm/Makefile b/src/simplehmm/Makefile index d83fba05900..89c9f70a8c3 100644 --- a/src/simplehmm/Makefile +++ b/src/simplehmm/Makefile @@ -5,7 +5,7 @@ include ../kaldi.mk TESTFILES = simple-hmm-test -OBJFILES = simple-hmm.o simple-hmm-utils.o +OBJFILES = simple-hmm.o simple-hmm-utils.o simple-hmm-graph-compiler.o LIBNAME = kaldi-simplehmm ADDLIBS = ../hmm/kaldi-hmm.a ../decoder/kaldi-decoder.a \ diff --git a/src/decoder/simple-hmm-graph-compiler.cc b/src/simplehmm/simple-hmm-graph-compiler.cc similarity index 98% rename from src/decoder/simple-hmm-graph-compiler.cc rename to src/simplehmm/simple-hmm-graph-compiler.cc index 5f91380ca06..9626e08ae5f 100644 --- a/src/decoder/simple-hmm-graph-compiler.cc +++ b/src/simplehmm/simple-hmm-graph-compiler.cc @@ -18,7 +18,7 @@ // See the Apache 2 License for the specific language governing permissions and // limitations under the License. -#include "decoder/simple-hmm-graph-compiler.h" +#include "simplehmm/simple-hmm-graph-compiler.h" #include "simplehmm/simple-hmm-utils.h" // for GetHTransducer namespace kaldi { diff --git a/src/decoder/simple-hmm-graph-compiler.h b/src/simplehmm/simple-hmm-graph-compiler.h similarity index 100% rename from src/decoder/simple-hmm-graph-compiler.h rename to src/simplehmm/simple-hmm-graph-compiler.h diff --git a/src/simplehmm/simple-hmm.h b/src/simplehmm/simple-hmm.h index 4b40f212401..6fa9b1db6d2 100644 --- a/src/simplehmm/simple-hmm.h +++ b/src/simplehmm/simple-hmm.h @@ -87,7 +87,7 @@ class SimpleHmm: public TransitionModel { int32 num_pdfs_; } ctx_dep_; - DISALLOW_COPY_AND_ASSIGN(SimpleHmm); + KALDI_DISALLOW_COPY_AND_ASSIGN(SimpleHmm); }; } // end namespace kaldi diff --git a/src/simplehmmbin/Makefile b/src/simplehmmbin/Makefile index f382b30277c..3546ebae7c2 100644 --- a/src/simplehmmbin/Makefile +++ b/src/simplehmmbin/Makefile @@ -10,9 +10,9 @@ BINFILES = simple-hmm-init \ OBJFILES = -ADDLIBS = ../decoder/kaldi-decoder.a ../lat/kaldi-lat.a \ +ADDLIBS = ../decoder/kaldi-decoder.a \ + ../simplehmm/kaldi-simplehmm.a ../lat/kaldi-lat.a \ ../fstext/kaldi-fstext.a ../hmm/kaldi-hmm.a \ - ../simplehmm/kaldi-simplehmm.a\ ../util/kaldi-util.a ../thread/kaldi-thread.a \ ../matrix/kaldi-matrix.a ../base/kaldi-base.a diff --git a/src/simplehmmbin/compile-train-simple-hmm-graphs.cc b/src/simplehmmbin/compile-train-simple-hmm-graphs.cc index a1914ed0763..3797e24f2a4 100644 --- a/src/simplehmmbin/compile-train-simple-hmm-graphs.cc +++ b/src/simplehmmbin/compile-train-simple-hmm-graphs.cc @@ -24,7 +24,7 @@ #include "tree/context-dep.h" #include "simplehmm/simple-hmm.h" #include "fstext/fstext-lib.h" -#include "decoder/simple-hmm-graph-compiler.h" +#include "simplehmm/simple-hmm-graph-compiler.h" int main(int argc, char *argv[]) { From 40a20864322fda62f4247a861e95cacd2c5e6b42 Mon Sep 17 00:00:00 2001 From: Vimal Manohar Date: Mon, 6 Mar 2017 00:02:04 -0500 Subject: [PATCH 199/213] asr_diarization: remove short segments. --- src/segmenter/segmentation-post-processor.cc | 4 +++- src/segmenter/segmentation-utils.cc | 16 ++++++++++------ src/segmenter/segmentation-utils.h | 4 ++-- 3 files changed, 15 insertions(+), 9 deletions(-) diff --git a/src/segmenter/segmentation-post-processor.cc b/src/segmenter/segmentation-post-processor.cc index e8c7747c8c4..1bec12360fc 100644 --- a/src/segmenter/segmentation-post-processor.cc +++ b/src/segmenter/segmentation-post-processor.cc @@ -115,7 +115,9 @@ void SegmentationPostProcessor::Check() const { << "--max-blend-length. It must be positive."; } - if (IsRemovingSegmentsToBeDone(opts_) && remove_labels_[0] < 0) { + if (IsRemovingSegmentsToBeDone(opts_) && + (remove_labels_[0] < -1 || + (remove_labels_.size() > 1 && remove_labels_[0] == -1))) { KALDI_ERR << "Invalid value " << opts_.remove_labels_csl << " for option " << "--remove-labels. " << "The labels must be non-negative."; diff --git a/src/segmenter/segmentation-utils.cc b/src/segmenter/segmentation-utils.cc index 3cece810d45..1a87ba78ad8 100644 --- a/src/segmenter/segmentation-utils.cc +++ b/src/segmenter/segmentation-utils.cc @@ -115,11 +115,15 @@ void RemoveSegments(const std::vector &labels, for (SegmentList::iterator it = segmentation->Begin(); it != segmentation->End(); ) { - if ((max_remove_length == -1 || - it->Length() < max_remove_length) && - std::binary_search(labels.begin(), labels.end(), - it->Label())) { - it = segmentation->Erase(it); + if (max_remove_length == -1) { + if (std::binary_search(labels.begin(), labels.end(), + it->Label())) + it = segmentation->Erase(it); + } else if (it->Length() < max_remove_length) { + if (std::binary_search(labels.begin(), labels.end(), + it->Label()) || + (labels.size() == 1 && labels[0] == -1)) + it = segmentation->Erase(it); } else { ++it; } @@ -355,7 +359,7 @@ void SubSegmentUsingNonOverlappingSegments( for (SegmentList::const_iterator f_it = filter_segmentation.Begin(); f_it != filter_segmentation.End(); ++f_it) { - int32 label = (unmatched_label > 0 ? unmatched_label : it->Label()); + int32 label = (unmatched_label >= 0 ? unmatched_label : it->Label()); if (f_it->Label() == secondary_label) { if (subsegment_label >= 0) { label = subsegment_label; diff --git a/src/segmenter/segmentation-utils.h b/src/segmenter/segmentation-utils.h index 4fa3271e874..16e63710c6a 100644 --- a/src/segmenter/segmentation-utils.h +++ b/src/segmenter/segmentation-utils.h @@ -150,12 +150,12 @@ void IntersectSegmentationAndAlignment(const Segmentation &in_segmentation, * changing the class_id of the filtered sub-segments. * The label for the newly created subsegments is determined as follows: * if secondary segment's label == secondary_label: - * if subsegment_label > 0: + * if subsegment_label >= 0: * label = subsegment_label * else: * label = secondary_label * else: - * if unmatched_label > 0: + * if unmatched_label >= 0: * label = unmatched_label * else: * label = primary_label From 95a550b88d16b12dcaab9ebe5ac5333f7374e571 Mon Sep 17 00:00:00 2001 From: Vimal Manohar Date: Mon, 24 Apr 2017 12:17:57 -0400 Subject: [PATCH 200/213] segmenter: Fixing RemoveSegments --- src/segmenter/segmentation-utils.cc | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/segmenter/segmentation-utils.cc b/src/segmenter/segmentation-utils.cc index 1a87ba78ad8..4d76afba0b8 100644 --- a/src/segmenter/segmentation-utils.cc +++ b/src/segmenter/segmentation-utils.cc @@ -115,15 +115,19 @@ void RemoveSegments(const std::vector &labels, for (SegmentList::iterator it = segmentation->Begin(); it != segmentation->End(); ) { - if (max_remove_length == -1) { + if (max_remove_length < 0) { if (std::binary_search(labels.begin(), labels.end(), it->Label())) it = segmentation->Erase(it); + else + ++it; } else if (it->Length() < max_remove_length) { if (std::binary_search(labels.begin(), labels.end(), it->Label()) || (labels.size() == 1 && labels[0] == -1)) it = segmentation->Erase(it); + else + ++it; } else { ++it; } From ecc483f8b93993e4e73c101d208c45b209b1c717 Mon Sep 17 00:00:00 2001 From: Vimal Manohar Date: Mon, 24 Apr 2017 12:19:54 -0400 Subject: [PATCH 201/213] sad: Updating subsegment_data_dir --- egs/wsj/s5/utils/data/get_utt2num_frames.sh | 2 +- egs/wsj/s5/utils/data/subsegment_data_dir.sh | 30 +++++++++++++++----- src/featbin/copy-feats.cc | 18 ++++++++++-- 3 files changed, 39 insertions(+), 11 deletions(-) diff --git a/egs/wsj/s5/utils/data/get_utt2num_frames.sh b/egs/wsj/s5/utils/data/get_utt2num_frames.sh index 3f6d15c45a5..ec80e771c83 100755 --- a/egs/wsj/s5/utils/data/get_utt2num_frames.sh +++ b/egs/wsj/s5/utils/data/get_utt2num_frames.sh @@ -19,7 +19,7 @@ fi data=$1 -if [ -f $data/utt2num_frames ]; then +if [ -s $data/utt2num_frames ]; then echo "$0: $data/utt2num_frames already present!" exit 0; fi diff --git a/egs/wsj/s5/utils/data/subsegment_data_dir.sh b/egs/wsj/s5/utils/data/subsegment_data_dir.sh index 10a8a9cb264..4c664f16441 100755 --- a/egs/wsj/s5/utils/data/subsegment_data_dir.sh +++ b/egs/wsj/s5/utils/data/subsegment_data_dir.sh @@ -52,11 +52,11 @@ export LC_ALL=C srcdir=$1 subsegments=$2 -no_text=true +add_subsegment_text=false if [ $# -eq 4 ]; then new_text=$3 dir=$4 - no_text=false + add_subsegment_text=true if [ ! -f "$new_text" ]; then echo "$0: no such file $new_text" @@ -78,7 +78,7 @@ if ! mkdir -p $dir; then echo "$0: failed to create directory $dir" fi -if ! $no_text; then +if $add_subsegment_text; then if ! cmp <(awk '{print $1}' <$subsegments) <(awk '{print $1}' <$new_text); then echo "$0: expected the first fields of the files $subsegments and $new_text to be identical" exit 1 @@ -102,7 +102,7 @@ utils/apply_map.pl -f 2 $srcdir/utt2spk < $dir/new2old_utt >$dir/utt2spk # .. and the new spk2utt file. utils/utt2spk_to_spk2utt.pl <$dir/utt2spk >$dir/spk2utt -if ! $no_text; then +if $add_subsegment_text; then # the new text file is just what the user provides. cp $new_text $dir/text fi @@ -143,6 +143,10 @@ if [ -f $srcdir/feats.scp ]; then frame_shift=$(utils/data/get_frame_shift.sh $srcdir) echo "$0: note: frame shift is $frame_shift [affects feats.scp]" + utils/data/get_utt2num_frames.sh --cmd "run.pl" --nj 1 $srcdir + awk '{print $1" "$2}' $subsegments | \ + utils/apply_map.pl -f 2 $srcdir/utt2num_frames > \ + $dir/utt2max_frames # The subsegments format is . # e.g. 'utt_foo-1 utt_foo 7.21 8.93' @@ -165,10 +169,22 @@ if [ -f $srcdir/feats.scp ]; then # utt_foo-1 some command|[721:892] # Lastly, utils/data/normalize_data_range.pl will only do something nontrivial if # the original data-dir already had data-ranges in square brackets. - awk -v s=$frame_shift '{print $1, $2, int(($3/s)+0.5), int(($4/s)-0.5);}' <$subsegments| \ + cat $subsegments | awk -v s=$frame_shift '{print $1, $2, int(($3/s)+0.5), int(($4/s)-0.5);}' | \ utils/apply_map.pl -f 2 $srcdir/feats.scp | \ awk '{p=NF-1; for (n=1;n$dir/feats.scp + utils/data/normalize_data_range.pl | \ + utils/data/fix_subsegmented_feats.pl $dir/utt2max_frames >$dir/feats.scp + + cat $dir/feats.scp | perl -ne 'm/^(\S+) .+\[(\d+):(\d+)\]$/; print "$1 " . ($3-$2+1) . "\n"' > \ + $dir/utt2num_frames + + if [ -f $srcdir/vad.scp ]; then + cat $subsegments | awk -v s=$frame_shift '{print $1, $2, int(($3/s)+0.5), int(($4/s)-0.5);}' | \ + utils/apply_map.pl -f 2 $srcdir/vad.scp | \ + awk '{p=NF-1; for (n=1;n$dir/vad.scp + fi fi @@ -202,7 +218,7 @@ utils/data/fix_data_dir.sh $dir validate_opts= [ ! -f $srcdir/feats.scp ] && validate_opts="$validate_opts --no-feats" [ ! -f $srcdir/wav.scp ] && validate_opts="$validate_opts --no-wav" -$no_text && validate_opts="$validate_opts --no-text" +! $add_subsegment_text && validate_opts="$validate_opts --no-text" utils/data/validate_data_dir.sh $validate_opts $dir diff --git a/src/featbin/copy-feats.cc b/src/featbin/copy-feats.cc index 0fbcca6399a..f1f58653f2f 100644 --- a/src/featbin/copy-feats.cc +++ b/src/featbin/copy-feats.cc @@ -102,19 +102,31 @@ int main(int argc, char *argv[]) { CompressedMatrixWriter kaldi_writer(wspecifier); if (htk_in) { SequentialTableReader htk_reader(rspecifier); - for (; !htk_reader.Done(); htk_reader.Next(), num_done++) + for (; !htk_reader.Done(); htk_reader.Next(), num_done++) { kaldi_writer.Write(htk_reader.Key(), CompressedMatrix(htk_reader.Value().first)); + if (!num_frames_wspecifier.empty()) + num_frames_writer.Write(htk_reader.Key(), + htk_reader.Value().first.NumRows()); + } } else if (sphinx_in) { SequentialTableReader > sphinx_reader(rspecifier); - for (; !sphinx_reader.Done(); sphinx_reader.Next(), num_done++) + for (; !sphinx_reader.Done(); sphinx_reader.Next(), num_done++) { kaldi_writer.Write(sphinx_reader.Key(), CompressedMatrix(sphinx_reader.Value())); + if (!num_frames_wspecifier.empty()) + num_frames_writer.Write(sphinx_reader.Key(), + sphinx_reader.Value().NumRows()); + } } else { SequentialBaseFloatMatrixReader kaldi_reader(rspecifier); - for (; !kaldi_reader.Done(); kaldi_reader.Next(), num_done++) + for (; !kaldi_reader.Done(); kaldi_reader.Next(), num_done++) { kaldi_writer.Write(kaldi_reader.Key(), CompressedMatrix(kaldi_reader.Value())); + if (!num_frames_wspecifier.empty()) + num_frames_writer.Write(kaldi_reader.Key(), + kaldi_reader.Value().NumRows()); + } } } KALDI_LOG << "Copied " << num_done << " feature matrices."; From 20f3072bf327eeff36923765be696863ba9cd715 Mon Sep 17 00:00:00 2001 From: Vimal Manohar Date: Mon, 24 Apr 2017 12:21:36 -0400 Subject: [PATCH 202/213] sad: xconfig stats layer --- egs/wsj/s5/steps/libs/nnet3/xconfig/layers.py | 1 - egs/wsj/s5/steps/libs/nnet3/xconfig/stats_layer.py | 9 ++++----- egs/wsj/s5/steps/libs/nnet3/xconfig/utils.py | 2 +- 3 files changed, 5 insertions(+), 7 deletions(-) diff --git a/egs/wsj/s5/steps/libs/nnet3/xconfig/layers.py b/egs/wsj/s5/steps/libs/nnet3/xconfig/layers.py index 1092be572b4..188e0ec4322 100644 --- a/egs/wsj/s5/steps/libs/nnet3/xconfig/layers.py +++ b/egs/wsj/s5/steps/libs/nnet3/xconfig/layers.py @@ -5,5 +5,4 @@ from basic_layers import * from lstm import * -from tdnn import * from stats_layer import * diff --git a/egs/wsj/s5/steps/libs/nnet3/xconfig/stats_layer.py b/egs/wsj/s5/steps/libs/nnet3/xconfig/stats_layer.py index beaf7c8923a..e49a4fa3df6 100644 --- a/egs/wsj/s5/steps/libs/nnet3/xconfig/stats_layer.py +++ b/egs/wsj/s5/steps/libs/nnet3/xconfig/stats_layer.py @@ -6,7 +6,6 @@ from __future__ import print_function import re -from libs.nnet3.xconfig.utils import XconfigParserError as xparser_error from libs.nnet3.xconfig.basic_layers import XconfigLayerBase @@ -46,13 +45,13 @@ def set_default_configs(self): def set_derived_configs(self): config_string = self.config['config'] if config_string == '': - raise xparser_error("config has to be non-empty", + raise RuntimeError("config has to be non-empty", self.str()) m = re.search("(mean|mean\+stddev|mean\+count|mean\+stddev\+count)" "\((-?\d+):(-?\d+):(-?\d+):(-?\d+)\)", config_string) if m is None: - raise xparser_error("Invalid statistic-config string: {0}".format( + raise RuntimeError("Invalid statistic-config string: {0}".format( config_string), self) self._output_stddev = (m.group(1) in ['mean+stddev', @@ -69,7 +68,7 @@ def set_derived_configs(self): + 1 if self._output_log_counts else 0) if self.config['dim'] > 0 and self.config['dim'] != output_dim: - raise xparser_error( + raise RuntimeError( "Invalid dim supplied {0:d} != " "actual output dim {1:d}".format( self.config['dim'], output_dim)) @@ -81,7 +80,7 @@ def check_configs(self): and self._left_context % self._stats_period == 0 and self._right_context % self._stats_period == 0 and self._stats_period % self._input_period == 0): - raise xparser_error( + raise RuntimeError( "Invalid configuration of statistics-extraction: {0}".format( self.config['config']), self) super(XconfigStatsLayer, self).check_configs() diff --git a/egs/wsj/s5/steps/libs/nnet3/xconfig/utils.py b/egs/wsj/s5/steps/libs/nnet3/xconfig/utils.py index 3d958568717..76477300884 100644 --- a/egs/wsj/s5/steps/libs/nnet3/xconfig/utils.py +++ b/egs/wsj/s5/steps/libs/nnet3/xconfig/utils.py @@ -484,7 +484,7 @@ def parse_config_line(orig_config_line): # treats splitting on space as a special case that may give zero fields. config_line = orig_config_line.split('#')[0] # Note: this set of allowed characters may have to be expanded in future. - x = re.search('[^a-zA-Z0-9\.\-\(\)@_=,/\s"]', config_line) + x = re.search('[^a-zA-Z0-9\.\-\(\)@_=,/+:\s"]', config_line) if x is not None: bad_char = x.group(0) if bad_char == "'": From ed129f1a17689f14c64095ddb3751e696f92619e Mon Sep 17 00:00:00 2001 From: Vimal Manohar Date: Mon, 24 Apr 2017 18:46:00 -0400 Subject: [PATCH 203/213] sad: Make utt2num_frames default in feats extraction --- egs/wsj/s5/steps/make_mfcc.sh | 2 +- egs/wsj/s5/steps/make_mfcc_pitch.sh | 17 +++++++++++++++-- 2 files changed, 16 insertions(+), 3 deletions(-) diff --git a/egs/wsj/s5/steps/make_mfcc.sh b/egs/wsj/s5/steps/make_mfcc.sh index ddb63a0e6fb..5362e7fa9d9 100755 --- a/egs/wsj/s5/steps/make_mfcc.sh +++ b/egs/wsj/s5/steps/make_mfcc.sh @@ -10,7 +10,7 @@ nj=4 cmd=run.pl mfcc_config=conf/mfcc.conf compress=true -write_utt2num_frames=false # if true writes utt2num_frames +write_utt2num_frames=true # if true writes utt2num_frames # End configuration section. echo "$0 $@" # Print the command line for logging diff --git a/egs/wsj/s5/steps/make_mfcc_pitch.sh b/egs/wsj/s5/steps/make_mfcc_pitch.sh index ff9a7d2f5f3..4a2808b811f 100755 --- a/egs/wsj/s5/steps/make_mfcc_pitch.sh +++ b/egs/wsj/s5/steps/make_mfcc_pitch.sh @@ -96,6 +96,12 @@ for n in $(seq $nj); do utils/create_data_link.pl $mfcc_pitch_dir/raw_mfcc_pitch_$name.$n.ark done +if $write_utt2num_frames; then + write_num_frames_opt="--write-num-frames=ark,t:$logdir/utt2num_frames.JOB" +else + write_num_frames_opt= +fi + if [ -f $data/segments ]; then echo "$0 [info]: segments file exists: using that." split_segments="" @@ -111,7 +117,7 @@ if [ -f $data/segments ]; then $cmd JOB=1:$nj $logdir/make_mfcc_pitch_${name}.JOB.log \ paste-feats --length-tolerance=$paste_length_tolerance "$mfcc_feats" "$pitch_feats" ark:- \| \ - copy-feats --compress=$compress ark:- \ + copy-feats --compress=$compress $write_num_frames_opt ark:- \ ark,scp:$mfcc_pitch_dir/raw_mfcc_pitch_$name.JOB.ark,$mfcc_pitch_dir/raw_mfcc_pitch_$name.JOB.scp \ || exit 1; @@ -129,7 +135,7 @@ else $cmd JOB=1:$nj $logdir/make_mfcc_pitch_${name}.JOB.log \ paste-feats --length-tolerance=$paste_length_tolerance "$mfcc_feats" "$pitch_feats" ark:- \| \ - copy-feats --compress=$compress ark:- \ + copy-feats --compress=$compress $write_num_frames_opt ark:- \ ark,scp:$mfcc_pitch_dir/raw_mfcc_pitch_$name.JOB.ark,$mfcc_pitch_dir/raw_mfcc_pitch_$name.JOB.scp \ || exit 1; @@ -147,6 +153,13 @@ for n in $(seq $nj); do cat $mfcc_pitch_dir/raw_mfcc_pitch_$name.$n.scp || exit 1; done > $data/feats.scp +if $write_utt2num_frames; then + for n in $(seq $nj); do + cat $logdir/utt2num_frames.$n || exit 1; + done > $data/utt2num_frames || exit 1 + rm $logdir/uttnum_frames.* +fi + rm $logdir/wav_${name}.*.scp $logdir/segments.* 2>/dev/null nf=`cat $data/feats.scp | wc -l` From ddf58d3e012416d81664fc6fbf240b3d8590947c Mon Sep 17 00:00:00 2001 From: Vimal Manohar Date: Mon, 24 Apr 2017 18:54:03 -0400 Subject: [PATCH 204/213] segmenter: Update local recipes --- .../local/segmentation/prepare_babel_data.sh | 14 +++-- .../local/segmentation/prepare_fisher_data.sh | 8 ++- .../local/segmentation/prepare_unsad_data.sh | 52 ++++++------------- 3 files changed, 26 insertions(+), 48 deletions(-) diff --git a/egs/aspire/s5/local/segmentation/prepare_babel_data.sh b/egs/aspire/s5/local/segmentation/prepare_babel_data.sh index 927c530663d..e70dc216980 100644 --- a/egs/aspire/s5/local/segmentation/prepare_babel_data.sh +++ b/egs/aspire/s5/local/segmentation/prepare_babel_data.sh @@ -68,20 +68,18 @@ EOF # The original data directory which will be converted to a whole (recording-level) directory. utils/copy_data_dir.sh $ROOT_DIR/data/train data/babel_${lang_id}_train train_data_dir=data/babel_${lang_id}_train -speeds="0.9 1.0 1.1" -num_speeds=$(echo $speeds | awk '{print NF}') # Expecting the user to have done run.sh to have $model_dir, # $sat_model_dir, $lang, $lang_test, $train_data_dir local/segmentation/prepare_unsad_data.sh \ - --sad-map $dir/babel_sad.map --speeds "$speeds" \ + --sad-map $dir/babel_sad.map \ --config-dir $ROOT_DIR/conf --feat-type plp --add-pitch true \ --reco-nj 40 --nj 100 --cmd "$train_cmd" \ --sat-model-dir $sat_model_dir \ --lang-test $lang_test \ $train_data_dir $lang $model_dir $dir -orig_data_dir=${train_data_dir}_sp${num_speeds} +orig_data_dir=${train_data_dir}_sp data_dir=${train_data_dir}_whole @@ -92,16 +90,16 @@ if [ ! -z $subset ]; then data_dir=${data_dir}_$subset fi -reco_vad_dir=$dir/`basename $model_dir`_reco_vad_`basename $train_data_dir`_sp4 +reco_vad_dir=$dir/`basename $model_dir`_reco_vad_`basename $train_data_dir`_sp # Add noise from MUSAN corpus to data directory and create a new data directory -local/segmentation/do_corruption_data_dir.sh \ - --data-dir $data_dir --speeds "$speeds" \ +local/segmentation/do_corruption_data_dir_snr.sh \ + --data-dir $data_dir \ --reco-vad-dir $reco_vad_dir \ --feat-suffix hires_bp --mfcc-config conf/mfcc_hires_bp.conf # Add music from MUSAN corpus to data directory and create a new data directory local/segmentation/do_corruption_data_dir_music.sh \ - --data-dir $data_dir --speeds "$speeds" \ + --data-dir $data_dir \ --reco-vad-dir $reco_vad_dir \ --feat-suffix hires_bp --mfcc-config conf/mfcc_hires_bp.conf diff --git a/egs/aspire/s5/local/segmentation/prepare_fisher_data.sh b/egs/aspire/s5/local/segmentation/prepare_fisher_data.sh index d90dd05f472..40f43cfd442 100644 --- a/egs/aspire/s5/local/segmentation/prepare_fisher_data.sh +++ b/egs/aspire/s5/local/segmentation/prepare_fisher_data.sh @@ -56,7 +56,6 @@ oov_I 3 oov_S 3 EOF -true && { # Expecting the user to have done run.sh to have $model_dir, # $sat_model_dir, $lang, $lang_test, $train_data_dir local/segmentation/prepare_unsad_data.sh \ @@ -66,21 +65,20 @@ local/segmentation/prepare_unsad_data.sh \ --sat-model-dir $sat_model_dir \ --lang-test $lang_test \ $train_data_dir $lang $model_dir $dir -} data_dir=${train_data_dir}_whole if [ ! -z $subset ]; then # Work on a subset - true && utils/subset_data_dir.sh ${data_dir} $subset \ + false && utils/subset_data_dir.sh ${data_dir} $subset \ ${data_dir}_$subset data_dir=${data_dir}_$subset fi -reco_vad_dir=$dir/`basename $model_dir`_reco_vad_`basename $train_data_dir`_sp4 +reco_vad_dir=$dir/`basename $model_dir`_reco_vad_`basename $train_data_dir`_sp # Add noise from MUSAN corpus to data directory and create a new data directory -true && local/segmentation/do_corruption_data_dir.sh \ +local/segmentation/do_corruption_data_dir_snr.sh \ --data-dir $data_dir \ --reco-vad-dir $reco_vad_dir \ --feat-suffix hires_bp --mfcc-config conf/mfcc_hires_bp.conf diff --git a/egs/aspire/s5/local/segmentation/prepare_unsad_data.sh b/egs/aspire/s5/local/segmentation/prepare_unsad_data.sh index dc4cbf58994..cccc7e2db84 100755 --- a/egs/aspire/s5/local/segmentation/prepare_unsad_data.sh +++ b/egs/aspire/s5/local/segmentation/prepare_unsad_data.sh @@ -38,8 +38,6 @@ sat_model_dir= # Model directory used for getting alignments lang_test= # Language directory used to build graph. # If its not provided, $lang will be used instead. -speeds="0.9 1.0 1.1" - . utils/parse_options.sh if [ $# -ne 4 ]; then @@ -190,15 +188,13 @@ if [ $stage -le -2 ]; then utils/data/get_utt2dur.sh ${whole_data_dir} fi -num_speeds=`echo $speeds | awk '{print NF}'` if $speed_perturb; then - plpdir=${plpdir}_sp$num_speeds - mfccdir=${mfccdir}_sp$num_speeds - + plpdir=${plpdir}_sp + mfccdir=${mfccdir}_sp if [ $stage -le -1 ]; then - utils/data/perturb_data_dir_speed_${num_speeds}way.sh ${whole_data_dir} ${whole_data_dir}_sp${num_speeds} - utils/data/perturb_data_dir_speed_${num_speeds}way.sh ${data_dir} ${data_dir}_sp${num_speeds} + utils/data/perturb_data_dir_speed_3way.sh ${whole_data_dir} ${whole_data_dir}_sp + utils/data/perturb_data_dir_speed_3way.sh ${data_dir} ${data_dir}_sp if [ $feat_type == "mfcc" ]; then if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $mfccdir/storage ]; then @@ -208,9 +204,9 @@ if $speed_perturb; then make_mfcc --cmd "$cmd --max-jobs-run 40" --nj $nj \ --mfcc-config $feat_config \ --add-pitch $add_pitch --pitch-config $pitch_config \ - ${whole_data_dir}_sp${num_speeds} exp/make_mfcc $mfccdir || exit 1 + ${whole_data_dir}_sp exp/make_mfcc $mfccdir || exit 1 steps/compute_cmvn_stats.sh \ - ${whole_data_dir}_sp${num_speeds} exp/make_mfcc $mfccdir || exit 1 + ${whole_data_dir}_sp exp/make_mfcc $mfccdir || exit 1 elif [ $feat_type == "plp" ]; then if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $plpdir/storage ]; then utils/create_split_dir.pl \ @@ -220,20 +216,20 @@ if $speed_perturb; then make_plp --cmd "$cmd --max-jobs-run 40" --nj $nj \ --plp-config $feat_config \ --add-pitch $add_pitch --pitch-config $pitch_config \ - ${whole_data_dir}_sp${num_speeds} exp/make_plp $plpdir || exit 1 + ${whole_data_dir}_sp exp/make_plp $plpdir || exit 1 steps/compute_cmvn_stats.sh \ - ${whole_data_dir}_sp${num_speeds} exp/make_plp $plpdir || exit 1 + ${whole_data_dir}_sp exp/make_plp $plpdir || exit 1 else echo "$0: Unknown feat-type $feat_type. Must be mfcc or plp." exit 1 fi - utils/fix_data_dir.sh ${whole_data_dir}_sp${num_speeds} + utils/fix_data_dir.sh ${whole_data_dir}_sp fi - data_dir=${data_dir}_sp${num_speeds} - whole_data_dir=${whole_data_dir}_sp${num_speeds} - data_id=${data_id}_sp${num_speeds} + data_dir=${data_dir}_sp + whole_data_dir=${whole_data_dir}_sp + data_id=${data_id}_sp fi @@ -241,18 +237,9 @@ fi # Compute length of recording ############################################################################### -utils/data/get_reco2utt.sh $data_dir - if [ $stage -le 0 ]; then - utils/data/get_utt2num_frames.sh \ - --frame-shift $frame_shift --frame-overlap $frame_overlap \ - --cmd "$cmd" --nj $reco_nj $whole_data_dir - - awk '{print $1" "$2}' ${data_dir}/segments | utils/apply_map.pl -f 2 ${whole_data_dir}/utt2num_frames > $data_dir/utt2max_frames - utils/data/get_subsegmented_feats.sh ${whole_data_dir}/feats.scp \ - $frame_shift $frame_overlap ${data_dir}/segments | \ - utils/data/fix_subsegmented_feats.pl $data_dir/utt2max_frames \ - > ${data_dir}/feats.scp + utils/subsegment_data_dir.sh $whole_data_dir ${data_dir}/segments ${data_dir}/tmp + cp $data_dir/tmp/feats.scp $data_dir if [ $feat_type == mfcc ]; then steps/compute_cmvn_stats.sh ${data_dir} exp/make_mfcc/${data_id} $mfccdir @@ -380,14 +367,9 @@ fi if [ $stage -le 6 ]; then - utils/data/get_reco2utt.sh $outside_data_dir - awk '{print $1" "$2}' $outside_data_dir/segments | utils/apply_map.pl -f 2 $whole_data_dir/utt2num_frames > $outside_data_dir/utt2max_frames - - utils/data/get_subsegmented_feats.sh ${whole_data_dir}/feats.scp \ - $frame_shift $frame_overlap ${outside_data_dir}/segments | \ - utils/data/fix_subsegmented_feats.pl $outside_data_dir/utt2max_frames \ - > ${outside_data_dir}/feats.scp - + utils/data/subsegment_data_dir.sh $whole_data_dir $outside_data_dir/segments \ + $outside_data_dir/tmp + cp $outside_data_dir/tmp/feats.scp $outside_data_dir fi extended_data_dir=$dir/${data_id}_extended From ddc85cf2108f769df3ce5b6117ab68490f36b892 Mon Sep 17 00:00:00 2001 From: Vimal Manohar Date: Mon, 24 Apr 2017 21:25:19 -0400 Subject: [PATCH 205/213] segmenter: Adding some missing files --- .../segmentation/prepare_unsad_data_simple.sh | 114 ++++++++++++ .../internal/make_bigram_G_fst.py | 174 ++++++++++++++++++ 2 files changed, 288 insertions(+) create mode 100755 egs/aspire/s5/local/segmentation/prepare_unsad_data_simple.sh create mode 100755 egs/wsj/s5/steps/segmentation/internal/make_bigram_G_fst.py diff --git a/egs/aspire/s5/local/segmentation/prepare_unsad_data_simple.sh b/egs/aspire/s5/local/segmentation/prepare_unsad_data_simple.sh new file mode 100755 index 00000000000..f3d1a7707e8 --- /dev/null +++ b/egs/aspire/s5/local/segmentation/prepare_unsad_data_simple.sh @@ -0,0 +1,114 @@ +#!/bin/bash + +# Copyright 2016 Vimal Manohar +# Apache 2.0 + +# This script prepares speech labels for +# training unsad network for speech activity detection and music detection. +# This is similar to the script prepare_unsad_data.sh, but directly +# uses existing alignments to create labels, instead of creating new alignments. + +set -e +set -o pipefail +set -u + +. path.sh + +stage=-2 +cmd=queue.pl + +# Options to be passed to get_sad_map.py +map_noise_to_sil=true # Map noise phones to silence label (0) +map_unk_to_speech=true # Map unk phones to speech label (1) +sad_map= # Initial mapping from phones to speech/non-speech labels. + # Overrides the default mapping using phones/silence.txt + # and phones/nonsilence.txt + +. utils/parse_options.sh + +if [ $# -ne 4 ]; then + echo "This script takes a data directory and alignment directory and " + echo "converts it into speech activity labels" + echo "for the purpose of training a Universal Speech Activity Detector.\n" + echo "Usage: $0 [options] " + echo " e.g.: $0 data/train_100k data/lang exp/tri4a_ali exp/vad_data_prep" + echo "" + echo "Main options (for others, see top of script file)" + echo " --config # config file containing options" + echo " --cmd (run.pl|/queue.pl ) # how to run jobs." + exit 1 +fi + +data_dir=$1 +lang=$2 +ali_dir=$3 +dir=$4 + +extra_files= + +for f in $data_dir/feats.scp $lang/phones.txt $lang/phones/silence.txt $lang/phones/nonsilence.txt $sad_map $ali_dir/ali.1.gz $ali_dir/final.mdl $ali_dir/tree $extra_files; do + if [ ! -f $f ]; then + echo "$f could not be found" + exit 1 + fi +done + +mkdir -p $dir + +data_id=$(basename $data_dir) + +if [ $stage -le 0 ]; then + # Get a mapping from the phones to the speech / non-speech labels + steps/segmentation/get_sad_map.py \ + --init-sad-map="$sad_map" \ + --map-noise-to-sil=$map_noise_to_sil \ + --map-unk-to-speech=$map_unk_to_speech \ + $lang | utils/sym2int.pl -f 1 $lang/phones.txt > $dir/sad_map +fi + +############################################################################### +# Convert alignment into SAD labels at utterance-level in segmentation format +############################################################################### + +vad_dir=$dir/`basename ${ali_dir}`_vad_${data_id} + +# Convert relative path to full path +vad_dir=`perl -e '($dir,$pwd)= @ARGV; if($dir!~m:^/:) { $dir = "$pwd/$dir"; } print $dir;' $vad_dir ${PWD}` + +if [ $stage -le 1 ]; then + steps/segmentation/internal/convert_ali_to_vad.sh --cmd "$cmd" \ + $ali_dir $dir/sad_map $vad_dir +fi + +[ ! -s $vad_dir/sad_seg.scp ] && echo "$0: $vad_dir/sad_seg.scp is empty" && exit 1 + +############################################################################### +# Post-process the segmentation and create frame-level alignments and +# per-frame deriv weights. +############################################################################### + +if [ $stage -le 2 ]; then + # Create per-frame speech / non-speech labels. + nj=`cat $vad_dir/num_jobs` + + utils/data/get_utt2num_frames.sh --nj $nj --cmd "$cmd" $data_dir + + set +e + for n in `seq $nj`; do + utils/create_data_link.pl $vad_dir/speech_labels.$n.ark + done + set -e + + $cmd JOB=1:$nj $vad_dir/log/get_speech_labels.JOB.log \ + segmentation-copy --keep-label=1 scp:$vad_dir/sad_seg.JOB.scp ark:- \| \ + segmentation-to-ali --lengths-rspecifier=ark,t:$data_dir/utt2num_frames \ + ark:- ark,scp:$vad_dir/speech_labels.JOB.ark,$vad_dir/speech_labels.JOB.scp + + for n in `seq $nj`; do + cat $vad_dir/speech_labels.$n.scp + done > $vad_dir/speech_labels.scp + + cp $vad_dir/speech_labels.scp $data_dir +fi + +echo "$0: Finished creating corpus for training Universal SAD with data in $data_dir and labels in $vad_dir" diff --git a/egs/wsj/s5/steps/segmentation/internal/make_bigram_G_fst.py b/egs/wsj/s5/steps/segmentation/internal/make_bigram_G_fst.py new file mode 100755 index 00000000000..2431d293c4c --- /dev/null +++ b/egs/wsj/s5/steps/segmentation/internal/make_bigram_G_fst.py @@ -0,0 +1,174 @@ +#! /usr/bin/env python + +from __future__ import print_function +import argparse +import logging +import math + + +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) +handler = logging.StreamHandler() +handler.setLevel(logging.INFO) +formatter = logging.Formatter("%(asctime)s [%(filename)s:%(lineno)s - " + "%(funcName)s - %(levelname)s ] %(message)s") +handler.setFormatter(formatter) +logger.addHandler(handler) + + +def get_args(): + parser = argparse.ArgumentParser( + description="""This script generates a bigram G.fst lang for decoding. + It needs as an input classes_info file with the format: + , + where each pair is :. + destination-class -1 is used to represent final probabilitiy.""") + + parser.add_argument("classes_info", type=argparse.FileType('r'), + help="File with classes_info") + parser.add_argument("out_file", type=argparse.FileType('w'), + help="Output G.fst. Use '-' for stdout") + args = parser.parse_args() + return args + + +class ClassInfo(object): + def __init__(self, class_id): + self.class_id = class_id + self.start_state = -1 + self.initial_prob = 0 + self.transitions = {} + + def __str__(self): + return ("class-id={0},start-state={1}," + "initial-prob={2:.2f},transitions={3}".format( + self.class_id, self.start_state, + self.initial_prob, ' '.join( + ['{0}:{1}'.format(x, y) + for x, y in self.transitions.iteritems()]))) + + +def read_classes_info(file_handle): + classes_info = {} + + num_states = 1 + num_classes = 0 + + for line in file_handle.readlines(): + try: + parts = line.split() + class_id = int(parts[0]) + assert class_id > 0, class_id + if class_id in classes_info: + raise RuntimeError( + "Duplicate class-id {0} in file {1}".format( + class_id, file_handle.name)) + + classes_info[class_id] = ClassInfo(class_id) + class_info = classes_info[class_id] + class_info.initial_prob = float(parts[1]) + class_info.start_state = num_states + num_states += 1 + num_classes += 1 + + total_prob = 0.0 + if len(parts) > 2: + for part in parts[2:]: + dest_class, transition_prob = part.split(':') + dest_class = int(dest_class) + total_prob += float(transition_prob) + + if total_prob > 1.0: + raise ValueError("total-probability out of class {0} " + "is {1} > 1.0".format(class_id, + total_prob)) + + if dest_class in class_info.transitions: + logger.error( + "Duplicate transition to class-id {0}" + "in transitions".format(dest_class)) + raise RuntimeError + class_info.transitions[dest_class] = float(transition_prob) + + if -1 in class_info.transitions: + if abs(total_prob - 1.0) > 0.001: + raise ValueError("total-probability out of class {0} " + "is {1} != 1.0".format(class_id, + total_prob)) + else: + class_info.transitions[-1] = 1.0 - total_prob + else: + raise RuntimeError( + "No transitions out of class {0}".format(class_id)) + except Exception: + logger.error("Error processing line %s in file %s", + line, file_handle.name) + raise + + # Final state + classes_info[-1] = ClassInfo(-1) + class_info = classes_info[-1] + class_info.start_state = num_states + + for class_id, class_info in classes_info.iteritems(): + logger.info("For class %d, got class-info %s", class_id, class_info) + + return classes_info, num_classes + + +def print_states_for_class(class_id, classes_info, out_file): + class_info = classes_info[class_id] + + state = class_info.start_state + + # Transition from the FST initial state + print ("0 {end} {logprob}".format( + end=state, logprob=-math.log(class_info.initial_prob)), + file=out_file) + + for dest_class, prob in class_info.transitions.iteritems(): + try: + if dest_class == class_id: # self loop + next_state = state + else: # other transition + next_state = classes_info[dest_class].start_state + + print ("{start} {end} {class_id} {class_id} {logprob}".format( + start=state, end=next_state, class_id=class_id, + logprob=-math.log(prob)), + file=out_file) + + except Exception: + logger.error("Failed to add transition (%d->%d).\n" + "classes_info = %s", class_id, dest_class, + class_info) + + print ("{start} {final} {class_id} {class_id}".format( + start=state, final=classes_info[-1].start_state, + class_id=class_id), + file=out_file) + print ("{0}".format(classes_info[-1].start_state), file=out_file) + + +def run(args): + classes_info, num_classes = read_classes_info(args.classes_info) + + for class_id in range(1, num_classes + 1): + print_states_for_class(class_id, classes_info, args.out_file) + + +def main(): + try: + args = get_args() + run(args) + except Exception: + logger.error("Failed to make G.fst") + raise + finally: + for f in [args.classes_info, args.out_file]: + if f is not None: + f.close() + + +if __name__ == '__main__': + main() From c90097e0011d3ae84b0c804e9141c02516a1d424 Mon Sep 17 00:00:00 2001 From: Vimal Manohar Date: Mon, 24 Apr 2017 21:25:43 -0400 Subject: [PATCH 206/213] segmenter: resample data directory --- egs/wsj/s5/utils/data/resample_data_dir.sh | 35 ++++++++++++++++++++++ 1 file changed, 35 insertions(+) create mode 100755 egs/wsj/s5/utils/data/resample_data_dir.sh diff --git a/egs/wsj/s5/utils/data/resample_data_dir.sh b/egs/wsj/s5/utils/data/resample_data_dir.sh new file mode 100755 index 00000000000..8781ee4c503 --- /dev/null +++ b/egs/wsj/s5/utils/data/resample_data_dir.sh @@ -0,0 +1,35 @@ +#! /bin/bash + +# Copyright 2016 Vimal Manohar +# Apache 2.0. + +if [ $# -ne 2 ]; then + echo "Usage: $0 " + exit 1 +fi + +freq=$1 +dir=$2 + +sox=`which sox` || { echo "Could not find sox in PATH"; exit 1; } + +if [ -f $dir/feats.scp ]; then + mkdir -p $dir/.backup + mv $dir/feats.scp $dir/.backup/ + if [ -f $dir/cmvn.scp ]; then + mv $dir/cmvn.scp $dir/.backup/ + fi + echo "$0: feats.scp already exists. Moving it to $dir/.backup" +fi + +mv $dir/wav.scp $dir/wav.scp.tmp +cat $dir/wav.scp.tmp | python -c "import sys +for line in sys.stdin.readlines(): + splits = line.strip().split() + if splits[-1] == '|': + out_line = line.strip() + ' $sox -t wav - -c 1 -b 16 -t wav - rate $freq |' + else: + out_line = 'cat {0} {1} | $sox -t wav - -c 1 -b 16 -t wav - rate $freq |'.format(splits[0], ' '.join(splits[1:])) + print (out_line)" > ${dir}/wav.scp +rm $dir/wav.scp.tmp + From 3ad4355a73d2e74962e95921a8dbe410e2f06a85 Mon Sep 17 00:00:00 2001 From: Vimal Manohar Date: Mon, 24 Apr 2017 21:35:33 -0400 Subject: [PATCH 207/213] segmenter: Updating major scripts --- .../segmentation/do_segmentation_data_dir.sh | 10 +-- .../do_segmentation_data_dir_simple.sh | 73 ++----------------- 2 files changed, 12 insertions(+), 71 deletions(-) diff --git a/egs/wsj/s5/steps/segmentation/do_segmentation_data_dir.sh b/egs/wsj/s5/steps/segmentation/do_segmentation_data_dir.sh index 9e95cca9cc0..2117dc2d939 100755 --- a/egs/wsj/s5/steps/segmentation/do_segmentation_data_dir.sh +++ b/egs/wsj/s5/steps/segmentation/do_segmentation_data_dir.sh @@ -130,11 +130,11 @@ if [ $stage -le 4 ]; then utils/apply_map.pl -f 2 ${test_data_dir}/reco2num_frames > \ ${data_dir}_seg/utt2max_frames - frame_shift_info=`cat $mfcc_config | steps/segmentation/get_frame_shift_info_from_config.pl` - utils/data/get_subsegment_feats.sh ${test_data_dir}/feats.scp \ - $frame_shift_info ${data_dir}_seg/segments | \ - utils/data/fix_subsegmented_feats.pl ${data_dir}_seg/utt2max_frames > \ - ${data_dir}_seg/feats.scp + #frame_shift_info=`cat $mfcc_config | steps/segmentation/get_frame_shift_info_from_config.pl` + #utils/data/get_subsegment_feats.sh ${test_data_dir}/feats.scp \ + # $frame_shift_info ${data_dir}_seg/segments | \ + # utils/data/fix_subsegmented_feats.pl ${data_dir}_seg/utt2max_frames > \ + # ${data_dir}_seg/feats.scp steps/compute_cmvn_stats.sh --fake ${data_dir}_seg utils/fix_data_dir.sh ${data_dir}_seg diff --git a/egs/wsj/s5/steps/segmentation/do_segmentation_data_dir_simple.sh b/egs/wsj/s5/steps/segmentation/do_segmentation_data_dir_simple.sh index cd4f36ded6b..7211b6b7084 100755 --- a/egs/wsj/s5/steps/segmentation/do_segmentation_data_dir_simple.sh +++ b/egs/wsj/s5/steps/segmentation/do_segmentation_data_dir_simple.sh @@ -14,7 +14,10 @@ nj=32 # works on recordings as against on speakers mfcc_config=conf/mfcc_hires_bp.conf feat_affix=bp # Affix for the type of feature used -skip_output_computation=false +convert_data_dir_to_whole=true + +# Set to true if the test data has > 8kHz sampling frequency. +do_downsampling=false stage=-1 sad_stage=-1 @@ -32,16 +35,12 @@ extra_right_context=0 frame_subsampling_factor=1 # Subsampling at the output -transition_scale=1.0 +transition_scale=3.0 loopscale=0.1 acwt=1.0 -# Set to true if the test data has > 8kHz sampling frequency. -do_downsampling=false - # Segmentation configs segmentation_config=conf/segmentation_speech.conf -convert_data_dir_to_whole=true echo $* @@ -82,10 +81,12 @@ if $convert_data_dir_to_whole; then utils/data/downsample_data_dir.sh $freq $whole_data_dir fi + rm -r ${test_data_dir} || true utils/copy_data_dir.sh ${whole_data_dir} $test_data_dir fi else if [ $stage -le 0 ]; then + rm -r ${test_data_dir} || true utils/copy_data_dir.sh $src_data_dir $test_data_dir if $do_downsampling; then @@ -179,63 +180,3 @@ if [ $stage -le 7 ]; then cp $src_data_dir/wav.scp ${data_dir}_seg fi - -exit 0 - -segments_opts="--single-speaker" - -if false; then - mkdir -p ${seg_dir}/post_process_${data_id} - echo $nj > ${seg_dir}/post_process_${data_id}/num_jobs - - $train_cmd JOB=1:$nj $seg_dir/log/convert_to_segments.JOB.log \ - segmentation-init-from-ali "ark:gunzip -c $seg_dir/ali.JOB.gz |" ark:- \| \ - segmentation-copy --label-map=$lang/phone2sad_map --frame-subsampling-factor=$frame_subsampling_factor ark:- ark:- \| \ - segmentation-to-segments --frame-overlap=0.02 $segments_opts ark:- \ - ark,t:${seg_dir}/post_process_${data_id}/utt2spk.JOB \ - ${seg_dir}/post_process_${data_id}/segments.JOB - - for n in `seq $nj`; do - cat ${seg_dir}/post_process_${data_id}/segments.$n - done > ${seg_dir}/post_process_${data_id}/segments - - for n in `seq $nj`; do - cat ${seg_dir}/post_process_${data_id}/utt2spk.$n - done > ${seg_dir}/post_process_${data_id}/utt2spk - - rm -r ${data_dir}_seg || true - mkdir -p ${data_dir}_seg - - utils/data/subsegment_data_dir.sh ${test_data_dir} \ - ${seg_dir}/post_process_${data_id}/segments ${data_dir}_seg - - cp ${src_data_dir}/wav.scp ${data_dir}_seg - cp ${seg_dir}/post_process_${data_id}/utt2spk ${data_dir}_seg - for f in stm glm reco2file_and_channel; do - [ -f $src_data_dir/$f ] && cp ${src_data_dir}/$f ${data_dir}_seg - done - - rm ${data_dir}/{cmvn.scp,spk2utt} || true - utils/fix_data_dir.sh ${data_dir}_seg -fi - -exit 0 - -# Subsegment data directory -if [ $stage -le 8 ]; then - utils/data/get_reco2num_frames.sh ${test_data_dir} - awk '{print $1" "$2}' ${data_dir}_seg/segments | \ - utils/apply_map.pl -f 2 ${test_data_dir}/reco2num_frames > \ - ${data_dir}_seg/utt2max_frames - - frame_shift_info=`cat $mfcc_config | steps/segmentation/get_frame_shift_info_from_config.pl` - utils/data/get_subsegment_feats.sh ${test_data_dir}/feats.scp \ - $frame_shift_info ${data_dir}_seg/segments | \ - utils/data/fix_subsegmented_feats.pl ${data_dir}_seg/utt2max_frames > \ - ${data_dir}_seg/feats.scp - steps/compute_cmvn_stats.sh --fake ${data_dir}_seg - - utils/fix_data_dir.sh ${data_dir}_seg -fi - - From 1dd03c71b7b4cffe7cd2be857a210a8b726d2c82 Mon Sep 17 00:00:00 2001 From: Vimal Manohar Date: Mon, 24 Apr 2017 21:36:39 -0400 Subject: [PATCH 208/213] segmenter: snr preparation --- .../do_corruption_data_dir_snr.sh | 236 ++++++++++++++++++ 1 file changed, 236 insertions(+) create mode 100755 egs/aspire/s5/local/segmentation/do_corruption_data_dir_snr.sh diff --git a/egs/aspire/s5/local/segmentation/do_corruption_data_dir_snr.sh b/egs/aspire/s5/local/segmentation/do_corruption_data_dir_snr.sh new file mode 100755 index 00000000000..19b4036c9aa --- /dev/null +++ b/egs/aspire/s5/local/segmentation/do_corruption_data_dir_snr.sh @@ -0,0 +1,236 @@ +#! /bin/bash + +# Copyright 2016 Vimal Manohar +# Apache 2.0 + +set -e +set -u +set -o pipefail + +. path.sh + +stage=0 +corruption_stage=-10 +corrupt_only=false + +# Data options +data_dir=data/train_si284 # Expecting whole data directory. +speed_perturb=true +num_data_reps=5 # Number of corrupted versions +snrs="20:10:15:5:0:-5" +foreground_snrs="20:10:15:5:0:-5" +background_snrs="20:10:15:5:2:0:-2:-5" +base_rirs=simulated +speeds="0.9 1.0 1.1" +resample_data_dir=false + +# Parallel options +reco_nj=40 +cmd=queue.pl + +# Options for feature extraction +mfcc_config=conf/mfcc_hires_bp.conf +feat_suffix=hires_bp + +reco_vad_dir= # Output of prepare_unsad_data.sh. + # If provided, the speech labels and deriv weights will be + # copied into the output data directory. + +. utils/parse_options.sh + +if [ $# -ne 0 ]; then + echo "Usage: $0" + exit 1 +fi + +data_id=`basename ${data_dir}` + +rvb_opts=() +if [ "$base_rirs" == "simulated" ]; then + # This is the config for the system using simulated RIRs and point-source noises + rvb_opts+=(--rir-set-parameters "0.5, RIRS_NOISES/simulated_rirs/smallroom/rir_list") + rvb_opts+=(--rir-set-parameters "0.5, RIRS_NOISES/simulated_rirs/mediumroom/rir_list") + rvb_opts+=(--noise-set-parameters "0.1, RIRS_NOISES/pointsource_noises/background_noise_list") + rvb_opts+=(--noise-set-parameters "0.9, RIRS_NOISES/pointsource_noises/foreground_noise_list") +else + # This is the config for the JHU ASpIRE submission system + rvb_opts+=(--rir-set-parameters "1.0, RIRS_NOISES/real_rirs_isotropic_noises/rir_list") + rvb_opts+=(--noise-set-parameters RIRS_NOISES/real_rirs_isotropic_noises/noise_list) +fi + +if $resample_data_dir; then + sample_frequency=`cat $mfcc_config | perl -ne 'if (m/--sample-frequency=(\S+)/) { print $1; }'` + if [ -z "$sample_frequency" ]; then + sample_frequency=16000 + fi + + utils/data/resample_data_dir.sh $sample_frequency ${data_dir} || exit 1 + data_id=`basename ${data_dir}` + rvb_opts+=(--source-sampling-rate=$sample_frequency) +fi + +corrupted_data_id=${data_id}_corrupted +clean_data_id=${data_id}_clean +noise_data_id=${data_id}_noise + +if [ $stage -le 1 ]; then + python steps/data/reverberate_data_dir.py \ + "${rvb_opts[@]}" \ + --prefix="rev" \ + --foreground-snrs=$foreground_snrs \ + --background-snrs=$background_snrs \ + --speech-rvb-probability=1 \ + --pointsource-noise-addition-probability=1 \ + --isotropic-noise-addition-probability=1 \ + --num-replications=$num_data_reps \ + --max-noises-per-minute=2 \ + --output-additive-noise-dir=data/${noise_data_id} \ + --output-reverb-dir=data/${clean_data_id} \ + data/${data_id} data/${corrupted_data_id} +fi + +corrupted_data_dir=data/${corrupted_data_id} +clean_data_dir=data/${clean_data_id} +noise_data_dir=data/${noise_data_id} + +if $speed_perturb; then + if [ $stage -le 2 ]; then + ## Assuming whole data directories + for x in $corrupted_data_dir $clean_data_dir $noise_data_dir; do + cp $x/reco2dur $x/utt2dur + utils/data/perturb_data_dir_speed_random.sh --speeds "$speeds" $x ${x}_spr + done + fi + + corrupted_data_dir=${corrupted_data_dir}_spr + clean_data_dir=${clean_data_dir}_spr + noise_data_dir=${noise_data_dir}_spr + corrupted_data_id=${corrupted_data_id}_spr + clean_data_id=${clean_data_id}_spr + noise_data_id=${noise_data_id}_spr + + if [ $stage -le 3 ]; then + utils/data/perturb_data_dir_volume.sh --scale-low 0.03125 --scale-high 2 \ + ${corrupted_data_dir} + utils/data/perturb_data_dir_volume.sh --reco2vol ${corrupted_data_dir}/reco2vol ${clean_data_dir} + utils/data/perturb_data_dir_volume.sh --reco2vol ${corrupted_data_dir}/reco2vol ${noise_data_dir} + fi +fi + +if $corrupt_only; then + echo "$0: Got corrupted data directory in ${corrupted_data_dir}" + exit 0 +fi + +mfccdir=`basename $mfcc_config` +mfccdir=${mfccdir%%.conf} + +if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $mfccdir/storage ]; then + utils/create_split_dir.pl \ + /export/b0{3,4,5,6}/$USER/kaldi-data/egs/aspire-$(date +'%m_%d_%H_%M')/s5/$mfccdir/storage $mfccdir/storage +fi + +if [ $stage -le 4 ]; then + utils/copy_data_dir.sh $corrupted_data_dir ${corrupted_data_dir}_$feat_suffix + corrupted_data_dir=${corrupted_data_dir}_$feat_suffix + steps/make_mfcc.sh --mfcc-config $mfcc_config \ + --cmd "$cmd" --nj $reco_nj \ + $corrupted_data_dir exp/make_${feat_suffix}/${corrupted_data_id} $mfccdir + steps/compute_cmvn_stats.sh --fake \ + $corrupted_data_dir exp/make_${feat_suffix}/${corrupted_data_id} $mfccdir +else + corrupted_data_dir=${corrupted_data_dir}_$feat_suffix +fi + +if [ $stage -le 5 ]; then + utils/copy_data_dir.sh $clean_data_dir ${clean_data_dir}_$feat_suffix + clean_data_dir=${clean_data_dir}_$feat_suffix + steps/make_mfcc.sh --mfcc-config $mfcc_config \ + --cmd "$cmd" --nj $reco_nj \ + $clean_data_dir exp/make_${feat_suffix}/${clean_data_id} $mfccdir + steps/compute_cmvn_stats.sh --fake \ + $clean_data_dir exp/make_${feat_suffix}/${clean_data_id} $mfccdir +else + clean_data_dir=${clean_data_dir}_$feat_suffix +fi + +if [ $stage -le 6 ]; then + utils/copy_data_dir.sh $noise_data_dir ${noise_data_dir}_$feat_suffix + noise_data_dir=${noise_data_dir}_$feat_suffix + steps/make_mfcc.sh --mfcc-config $mfcc_config \ + --cmd "$cmd" --nj $reco_nj \ + $noise_data_dir exp/make_${feat_suffix}/${noise_data_id} $mfccdir + steps/compute_cmvn_stats.sh --fake \ + $noise_data_dir exp/make_${feat_suffix}/${noise_data_id} $mfccdir +else + noise_data_dir=${noise_data_dir}_$feat_suffix +fi + +targets_dir=irm_targets +if [ $stage -le 7 ]; then + mkdir -p exp/make_log_snr/${corrupted_data_id} + + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $targets_dir/storage ]; then + utils/create_split_dir.pl \ + /export/b0{3,4,5,6}/$USER/kaldi-data/egs/aspire-$(date +'%m_%d_%H_%M')/s5/$targets_dir/storage $targets_dir/storage + fi + + idct_params=`cat $mfcc_config | perl -e ' + $num_mel_bins = 23; $num_ceps = 13; $cepstral_lifter = 22.0; + while (<>) { + chomp; + s/#.+//g; + if (m/^\s*$/) { next; } + if (m/--num-mel-bins=(\S+)/) { + $num_mel_bins = $1; + } elsif (m/--num-ceps=(\S+)/) { + $num_ceps = $1; + } elsif (m/--cepstral-lifter=(\S+)/) { + $cepstral_lifter = $1; + } + } + print "$num_mel_bins $num_ceps $cepstral_lifter";'` + + num_filters=`echo $idct_params | awk '{print $1}'` + num_ceps=`echo $idct_params | awk '{print $2}'` + cepstral_lifter=`echo $idct_params | awk '{print $3}'` + echo "$num_filters $num_ceps $cepstral_lifter" + + mkdir -p exp/make_irm_targets/$corrupted_data_id + utils/data/get_dct_matrix.py --get-idct-matrix=true \ + --num-filters=$num_filters --num-ceps=$num_ceps \ + --cepstral-lifter=$cepstral_lifter \ + exp/make_irm_targets/$corrupted_data_id/idct_matrix + + # Get log-SNR targets + steps/segmentation/make_snr_targets.sh \ + --nj $reco_nj --cmd "$cmd" \ + --target-type Irm --compress false \ + --transform-matrix exp/make_irm_targets/$corrupted_data_id/idct_matrix \ + ${clean_data_dir} ${noise_data_dir} ${corrupted_data_dir} \ + exp/make_irm_targets/${corrupted_data_id} $targets_dir +fi + + +if [ $stage -le 8 ]; then + if [ ! -z "$reco_vad_dir" ]; then + if [ ! -f $reco_vad_dir/speech_labels.scp ]; then + echo "$0: Could not find file $reco_vad_dir/speech_labels.scp" + exit 1 + fi + + cat $reco_vad_dir/speech_labels.scp | \ + steps/segmentation/get_reverb_scp.pl -f 1 $num_data_reps | \ + sort -k1,1 > ${corrupted_data_dir}/speech_labels.scp + + cat $reco_vad_dir/deriv_weights.scp | \ + steps/segmentation/get_reverb_scp.pl -f 1 $num_data_reps | \ + sort -k1,1 > ${corrupted_data_dir}/deriv_weights.scp + + cat $reco_vad_dir/deriv_weights_manual_seg.scp | \ + steps/segmentation/get_reverb_scp.pl -f 1 $num_data_reps | \ + sort -k1,1 > ${corrupted_data_dir}/deriv_weights_for_irm_targets.scp + fi +fi + +exit 0 From c85d1617d3487935c0f525059d910862d8b4c948 Mon Sep 17 00:00:00 2001 From: Vimal Manohar Date: Tue, 25 Apr 2017 03:23:56 -0400 Subject: [PATCH 209/213] segmenter: Temporary fix for nnet3 computation --- src/nnet3/nnet-optimize-utils.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/nnet3/nnet-optimize-utils.cc b/src/nnet3/nnet-optimize-utils.cc index 60ec93f3f18..72f4147931b 100644 --- a/src/nnet3/nnet-optimize-utils.cc +++ b/src/nnet3/nnet-optimize-utils.cc @@ -2523,7 +2523,7 @@ void ComputationExpander::ExpandRowRangesCommand( num_rows_new = expanded_computation_->submatrices[s1].num_rows; KALDI_ASSERT(static_cast(c_in.arg3) < computation_.indexes_ranges.size()); - KALDI_ASSERT(num_rows_old % 2 == 0); + //KALDI_ASSERT(num_rows_old % 2 == 0); int32 num_n_values = num_n_values_; From 33c6ea47e770a5ec508203c551082c734e49b817 Mon Sep 17 00:00:00 2001 From: Vimal Manohar Date: Wed, 26 Apr 2017 16:43:55 -0400 Subject: [PATCH 210/213] SAD: More tuning recipes --- .../tuning/train_lstm_sad_music_1g.sh | 291 ++++++++++++++++ .../tuning/train_lstm_sad_music_1h.sh | 291 ++++++++++++++++ .../tuning/train_lstm_sad_music_1i.sh | 308 +++++++++++++++++ .../tuning/train_lstm_sad_music_snr_1h.sh | 306 +++++++++++++++++ .../tuning/train_lstm_sad_music_snr_1i.sh | 315 +++++++++++++++++ .../tuning/train_lstm_sad_music_snr_1j.sh | 312 +++++++++++++++++ .../tuning/train_lstm_sad_music_snr_1k.sh | 316 ++++++++++++++++++ .../tuning/train_stats_sad_music_snr_1h.sh | 310 +++++++++++++++++ .../tuning/train_stats_sad_music_snr_1i.sh | 310 +++++++++++++++++ 9 files changed, 2759 insertions(+) create mode 100644 egs/aspire/s5/local/segmentation/tuning/train_lstm_sad_music_1g.sh create mode 100644 egs/aspire/s5/local/segmentation/tuning/train_lstm_sad_music_1h.sh create mode 100644 egs/aspire/s5/local/segmentation/tuning/train_lstm_sad_music_1i.sh create mode 100644 egs/aspire/s5/local/segmentation/tuning/train_lstm_sad_music_snr_1h.sh create mode 100644 egs/aspire/s5/local/segmentation/tuning/train_lstm_sad_music_snr_1i.sh create mode 100644 egs/aspire/s5/local/segmentation/tuning/train_lstm_sad_music_snr_1j.sh create mode 100644 egs/aspire/s5/local/segmentation/tuning/train_lstm_sad_music_snr_1k.sh create mode 100644 egs/aspire/s5/local/segmentation/tuning/train_stats_sad_music_snr_1h.sh create mode 100644 egs/aspire/s5/local/segmentation/tuning/train_stats_sad_music_snr_1i.sh diff --git a/egs/aspire/s5/local/segmentation/tuning/train_lstm_sad_music_1g.sh b/egs/aspire/s5/local/segmentation/tuning/train_lstm_sad_music_1g.sh new file mode 100644 index 00000000000..eea5956e005 --- /dev/null +++ b/egs/aspire/s5/local/segmentation/tuning/train_lstm_sad_music_1g.sh @@ -0,0 +1,291 @@ +#!/bin/bash + +# This is a script to train a time-delay neural network for speech activity detection (SAD) and +# music-id using statistic pooling component for long-context information. +# This script is same as 1c, but uses larger amount of data. + +set -o pipefail +set -e +set -u + +. cmd.sh + +# At this script level we don't support not running on GPU, as it would be painfully slow. +# If you want to run without GPU you'd have to call train_tdnn.sh with --gpu false, +# --num-threads 16 and --minibatch-size 128. + +stage=0 +train_stage=-10 +get_egs_stage=-10 +egs_opts= # Directly passed to get_egs_multiple_targets.py + +chunk_width=40 +num_chunk_per_minibatch=64 + +extra_left_context=40 +extra_right_context=0 + +relu_dim=256 +cell_dim=256 +projection_dim=64 + +# training options +num_epochs=2 +initial_effective_lrate=0.0003 +final_effective_lrate=0.00003 +num_jobs_initial=3 +num_jobs_final=8 +remove_egs=false +max_param_change=0.2 # Small max-param change for small network +extra_egs_copy_cmd= # Used if you want to do some weird stuff to egs + # such as removing one of the targets + +extra_egs_copy_cmd="nnet3-copy-egs --keep-outputs=output-speech,output-music,output-speech_music ark:- ark:- |" + +egs_dir= +nj=40 +feat_type=raw +config_dir= + +dir= +affix=1g + +. cmd.sh +. ./path.sh +. ./utils/parse_options.sh + +utils/combine_data.sh --extra-files "deriv_weights.scp speech_labels.scp" \ + data/train_tztec_whole_corrupted_spr_hires_bp data/fisher_train_100k_whole_900_corrupted_spr_hires_bp/ \ + data/babel_{turkish,zulu,cantonese,tamil}_train_whole_corrupted_spr_hires_bp/ + +cp data/train_tztec_whole_corrupted_spr_hires_bp/{speech_labels.scp,speech_music_labels.scp} + +utils/combine_data.sh --extra-files "deriv_weights.scp speech_labels.scp music_labels.scp speech_music_labels.scp" \ + data/train_tztec_whole_music_corrupted_spr_hires_bp data/fisher_train_100k_whole_900_music_corrupted_spr_hires_bp/ \ + data/babel_{turkish,zulu,cantonese,tamil}_train_whole_music_corrupted_spr_hires_bp/ + +sad_data_dir=data/train_tztec_whole_corrupted_spr_hires_bp +music_data_dir=data/train_tztec_whole_music_corrupted_spr_hires_bp + +num_utts=`cat $sad_data_dir/utt2spk $music_data_dir/utt2spk | wc -l` +num_utts_subset_valid=`perl -e '$n=int($ARGV[0] * 0.005); print ($n > 4000 ? 4000 : $n)' $num_utts` +num_utts_subset_train=`perl -e '$n=int($ARGV[0] * 0.005); print ($n > 4000 ? 4000 : $n)' $num_utts` + +if [ -z "$dir" ]; then + dir=exp/nnet3_lstm_sad_music/nnet_lstm +fi + +dir=$dir${affix:+_$affix} + +if ! cuda-compiled; then + cat < $dir/configs/network.xconfig + input dim=`feat-to-dim scp:$sad_data_dir/feats.scp -` name=input + + relu-renorm-layer name=tdnn1 input=Append(-3,-2,-1,0,1,2,3) dim=$relu_dim add-log-stddev=true + relu-renorm-layer name=tdnn2 input=Append(-6,0,6) dim=$relu_dim + fast-lstmp-layer name=lstm1 cell-dim=$cell_dim recurrent-projection-dim=$projection_dim non-recurrent-projection-dim=$projection_dim delay=-6 + relu-renorm-layer name=tdnn3 input=Append(-12,0,12) dim=$relu_dim + + output-layer name=output-speech include-log-softmax=true dim=2 presoftmax-scale-file=$dir/presoftmax_prior_scale_output-speech.txt learning-rate-factor=0.1 objective-scale=$speech_scale input=tdnn3 + output-layer name=output-music include-log-softmax=true dim=2 presoftmax-scale-file=$dir/presoftmax_prior_scale_output-music.txt learning-rate-factor=0.1 objective-scale=$music_scale input=tdnn3 + output-layer name=output-speech_music include-log-softmax=true dim=4 presoftmax-scale-file=$dir/presoftmax_prior_scale_output-speech_music.txt learning-rate-factor=0.1 objective-scale=$speech_music_scale input=tdnn3 + + output name=output-temp input=Append(input@-3,input@-2,input@-1,input,input@1,input@2, input@3) +EOF + steps/nnet3/xconfig_to_configs.py --xconfig-file $dir/configs/network.xconfig \ + --config-dir $dir/configs/ \ + --nnet-edits="rename-node old-name=output-speech new-name=output" + + cat <> $dir/configs/vars +add_lda=false +EOF +fi + +samples_per_iter=`perl -e "print int(400000 / $chunk_width)"` + +if [ -z "$egs_dir" ]; then + egs_dir=$dir/egs_multi + if [ $stage -le 2 ]; then + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $dir/egs_speech/storage ]; then + utils/create_split_dir.pl \ + /export/b{03,04,05,06}/$USER/kaldi-data/egs/aspire-$(date +'%m_%d_%H_%M')/s5/$dir/egs_speech/storage $dir/egs_speech/storage + fi + + . $dir/configs/vars + + steps/nnet3/get_egs_multiple_targets.py --cmd="$decode_cmd" \ + $egs_opts \ + --feat.dir="$sad_data_dir" \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --frames-per-eg=$chunk_width \ + --left-context=$[model_left_context + extra_left_context] \ + --right-context=$[model_right_context + extra_right_context] \ + --num-utts-subset-train=$num_utts_subset_train \ + --num-utts-subset-valid=$num_utts_subset_valid \ + --samples-per-iter=$samples_per_iter \ + --stage=$get_egs_stage \ + --targets-parameters="--output-name=output-speech --target-type=sparse --dim=2 --targets-scp=$sad_data_dir/speech_labels.scp --deriv-weights-scp=$sad_data_dir/deriv_weights.scp --scp2ark-cmd=\"ali-to-post scp:- ark:- |\" --compress=true" \ + --targets-parameters="--output-name=output-speech_music --target-type=sparse --dim=4 --targets-scp=$sad_data_dir/speech_music_labels.scp --deriv-weights-scp=$sad_data_dir/deriv_weights.scp --scp2ark-cmd=\"ali-to-post scp:- ark:- |\" --compress=true" \ + --generate-egs-scp=true \ + --dir=$dir/egs_speech + #--targets-parameters="--output-name=output-snr --target-type=dense --targets-scp=$sad_data_dir/irm_targets.scp --deriv-weights-scp=$sad_data_dir/deriv_weights_manual_seg.scp" \ + fi + + if [ $stage -le 3 ]; then + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $dir/egs_music/storage ]; then + utils/create_split_dir.pl \ + /export/b{03,04,05,06}/$USER/kaldi-data/egs/aspire-$(date +'%m_%d_%H_%M')/s5/$dir/egs_music/storage $dir/egs_music/storage + fi + + . $dir/configs/vars + + steps/nnet3/get_egs_multiple_targets.py --cmd="$decode_cmd" \ + $egs_opts \ + --feat.dir="$music_data_dir" \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --frames-per-eg=$chunk_width \ + --left-context=$[model_left_context + extra_left_context] \ + --right-context=$[model_right_context + extra_right_context] \ + --num-utts-subset-train=$num_utts_subset_train \ + --num-utts-subset-valid=$num_utts_subset_valid \ + --samples-per-iter=$samples_per_iter \ + --stage=$get_egs_stage \ + --targets-parameters="--output-name=output-music --target-type=sparse --dim=2 --targets-scp=$music_data_dir/music_labels.scp --scp2ark-cmd=\"ali-to-post scp:- ark:- |\" --compress=true" \ + --targets-parameters="--output-name=output-speech_music --target-type=sparse --dim=4 --targets-scp=$music_data_dir/speech_music_labels.scp --deriv-weights-scp=$music_data_dir/deriv_weights.scp --scp2ark-cmd=\"ali-to-post scp:- ark:- |\" --compress=true" \ + --targets-parameters="--output-name=output-speech --target-type=sparse --dim=2 --targets-scp=$music_data_dir/speech_labels.scp --deriv-weights-scp=$music_data_dir/deriv_weights.scp --scp2ark-cmd=\"ali-to-post scp:- ark:- |\" --compress=true" \ + --generate-egs-scp=true \ + --dir=$dir/egs_music + fi + + if [ $stage -le 4 ]; then + # num_chunk_per_minibatch is multiplied by 4 to allow a buffer to use + # the same egs with a different num_chunk_per_minibatch + steps/nnet3/multilingual/get_egs.sh \ + --cmd "$train_cmd" \ + --minibatch-size $[num_chunk_per_minibatch * 4] \ + --samples-per-iter $samples_per_iter \ + 2 $dir/egs_speech $dir/egs_music $dir/egs_multi + fi +fi + +if [ $stage -le 5 ]; then + steps/nnet3/train_raw_rnn.py --stage=$train_stage \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --egs.chunk-width=$chunk_width \ + --egs.dir="$egs_dir" --egs.stage=$get_egs_stage \ + --egs.chunk-left-context=$extra_left_context \ + --egs.chunk-right-context=$extra_right_context \ + --egs.use-multitask-egs=true --egs.rename-multitask-outputs=false \ + ${extra_egs_copy_cmd:+--egs.extra-copy-cmd="$extra_egs_copy_cmd"} \ + --trainer.num-epochs=$num_epochs \ + --trainer.samples-per-iter=20000 \ + --trainer.optimization.num-jobs-initial=$num_jobs_initial \ + --trainer.optimization.num-jobs-final=$num_jobs_final \ + --trainer.optimization.initial-effective-lrate=$initial_effective_lrate \ + --trainer.optimization.final-effective-lrate=$final_effective_lrate \ + --trainer.optimization.shrink-value=1.0 \ + --trainer.rnn.num-chunk-per-minibatch=$num_chunk_per_minibatch \ + --trainer.deriv-truncate-margin=8 \ + --trainer.max-param-change=$max_param_change \ + --trainer.compute-per-dim-accuracy=true \ + --cmd="$decode_cmd" --nj 40 \ + --cleanup=true \ + --cleanup.remove-egs=$remove_egs \ + --cleanup.preserve-model-interval=10 \ + --use-gpu=true \ + --use-dense-targets=false \ + --feat-dir=$sad_data_dir \ + --targets-scp="$sad_data_dir/speech_labels.scp" \ + --dir=$dir || exit 1 +fi + diff --git a/egs/aspire/s5/local/segmentation/tuning/train_lstm_sad_music_1h.sh b/egs/aspire/s5/local/segmentation/tuning/train_lstm_sad_music_1h.sh new file mode 100644 index 00000000000..d9e1966bf6a --- /dev/null +++ b/egs/aspire/s5/local/segmentation/tuning/train_lstm_sad_music_1h.sh @@ -0,0 +1,291 @@ +#!/bin/bash + +# This is a script to train a time-delay neural network for speech activity detection (SAD) and +# music-id using statistic pooling component for long-context information. +# This script is same as 1c, but uses larger amount of data. + +set -o pipefail +set -e +set -u + +. cmd.sh + +# At this script level we don't support not running on GPU, as it would be painfully slow. +# If you want to run without GPU you'd have to call train_tdnn.sh with --gpu false, +# --num-threads 16 and --minibatch-size 128. + +stage=0 +train_stage=-10 +get_egs_stage=-10 +egs_opts= # Directly passed to get_egs_multiple_targets.py + +chunk_width=40 +num_chunk_per_minibatch=64 + +extra_left_context=40 +extra_right_context=0 + +relu_dim=256 +cell_dim=256 +projection_dim=64 + +# training options +num_epochs=2 +initial_effective_lrate=0.0003 +final_effective_lrate=0.00003 +num_jobs_initial=3 +num_jobs_final=8 +remove_egs=false +max_param_change=0.2 # Small max-param change for small network +extra_egs_copy_cmd= # Used if you want to do some weird stuff to egs + # such as removing one of the targets + +extra_egs_copy_cmd="nnet3-copy-egs --keep-outputs=output-speech,output-music,output-speech_music ark:- ark:- |" + +egs_dir= +nj=40 +feat_type=raw +config_dir= + +dir= +affix=1h + +. cmd.sh +. ./path.sh +. ./utils/parse_options.sh + +utils/combine_data.sh --extra-files "deriv_weights.scp speech_labels.scp" \ + data/train_tztec_whole_corrupted_spr_hires_bp data/fisher_train_100k_whole_900_corrupted_spr_hires_bp/ \ + data/babel_{turkish,zulu,cantonese,tamil}_train_whole_corrupted_spr_hires_bp/ + +cp data/train_tztec_whole_corrupted_spr_hires_bp/{speech_labels.scp,speech_music_labels.scp} + +utils/combine_data.sh --extra-files "deriv_weights.scp speech_labels.scp music_labels.scp speech_music_labels.scp" \ + data/train_tztec_whole_music_corrupted_spr_hires_bp data/fisher_train_100k_whole_900_music_corrupted_spr_hires_bp/ \ + data/babel_{turkish,zulu,cantonese,tamil}_train_whole_music_corrupted_spr_hires_bp/ + +sad_data_dir=data/train_tztec_whole_corrupted_spr_hires_bp +music_data_dir=data/train_tztec_whole_music_corrupted_spr_hires_bp + +num_utts=`cat $sad_data_dir/utt2spk $music_data_dir/utt2spk | wc -l` +num_utts_subset_valid=`perl -e '$n=int($ARGV[0] * 0.005); print ($n > 4000 ? 4000 : $n)' $num_utts` +num_utts_subset_train=`perl -e '$n=int($ARGV[0] * 0.005); print ($n > 4000 ? 4000 : $n)' $num_utts` + +if [ -z "$dir" ]; then + dir=exp/nnet3_lstm_sad_music/nnet_lstm +fi + +dir=$dir${affix:+_$affix} + +if ! cuda-compiled; then + cat < $dir/configs/network.xconfig + input dim=`feat-to-dim scp:$sad_data_dir/feats.scp -` name=input + + relu-renorm-layer name=tdnn1 input=Append(-3,-2,-1,0,1,2,3) dim=$relu_dim add-log-stddev=true + relu-renorm-layer name=tdnn2 input=Append(-6,0,6) dim=$relu_dim + fast-lstmp-layer name=lstm1 cell-dim=$cell_dim recurrent-projection-dim=$projection_dim non-recurrent-projection-dim=$projection_dim delay=-6 + relu-renorm-layer name=tdnn3 input=Append(-12,0,12) dim=$relu_dim + + output-layer name=output-speech include-log-softmax=true dim=2 presoftmax-scale-file=$dir/presoftmax_prior_scale_output-speech.txt learning-rate-factor=0.1 objective-scale=$speech_scale input=tdnn3 + output-layer name=output-music include-log-softmax=true dim=2 presoftmax-scale-file=$dir/presoftmax_prior_scale_output-music.txt learning-rate-factor=0.1 objective-scale=$music_scale input=tdnn3 + output-layer name=output-speech_music include-log-softmax=true dim=4 presoftmax-scale-file=$dir/presoftmax_prior_scale_output-speech_music.txt learning-rate-factor=0.1 objective-scale=$speech_music_scale input=tdnn3 + + output name=output-temp input=Append(input@-3,input@-2,input@-1,input,input@1,input@2, input@3) +EOF + steps/nnet3/xconfig_to_configs.py --xconfig-file $dir/configs/network.xconfig \ + --config-dir $dir/configs/ \ + --nnet-edits="rename-node old-name=output-speech new-name=output" + + cat <> $dir/configs/vars +add_lda=false +EOF +fi + +samples_per_iter=`perl -e "print int(400000 / $chunk_width)"` + +if [ -z "$egs_dir" ]; then + egs_dir=$dir/egs_multi + if [ $stage -le 2 ]; then + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $dir/egs_speech/storage ]; then + utils/create_split_dir.pl \ + /export/b{03,04,05,06}/$USER/kaldi-data/egs/aspire-$(date +'%m_%d_%H_%M')/s5/$dir/egs_speech/storage $dir/egs_speech/storage + fi + + . $dir/configs/vars + + steps/nnet3/get_egs_multiple_targets.py --cmd="$decode_cmd" \ + $egs_opts \ + --feat.dir="$sad_data_dir" \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --frames-per-eg=$chunk_width \ + --left-context=$[model_left_context + extra_left_context] \ + --right-context=$[model_right_context + extra_right_context] \ + --num-utts-subset-train=$num_utts_subset_train \ + --num-utts-subset-valid=$num_utts_subset_valid \ + --samples-per-iter=$samples_per_iter \ + --stage=$get_egs_stage \ + --targets-parameters="--output-name=output-speech --target-type=sparse --dim=2 --targets-scp=$sad_data_dir/speech_labels.scp --deriv-weights-scp=$sad_data_dir/deriv_weights.scp --scp2ark-cmd=\"ali-to-post scp:- ark:- |\" --compress=true" \ + --targets-parameters="--output-name=output-speech_music --target-type=sparse --dim=4 --targets-scp=$sad_data_dir/speech_music_labels.scp --deriv-weights-scp=$sad_data_dir/deriv_weights.scp --scp2ark-cmd=\"ali-to-post scp:- ark:- |\" --compress=true" \ + --generate-egs-scp=true \ + --dir=$dir/egs_speech + #--targets-parameters="--output-name=output-snr --target-type=dense --targets-scp=$sad_data_dir/irm_targets.scp --deriv-weights-scp=$sad_data_dir/deriv_weights_manual_seg.scp" \ + fi + + if [ $stage -le 3 ]; then + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $dir/egs_music/storage ]; then + utils/create_split_dir.pl \ + /export/b{03,04,05,06}/$USER/kaldi-data/egs/aspire-$(date +'%m_%d_%H_%M')/s5/$dir/egs_music/storage $dir/egs_music/storage + fi + + . $dir/configs/vars + + steps/nnet3/get_egs_multiple_targets.py --cmd="$decode_cmd" \ + $egs_opts \ + --feat.dir="$music_data_dir" \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --frames-per-eg=$chunk_width \ + --left-context=$[model_left_context + extra_left_context] \ + --right-context=$[model_right_context + extra_right_context] \ + --num-utts-subset-train=$num_utts_subset_train \ + --num-utts-subset-valid=$num_utts_subset_valid \ + --samples-per-iter=$samples_per_iter \ + --stage=$get_egs_stage \ + --targets-parameters="--output-name=output-music --target-type=sparse --dim=2 --targets-scp=$music_data_dir/music_labels.scp --scp2ark-cmd=\"ali-to-post scp:- ark:- |\" --compress=true" \ + --targets-parameters="--output-name=output-speech_music --target-type=sparse --dim=4 --targets-scp=$music_data_dir/speech_music_labels.scp --deriv-weights-scp=$music_data_dir/deriv_weights.scp --scp2ark-cmd=\"ali-to-post scp:- ark:- |\" --compress=true" \ + --targets-parameters="--output-name=output-speech --target-type=sparse --dim=2 --targets-scp=$music_data_dir/speech_labels.scp --deriv-weights-scp=$music_data_dir/deriv_weights.scp --scp2ark-cmd=\"ali-to-post scp:- ark:- |\" --compress=true" \ + --generate-egs-scp=true \ + --dir=$dir/egs_music + fi + + if [ $stage -le 4 ]; then + # num_chunk_per_minibatch is multiplied by 4 to allow a buffer to use + # the same egs with a different num_chunk_per_minibatch + steps/nnet3/multilingual/get_egs.sh \ + --cmd "$train_cmd" \ + --minibatch-size $num_chunk_per_minibatch \ + --samples-per-iter $samples_per_iter \ + 2 $dir/egs_speech $dir/egs_music $dir/egs_multi + fi +fi + +if [ $stage -le 5 ]; then + steps/nnet3/train_raw_rnn.py --stage=$train_stage \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --egs.chunk-width=$chunk_width \ + --egs.dir="$egs_dir" --egs.stage=$get_egs_stage \ + --egs.chunk-left-context=$extra_left_context \ + --egs.chunk-right-context=$extra_right_context \ + --egs.use-multitask-egs=true --egs.rename-multitask-outputs=false \ + ${extra_egs_copy_cmd:+--egs.extra-copy-cmd="$extra_egs_copy_cmd"} \ + --trainer.num-epochs=$num_epochs \ + --trainer.samples-per-iter=20000 \ + --trainer.optimization.num-jobs-initial=$num_jobs_initial \ + --trainer.optimization.num-jobs-final=$num_jobs_final \ + --trainer.optimization.initial-effective-lrate=$initial_effective_lrate \ + --trainer.optimization.final-effective-lrate=$final_effective_lrate \ + --trainer.optimization.shrink-value=1.0 \ + --trainer.rnn.num-chunk-per-minibatch=$num_chunk_per_minibatch \ + --trainer.deriv-truncate-margin=8 \ + --trainer.max-param-change=$max_param_change \ + --trainer.compute-per-dim-accuracy=true \ + --cmd="$decode_cmd" --nj 40 \ + --cleanup=true \ + --cleanup.remove-egs=$remove_egs \ + --cleanup.preserve-model-interval=10 \ + --use-gpu=true \ + --use-dense-targets=false \ + --feat-dir=$sad_data_dir \ + --targets-scp="$sad_data_dir/speech_labels.scp" \ + --dir=$dir || exit 1 +fi + diff --git a/egs/aspire/s5/local/segmentation/tuning/train_lstm_sad_music_1i.sh b/egs/aspire/s5/local/segmentation/tuning/train_lstm_sad_music_1i.sh new file mode 100644 index 00000000000..be568eefd97 --- /dev/null +++ b/egs/aspire/s5/local/segmentation/tuning/train_lstm_sad_music_1i.sh @@ -0,0 +1,308 @@ +#!/bin/bash + +# This is a script to train a TDNN-LSTM for speech activity detection (SAD) and +# music-id using LSTM for long-context information. +# This is same as 1h, but has more layers. + +set -o pipefail +set -e +set -u + +. cmd.sh + +# At this script level we don't support not running on GPU, as it would be painfully slow. +# If you want to run without GPU you'd have to call train_tdnn.sh with --gpu false, +# --num-threads 16 and --minibatch-size 128. + +stage=0 +train_stage=-10 +get_egs_stage=-10 +egs_opts= # Directly passed to get_egs_multiple_targets.py + +chunk_width=20 +num_chunk_per_minibatch=64 + +extra_left_context=40 +extra_right_context=0 + +relu_dim=256 +cell_dim=256 +projection_dim=64 + +# training options +num_epochs=2 +initial_effective_lrate=0.0003 +final_effective_lrate=0.00003 +num_jobs_initial=3 +num_jobs_final=8 +remove_egs=false +max_param_change=0.2 # Small max-param change for small network +extra_egs_copy_cmd= # Used if you want to do some weird stuff to egs + # such as removing one of the targets + +extra_egs_copy_cmd="nnet3-copy-egs --keep-outputs=output-speech,output-music,output-speech_music ark:- ark:- |" + +egs_dir= +nj=40 +feat_type=raw +config_dir= + +dir= +affix=1i + +. cmd.sh +. ./path.sh +. ./utils/parse_options.sh + +if [ $stage -le -1 ]; then + utils/combine_data.sh --extra-files "deriv_weights.scp speech_labels.scp irm_targets.scp deriv_weights_for_irm_targets.scp" \ + data/train_tztec_whole_corrupted_spr_hires_bp data/fisher_train_100k_whole_900_corrupted_spr_hires_bp/ \ + data/babel_{turkish,zulu,cantonese,tamil}_train_whole_corrupted_spr_hires_bp/ + + cp data/train_tztec_whole_corrupted_spr_hires_bp/{speech_labels.scp,speech_music_labels.scp} + + utils/combine_data.sh --extra-files "deriv_weights.scp speech_labels.scp music_labels.scp speech_music_labels.scp" \ + data/train_tztec_whole_music_corrupted_spr_hires_bp data/fisher_train_100k_whole_900_music_corrupted_spr_hires_bp/ \ + data/babel_{turkish,zulu,cantonese,tamil}_train_whole_music_corrupted_spr_hires_bp/ +fi + +sad_data_dir=data/train_tztec_whole_corrupted_spr_hires_bp +music_data_dir=data/train_tztec_whole_music_corrupted_spr_hires_bp + +num_utts=`cat $sad_data_dir/utt2spk $music_data_dir/utt2spk | wc -l` +num_utts_subset_valid=`perl -e '$n=int($ARGV[0] * 0.005); print ($n > 4000 ? 4000 : $n)' $num_utts` +num_utts_subset_train=`perl -e '$n=int($ARGV[0] * 0.005); print ($n > 4000 ? 4000 : $n)' $num_utts` + +if [ -z "$dir" ]; then + dir=exp/nnet3_lstm_sad_music/nnet_lstm +fi + +dir=$dir${affix:+_$affix} + +if ! cuda-compiled; then + cat < $dir/scales +fi + +if [ $stage -le 2 ]; then + echo "$0: creating neural net configs using the xconfig parser"; + + scales=`cat $dir/scales` + + speech_scale=`echo $scales | awk '{print $1}'` + music_scale=`echo $scales | awk '{print $2}'` + speech_music_scale=`echo $scales | awk '{print $3}'` + + mkdir -p $dir/configs + cat < $dir/configs/network.xconfig + input dim=`feat-to-dim scp:$sad_data_dir/feats.scp -` name=input + + relu-renorm-layer name=tdnn1 input=Append(-2,-1,0,1,2) dim=$relu_dim add-log-stddev=true + relu-renorm-layer name=tdnn2 input=Append(-1,0,1,2) dim=$relu_dim add-log-stddev=true + relu-renorm-layer name=tdnn3 input=Append(-3,0,3,6) dim=$relu_dim add-log-stddev=true + fast-lstmp-layer name=lstm1 cell-dim=$cell_dim recurrent-projection-dim=$projection_dim non-recurrent-projection-dim=$projection_dim delay=-6 + relu-renorm-layer name=tdnn4 input=Append(-6,0,6,12) dim=$relu_dim + relu-renorm-layer name=tdnn5 input=Append(-12,0,12,24) dim=$relu_dim + + output-layer name=output-speech include-log-softmax=true dim=2 presoftmax-scale-file=$dir/presoftmax_prior_scale_output-speech.txt learning-rate-factor=0.1 objective-scale=$speech_scale input=tdnn5 + output-layer name=output-music include-log-softmax=true dim=2 presoftmax-scale-file=$dir/presoftmax_prior_scale_output-music.txt learning-rate-factor=0.1 objective-scale=$music_scale input=tdnn5 + output-layer name=output-speech_music include-log-softmax=true dim=4 presoftmax-scale-file=$dir/presoftmax_prior_scale_output-speech_music.txt learning-rate-factor=0.1 objective-scale=$speech_music_scale input=tdnn5 + + output name=output-temp input=Append(input@-3,input@-2,input@-1,input,input@1,input@2, input@3) +EOF + steps/nnet3/xconfig_to_configs.py --xconfig-file $dir/configs/network.xconfig \ + --config-dir $dir/configs/ \ + --nnet-edits="rename-node old-name=output-speech new-name=output" + + cat <> $dir/configs/vars +add_lda=false +EOF +fi + +samples_per_iter=`perl -e "print int(400000 / $chunk_width)"` + +if [ -z "$egs_dir" ]; then + egs_dir=$dir/egs_multi + if [ $stage -le 3 ]; then + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $dir/egs_speech/storage ]; then + utils/create_split_dir.pl \ + /export/b{03,04,05,06}/$USER/kaldi-data/egs/aspire-$(date +'%m_%d_%H_%M')/s5/$dir/egs_speech/storage $dir/egs_speech/storage + fi + + . $dir/configs/vars + + steps/nnet3/get_egs_multiple_targets.py --cmd="$decode_cmd" \ + $egs_opts \ + --feat.dir="$sad_data_dir" \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --frames-per-eg=$chunk_width \ + --left-context=$[model_left_context + extra_left_context] \ + --right-context=$[model_right_context + extra_right_context] \ + --num-utts-subset-train=$num_utts_subset_train \ + --num-utts-subset-valid=$num_utts_subset_valid \ + --samples-per-iter=$samples_per_iter \ + --stage=$get_egs_stage \ + --targets-parameters="--output-name=output-speech --target-type=sparse --dim=2 --targets-scp=$sad_data_dir/speech_labels.scp --deriv-weights-scp=$sad_data_dir/deriv_weights.scp --scp2ark-cmd=\"ali-to-post scp:- ark:- |\" --compress=true" \ + --targets-parameters="--output-name=output-speech_music --target-type=sparse --dim=4 --targets-scp=$sad_data_dir/speech_music_labels.scp --deriv-weights-scp=$sad_data_dir/deriv_weights.scp --scp2ark-cmd=\"ali-to-post scp:- ark:- |\" --compress=true" \ + --generate-egs-scp=true \ + --dir=$dir/egs_speech + fi + + if [ $stage -le 4 ]; then + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $dir/egs_music/storage ]; then + utils/create_split_dir.pl \ + /export/b{03,04,05,06}/$USER/kaldi-data/egs/aspire-$(date +'%m_%d_%H_%M')/s5/$dir/egs_music/storage $dir/egs_music/storage + fi + + . $dir/configs/vars + + steps/nnet3/get_egs_multiple_targets.py --cmd="$decode_cmd" \ + $egs_opts \ + --feat.dir="$music_data_dir" \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --frames-per-eg=$chunk_width \ + --left-context=$[model_left_context + extra_left_context] \ + --right-context=$[model_right_context + extra_right_context] \ + --num-utts-subset-train=$num_utts_subset_train \ + --num-utts-subset-valid=$num_utts_subset_valid \ + --samples-per-iter=$samples_per_iter \ + --stage=$get_egs_stage \ + --targets-parameters="--output-name=output-music --target-type=sparse --dim=2 --targets-scp=$music_data_dir/music_labels.scp --scp2ark-cmd=\"ali-to-post scp:- ark:- |\" --compress=true" \ + --targets-parameters="--output-name=output-speech_music --target-type=sparse --dim=4 --targets-scp=$music_data_dir/speech_music_labels.scp --deriv-weights-scp=$music_data_dir/deriv_weights.scp --scp2ark-cmd=\"ali-to-post scp:- ark:- |\" --compress=true" \ + --targets-parameters="--output-name=output-speech --target-type=sparse --dim=2 --targets-scp=$music_data_dir/speech_labels.scp --deriv-weights-scp=$music_data_dir/deriv_weights.scp --scp2ark-cmd=\"ali-to-post scp:- ark:- |\" --compress=true" \ + --generate-egs-scp=true \ + --dir=$dir/egs_music + fi + + if [ $stage -le 5 ]; then + # num_chunk_per_minibatch is multiplied by 4 to allow a buffer to use + # the same egs with a different num_chunk_per_minibatch + steps/nnet3/multilingual/get_egs.sh \ + --cmd "$train_cmd" \ + --minibatch-size $num_chunk_per_minibatch \ + --samples-per-iter $samples_per_iter \ + 2 $dir/egs_speech $dir/egs_music $dir/egs_multi + fi +fi + +if [ $stage -le 6 ]; then + steps/nnet3/train_raw_rnn.py --stage=$train_stage \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --egs.chunk-width=$chunk_width \ + --egs.dir="$egs_dir" --egs.stage=$get_egs_stage \ + --egs.chunk-left-context=$extra_left_context \ + --egs.chunk-right-context=$extra_right_context \ + --egs.use-multitask-egs=true --egs.rename-multitask-outputs=false \ + ${extra_egs_copy_cmd:+--egs.extra-copy-cmd="$extra_egs_copy_cmd"} \ + --trainer.num-epochs=$num_epochs \ + --trainer.samples-per-iter=20000 \ + --trainer.optimization.num-jobs-initial=$num_jobs_initial \ + --trainer.optimization.num-jobs-final=$num_jobs_final \ + --trainer.optimization.initial-effective-lrate=$initial_effective_lrate \ + --trainer.optimization.final-effective-lrate=$final_effective_lrate \ + --trainer.optimization.shrink-value=1.0 \ + --trainer.rnn.num-chunk-per-minibatch=$num_chunk_per_minibatch \ + --trainer.deriv-truncate-margin=8 \ + --trainer.max-param-change=$max_param_change \ + --trainer.compute-per-dim-accuracy=true \ + --cmd="$decode_cmd" --nj 40 \ + --cleanup=true \ + --cleanup.remove-egs=$remove_egs \ + --cleanup.preserve-model-interval=10 \ + --use-gpu=true \ + --use-dense-targets=false \ + --feat-dir=$sad_data_dir \ + --targets-scp="$sad_data_dir/speech_labels.scp" \ + --dir=$dir || exit 1 +fi + + diff --git a/egs/aspire/s5/local/segmentation/tuning/train_lstm_sad_music_snr_1h.sh b/egs/aspire/s5/local/segmentation/tuning/train_lstm_sad_music_snr_1h.sh new file mode 100644 index 00000000000..ae85a93a7fc --- /dev/null +++ b/egs/aspire/s5/local/segmentation/tuning/train_lstm_sad_music_snr_1h.sh @@ -0,0 +1,306 @@ +#!/bin/bash + +# This is a script to train a TDNN-LSTM for speech activity detection (SAD) and +# music-id using LSTM for long-context information. + +set -o pipefail +set -e +set -u + +. cmd.sh + +# At this script level we don't support not running on GPU, as it would be painfully slow. +# If you want to run without GPU you'd have to call train_tdnn.sh with --gpu false, +# --num-threads 16 and --minibatch-size 128. + +stage=0 +train_stage=-10 +get_egs_stage=-10 +egs_opts= # Directly passed to get_egs_multiple_targets.py + +chunk_width=40 +num_chunk_per_minibatch=64 + +extra_left_context=40 +extra_right_context=0 + +relu_dim=256 +cell_dim=256 +projection_dim=64 + +# training options +num_epochs=2 +initial_effective_lrate=0.0003 +final_effective_lrate=0.00003 +num_jobs_initial=3 +num_jobs_final=8 +remove_egs=false +max_param_change=0.2 # Small max-param change for small network +extra_egs_copy_cmd= # Used if you want to do some weird stuff to egs + # such as removing one of the targets + +extra_egs_copy_cmd="nnet3-copy-egs --keep-outputs=output-speech,output-music,output-speech_music,output-snr ark:- ark:- |" + +egs_dir= +nj=40 +feat_type=raw +config_dir= + +dir= +affix=1h + +. cmd.sh +. ./path.sh +. ./utils/parse_options.sh + +if [ $stage -le -1 ]; then + utils/combine_data.sh --extra-files "deriv_weights.scp speech_labels.scp irm_targets.scp deriv_weights_for_irm_targets.scp" \ + data/train_tztec_whole_corrupted_spr_hires_bp data/fisher_train_100k_whole_900_corrupted_spr_hires_bp/ \ + data/babel_{turkish,zulu,cantonese,tamil}_train_whole_corrupted_spr_hires_bp/ + + cp data/train_tztec_whole_corrupted_spr_hires_bp/{speech_labels.scp,speech_music_labels.scp} + + utils/combine_data.sh --extra-files "deriv_weights.scp speech_labels.scp music_labels.scp speech_music_labels.scp" \ + data/train_tztec_whole_music_corrupted_spr_hires_bp data/fisher_train_100k_whole_900_music_corrupted_spr_hires_bp/ \ + data/babel_{turkish,zulu,cantonese,tamil}_train_whole_music_corrupted_spr_hires_bp/ +fi + +sad_data_dir=data/train_tztec_whole_corrupted_spr_hires_bp +music_data_dir=data/train_tztec_whole_music_corrupted_spr_hires_bp + +num_utts=`cat $sad_data_dir/utt2spk $music_data_dir/utt2spk | wc -l` +num_utts_subset_valid=`perl -e '$n=int($ARGV[0] * 0.005); print ($n > 4000 ? 4000 : $n)' $num_utts` +num_utts_subset_train=`perl -e '$n=int($ARGV[0] * 0.005); print ($n > 4000 ? 4000 : $n)' $num_utts` + +if [ -z "$dir" ]; then + dir=exp/nnet3_lstm_sad_music_snr/nnet_lstm +fi + +dir=$dir${affix:+_$affix} + +if ! cuda-compiled; then + cat < $dir/configs/network.xconfig + input dim=`feat-to-dim scp:$sad_data_dir/feats.scp -` name=input + + relu-renorm-layer name=tdnn1 input=Append(-3,-2,-1,0,1,2,3) dim=$relu_dim add-log-stddev=true + relu-renorm-layer name=tdnn2 input=Append(-6,0,6) dim=$relu_dim + fast-lstmp-layer name=lstm1 cell-dim=$cell_dim recurrent-projection-dim=$projection_dim non-recurrent-projection-dim=$projection_dim delay=-6 + relu-renorm-layer name=tdnn3 input=Append(-12,0,12) dim=$relu_dim + relu-renorm-layer name=tdnn3-snr input=Append(lstm1@-12,lstm1@0,lstm1@12,tdnn3) dim=$relu_dim + + output-layer name=output-speech include-log-softmax=true dim=2 presoftmax-scale-file=$dir/presoftmax_prior_scale_output-speech.txt learning-rate-factor=0.1 objective-scale=$speech_scale input=tdnn3 + output-layer name=output-music include-log-softmax=true dim=2 presoftmax-scale-file=$dir/presoftmax_prior_scale_output-music.txt learning-rate-factor=0.1 objective-scale=$music_scale input=tdnn3 + output-layer name=output-speech_music include-log-softmax=true dim=4 presoftmax-scale-file=$dir/presoftmax_prior_scale_output-speech_music.txt learning-rate-factor=0.1 objective-scale=$speech_music_scale input=tdnn3 + output-layer name=output-snr include-log-softmax=false dim=$num_snr_bins objective-type=quadratic learning-rate-factor=0.1 objective-scale=$snr_scale input=tdnn3-snr + + output name=output-temp input=Append(input@-3,input@-2,input@-1,input,input@1,input@2, input@3) +EOF + steps/nnet3/xconfig_to_configs.py --xconfig-file $dir/configs/network.xconfig \ + --config-dir $dir/configs/ \ + --nnet-edits="rename-node old-name=output-speech new-name=output" + + cat <> $dir/configs/vars +add_lda=false +EOF +fi + +samples_per_iter=`perl -e "print int(400000 / $chunk_width)"` + +if [ -z "$egs_dir" ]; then + egs_dir=$dir/egs_multi + if [ $stage -le 2 ]; then + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $dir/egs_speech/storage ]; then + utils/create_split_dir.pl \ + /export/b{03,04,05,06}/$USER/kaldi-data/egs/aspire-$(date +'%m_%d_%H_%M')/s5/$dir/egs_speech/storage $dir/egs_speech/storage + fi + + . $dir/configs/vars + + steps/nnet3/get_egs_multiple_targets.py --cmd="$decode_cmd" \ + $egs_opts \ + --feat.dir="$sad_data_dir" \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --frames-per-eg=$chunk_width \ + --left-context=$[model_left_context + extra_left_context] \ + --right-context=$[model_right_context + extra_right_context] \ + --num-utts-subset-train=$num_utts_subset_train \ + --num-utts-subset-valid=$num_utts_subset_valid \ + --samples-per-iter=$samples_per_iter \ + --stage=$get_egs_stage \ + --targets-parameters="--output-name=output-speech --target-type=sparse --dim=2 --targets-scp=$sad_data_dir/speech_labels.scp --deriv-weights-scp=$sad_data_dir/deriv_weights.scp --scp2ark-cmd=\"ali-to-post scp:- ark:- |\" --compress=true" \ + --targets-parameters="--output-name=output-speech_music --target-type=sparse --dim=4 --targets-scp=$sad_data_dir/speech_music_labels.scp --deriv-weights-scp=$sad_data_dir/deriv_weights.scp --scp2ark-cmd=\"ali-to-post scp:- ark:- |\" --compress=true" \ + --targets-parameters="--output-name=output-snr --target-type=dense --targets-scp=$sad_data_dir/irm_targets.scp --deriv-weights-scp=$sad_data_dir/deriv_weights_for_irm_targets.scp" \ + --generate-egs-scp=true \ + --dir=$dir/egs_speech + fi + + if [ $stage -le 3 ]; then + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $dir/egs_music/storage ]; then + utils/create_split_dir.pl \ + /export/b{03,04,05,06}/$USER/kaldi-data/egs/aspire-$(date +'%m_%d_%H_%M')/s5/$dir/egs_music/storage $dir/egs_music/storage + fi + + . $dir/configs/vars + + steps/nnet3/get_egs_multiple_targets.py --cmd="$decode_cmd" \ + $egs_opts \ + --feat.dir="$music_data_dir" \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --frames-per-eg=$chunk_width \ + --left-context=$[model_left_context + extra_left_context] \ + --right-context=$[model_right_context + extra_right_context] \ + --num-utts-subset-train=$num_utts_subset_train \ + --num-utts-subset-valid=$num_utts_subset_valid \ + --samples-per-iter=$samples_per_iter \ + --stage=$get_egs_stage \ + --targets-parameters="--output-name=output-music --target-type=sparse --dim=2 --targets-scp=$music_data_dir/music_labels.scp --scp2ark-cmd=\"ali-to-post scp:- ark:- |\" --compress=true" \ + --targets-parameters="--output-name=output-speech_music --target-type=sparse --dim=4 --targets-scp=$music_data_dir/speech_music_labels.scp --deriv-weights-scp=$music_data_dir/deriv_weights.scp --scp2ark-cmd=\"ali-to-post scp:- ark:- |\" --compress=true" \ + --targets-parameters="--output-name=output-speech --target-type=sparse --dim=2 --targets-scp=$music_data_dir/speech_labels.scp --deriv-weights-scp=$music_data_dir/deriv_weights.scp --scp2ark-cmd=\"ali-to-post scp:- ark:- |\" --compress=true" \ + --generate-egs-scp=true \ + --dir=$dir/egs_music + fi + + if [ $stage -le 4 ]; then + # num_chunk_per_minibatch is multiplied by 4 to allow a buffer to use + # the same egs with a different num_chunk_per_minibatch + steps/nnet3/multilingual/get_egs.sh \ + --cmd "$train_cmd" \ + --minibatch-size $num_chunk_per_minibatch \ + --samples-per-iter $samples_per_iter \ + 2 $dir/egs_speech $dir/egs_music $dir/egs_multi + fi +fi + +if [ $stage -le 5 ]; then + steps/nnet3/train_raw_rnn.py --stage=$train_stage \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --egs.chunk-width=$chunk_width \ + --egs.dir="$egs_dir" --egs.stage=$get_egs_stage \ + --egs.chunk-left-context=$extra_left_context \ + --egs.chunk-right-context=$extra_right_context \ + --egs.use-multitask-egs=true --egs.rename-multitask-outputs=false \ + ${extra_egs_copy_cmd:+--egs.extra-copy-cmd="$extra_egs_copy_cmd"} \ + --trainer.num-epochs=$num_epochs \ + --trainer.samples-per-iter=20000 \ + --trainer.optimization.num-jobs-initial=$num_jobs_initial \ + --trainer.optimization.num-jobs-final=$num_jobs_final \ + --trainer.optimization.initial-effective-lrate=$initial_effective_lrate \ + --trainer.optimization.final-effective-lrate=$final_effective_lrate \ + --trainer.optimization.shrink-value=1.0 \ + --trainer.rnn.num-chunk-per-minibatch=$num_chunk_per_minibatch \ + --trainer.deriv-truncate-margin=8 \ + --trainer.max-param-change=$max_param_change \ + --trainer.compute-per-dim-accuracy=true \ + --cmd="$decode_cmd" --nj 40 \ + --cleanup=true \ + --cleanup.remove-egs=$remove_egs \ + --cleanup.preserve-model-interval=10 \ + --use-gpu=true \ + --use-dense-targets=false \ + --feat-dir=$sad_data_dir \ + --targets-scp="$sad_data_dir/speech_labels.scp" \ + --dir=$dir || exit 1 +fi + + diff --git a/egs/aspire/s5/local/segmentation/tuning/train_lstm_sad_music_snr_1i.sh b/egs/aspire/s5/local/segmentation/tuning/train_lstm_sad_music_snr_1i.sh new file mode 100644 index 00000000000..b6c43a92992 --- /dev/null +++ b/egs/aspire/s5/local/segmentation/tuning/train_lstm_sad_music_snr_1i.sh @@ -0,0 +1,315 @@ +#!/bin/bash + +# This is a script to train a TDNN-LSTM for speech activity detection (SAD) and +# music-id using LSTM for long-context information. +# This is same as 1h, but has more layers. + +set -o pipefail +set -e +set -u + +. cmd.sh + +# At this script level we don't support not running on GPU, as it would be painfully slow. +# If you want to run without GPU you'd have to call train_tdnn.sh with --gpu false, +# --num-threads 16 and --minibatch-size 128. + +stage=0 +train_stage=-10 +get_egs_stage=-10 +egs_opts= # Directly passed to get_egs_multiple_targets.py + +chunk_width=20 +num_chunk_per_minibatch=64 + +extra_left_context=40 +extra_right_context=0 + +relu_dim=256 +cell_dim=256 +projection_dim=64 + +# training options +num_epochs=2 +initial_effective_lrate=0.0003 +final_effective_lrate=0.00003 +num_jobs_initial=3 +num_jobs_final=8 +remove_egs=false +max_param_change=0.2 # Small max-param change for small network +extra_egs_copy_cmd= # Used if you want to do some weird stuff to egs + # such as removing one of the targets + +extra_egs_copy_cmd="nnet3-copy-egs --keep-outputs=output-speech,output-music,output-speech_music,output-snr ark:- ark:- |" + +egs_dir= +nj=40 +feat_type=raw +config_dir= + +dir= +affix=1i + +. cmd.sh +. ./path.sh +. ./utils/parse_options.sh + +if [ $stage -le -1 ]; then + utils/combine_data.sh --extra-files "deriv_weights.scp speech_labels.scp irm_targets.scp deriv_weights_for_irm_targets.scp" \ + data/train_tztec_whole_corrupted_spr_hires_bp data/fisher_train_100k_whole_900_corrupted_spr_hires_bp/ \ + data/babel_{turkish,zulu,cantonese,tamil}_train_whole_corrupted_spr_hires_bp/ + + cp data/train_tztec_whole_corrupted_spr_hires_bp/{speech_labels.scp,speech_music_labels.scp} + + utils/combine_data.sh --extra-files "deriv_weights.scp speech_labels.scp music_labels.scp speech_music_labels.scp" \ + data/train_tztec_whole_music_corrupted_spr_hires_bp data/fisher_train_100k_whole_900_music_corrupted_spr_hires_bp/ \ + data/babel_{turkish,zulu,cantonese,tamil}_train_whole_music_corrupted_spr_hires_bp/ +fi + +sad_data_dir=data/train_tztec_whole_corrupted_spr_hires_bp +music_data_dir=data/train_tztec_whole_music_corrupted_spr_hires_bp + +num_utts=`cat $sad_data_dir/utt2spk $music_data_dir/utt2spk | wc -l` +num_utts_subset_valid=`perl -e '$n=int($ARGV[0] * 0.005); print ($n > 4000 ? 4000 : $n)' $num_utts` +num_utts_subset_train=`perl -e '$n=int($ARGV[0] * 0.005); print ($n > 4000 ? 4000 : $n)' $num_utts` + +if [ -z "$dir" ]; then + dir=exp/nnet3_lstm_sad_music_snr/nnet_lstm +fi + +dir=$dir${affix:+_$affix} + +if ! cuda-compiled; then + cat < $dir/scales +fi + +if [ $stage -le 2 ]; then + echo "$0: creating neural net configs using the xconfig parser"; + + scales=`cat $dir/scales` + + speech_scale=`echo $scales | awk '{print $1}'` + music_scale=`echo $scales | awk '{print $2}'` + speech_music_scale=`echo $scales | awk '{print $3}'` + snr_scale=`echo $scales | awk '{print $4}'` + + num_snr_bins=`feat-to-dim scp:$sad_data_dir/irm_targets.scp -` + snr_scale=`perl -e "print $snr_scale / $num_snr_bins"` + + mkdir -p $dir/configs + cat < $dir/configs/network.xconfig + input dim=`feat-to-dim scp:$sad_data_dir/feats.scp -` name=input + + relu-renorm-layer name=tdnn1 input=Append(-2,-1,0,1,2) dim=$relu_dim add-log-stddev=true + relu-renorm-layer name=tdnn2 input=Append(-1,0,1,2) dim=$relu_dim add-log-stddev=true + relu-renorm-layer name=tdnn3 input=Append(-3,0,3,6) dim=$relu_dim add-log-stddev=true + fast-lstmp-layer name=lstm1 cell-dim=$cell_dim recurrent-projection-dim=$projection_dim non-recurrent-projection-dim=$projection_dim delay=-6 + relu-renorm-layer name=tdnn4 input=Append(-6,0,6,12) dim=$relu_dim + relu-renorm-layer name=tdnn5 input=Append(-12,0,12,24) dim=$relu_dim + relu-renorm-layer name=tdnn5-snr input=Append(lstm1@-6,lstm1@0,lstm1@6,lstm1@12,tdnn5) dim=$relu_dim + + output-layer name=output-speech include-log-softmax=true dim=2 presoftmax-scale-file=$dir/presoftmax_prior_scale_output-speech.txt learning-rate-factor=0.1 objective-scale=$speech_scale input=tdnn5 + output-layer name=output-music include-log-softmax=true dim=2 presoftmax-scale-file=$dir/presoftmax_prior_scale_output-music.txt learning-rate-factor=0.1 objective-scale=$music_scale input=tdnn5 + output-layer name=output-speech_music include-log-softmax=true dim=4 presoftmax-scale-file=$dir/presoftmax_prior_scale_output-speech_music.txt learning-rate-factor=0.1 objective-scale=$speech_music_scale input=tdnn5 + output-layer name=output-snr include-log-softmax=false dim=$num_snr_bins objective-type=quadratic learning-rate-factor=0.1 objective-scale=$snr_scale input=tdnn5-snr + + output name=output-temp input=Append(input@-3,input@-2,input@-1,input,input@1,input@2, input@3) +EOF + steps/nnet3/xconfig_to_configs.py --xconfig-file $dir/configs/network.xconfig \ + --config-dir $dir/configs/ \ + --nnet-edits="rename-node old-name=output-speech new-name=output" + + cat <> $dir/configs/vars +add_lda=false +EOF +fi + +samples_per_iter=`perl -e "print int(400000 / $chunk_width)"` + +if [ -z "$egs_dir" ]; then + egs_dir=$dir/egs_multi + if [ $stage -le 3 ]; then + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $dir/egs_speech/storage ]; then + utils/create_split_dir.pl \ + /export/b{03,04,05,06}/$USER/kaldi-data/egs/aspire-$(date +'%m_%d_%H_%M')/s5/$dir/egs_speech/storage $dir/egs_speech/storage + fi + + . $dir/configs/vars + + steps/nnet3/get_egs_multiple_targets.py --cmd="$decode_cmd" \ + $egs_opts \ + --feat.dir="$sad_data_dir" \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --frames-per-eg=$chunk_width \ + --left-context=$[model_left_context + extra_left_context] \ + --right-context=$[model_right_context + extra_right_context] \ + --num-utts-subset-train=$num_utts_subset_train \ + --num-utts-subset-valid=$num_utts_subset_valid \ + --samples-per-iter=$samples_per_iter \ + --stage=$get_egs_stage \ + --targets-parameters="--output-name=output-speech --target-type=sparse --dim=2 --targets-scp=$sad_data_dir/speech_labels.scp --deriv-weights-scp=$sad_data_dir/deriv_weights.scp --scp2ark-cmd=\"ali-to-post scp:- ark:- |\" --compress=true" \ + --targets-parameters="--output-name=output-speech_music --target-type=sparse --dim=4 --targets-scp=$sad_data_dir/speech_music_labels.scp --deriv-weights-scp=$sad_data_dir/deriv_weights.scp --scp2ark-cmd=\"ali-to-post scp:- ark:- |\" --compress=true" \ + --targets-parameters="--output-name=output-snr --target-type=dense --targets-scp=$sad_data_dir/irm_targets.scp --deriv-weights-scp=$sad_data_dir/deriv_weights_for_irm_targets.scp" \ + --generate-egs-scp=true \ + --dir=$dir/egs_speech + fi + + if [ $stage -le 4 ]; then + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $dir/egs_music/storage ]; then + utils/create_split_dir.pl \ + /export/b{03,04,05,06}/$USER/kaldi-data/egs/aspire-$(date +'%m_%d_%H_%M')/s5/$dir/egs_music/storage $dir/egs_music/storage + fi + + . $dir/configs/vars + + steps/nnet3/get_egs_multiple_targets.py --cmd="$decode_cmd" \ + $egs_opts \ + --feat.dir="$music_data_dir" \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --frames-per-eg=$chunk_width \ + --left-context=$[model_left_context + extra_left_context] \ + --right-context=$[model_right_context + extra_right_context] \ + --num-utts-subset-train=$num_utts_subset_train \ + --num-utts-subset-valid=$num_utts_subset_valid \ + --samples-per-iter=$samples_per_iter \ + --stage=$get_egs_stage \ + --targets-parameters="--output-name=output-music --target-type=sparse --dim=2 --targets-scp=$music_data_dir/music_labels.scp --scp2ark-cmd=\"ali-to-post scp:- ark:- |\" --compress=true" \ + --targets-parameters="--output-name=output-speech_music --target-type=sparse --dim=4 --targets-scp=$music_data_dir/speech_music_labels.scp --deriv-weights-scp=$music_data_dir/deriv_weights.scp --scp2ark-cmd=\"ali-to-post scp:- ark:- |\" --compress=true" \ + --targets-parameters="--output-name=output-speech --target-type=sparse --dim=2 --targets-scp=$music_data_dir/speech_labels.scp --deriv-weights-scp=$music_data_dir/deriv_weights.scp --scp2ark-cmd=\"ali-to-post scp:- ark:- |\" --compress=true" \ + --generate-egs-scp=true \ + --dir=$dir/egs_music + fi + + if [ $stage -le 5 ]; then + # num_chunk_per_minibatch is multiplied by 4 to allow a buffer to use + # the same egs with a different num_chunk_per_minibatch + steps/nnet3/multilingual/get_egs.sh \ + --cmd "$train_cmd" \ + --minibatch-size $num_chunk_per_minibatch \ + --samples-per-iter $samples_per_iter \ + 2 $dir/egs_speech $dir/egs_music $dir/egs_multi + fi +fi + +if [ $stage -le 6 ]; then + steps/nnet3/train_raw_rnn.py --stage=$train_stage \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --egs.chunk-width=$chunk_width \ + --egs.dir="$egs_dir" --egs.stage=$get_egs_stage \ + --egs.chunk-left-context=$extra_left_context \ + --egs.chunk-right-context=$extra_right_context \ + --egs.use-multitask-egs=true --egs.rename-multitask-outputs=false \ + ${extra_egs_copy_cmd:+--egs.extra-copy-cmd="$extra_egs_copy_cmd"} \ + --trainer.num-epochs=$num_epochs \ + --trainer.samples-per-iter=20000 \ + --trainer.optimization.num-jobs-initial=$num_jobs_initial \ + --trainer.optimization.num-jobs-final=$num_jobs_final \ + --trainer.optimization.initial-effective-lrate=$initial_effective_lrate \ + --trainer.optimization.final-effective-lrate=$final_effective_lrate \ + --trainer.optimization.shrink-value=1.0 \ + --trainer.rnn.num-chunk-per-minibatch=$num_chunk_per_minibatch \ + --trainer.deriv-truncate-margin=8 \ + --trainer.max-param-change=$max_param_change \ + --trainer.compute-per-dim-accuracy=true \ + --cmd="$decode_cmd" --nj 40 \ + --cleanup=true \ + --cleanup.remove-egs=$remove_egs \ + --cleanup.preserve-model-interval=10 \ + --use-gpu=true \ + --use-dense-targets=false \ + --feat-dir=$sad_data_dir \ + --targets-scp="$sad_data_dir/speech_labels.scp" \ + --dir=$dir || exit 1 +fi + + diff --git a/egs/aspire/s5/local/segmentation/tuning/train_lstm_sad_music_snr_1j.sh b/egs/aspire/s5/local/segmentation/tuning/train_lstm_sad_music_snr_1j.sh new file mode 100644 index 00000000000..bf397565148 --- /dev/null +++ b/egs/aspire/s5/local/segmentation/tuning/train_lstm_sad_music_snr_1j.sh @@ -0,0 +1,312 @@ +#!/bin/bash + +# This is a script to train a TDNN-LSTM for speech activity detection (SAD) and +# music-id using LSTM for long-context information. +# This is same as 1i, but removes the speech-music output. + +set -o pipefail +set -e +set -u + +. cmd.sh + +# At this script level we don't support not running on GPU, as it would be painfully slow. +# If you want to run without GPU you'd have to call train_tdnn.sh with --gpu false, +# --num-threads 16 and --minibatch-size 128. + +stage=0 +train_stage=-10 +get_egs_stage=-10 +egs_opts= # Directly passed to get_egs_multiple_targets.py + +chunk_width=20 +num_chunk_per_minibatch=64 + +extra_left_context=40 +extra_right_context=0 + +relu_dim=256 +cell_dim=256 +projection_dim=64 + +# training options +num_epochs=2 +initial_effective_lrate=0.0003 +final_effective_lrate=0.00003 +num_jobs_initial=3 +num_jobs_final=8 +remove_egs=false +max_param_change=0.2 # Small max-param change for small network +extra_egs_copy_cmd= # Used if you want to do some weird stuff to egs + # such as removing one of the targets + +extra_egs_copy_cmd="nnet3-copy-egs --keep-outputs=output-speech,output-music,output-snr ark:- ark:- |" + +egs_dir= +nj=40 +feat_type=raw +config_dir= + +dir= +affix=1j + +. cmd.sh +. ./path.sh +. ./utils/parse_options.sh + +if [ $stage -le -1 ]; then + utils/combine_data.sh --extra-files "deriv_weights.scp speech_labels.scp irm_targets.scp deriv_weights_for_irm_targets.scp" \ + data/train_tztec_whole_corrupted_spr_hires_bp data/fisher_train_100k_whole_900_corrupted_spr_hires_bp/ \ + data/babel_{turkish,zulu,cantonese,tamil}_train_whole_corrupted_spr_hires_bp/ + + cp data/train_tztec_whole_corrupted_spr_hires_bp/{speech_labels.scp,speech_music_labels.scp} + + utils/combine_data.sh --extra-files "deriv_weights.scp speech_labels.scp music_labels.scp speech_music_labels.scp" \ + data/train_tztec_whole_music_corrupted_spr_hires_bp data/fisher_train_100k_whole_900_music_corrupted_spr_hires_bp/ \ + data/babel_{turkish,zulu,cantonese,tamil}_train_whole_music_corrupted_spr_hires_bp/ +fi + +sad_data_dir=data/train_tztec_whole_corrupted_spr_hires_bp +music_data_dir=data/train_tztec_whole_music_corrupted_spr_hires_bp + +num_utts=`cat $sad_data_dir/utt2spk $music_data_dir/utt2spk | wc -l` +num_utts_subset_valid=`perl -e '$n=int($ARGV[0] * 0.005); print ($n > 4000 ? 4000 : $n)' $num_utts` +num_utts_subset_train=`perl -e '$n=int($ARGV[0] * 0.005); print ($n > 4000 ? 4000 : $n)' $num_utts` + +if [ -z "$dir" ]; then + dir=exp/nnet3_lstm_sad_music_snr/nnet_lstm +fi + +dir=$dir${affix:+_$affix} + +if ! cuda-compiled; then + cat < $dir/scales +fi + +if [ $stage -le 2 ]; then + echo "$0: creating neural net configs using the xconfig parser"; + + scales=`cat $dir/scales` + + speech_scale=`echo $scales | awk '{print $1}'` + music_scale=`echo $scales | awk '{print $2}'` + speech_music_scale=`echo $scales | awk '{print $3}'` + snr_scale=`echo $scales | awk '{print $4}'` + + num_snr_bins=`feat-to-dim scp:$sad_data_dir/irm_targets.scp -` + snr_scale=`perl -e "print $snr_scale / $num_snr_bins"` + + mkdir -p $dir/configs + cat < $dir/configs/network.xconfig + input dim=`feat-to-dim scp:$sad_data_dir/feats.scp -` name=input + + relu-renorm-layer name=tdnn1 input=Append(-2,-1,0,1,2) dim=$relu_dim add-log-stddev=true + relu-renorm-layer name=tdnn2 input=Append(-1,0,1,2) dim=$relu_dim add-log-stddev=true + relu-renorm-layer name=tdnn3 input=Append(-3,0,3,6) dim=$relu_dim add-log-stddev=true + fast-lstmp-layer name=lstm1 cell-dim=$cell_dim recurrent-projection-dim=$projection_dim non-recurrent-projection-dim=$projection_dim delay=-6 + relu-renorm-layer name=tdnn4 input=Append(-6,0,6,12) dim=$relu_dim + relu-renorm-layer name=tdnn5 input=Append(-12,0,12,24) dim=$relu_dim + relu-renorm-layer name=tdnn5-snr input=Append(lstm1@-6,lstm1@0,lstm1@6,lstm1@12,tdnn5) dim=$relu_dim + + output-layer name=output-speech include-log-softmax=true dim=2 presoftmax-scale-file=$dir/presoftmax_prior_scale_output-speech.txt learning-rate-factor=0.1 objective-scale=$speech_scale input=tdnn5 + output-layer name=output-music include-log-softmax=true dim=2 presoftmax-scale-file=$dir/presoftmax_prior_scale_output-music.txt learning-rate-factor=0.1 objective-scale=$music_scale input=tdnn5 + output-layer name=output-snr include-log-softmax=false dim=$num_snr_bins objective-type=quadratic learning-rate-factor=0.1 objective-scale=$snr_scale input=tdnn5-snr + + output name=output-temp input=Append(input@-3,input@-2,input@-1,input,input@1,input@2, input@3) +EOF + steps/nnet3/xconfig_to_configs.py --xconfig-file $dir/configs/network.xconfig \ + --config-dir $dir/configs/ \ + --nnet-edits="rename-node old-name=output-speech new-name=output" + + cat <> $dir/configs/vars +add_lda=false +EOF +fi + +samples_per_iter=`perl -e "print int(400000 / $chunk_width)"` + +if [ -z "$egs_dir" ]; then + egs_dir=$dir/egs_multi + if [ $stage -le 3 ]; then + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $dir/egs_speech/storage ]; then + utils/create_split_dir.pl \ + /export/b{03,04,05,06}/$USER/kaldi-data/egs/aspire-$(date +'%m_%d_%H_%M')/s5/$dir/egs_speech/storage $dir/egs_speech/storage + fi + + . $dir/configs/vars + + steps/nnet3/get_egs_multiple_targets.py --cmd="$decode_cmd" \ + $egs_opts \ + --feat.dir="$sad_data_dir" \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --frames-per-eg=$chunk_width \ + --left-context=$[model_left_context + extra_left_context] \ + --right-context=$[model_right_context + extra_right_context] \ + --num-utts-subset-train=$num_utts_subset_train \ + --num-utts-subset-valid=$num_utts_subset_valid \ + --samples-per-iter=$samples_per_iter \ + --stage=$get_egs_stage \ + --targets-parameters="--output-name=output-speech --target-type=sparse --dim=2 --targets-scp=$sad_data_dir/speech_labels.scp --deriv-weights-scp=$sad_data_dir/deriv_weights.scp --scp2ark-cmd=\"ali-to-post scp:- ark:- |\" --compress=true" \ + --targets-parameters="--output-name=output-speech_music --target-type=sparse --dim=4 --targets-scp=$sad_data_dir/speech_music_labels.scp --deriv-weights-scp=$sad_data_dir/deriv_weights.scp --scp2ark-cmd=\"ali-to-post scp:- ark:- |\" --compress=true" \ + --targets-parameters="--output-name=output-snr --target-type=dense --targets-scp=$sad_data_dir/irm_targets.scp --deriv-weights-scp=$sad_data_dir/deriv_weights_for_irm_targets.scp" \ + --generate-egs-scp=true \ + --dir=$dir/egs_speech + fi + + if [ $stage -le 4 ]; then + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $dir/egs_music/storage ]; then + utils/create_split_dir.pl \ + /export/b{03,04,05,06}/$USER/kaldi-data/egs/aspire-$(date +'%m_%d_%H_%M')/s5/$dir/egs_music/storage $dir/egs_music/storage + fi + + . $dir/configs/vars + + steps/nnet3/get_egs_multiple_targets.py --cmd="$decode_cmd" \ + $egs_opts \ + --feat.dir="$music_data_dir" \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --frames-per-eg=$chunk_width \ + --left-context=$[model_left_context + extra_left_context] \ + --right-context=$[model_right_context + extra_right_context] \ + --num-utts-subset-train=$num_utts_subset_train \ + --num-utts-subset-valid=$num_utts_subset_valid \ + --samples-per-iter=$samples_per_iter \ + --stage=$get_egs_stage \ + --targets-parameters="--output-name=output-music --target-type=sparse --dim=2 --targets-scp=$music_data_dir/music_labels.scp --scp2ark-cmd=\"ali-to-post scp:- ark:- |\" --compress=true" \ + --targets-parameters="--output-name=output-speech_music --target-type=sparse --dim=4 --targets-scp=$music_data_dir/speech_music_labels.scp --deriv-weights-scp=$music_data_dir/deriv_weights.scp --scp2ark-cmd=\"ali-to-post scp:- ark:- |\" --compress=true" \ + --targets-parameters="--output-name=output-speech --target-type=sparse --dim=2 --targets-scp=$music_data_dir/speech_labels.scp --deriv-weights-scp=$music_data_dir/deriv_weights.scp --scp2ark-cmd=\"ali-to-post scp:- ark:- |\" --compress=true" \ + --generate-egs-scp=true \ + --dir=$dir/egs_music + fi + + if [ $stage -le 5 ]; then + # num_chunk_per_minibatch is multiplied by 4 to allow a buffer to use + # the same egs with a different num_chunk_per_minibatch + steps/nnet3/multilingual/get_egs.sh \ + --cmd "$train_cmd" \ + --minibatch-size $num_chunk_per_minibatch \ + --samples-per-iter $samples_per_iter \ + 2 $dir/egs_speech $dir/egs_music $dir/egs_multi + fi +fi + +if [ $stage -le 6 ]; then + steps/nnet3/train_raw_rnn.py --stage=$train_stage \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --egs.chunk-width=$chunk_width \ + --egs.dir="$egs_dir" --egs.stage=$get_egs_stage \ + --egs.chunk-left-context=$extra_left_context \ + --egs.chunk-right-context=$extra_right_context \ + --egs.use-multitask-egs=true --egs.rename-multitask-outputs=false \ + ${extra_egs_copy_cmd:+--egs.extra-copy-cmd="$extra_egs_copy_cmd"} \ + --trainer.num-epochs=$num_epochs \ + --trainer.samples-per-iter=20000 \ + --trainer.optimization.num-jobs-initial=$num_jobs_initial \ + --trainer.optimization.num-jobs-final=$num_jobs_final \ + --trainer.optimization.initial-effective-lrate=$initial_effective_lrate \ + --trainer.optimization.final-effective-lrate=$final_effective_lrate \ + --trainer.optimization.shrink-value=1.0 \ + --trainer.rnn.num-chunk-per-minibatch=$num_chunk_per_minibatch \ + --trainer.deriv-truncate-margin=8 \ + --trainer.max-param-change=$max_param_change \ + --trainer.compute-per-dim-accuracy=true \ + --cmd="$decode_cmd" --nj 40 \ + --cleanup=true \ + --cleanup.remove-egs=$remove_egs \ + --cleanup.preserve-model-interval=10 \ + --use-gpu=true \ + --use-dense-targets=false \ + --feat-dir=$sad_data_dir \ + --targets-scp="$sad_data_dir/speech_labels.scp" \ + --dir=$dir || exit 1 +fi diff --git a/egs/aspire/s5/local/segmentation/tuning/train_lstm_sad_music_snr_1k.sh b/egs/aspire/s5/local/segmentation/tuning/train_lstm_sad_music_snr_1k.sh new file mode 100644 index 00000000000..cb585523f74 --- /dev/null +++ b/egs/aspire/s5/local/segmentation/tuning/train_lstm_sad_music_snr_1k.sh @@ -0,0 +1,316 @@ +#!/bin/bash + +# This is a script to train a TDNN-LSTM for speech activity detection (SAD) and +# music-id using LSTM for long-context information. +# This is same as 1h, but has more layers. + +set -o pipefail +set -e +set -u + +. cmd.sh + +# At this script level we don't support not running on GPU, as it would be painfully slow. +# If you want to run without GPU you'd have to call train_tdnn.sh with --gpu false, +# --num-threads 16 and --minibatch-size 128. + +stage=0 +train_stage=-10 +get_egs_stage=-10 +egs_opts= # Directly passed to get_egs_multiple_targets.py + +chunk_width=20 +num_chunk_per_minibatch=64 + +extra_left_context=40 +extra_right_context=0 + +relu_dim=256 +cell_dim=256 +projection_dim=64 + +# training options +num_epochs=2 +initial_effective_lrate=0.0003 +final_effective_lrate=0.00003 +num_jobs_initial=3 +num_jobs_final=8 +remove_egs=false +max_param_change=0.2 # Small max-param change for small network +extra_egs_copy_cmd= # Used if you want to do some weird stuff to egs + # such as removing one of the targets + +extra_egs_copy_cmd="nnet3-copy-egs --keep-outputs=output-speech,output-music,output-speech_music,output-snr ark:- ark:- |" + +egs_dir= +nj=40 +feat_type=raw +config_dir= + +dir= +affix=1k + +. cmd.sh +. ./path.sh +. ./utils/parse_options.sh + +if [ $stage -le -1 ]; then + utils/combine_data.sh --extra-files "deriv_weights.scp speech_labels.scp irm_targets.scp deriv_weights_for_irm_targets.scp" \ + data/train_tztec_whole_corrupted_spr_hires_bp data/fisher_train_100k_whole_900_corrupted_spr_hires_bp/ \ + data/babel_{turkish,zulu,cantonese,tamil}_train_whole_corrupted_spr_hires_bp/ + + cp data/train_tztec_whole_corrupted_spr_hires_bp/{speech_labels.scp,speech_music_labels.scp} + + utils/combine_data.sh --extra-files "deriv_weights.scp speech_labels.scp music_labels.scp speech_music_labels.scp" \ + data/train_tztec_whole_music_corrupted_spr_hires_bp data/fisher_train_100k_whole_900_music_corrupted_spr_hires_bp/ \ + data/babel_{turkish,zulu,cantonese,tamil}_train_whole_music_corrupted_spr_hires_bp/ +fi + +sad_data_dir=data/train_tztec_whole_corrupted_spr_hires_bp +music_data_dir=data/train_tztec_whole_music_corrupted_spr_hires_bp + +num_utts=`cat $sad_data_dir/utt2spk $music_data_dir/utt2spk | wc -l` +num_utts_subset_valid=`perl -e '$n=int($ARGV[0] * 0.005); print ($n > 4000 ? 4000 : $n)' $num_utts` +num_utts_subset_train=`perl -e '$n=int($ARGV[0] * 0.005); print ($n > 4000 ? 4000 : $n)' $num_utts` + +if [ -z "$dir" ]; then + dir=exp/nnet3_lstm_sad_music_snr/nnet_lstm +fi + +dir=$dir${affix:+_$affix} + +if ! cuda-compiled; then + cat < $dir/scales +fi + +if [ $stage -le 2 ]; then + echo "$0: creating neural net configs using the xconfig parser"; + + scales=`cat $dir/scales` + + speech_scale=`echo $scales | awk '{print $1}'` + music_scale=`echo $scales | awk '{print $2}'` + speech_music_scale=`echo $scales | awk '{print $3}'` + snr_scale=`echo $scales | awk '{print $4}'` + + num_snr_bins=`feat-to-dim scp:$sad_data_dir/irm_targets.scp -` + snr_scale=`perl -e "print $snr_scale / $num_snr_bins"` + + mkdir -p $dir/configs + cat < $dir/configs/network.xconfig + input dim=`feat-to-dim scp:$sad_data_dir/feats.scp -` name=input + + relu-renorm-layer name=tdnn1 input=Append(-2,-1,0,1,2) dim=$relu_dim add-log-stddev=true + relu-renorm-layer name=tdnn2 input=Append(-1,0,1,2) dim=$relu_dim add-log-stddev=true + relu-renorm-layer name=tdnn3 input=Append(-3,0,3,6) dim=$relu_dim add-log-stddev=true + fast-lstmp-layer name=lstm1 cell-dim=$cell_dim recurrent-projection-dim=$projection_dim non-recurrent-projection-dim=$projection_dim delay=-6 + relu-renorm-layer name=tdnn4 input=Append(-6,0,6,12) dim=$relu_dim + fast-lstmp-layer name=lstm2 cell-dim=$cell_dim recurrent-projection-dim=$projection_dim non-recurrent-projection-dim=$projection_dim delay=-6 + relu-renorm-layer name=tdnn5 input=Append(-12,0,12,24) dim=$relu_dim + relu-renorm-layer name=tdnn5-snr input=Append(lstm2@-6,lstm2@0,lstm2@6,lstm2@12,tdnn5) dim=$relu_dim + + output-layer name=output-speech include-log-softmax=true dim=2 presoftmax-scale-file=$dir/presoftmax_prior_scale_output-speech.txt learning-rate-factor=0.1 objective-scale=$speech_scale input=tdnn5 + output-layer name=output-music include-log-softmax=true dim=2 presoftmax-scale-file=$dir/presoftmax_prior_scale_output-music.txt learning-rate-factor=0.1 objective-scale=$music_scale input=tdnn5 + output-layer name=output-speech_music include-log-softmax=true dim=4 presoftmax-scale-file=$dir/presoftmax_prior_scale_output-speech_music.txt learning-rate-factor=0.1 objective-scale=$speech_music_scale input=tdnn5 + output-layer name=output-snr include-log-softmax=false dim=$num_snr_bins objective-type=quadratic learning-rate-factor=0.1 objective-scale=$snr_scale input=tdnn5-snr + + output name=output-temp input=Append(input@-3,input@-2,input@-1,input,input@1,input@2, input@3) +EOF + steps/nnet3/xconfig_to_configs.py --xconfig-file $dir/configs/network.xconfig \ + --config-dir $dir/configs/ \ + --nnet-edits="rename-node old-name=output-speech new-name=output" + + cat <> $dir/configs/vars +add_lda=false +EOF +fi + +samples_per_iter=`perl -e "print int(400000 / $chunk_width)"` + +if [ -z "$egs_dir" ]; then + egs_dir=$dir/egs_multi + if [ $stage -le 3 ]; then + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $dir/egs_speech/storage ]; then + utils/create_split_dir.pl \ + /export/b{03,04,05,06}/$USER/kaldi-data/egs/aspire-$(date +'%m_%d_%H_%M')/s5/$dir/egs_speech/storage $dir/egs_speech/storage + fi + + . $dir/configs/vars + + steps/nnet3/get_egs_multiple_targets.py --cmd="$decode_cmd" \ + $egs_opts \ + --feat.dir="$sad_data_dir" \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --frames-per-eg=$chunk_width \ + --left-context=$[model_left_context + extra_left_context] \ + --right-context=$[model_right_context + extra_right_context] \ + --num-utts-subset-train=$num_utts_subset_train \ + --num-utts-subset-valid=$num_utts_subset_valid \ + --samples-per-iter=$samples_per_iter \ + --stage=$get_egs_stage \ + --targets-parameters="--output-name=output-speech --target-type=sparse --dim=2 --targets-scp=$sad_data_dir/speech_labels.scp --deriv-weights-scp=$sad_data_dir/deriv_weights.scp --scp2ark-cmd=\"ali-to-post scp:- ark:- |\" --compress=true" \ + --targets-parameters="--output-name=output-speech_music --target-type=sparse --dim=4 --targets-scp=$sad_data_dir/speech_music_labels.scp --deriv-weights-scp=$sad_data_dir/deriv_weights.scp --scp2ark-cmd=\"ali-to-post scp:- ark:- |\" --compress=true" \ + --targets-parameters="--output-name=output-snr --target-type=dense --targets-scp=$sad_data_dir/irm_targets.scp --deriv-weights-scp=$sad_data_dir/deriv_weights_for_irm_targets.scp" \ + --generate-egs-scp=true \ + --dir=$dir/egs_speech + fi + + if [ $stage -le 4 ]; then + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $dir/egs_music/storage ]; then + utils/create_split_dir.pl \ + /export/b{03,04,05,06}/$USER/kaldi-data/egs/aspire-$(date +'%m_%d_%H_%M')/s5/$dir/egs_music/storage $dir/egs_music/storage + fi + + . $dir/configs/vars + + steps/nnet3/get_egs_multiple_targets.py --cmd="$decode_cmd" \ + $egs_opts \ + --feat.dir="$music_data_dir" \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --frames-per-eg=$chunk_width \ + --left-context=$[model_left_context + extra_left_context] \ + --right-context=$[model_right_context + extra_right_context] \ + --num-utts-subset-train=$num_utts_subset_train \ + --num-utts-subset-valid=$num_utts_subset_valid \ + --samples-per-iter=$samples_per_iter \ + --stage=$get_egs_stage \ + --targets-parameters="--output-name=output-music --target-type=sparse --dim=2 --targets-scp=$music_data_dir/music_labels.scp --scp2ark-cmd=\"ali-to-post scp:- ark:- |\" --compress=true" \ + --targets-parameters="--output-name=output-speech_music --target-type=sparse --dim=4 --targets-scp=$music_data_dir/speech_music_labels.scp --deriv-weights-scp=$music_data_dir/deriv_weights.scp --scp2ark-cmd=\"ali-to-post scp:- ark:- |\" --compress=true" \ + --targets-parameters="--output-name=output-speech --target-type=sparse --dim=2 --targets-scp=$music_data_dir/speech_labels.scp --deriv-weights-scp=$music_data_dir/deriv_weights.scp --scp2ark-cmd=\"ali-to-post scp:- ark:- |\" --compress=true" \ + --generate-egs-scp=true \ + --dir=$dir/egs_music + fi + + if [ $stage -le 5 ]; then + # num_chunk_per_minibatch is multiplied by 4 to allow a buffer to use + # the same egs with a different num_chunk_per_minibatch + steps/nnet3/multilingual/get_egs.sh \ + --cmd "$train_cmd" \ + --minibatch-size $num_chunk_per_minibatch \ + --samples-per-iter $samples_per_iter \ + 2 $dir/egs_speech $dir/egs_music $dir/egs_multi + fi +fi + +if [ $stage -le 6 ]; then + steps/nnet3/train_raw_rnn.py --stage=$train_stage \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --egs.chunk-width=$chunk_width \ + --egs.dir="$egs_dir" --egs.stage=$get_egs_stage \ + --egs.chunk-left-context=$extra_left_context \ + --egs.chunk-right-context=$extra_right_context \ + --egs.use-multitask-egs=true --egs.rename-multitask-outputs=false \ + ${extra_egs_copy_cmd:+--egs.extra-copy-cmd="$extra_egs_copy_cmd"} \ + --trainer.num-epochs=$num_epochs \ + --trainer.samples-per-iter=20000 \ + --trainer.optimization.num-jobs-initial=$num_jobs_initial \ + --trainer.optimization.num-jobs-final=$num_jobs_final \ + --trainer.optimization.initial-effective-lrate=$initial_effective_lrate \ + --trainer.optimization.final-effective-lrate=$final_effective_lrate \ + --trainer.optimization.shrink-value=1.0 \ + --trainer.rnn.num-chunk-per-minibatch=$num_chunk_per_minibatch \ + --trainer.deriv-truncate-margin=8 \ + --trainer.max-param-change=$max_param_change \ + --trainer.compute-per-dim-accuracy=true \ + --cmd="$decode_cmd" --nj 40 \ + --cleanup=true \ + --cleanup.remove-egs=$remove_egs \ + --cleanup.preserve-model-interval=10 \ + --use-gpu=true \ + --use-dense-targets=false \ + --feat-dir=$sad_data_dir \ + --targets-scp="$sad_data_dir/speech_labels.scp" \ + --dir=$dir || exit 1 +fi + + diff --git a/egs/aspire/s5/local/segmentation/tuning/train_stats_sad_music_snr_1h.sh b/egs/aspire/s5/local/segmentation/tuning/train_stats_sad_music_snr_1h.sh new file mode 100644 index 00000000000..e585f27e5fd --- /dev/null +++ b/egs/aspire/s5/local/segmentation/tuning/train_stats_sad_music_snr_1h.sh @@ -0,0 +1,310 @@ +#!/bin/bash + +# This is a script to train a time-delay neural network for speech activity detection (SAD) and +# music-id using statistic pooling component for long-context information. +# This script is same as 1c, but uses larger amount of data. + +set -o pipefail +set -e +set -u + +. cmd.sh + +# At this script level we don't support not running on GPU, as it would be painfully slow. +# If you want to run without GPU you'd have to call train_tdnn.sh with --gpu false, +# --num-threads 16 and --minibatch-size 128. + +stage=0 +train_stage=-10 +get_egs_stage=-10 +egs_opts= # Directly passed to get_egs_multiple_targets.py + +chunk_width=20 +num_chunk_per_minibatch=64 + +extra_left_context=79 +extra_right_context=11 + +relu_dim=256 +cell_dim=256 +projection_dim=64 + +# training options +num_epochs=2 +initial_effective_lrate=0.0003 +final_effective_lrate=0.00003 +num_jobs_initial=3 +num_jobs_final=8 +remove_egs=false +max_param_change=0.2 # Small max-param change for small network +extra_egs_copy_cmd= # Used if you want to do some weird stuff to egs + # such as removing one of the targets + +extra_egs_copy_cmd="nnet3-copy-egs --keep-outputs=output-speech,output-music,output-speech_music,output-snr ark:- ark:- |" + +egs_dir= +nj=40 +feat_type=raw +config_dir= + +dir= +affix=1h + +. cmd.sh +. ./path.sh +. ./utils/parse_options.sh + +if [ $stage -le -1 ]; then + utils/combine_data.sh --extra-files "deriv_weights.scp speech_labels.scp irm_targets.scp deriv_weights_for_irm_targets.scp" \ + data/train_tztec_whole_corrupted_spr_hires_bp data/fisher_train_100k_whole_900_corrupted_spr_hires_bp/ \ + data/babel_{turkish,zulu,cantonese,tamil}_train_whole_corrupted_spr_hires_bp/ + + cp data/train_tztec_whole_corrupted_spr_hires_bp/{speech_labels.scp,speech_music_labels.scp} + + utils/combine_data.sh --extra-files "deriv_weights.scp speech_labels.scp music_labels.scp speech_music_labels.scp" \ + data/train_tztec_whole_music_corrupted_spr_hires_bp data/fisher_train_100k_whole_900_music_corrupted_spr_hires_bp/ \ + data/babel_{turkish,zulu,cantonese,tamil}_train_whole_music_corrupted_spr_hires_bp/ +fi + +sad_data_dir=data/train_tztec_whole_corrupted_spr_hires_bp +music_data_dir=data/train_tztec_whole_music_corrupted_spr_hires_bp + +num_utts=`cat $sad_data_dir/utt2spk $music_data_dir/utt2spk | wc -l` +num_utts_subset_valid=`perl -e '$n=int($ARGV[0] * 0.005); print ($n > 4000 ? 4000 : $n)' $num_utts` +num_utts_subset_train=`perl -e '$n=int($ARGV[0] * 0.005); print ($n > 4000 ? 4000 : $n)' $num_utts` + +if [ -z "$dir" ]; then + dir=exp/nnet3_stats_sad_music_snr/nnet_tdnn_stats +fi + +dir=$dir${affix:+_$affix} + +if ! cuda-compiled; then + cat < $dir/configs/network.xconfig + input dim=`feat-to-dim scp:$sad_data_dir/feats.scp -` name=input + + relu-renorm-layer name=tdnn1 input=Append(-3,-2,-1,0,1,2,3) dim=$relu_dim add-log-stddev=true + relu-renorm-layer name=tdnn2 input=Append(-6,0,6) dim=$relu_dim + stats-layer name=tdnn2_stats config=mean+count(-108:6:18:108) + relu-renorm-layer name=tdnn3 input=Append(tdnn2@-12,tdnn2@0,tdnn2@12,tdnn2_stats) dim=$relu_dim + stats-layer name=tdnn3_stats config=mean+count(-108:12:36:108) + relu-renorm-layer name=tdnn4 input=Append(tdnn3@-12,tdnn3@0,tdnn3@12,tdnn3_stats) dim=$relu_dim + relu-renorm-layer name=tdnn4-snr input=Append(tdnn3@-12,tdnn3@0,tdnn3@12,tdnn3_stats) dim=$relu_dim + + output-layer name=output-speech include-log-softmax=true dim=2 presoftmax-scale-file=$dir/presoftmax_prior_scale_output-speech.txt learning-rate-factor=0.1 objective-scale=$speech_scale input=tdnn4 + output-layer name=output-music include-log-softmax=true dim=2 presoftmax-scale-file=$dir/presoftmax_prior_scale_output-music.txt learning-rate-factor=0.1 objective-scale=$music_scale input=tdnn4 + output-layer name=output-speech_music include-log-softmax=true dim=4 presoftmax-scale-file=$dir/presoftmax_prior_scale_output-speech_music.txt learning-rate-factor=0.1 objective-scale=$speech_music_scale input=tdnn4 + output-layer name=output-snr include-log-softmax=false dim=$num_snr_bins objective-type=quadratic learning-rate-factor=0.1 objective-scale=$snr_scale input=tdnn4-snr + + output name=output-temp input=Append(input@-3,input@-2,input@-1,input,input@1,input@2, input@3) +EOF + steps/nnet3/xconfig_to_configs.py --xconfig-file $dir/configs/network.xconfig \ + --config-dir $dir/configs/ \ + --nnet-edits="rename-node old-name=output-speech new-name=output" + + cat <> $dir/configs/vars +add_lda=false +EOF +fi + +samples_per_iter=`perl -e "print int(400000 / $chunk_width)"` + +if [ -z "$egs_dir" ]; then + egs_dir=$dir/egs_multi + if [ $stage -le 2 ]; then + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $dir/egs_speech/storage ]; then + utils/create_split_dir.pl \ + /export/b{03,04,05,06}/$USER/kaldi-data/egs/aspire-$(date +'%m_%d_%H_%M')/s5/$dir/egs_speech/storage $dir/egs_speech/storage + fi + + . $dir/configs/vars + + steps/nnet3/get_egs_multiple_targets.py --cmd="$decode_cmd" \ + $egs_opts \ + --feat.dir="$sad_data_dir" \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --frames-per-eg=$chunk_width \ + --left-context=$[model_left_context + extra_left_context] \ + --right-context=$[model_right_context + extra_right_context] \ + --num-utts-subset-train=$num_utts_subset_train \ + --num-utts-subset-valid=$num_utts_subset_valid \ + --samples-per-iter=$samples_per_iter \ + --stage=$get_egs_stage \ + --targets-parameters="--output-name=output-speech --target-type=sparse --dim=2 --targets-scp=$sad_data_dir/speech_labels.scp --deriv-weights-scp=$sad_data_dir/deriv_weights.scp --scp2ark-cmd=\"ali-to-post scp:- ark:- |\" --compress=true" \ + --targets-parameters="--output-name=output-speech_music --target-type=sparse --dim=4 --targets-scp=$sad_data_dir/speech_music_labels.scp --deriv-weights-scp=$sad_data_dir/deriv_weights.scp --scp2ark-cmd=\"ali-to-post scp:- ark:- |\" --compress=true" \ + --targets-parameters="--output-name=output-snr --target-type=dense --targets-scp=$sad_data_dir/irm_targets.scp --deriv-weights-scp=$sad_data_dir/deriv_weights_for_irm_targets.scp" \ + --generate-egs-scp=true \ + --dir=$dir/egs_speech + fi + + if [ $stage -le 3 ]; then + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $dir/egs_music/storage ]; then + utils/create_split_dir.pl \ + /export/b{03,04,05,06}/$USER/kaldi-data/egs/aspire-$(date +'%m_%d_%H_%M')/s5/$dir/egs_music/storage $dir/egs_music/storage + fi + + . $dir/configs/vars + + steps/nnet3/get_egs_multiple_targets.py --cmd="$decode_cmd" \ + $egs_opts \ + --feat.dir="$music_data_dir" \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --frames-per-eg=$chunk_width \ + --left-context=$[model_left_context + extra_left_context] \ + --right-context=$[model_right_context + extra_right_context] \ + --num-utts-subset-train=$num_utts_subset_train \ + --num-utts-subset-valid=$num_utts_subset_valid \ + --samples-per-iter=$samples_per_iter \ + --stage=$get_egs_stage \ + --targets-parameters="--output-name=output-music --target-type=sparse --dim=2 --targets-scp=$music_data_dir/music_labels.scp --scp2ark-cmd=\"ali-to-post scp:- ark:- |\" --compress=true" \ + --targets-parameters="--output-name=output-speech_music --target-type=sparse --dim=4 --targets-scp=$music_data_dir/speech_music_labels.scp --deriv-weights-scp=$music_data_dir/deriv_weights.scp --scp2ark-cmd=\"ali-to-post scp:- ark:- |\" --compress=true" \ + --targets-parameters="--output-name=output-speech --target-type=sparse --dim=2 --targets-scp=$music_data_dir/speech_labels.scp --deriv-weights-scp=$music_data_dir/deriv_weights.scp --scp2ark-cmd=\"ali-to-post scp:- ark:- |\" --compress=true" \ + --generate-egs-scp=true \ + --dir=$dir/egs_music + fi + + if [ $stage -le 4 ]; then + # num_chunk_per_minibatch is multiplied by 4 to allow a buffer to use + # the same egs with a different num_chunk_per_minibatch + steps/nnet3/multilingual/get_egs.sh \ + --cmd "$train_cmd" \ + --minibatch-size $num_chunk_per_minibatch \ + --samples-per-iter $samples_per_iter \ + 2 $dir/egs_speech $dir/egs_music $dir/egs_multi + fi +fi + +if [ $stage -le 5 ]; then + steps/nnet3/train_raw_rnn.py --stage=$train_stage \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --egs.chunk-width=$chunk_width \ + --egs.dir="$egs_dir" --egs.stage=$get_egs_stage \ + --egs.chunk-left-context=$extra_left_context \ + --egs.chunk-right-context=$extra_right_context \ + --egs.use-multitask-egs=true --egs.rename-multitask-outputs=false \ + ${extra_egs_copy_cmd:+--egs.extra-copy-cmd="$extra_egs_copy_cmd"} \ + --trainer.num-epochs=$num_epochs \ + --trainer.samples-per-iter=20000 \ + --trainer.optimization.num-jobs-initial=$num_jobs_initial \ + --trainer.optimization.num-jobs-final=$num_jobs_final \ + --trainer.optimization.initial-effective-lrate=$initial_effective_lrate \ + --trainer.optimization.final-effective-lrate=$final_effective_lrate \ + --trainer.optimization.shrink-value=1.0 \ + --trainer.rnn.num-chunk-per-minibatch=$num_chunk_per_minibatch \ + --trainer.deriv-truncate-margin=8 \ + --trainer.max-param-change=$max_param_change \ + --trainer.compute-per-dim-accuracy=true \ + --cmd="$decode_cmd" --nj 40 \ + --cleanup=true \ + --cleanup.remove-egs=$remove_egs \ + --cleanup.preserve-model-interval=10 \ + --use-gpu=true \ + --use-dense-targets=false \ + --feat-dir=$sad_data_dir \ + --targets-scp="$sad_data_dir/speech_labels.scp" \ + --dir=$dir || exit 1 +fi + + + diff --git a/egs/aspire/s5/local/segmentation/tuning/train_stats_sad_music_snr_1i.sh b/egs/aspire/s5/local/segmentation/tuning/train_stats_sad_music_snr_1i.sh new file mode 100644 index 00000000000..3ddcdd795db --- /dev/null +++ b/egs/aspire/s5/local/segmentation/tuning/train_stats_sad_music_snr_1i.sh @@ -0,0 +1,310 @@ +#!/bin/bash + +# This is a script to train a time-delay neural network for speech activity detection (SAD) and +# music-id using statistic pooling component for long-context information. +# This script is same as 1c, but uses larger amount of data. + +set -o pipefail +set -e +set -u + +. cmd.sh + +# At this script level we don't support not running on GPU, as it would be painfully slow. +# If you want to run without GPU you'd have to call train_tdnn.sh with --gpu false, +# --num-threads 16 and --minibatch-size 128. + +stage=0 +train_stage=-10 +get_egs_stage=-10 +egs_opts= # Directly passed to get_egs_multiple_targets.py + +chunk_width=20 +num_chunk_per_minibatch=64 + +extra_left_context=79 +extra_right_context=11 + +relu_dim=256 +cell_dim=256 +projection_dim=64 + +# training options +num_epochs=2 +initial_effective_lrate=0.0003 +final_effective_lrate=0.00003 +num_jobs_initial=3 +num_jobs_final=8 +remove_egs=false +max_param_change=0.2 # Small max-param change for small network +extra_egs_copy_cmd= # Used if you want to do some weird stuff to egs + # such as removing one of the targets + +extra_egs_copy_cmd="nnet3-copy-egs --keep-outputs=output-speech,output-music,output-speech_music,output-snr ark:- ark:- |" + +egs_dir= +nj=40 +feat_type=raw +config_dir= + +dir= +affix=1i + +. cmd.sh +. ./path.sh +. ./utils/parse_options.sh + +if [ $stage -le -1 ]; then + utils/combine_data.sh --extra-files "deriv_weights.scp speech_labels.scp irm_targets.scp deriv_weights_for_irm_targets.scp" \ + data/train_tztec_whole_corrupted_spr_hires_bp data/fisher_train_100k_whole_900_corrupted_spr_hires_bp/ \ + data/babel_{turkish,zulu,cantonese,tamil,amharic}_train_whole_corrupted_spr_hires_bp/ + + cp data/train_tztec_whole_corrupted_spr_hires_bp/{speech_labels.scp,speech_music_labels.scp} + + utils/combine_data.sh --extra-files "deriv_weights.scp speech_labels.scp music_labels.scp speech_music_labels.scp" \ + data/train_tztec_whole_music_corrupted_spr_hires_bp data/fisher_train_100k_whole_900_music_corrupted_spr_hires_bp/ \ + data/babel_{turkish,zulu,cantonese,tamil,amharic}_train_whole_music_corrupted_spr_hires_bp/ +fi + +sad_data_dir=data/train_tztec_whole_corrupted_spr_hires_bp +music_data_dir=data/train_tztec_whole_music_corrupted_spr_hires_bp + +num_utts=`cat $sad_data_dir/utt2spk $music_data_dir/utt2spk | wc -l` +num_utts_subset_valid=`perl -e '$n=int($ARGV[0] * 0.005); print ($n > 4000 ? 4000 : $n)' $num_utts` +num_utts_subset_train=`perl -e '$n=int($ARGV[0] * 0.005); print ($n > 4000 ? 4000 : $n)' $num_utts` + +if [ -z "$dir" ]; then + dir=exp/nnet3_stats_sad_music_snr/nnet_tdnn_stats +fi + +dir=$dir${affix:+_$affix} + +if ! cuda-compiled; then + cat < $dir/configs/network.xconfig + input dim=`feat-to-dim scp:$sad_data_dir/feats.scp -` name=input + + relu-renorm-layer name=tdnn1 input=Append(-3,-2,-1,0,1,2,3) dim=$relu_dim add-log-stddev=true + relu-renorm-layer name=tdnn2 input=Append(-6,0,6) dim=$relu_dim + stats-layer name=tdnn2_stats config=mean+count(-108:6:18:108) + relu-renorm-layer name=tdnn3 input=Append(tdnn2@-12,tdnn2@0,tdnn2@12,tdnn2_stats) dim=$relu_dim + stats-layer name=tdnn3_stats config=mean+count(-108:12:36:108) + relu-renorm-layer name=tdnn4 input=Append(tdnn3@-12,tdnn3@0,tdnn3@12,tdnn3_stats) dim=$relu_dim + relu-renorm-layer name=tdnn4-snr input=Append(tdnn3@-12,tdnn3@0,tdnn3@12,tdnn3_stats) dim=$relu_dim + + output-layer name=output-speech include-log-softmax=true dim=2 presoftmax-scale-file=$dir/presoftmax_prior_scale_output-speech.txt learning-rate-factor=0.1 objective-scale=$speech_scale input=tdnn4 + output-layer name=output-music include-log-softmax=true dim=2 presoftmax-scale-file=$dir/presoftmax_prior_scale_output-music.txt learning-rate-factor=0.1 objective-scale=$music_scale input=tdnn4 + output-layer name=output-speech_music include-log-softmax=true dim=4 presoftmax-scale-file=$dir/presoftmax_prior_scale_output-speech_music.txt learning-rate-factor=0.1 objective-scale=$speech_music_scale input=tdnn4 + output-layer name=output-snr include-log-softmax=false dim=$num_snr_bins objective-type=quadratic learning-rate-factor=0.1 objective-scale=$snr_scale input=tdnn4-snr + + output name=output-temp input=Append(input@-3,input@-2,input@-1,input,input@1,input@2, input@3) +EOF + steps/nnet3/xconfig_to_configs.py --xconfig-file $dir/configs/network.xconfig \ + --config-dir $dir/configs/ \ + --nnet-edits="rename-node old-name=output-speech new-name=output" + + cat <> $dir/configs/vars +add_lda=false +EOF +fi + +samples_per_iter=`perl -e "print int(400000 / $chunk_width)"` + +if [ -z "$egs_dir" ]; then + egs_dir=$dir/egs_multi + if [ $stage -le 2 ]; then + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $dir/egs_speech/storage ]; then + utils/create_split_dir.pl \ + /export/b{03,04,05,06}/$USER/kaldi-data/egs/aspire-$(date +'%m_%d_%H_%M')/s5/$dir/egs_speech/storage $dir/egs_speech/storage + fi + + . $dir/configs/vars + + steps/nnet3/get_egs_multiple_targets.py --cmd="$decode_cmd" \ + $egs_opts \ + --feat.dir="$sad_data_dir" \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --frames-per-eg=$chunk_width \ + --left-context=$[model_left_context + extra_left_context] \ + --right-context=$[model_right_context + extra_right_context] \ + --num-utts-subset-train=$num_utts_subset_train \ + --num-utts-subset-valid=$num_utts_subset_valid \ + --samples-per-iter=$samples_per_iter \ + --stage=$get_egs_stage \ + --targets-parameters="--output-name=output-speech --target-type=sparse --dim=2 --targets-scp=$sad_data_dir/speech_labels.scp --deriv-weights-scp=$sad_data_dir/deriv_weights.scp --scp2ark-cmd=\"ali-to-post scp:- ark:- |\" --compress=true" \ + --targets-parameters="--output-name=output-speech_music --target-type=sparse --dim=4 --targets-scp=$sad_data_dir/speech_music_labels.scp --deriv-weights-scp=$sad_data_dir/deriv_weights.scp --scp2ark-cmd=\"ali-to-post scp:- ark:- |\" --compress=true" \ + --targets-parameters="--output-name=output-snr --target-type=dense --targets-scp=$sad_data_dir/irm_targets.scp --deriv-weights-scp=$sad_data_dir/deriv_weights_for_irm_targets.scp" \ + --generate-egs-scp=true \ + --dir=$dir/egs_speech + fi + + if [ $stage -le 3 ]; then + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $dir/egs_music/storage ]; then + utils/create_split_dir.pl \ + /export/b{03,04,05,06}/$USER/kaldi-data/egs/aspire-$(date +'%m_%d_%H_%M')/s5/$dir/egs_music/storage $dir/egs_music/storage + fi + + . $dir/configs/vars + + steps/nnet3/get_egs_multiple_targets.py --cmd="$decode_cmd" \ + $egs_opts \ + --feat.dir="$music_data_dir" \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --frames-per-eg=$chunk_width \ + --left-context=$[model_left_context + extra_left_context] \ + --right-context=$[model_right_context + extra_right_context] \ + --num-utts-subset-train=$num_utts_subset_train \ + --num-utts-subset-valid=$num_utts_subset_valid \ + --samples-per-iter=$samples_per_iter \ + --stage=$get_egs_stage \ + --targets-parameters="--output-name=output-music --target-type=sparse --dim=2 --targets-scp=$music_data_dir/music_labels.scp --scp2ark-cmd=\"ali-to-post scp:- ark:- |\" --compress=true" \ + --targets-parameters="--output-name=output-speech_music --target-type=sparse --dim=4 --targets-scp=$music_data_dir/speech_music_labels.scp --deriv-weights-scp=$music_data_dir/deriv_weights.scp --scp2ark-cmd=\"ali-to-post scp:- ark:- |\" --compress=true" \ + --targets-parameters="--output-name=output-speech --target-type=sparse --dim=2 --targets-scp=$music_data_dir/speech_labels.scp --deriv-weights-scp=$music_data_dir/deriv_weights.scp --scp2ark-cmd=\"ali-to-post scp:- ark:- |\" --compress=true" \ + --generate-egs-scp=true \ + --dir=$dir/egs_music + fi + + if [ $stage -le 4 ]; then + # num_chunk_per_minibatch is multiplied by 4 to allow a buffer to use + # the same egs with a different num_chunk_per_minibatch + steps/nnet3/multilingual/get_egs.sh \ + --cmd "$train_cmd" \ + --minibatch-size $num_chunk_per_minibatch \ + --samples-per-iter $samples_per_iter \ + 2 $dir/egs_speech $dir/egs_music $dir/egs_multi + fi +fi + +if [ $stage -le 5 ]; then + steps/nnet3/train_raw_rnn.py --stage=$train_stage \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --egs.chunk-width=$chunk_width \ + --egs.dir="$egs_dir" --egs.stage=$get_egs_stage \ + --egs.chunk-left-context=$extra_left_context \ + --egs.chunk-right-context=$extra_right_context \ + --egs.use-multitask-egs=true --egs.rename-multitask-outputs=false \ + ${extra_egs_copy_cmd:+--egs.extra-copy-cmd="$extra_egs_copy_cmd"} \ + --trainer.num-epochs=$num_epochs \ + --trainer.samples-per-iter=20000 \ + --trainer.optimization.num-jobs-initial=$num_jobs_initial \ + --trainer.optimization.num-jobs-final=$num_jobs_final \ + --trainer.optimization.initial-effective-lrate=$initial_effective_lrate \ + --trainer.optimization.final-effective-lrate=$final_effective_lrate \ + --trainer.optimization.shrink-value=1.0 \ + --trainer.rnn.num-chunk-per-minibatch=$num_chunk_per_minibatch \ + --trainer.deriv-truncate-margin=8 \ + --trainer.max-param-change=$max_param_change \ + --trainer.compute-per-dim-accuracy=true \ + --cmd="$decode_cmd" --nj 40 \ + --cleanup=true \ + --cleanup.remove-egs=$remove_egs \ + --cleanup.preserve-model-interval=10 \ + --use-gpu=true \ + --use-dense-targets=false \ + --feat-dir=$sad_data_dir \ + --targets-scp="$sad_data_dir/speech_labels.scp" \ + --dir=$dir || exit 1 +fi + + + From cbe647770466fd2acc5acdb2a9a032458437474a Mon Sep 17 00:00:00 2001 From: Vimal Manohar Date: Thu, 27 Apr 2017 22:26:47 -0400 Subject: [PATCH 211/213] SAD: prepare musan music --- .../local/segmentation/prepare_musan_music.sh | 24 ++++ egs/wsj/s5/steps/data/split_wavs_randomly.py | 114 ++++++++++++++++++ 2 files changed, 138 insertions(+) create mode 100644 egs/aspire/s5/local/segmentation/prepare_musan_music.sh create mode 100755 egs/wsj/s5/steps/data/split_wavs_randomly.py diff --git a/egs/aspire/s5/local/segmentation/prepare_musan_music.sh b/egs/aspire/s5/local/segmentation/prepare_musan_music.sh new file mode 100644 index 00000000000..16fb946b0c8 --- /dev/null +++ b/egs/aspire/s5/local/segmentation/prepare_musan_music.sh @@ -0,0 +1,24 @@ +#! /bin/bash + +# Copyright 2016 Vimal Manohar +# Apache 2.0 + +if [ $# -ne 2 ]; then + echo "Usage: $0 " + echo " e.g.: $0 /export/corpora/JHU/musan RIRS_NOISES/music" + exit 1 +fi + +SRC_DIR=$1 +dir=$2 + +mkdir -p $dir + +local/segmentation/make_musan_music.py $SRC_DIR $dir/wav.scp + +wav-to-duration scp:$dir/wav.scp ark,t:$dir/reco2dur +steps/data/split_wavs_randomly.py $dir/wav.scp $dir/reco2dur \ + $dir/split_utt2dur $dir/split_wav.scp + +awk '{print $1" "int($2*100)}' $dir/split_utt2dur > $dir/split_utt2num_frames +steps/data/wav_scp2noise_list.py $dir/split_wav.scp $dir/music_list diff --git a/egs/wsj/s5/steps/data/split_wavs_randomly.py b/egs/wsj/s5/steps/data/split_wavs_randomly.py new file mode 100755 index 00000000000..b4c3b660ddd --- /dev/null +++ b/egs/wsj/s5/steps/data/split_wavs_randomly.py @@ -0,0 +1,114 @@ +#! /usr/bin/env python + +# Copyright 2016 Vimal Manohar +# Apache 2.0 + +from __future__ import print_function +import argparse +import random + +def get_args(): + parser = argparse.ArgumentParser(description="""This script converts a + wav.scp into split wav.scp that can be converted into noise-set-paramters + that can be passed to steps/data/reverberate_data_dir.py. The wav files in + wav.scp is trimmed randomly into pieces based on options such options such + as --max-duration, --skip-initial-duration and --num-parts-per-minute.""", + formatter_class=arparse.ArgumentDefaultsHelpFormatter) + + parser.add_argument("--max-duration", type=float, default=30, + help="Maximum duration in seconds of the created " + "signal pieces") + parser.add_argument("--min-duration", type=float, default=0.5, + help="Minimum duration in seconds of the created " + "signal pieces") + parser.add_argument("--skip-initial-duration", type=float, default=5, + help="The duration in seconds of the original signal " + "that will be ignored while creating the pieces") + parser.add_argument("--num-parts-per-minute", type=int, default=3, + help="Used to control the number of parts to create " + "from a recording") + parser.add_argument("--sampling-rate", type=float, default=8000, + help="Required sampling rate of the output signals.") + parser.add_argument('--random-seed', type=int, default=0, + help='seed to be used in the random split of signals') + parser.add_argument("wav_scp", type=str, + help="The input wav.scp") + parser.add_argument("reco2dur", type=str, + help="""Durations of the recordings corresponding to the + input wav.scp""") + parser.add_argument("out_utt2dur", type=str, + help="Output utt2dur corresponding to split wavs") + parser.add_argument("out_wav_scp", type=str, + help="Output wav.scp corresponding to split wavs") + + args = parser.parse_args() + + return args + + +def get_noise_set(reco, reco_dur, wav_rspecifier_split, sampling_rate, + num_parts, max_duration, min_duration, skip_initial_duration): + noise_set = [] + for i in range(num_parts): + utt = "{0}-{1}".format(reco, i+1) + + start_time = round(random.random() * (reco_dur - skip_initial_duration) + + skip_initial_duration, 2) + duration = min(round(random.random() * (max_duration-min_duration) + + min_duration, 2), + reco_dur - start_time) + + if len(wav_rspecifier_split) == 1: + rspecifier = ("sox -D {wav} -r {sr} -t wav - " + "trim {st} {dur} |".format( + wav=wav_rspecifier_split[0], + sr=sampling_rate, st=start_time, dur=duration) + else: + rspecifier = ("{wav} sox -D -t wav - -r {sr} -t wav - " + "trim {st} {dur} |".format( + wav=" ".join(wav_rspecifier_split), + sr=sampling_rate, st=start_time, dur=duration) + + noise_set.append( (utt, rspecifier, duration) ) + return noise_set + + +def main(): + args = get_args() + random.seed(args.random_seed) + + reco2dur = {} + for line in open(args.reco2dur): + parts = line.strip().split() + if len(parts) != 2: + raise Exception( + "Expecting reco2dur to contain lines of the format " + " ; Got {0}".format(line)) + reco2dur[parts[0]] = float(parts[1]) + + out_wav_scp = open(args.out_wav_scp, 'w') + out_utt2dur = open(args.out_utt2dur, 'w') + + for line in open(args.wav_scp): + parts = line.strip().split() + reco = parts[0] + dur = reco2dur[reco] + + num_parts = int(float(args.num_parts_per_minute) / 60 * reco2dur[reco]) + + noise_set = get_noise_set( + reco, reco2dur[reco], wav_rspecifier_split=parts[1:], + sampling_rate=args.sampling_rate, num_parts=num_parts, + max_duration=args.max_duration, min_duration=args.min_duration, + skip_initial_duration=args.skip_initial_duration) + + for utt, rspecifier, dur in noise_set: + print ("{0} {1}".format(utt, rspecifier), file=out_wav_scp) + print ("{0} {1}".format(utt, dur), file=out_utt2dur) + + out_wav_scp.close() + out_utt2dur.close() + + +if __name__ == '__main__': + main() From a632f00ce7b3fd602bff5d8edc82d659cbb8e6e5 Mon Sep 17 00:00:00 2001 From: Vimal Manohar Date: Thu, 27 Apr 2017 22:45:38 -0400 Subject: [PATCH 212/213] segmenter: Prepare fisher data music --- .../s5/local/segmentation/prepare_fisher_data.sh | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/egs/aspire/s5/local/segmentation/prepare_fisher_data.sh b/egs/aspire/s5/local/segmentation/prepare_fisher_data.sh index 40f43cfd442..4f55cc6929e 100644 --- a/egs/aspire/s5/local/segmentation/prepare_fisher_data.sh +++ b/egs/aspire/s5/local/segmentation/prepare_fisher_data.sh @@ -56,6 +56,17 @@ oov_I 3 oov_S 3 EOF +if [ ! -d RIRS_NOISES/ ]; then + # Prepare MUSAN rirs and noises + wget --no-check-certificate http://www.openslr.org/resources/28/rirs_noises.zip + unzip rirs_noises.zip +fi + +if [ ! -d RIRS_NOISES/music ]; then + # Prepare MUSAN music + local/segmentation/prepare_musan_music.sh /export/corpora/JHU/musan RIRS_NOISES/music +fi + # Expecting the user to have done run.sh to have $model_dir, # $sat_model_dir, $lang, $lang_test, $train_data_dir local/segmentation/prepare_unsad_data.sh \ From 579fc8cf6a9999576846bb3d20fe7adf72301a66 Mon Sep 17 00:00:00 2001 From: Vimal Manohar Date: Wed, 10 May 2017 15:29:29 -0400 Subject: [PATCH 213/213] segmentaion: Adding more recipes --- .../tuning/train_lstm_sad_music_snr_1l.sh | 316 +++++++++++++++++ .../tuning/train_stats_sad_music_snr_1j.sh | 316 +++++++++++++++++ .../tuning/train_stats_sad_music_snr_1k.sh | 317 +++++++++++++++++ .../tuning/train_stats_sad_music_snr_1l.sh | 318 ++++++++++++++++++ 4 files changed, 1267 insertions(+) create mode 100644 egs/aspire/s5/local/segmentation/tuning/train_lstm_sad_music_snr_1l.sh create mode 100644 egs/aspire/s5/local/segmentation/tuning/train_stats_sad_music_snr_1j.sh create mode 100644 egs/aspire/s5/local/segmentation/tuning/train_stats_sad_music_snr_1k.sh create mode 100644 egs/aspire/s5/local/segmentation/tuning/train_stats_sad_music_snr_1l.sh diff --git a/egs/aspire/s5/local/segmentation/tuning/train_lstm_sad_music_snr_1l.sh b/egs/aspire/s5/local/segmentation/tuning/train_lstm_sad_music_snr_1l.sh new file mode 100644 index 00000000000..d8910053e61 --- /dev/null +++ b/egs/aspire/s5/local/segmentation/tuning/train_lstm_sad_music_snr_1l.sh @@ -0,0 +1,316 @@ +#!/bin/bash + +# This is a script to train a TDNN-LSTM for speech activity detection (SAD) and +# music-id using LSTM for long-context information. +# This is same as 1h, but has more layers. + +set -o pipefail +set -e +set -u + +. cmd.sh + +# At this script level we don't support not running on GPU, as it would be painfully slow. +# If you want to run without GPU you'd have to call train_tdnn.sh with --gpu false, +# --num-threads 16 and --minibatch-size 128. + +stage=0 +train_stage=-10 +get_egs_stage=-10 +egs_opts= # Directly passed to get_egs_multiple_targets.py + +chunk_width=20 +num_chunk_per_minibatch=64 + +extra_left_context=40 +extra_right_context=0 + +relu_dim=256 +cell_dim=256 +projection_dim=64 + +# training options +num_epochs=2 +initial_effective_lrate=0.0003 +final_effective_lrate=0.00003 +num_jobs_initial=3 +num_jobs_final=8 +remove_egs=false +max_param_change=0.2 # Small max-param change for small network +extra_egs_copy_cmd= # Used if you want to do some weird stuff to egs + # such as removing one of the targets + +extra_egs_copy_cmd="nnet3-copy-egs --keep-outputs=output-speech,output-music,output-speech_music,output-snr ark:- ark:- |" + +egs_dir= +nj=40 +feat_type=raw +config_dir= + +dir= +affix=1k + +. cmd.sh +. ./path.sh +. ./utils/parse_options.sh + +if [ $stage -le -1 ]; then + utils/combine_data.sh --extra-files "deriv_weights.scp speech_labels.scp irm_targets.scp deriv_weights_for_irm_targets.scp" \ + data/train_tztec_whole_corrupted_spr_hires_bp data/fisher_train_100k_whole_900_corrupted_spr_hires_bp/ \ + data/babel_{turkish,zulu,cantonese,tamil}_train_whole_corrupted_spr_hires_bp/ + + cp data/train_tztec_whole_corrupted_spr_hires_bp/{speech_labels.scp,speech_music_labels.scp} + + utils/combine_data.sh --extra-files "deriv_weights.scp speech_labels.scp music_labels.scp speech_music_labels.scp" \ + data/train_tztec_whole_music_corrupted_spr_hires_bp data/fisher_train_100k_whole_900_music_corrupted_spr_hires_bp/ \ + data/babel_{turkish,zulu,cantonese,tamil}_train_whole_music_corrupted_spr_hires_bp/ +fi + +sad_data_dir=data/train_tztec_whole_corrupted_spr_hires_bp +music_data_dir=data/train_tztec_whole_music_corrupted_spr_hires_bp + +num_utts=`cat $sad_data_dir/utt2spk $music_data_dir/utt2spk | wc -l` +num_utts_subset_valid=`perl -e '$n=int($ARGV[0] * 0.005); print ($n > 4000 ? 4000 : $n)' $num_utts` +num_utts_subset_train=`perl -e '$n=int($ARGV[0] * 0.005); print ($n > 4000 ? 4000 : $n)' $num_utts` + +if [ -z "$dir" ]; then + dir=exp/nnet3_lstm_sad_music_snr/nnet_lstm +fi + +dir=$dir${affix:+_$affix} + +if ! cuda-compiled; then + cat < $dir/scales +fi + +if [ $stage -le 2 ]; then + echo "$0: creating neural net configs using the xconfig parser"; + + scales=`cat $dir/scales` + + speech_scale=`echo $scales | awk '{print $1}'` + music_scale=`echo $scales | awk '{print $2}'` + speech_music_scale=`echo $scales | awk '{print $3}'` + snr_scale=`echo $scales | awk '{print $4}'` + + num_snr_bins=`feat-to-dim scp:$sad_data_dir/irm_targets.scp -` + snr_scale=`perl -e "print $snr_scale / $num_snr_bins"` + + mkdir -p $dir/configs + cat < $dir/configs/network.xconfig + input dim=`feat-to-dim scp:$sad_data_dir/feats.scp -` name=input + + relu-renorm-layer name=tdnn1 input=Append(-2,-1,0,1,2) dim=$relu_dim add-log-stddev=true + relu-renorm-layer name=tdnn2 input=Append(-1,0,1,2) dim=$relu_dim add-log-stddev=true + relu-renorm-layer name=tdnn3 input=Append(-3,0,3,6) dim=$relu_dim add-log-stddev=true + fast-lstmp-layer name=lstm1 cell-dim=$cell_dim recurrent-projection-dim=$projection_dim non-recurrent-projection-dim=$projection_dim delay=-3 + relu-renorm-layer name=tdnn4 input=Append(-6,0,6,12) dim=$relu_dim + fast-lstmp-layer name=lstm2 cell-dim=$cell_dim recurrent-projection-dim=$projection_dim non-recurrent-projection-dim=$projection_dim delay=-6 + relu-renorm-layer name=tdnn5 input=Append(-12,0,12,24) dim=$relu_dim + relu-renorm-layer name=tdnn5-snr input=Append(lstm2@-6,lstm2@0,lstm2@6,lstm2@12,tdnn5) dim=$relu_dim + + output-layer name=output-speech include-log-softmax=true dim=2 presoftmax-scale-file=$dir/presoftmax_prior_scale_output-speech.txt learning-rate-factor=0.1 objective-scale=$speech_scale input=tdnn5 + output-layer name=output-music include-log-softmax=true dim=2 presoftmax-scale-file=$dir/presoftmax_prior_scale_output-music.txt learning-rate-factor=0.1 objective-scale=$music_scale input=tdnn5 + output-layer name=output-speech_music include-log-softmax=true dim=4 presoftmax-scale-file=$dir/presoftmax_prior_scale_output-speech_music.txt learning-rate-factor=0.1 objective-scale=$speech_music_scale input=tdnn5 + output-layer name=output-snr include-log-softmax=false dim=$num_snr_bins objective-type=quadratic learning-rate-factor=0.1 objective-scale=$snr_scale input=tdnn5-snr + + output name=output-temp input=Append(input@-3,input@-2,input@-1,input,input@1,input@2, input@3) +EOF + steps/nnet3/xconfig_to_configs.py --xconfig-file $dir/configs/network.xconfig \ + --config-dir $dir/configs/ \ + --nnet-edits="rename-node old-name=output-speech new-name=output" + + cat <> $dir/configs/vars +add_lda=false +EOF +fi + +samples_per_iter=`perl -e "print int(400000 / $chunk_width)"` + +if [ -z "$egs_dir" ]; then + egs_dir=$dir/egs_multi + if [ $stage -le 3 ]; then + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $dir/egs_speech/storage ]; then + utils/create_split_dir.pl \ + /export/b{03,04,05,06}/$USER/kaldi-data/egs/aspire-$(date +'%m_%d_%H_%M')/s5/$dir/egs_speech/storage $dir/egs_speech/storage + fi + + . $dir/configs/vars + + steps/nnet3/get_egs_multiple_targets.py --cmd="$decode_cmd" \ + $egs_opts \ + --feat.dir="$sad_data_dir" \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --frames-per-eg=$chunk_width \ + --left-context=$[model_left_context + extra_left_context] \ + --right-context=$[model_right_context + extra_right_context] \ + --num-utts-subset-train=$num_utts_subset_train \ + --num-utts-subset-valid=$num_utts_subset_valid \ + --samples-per-iter=$samples_per_iter \ + --stage=$get_egs_stage \ + --targets-parameters="--output-name=output-speech --target-type=sparse --dim=2 --targets-scp=$sad_data_dir/speech_labels.scp --deriv-weights-scp=$sad_data_dir/deriv_weights.scp --scp2ark-cmd=\"ali-to-post scp:- ark:- |\" --compress=true" \ + --targets-parameters="--output-name=output-speech_music --target-type=sparse --dim=4 --targets-scp=$sad_data_dir/speech_music_labels.scp --deriv-weights-scp=$sad_data_dir/deriv_weights.scp --scp2ark-cmd=\"ali-to-post scp:- ark:- |\" --compress=true" \ + --targets-parameters="--output-name=output-snr --target-type=dense --targets-scp=$sad_data_dir/irm_targets.scp --deriv-weights-scp=$sad_data_dir/deriv_weights_for_irm_targets.scp" \ + --generate-egs-scp=true \ + --dir=$dir/egs_speech + fi + + if [ $stage -le 4 ]; then + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $dir/egs_music/storage ]; then + utils/create_split_dir.pl \ + /export/b{03,04,05,06}/$USER/kaldi-data/egs/aspire-$(date +'%m_%d_%H_%M')/s5/$dir/egs_music/storage $dir/egs_music/storage + fi + + . $dir/configs/vars + + steps/nnet3/get_egs_multiple_targets.py --cmd="$decode_cmd" \ + $egs_opts \ + --feat.dir="$music_data_dir" \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --frames-per-eg=$chunk_width \ + --left-context=$[model_left_context + extra_left_context] \ + --right-context=$[model_right_context + extra_right_context] \ + --num-utts-subset-train=$num_utts_subset_train \ + --num-utts-subset-valid=$num_utts_subset_valid \ + --samples-per-iter=$samples_per_iter \ + --stage=$get_egs_stage \ + --targets-parameters="--output-name=output-music --target-type=sparse --dim=2 --targets-scp=$music_data_dir/music_labels.scp --scp2ark-cmd=\"ali-to-post scp:- ark:- |\" --compress=true" \ + --targets-parameters="--output-name=output-speech_music --target-type=sparse --dim=4 --targets-scp=$music_data_dir/speech_music_labels.scp --deriv-weights-scp=$music_data_dir/deriv_weights.scp --scp2ark-cmd=\"ali-to-post scp:- ark:- |\" --compress=true" \ + --targets-parameters="--output-name=output-speech --target-type=sparse --dim=2 --targets-scp=$music_data_dir/speech_labels.scp --deriv-weights-scp=$music_data_dir/deriv_weights.scp --scp2ark-cmd=\"ali-to-post scp:- ark:- |\" --compress=true" \ + --generate-egs-scp=true \ + --dir=$dir/egs_music + fi + + if [ $stage -le 5 ]; then + # num_chunk_per_minibatch is multiplied by 4 to allow a buffer to use + # the same egs with a different num_chunk_per_minibatch + steps/nnet3/multilingual/get_egs.sh \ + --cmd "$train_cmd" \ + --minibatch-size $num_chunk_per_minibatch \ + --samples-per-iter $samples_per_iter \ + 2 $dir/egs_speech $dir/egs_music $dir/egs_multi + fi +fi + +if [ $stage -le 6 ]; then + steps/nnet3/train_raw_rnn.py --stage=$train_stage \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --egs.chunk-width=$chunk_width \ + --egs.dir="$egs_dir" --egs.stage=$get_egs_stage \ + --egs.chunk-left-context=$extra_left_context \ + --egs.chunk-right-context=$extra_right_context \ + --egs.use-multitask-egs=true --egs.rename-multitask-outputs=false \ + ${extra_egs_copy_cmd:+--egs.extra-copy-cmd="$extra_egs_copy_cmd"} \ + --trainer.num-epochs=$num_epochs \ + --trainer.samples-per-iter=20000 \ + --trainer.optimization.num-jobs-initial=$num_jobs_initial \ + --trainer.optimization.num-jobs-final=$num_jobs_final \ + --trainer.optimization.initial-effective-lrate=$initial_effective_lrate \ + --trainer.optimization.final-effective-lrate=$final_effective_lrate \ + --trainer.optimization.shrink-value=1.0 \ + --trainer.rnn.num-chunk-per-minibatch=$num_chunk_per_minibatch \ + --trainer.deriv-truncate-margin=8 \ + --trainer.max-param-change=$max_param_change \ + --trainer.compute-per-dim-accuracy=true \ + --cmd="$decode_cmd" --nj 40 \ + --cleanup=true \ + --cleanup.remove-egs=$remove_egs \ + --cleanup.preserve-model-interval=10 \ + --use-gpu=true \ + --use-dense-targets=false \ + --feat-dir=$sad_data_dir \ + --targets-scp="$sad_data_dir/speech_labels.scp" \ + --dir=$dir || exit 1 +fi + + diff --git a/egs/aspire/s5/local/segmentation/tuning/train_stats_sad_music_snr_1j.sh b/egs/aspire/s5/local/segmentation/tuning/train_stats_sad_music_snr_1j.sh new file mode 100644 index 00000000000..059fbf7b1a9 --- /dev/null +++ b/egs/aspire/s5/local/segmentation/tuning/train_stats_sad_music_snr_1j.sh @@ -0,0 +1,316 @@ +#!/bin/bash + +# This is a script to train a TDNN-LSTM for speech activity detection (SAD) and +# music-id using LSTM for long-context information. +# This is same as 1h, but has more layers. + +set -o pipefail +set -e +set -u + +. cmd.sh + +# At this script level we don't support not running on GPU, as it would be painfully slow. +# If you want to run without GPU you'd have to call train_tdnn.sh with --gpu false, +# --num-threads 16 and --minibatch-size 128. + +stage=0 +train_stage=-10 +get_egs_stage=-10 +egs_opts= # Directly passed to get_egs_multiple_targets.py + +chunk_width=20 +num_chunk_per_minibatch=64 + +extra_left_context=79 +extra_right_context=11 + +relu_dim=256 +cell_dim=256 +projection_dim=64 + +# training options +num_epochs=2 +initial_effective_lrate=0.0003 +final_effective_lrate=0.00003 +num_jobs_initial=3 +num_jobs_final=8 +remove_egs=false +max_param_change=0.2 # Small max-param change for small network +extra_egs_copy_cmd= # Used if you want to do some weird stuff to egs + # such as removing one of the targets + +extra_egs_copy_cmd="nnet3-copy-egs --keep-outputs=output-speech,output-music,output-speech_music,output-snr ark:- ark:- |" + +egs_dir= +nj=40 +feat_type=raw +config_dir= + +dir= +affix=1j + +. cmd.sh +. ./path.sh +. ./utils/parse_options.sh + +if [ $stage -le -1 ]; then + utils/combine_data.sh --extra-files "deriv_weights.scp speech_labels.scp irm_targets.scp deriv_weights_for_irm_targets.scp" \ + data/train_tztec_whole_corrupted_spr_hires_bp data/fisher_train_100k_whole_900_corrupted_spr_hires_bp/ \ + data/babel_{turkish,zulu,cantonese,tamil}_train_whole_corrupted_spr_hires_bp/ + + cp data/train_tztec_whole_corrupted_spr_hires_bp/{speech_labels.scp,speech_music_labels.scp} + + utils/combine_data.sh --extra-files "deriv_weights.scp speech_labels.scp music_labels.scp speech_music_labels.scp" \ + data/train_tztec_whole_music_corrupted_spr_hires_bp data/fisher_train_100k_whole_900_music_corrupted_spr_hires_bp/ \ + data/babel_{turkish,zulu,cantonese,tamil}_train_whole_music_corrupted_spr_hires_bp/ +fi + +sad_data_dir=data/train_tztec_whole_corrupted_spr_hires_bp +music_data_dir=data/train_tztec_whole_music_corrupted_spr_hires_bp + +num_utts=`cat $sad_data_dir/utt2spk $music_data_dir/utt2spk | wc -l` +num_utts_subset_valid=`perl -e '$n=int($ARGV[0] * 0.005); print ($n > 4000 ? 4000 : $n)' $num_utts` +num_utts_subset_train=`perl -e '$n=int($ARGV[0] * 0.005); print ($n > 4000 ? 4000 : $n)' $num_utts` + +if [ -z "$dir" ]; then + dir=exp/nnet3_stats_sad_music_snr/nnet_tdnn_stats +fi + +dir=$dir${affix:+_$affix} + +if ! cuda-compiled; then + cat < $dir/scales +fi + +if [ $stage -le 2 ]; then + echo "$0: creating neural net configs using the xconfig parser"; + + scales=`cat $dir/scales` + + speech_scale=`echo $scales | awk '{print $1}'` + music_scale=`echo $scales | awk '{print $2}'` + speech_music_scale=`echo $scales | awk '{print $3}'` + snr_scale=`echo $scales | awk '{print $4}'` + + num_snr_bins=`feat-to-dim scp:$sad_data_dir/irm_targets.scp -` + snr_scale=`perl -e "print $snr_scale / $num_snr_bins"` + + mkdir -p $dir/configs + cat < $dir/configs/network.xconfig + input dim=`feat-to-dim scp:$sad_data_dir/feats.scp -` name=input + + relu-renorm-layer name=tdnn1 input=Append(-2,-1,0,1,2) dim=$relu_dim add-log-stddev=true + relu-renorm-layer name=tdnn2 input=Append(-1,0,1,2) dim=$relu_dim add-log-stddev=true + relu-renorm-layer name=tdnn3 input=Append(-3,0,3,6) dim=$relu_dim add-log-stddev=true + stats-layer name=tdnn3_stats config=mean+stddev+count(-99:3:9:99) + relu-renorm-layer name=tdnn4 input=Append(tdnn3@-6,tdnn3@0,tdnn3@6,tdnn3@12,tdnn3_stats) add-log-stddev=true dim=$relu_dim + stats-layer name=tdnn4_stats config=mean+stddev+count(-108:6:18:108) + relu-renorm-layer name=tdnn5 input=Append(tdnn4@-12,tdnn4@0,tdnn4@12,tdnn4@24,tdnn4_stats) dim=$relu_dim + relu-renorm-layer name=tdnn5-snr input=Append(tdnn3@-6,tdnn3@0,tdnn3@6,tdnn3@12,tdnn5) dim=$relu_dim + + output-layer name=output-speech include-log-softmax=true dim=2 presoftmax-scale-file=$dir/presoftmax_prior_scale_output-speech.txt learning-rate-factor=0.1 objective-scale=$speech_scale input=tdnn5 + output-layer name=output-music include-log-softmax=true dim=2 presoftmax-scale-file=$dir/presoftmax_prior_scale_output-music.txt learning-rate-factor=0.1 objective-scale=$music_scale input=tdnn5 + output-layer name=output-speech_music include-log-softmax=true dim=4 presoftmax-scale-file=$dir/presoftmax_prior_scale_output-speech_music.txt learning-rate-factor=0.1 objective-scale=$speech_music_scale input=tdnn5 + output-layer name=output-snr include-log-softmax=false dim=$num_snr_bins objective-type=quadratic learning-rate-factor=0.1 objective-scale=$snr_scale input=tdnn5-snr + + output name=output-temp input=Append(input@-3,input@-2,input@-1,input,input@1,input@2, input@3) +EOF + steps/nnet3/xconfig_to_configs.py --xconfig-file $dir/configs/network.xconfig \ + --config-dir $dir/configs/ \ + --nnet-edits="rename-node old-name=output-speech new-name=output" + + cat <> $dir/configs/vars +add_lda=false +EOF +fi + +samples_per_iter=`perl -e "print int(400000 / $chunk_width)"` + +if [ -z "$egs_dir" ]; then + egs_dir=$dir/egs_multi + if [ $stage -le 3 ]; then + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $dir/egs_speech/storage ]; then + utils/create_split_dir.pl \ + /export/b{03,04,05,06}/$USER/kaldi-data/egs/aspire-$(date +'%m_%d_%H_%M')/s5/$dir/egs_speech/storage $dir/egs_speech/storage + fi + + . $dir/configs/vars + + steps/nnet3/get_egs_multiple_targets.py --cmd="$decode_cmd" \ + $egs_opts \ + --feat.dir="$sad_data_dir" \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --frames-per-eg=$chunk_width \ + --left-context=$[model_left_context + extra_left_context] \ + --right-context=$[model_right_context + extra_right_context] \ + --num-utts-subset-train=$num_utts_subset_train \ + --num-utts-subset-valid=$num_utts_subset_valid \ + --samples-per-iter=$samples_per_iter \ + --stage=$get_egs_stage \ + --targets-parameters="--output-name=output-speech --target-type=sparse --dim=2 --targets-scp=$sad_data_dir/speech_labels.scp --deriv-weights-scp=$sad_data_dir/deriv_weights.scp --scp2ark-cmd=\"ali-to-post scp:- ark:- |\" --compress=true" \ + --targets-parameters="--output-name=output-speech_music --target-type=sparse --dim=4 --targets-scp=$sad_data_dir/speech_music_labels.scp --deriv-weights-scp=$sad_data_dir/deriv_weights.scp --scp2ark-cmd=\"ali-to-post scp:- ark:- |\" --compress=true" \ + --targets-parameters="--output-name=output-snr --target-type=dense --targets-scp=$sad_data_dir/irm_targets.scp --deriv-weights-scp=$sad_data_dir/deriv_weights_for_irm_targets.scp" \ + --generate-egs-scp=true \ + --dir=$dir/egs_speech + fi + + if [ $stage -le 4 ]; then + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $dir/egs_music/storage ]; then + utils/create_split_dir.pl \ + /export/b{03,04,05,06}/$USER/kaldi-data/egs/aspire-$(date +'%m_%d_%H_%M')/s5/$dir/egs_music/storage $dir/egs_music/storage + fi + + . $dir/configs/vars + + steps/nnet3/get_egs_multiple_targets.py --cmd="$decode_cmd" \ + $egs_opts \ + --feat.dir="$music_data_dir" \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --frames-per-eg=$chunk_width \ + --left-context=$[model_left_context + extra_left_context] \ + --right-context=$[model_right_context + extra_right_context] \ + --num-utts-subset-train=$num_utts_subset_train \ + --num-utts-subset-valid=$num_utts_subset_valid \ + --samples-per-iter=$samples_per_iter \ + --stage=$get_egs_stage \ + --targets-parameters="--output-name=output-music --target-type=sparse --dim=2 --targets-scp=$music_data_dir/music_labels.scp --scp2ark-cmd=\"ali-to-post scp:- ark:- |\" --compress=true" \ + --targets-parameters="--output-name=output-speech_music --target-type=sparse --dim=4 --targets-scp=$music_data_dir/speech_music_labels.scp --deriv-weights-scp=$music_data_dir/deriv_weights.scp --scp2ark-cmd=\"ali-to-post scp:- ark:- |\" --compress=true" \ + --targets-parameters="--output-name=output-speech --target-type=sparse --dim=2 --targets-scp=$music_data_dir/speech_labels.scp --deriv-weights-scp=$music_data_dir/deriv_weights.scp --scp2ark-cmd=\"ali-to-post scp:- ark:- |\" --compress=true" \ + --generate-egs-scp=true \ + --dir=$dir/egs_music + fi + + if [ $stage -le 5 ]; then + # num_chunk_per_minibatch is multiplied by 4 to allow a buffer to use + # the same egs with a different num_chunk_per_minibatch + steps/nnet3/multilingual/get_egs.sh \ + --cmd "$train_cmd" \ + --minibatch-size $num_chunk_per_minibatch \ + --samples-per-iter $samples_per_iter \ + 2 $dir/egs_speech $dir/egs_music $dir/egs_multi + fi +fi + +if [ $stage -le 6 ]; then + steps/nnet3/train_raw_rnn.py --stage=$train_stage \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --egs.chunk-width=$chunk_width \ + --egs.dir="$egs_dir" --egs.stage=$get_egs_stage \ + --egs.chunk-left-context=$extra_left_context \ + --egs.chunk-right-context=$extra_right_context \ + --egs.use-multitask-egs=true --egs.rename-multitask-outputs=false \ + ${extra_egs_copy_cmd:+--egs.extra-copy-cmd="$extra_egs_copy_cmd"} \ + --trainer.num-epochs=$num_epochs \ + --trainer.samples-per-iter=20000 \ + --trainer.optimization.num-jobs-initial=$num_jobs_initial \ + --trainer.optimization.num-jobs-final=$num_jobs_final \ + --trainer.optimization.initial-effective-lrate=$initial_effective_lrate \ + --trainer.optimization.final-effective-lrate=$final_effective_lrate \ + --trainer.optimization.shrink-value=1.0 \ + --trainer.rnn.num-chunk-per-minibatch=$num_chunk_per_minibatch \ + --trainer.deriv-truncate-margin=8 \ + --trainer.max-param-change=$max_param_change \ + --trainer.compute-per-dim-accuracy=true \ + --cmd="$decode_cmd" --nj 40 \ + --cleanup=true \ + --cleanup.remove-egs=$remove_egs \ + --cleanup.preserve-model-interval=10 \ + --use-gpu=true \ + --use-dense-targets=false \ + --feat-dir=$sad_data_dir \ + --targets-scp="$sad_data_dir/speech_labels.scp" \ + --dir=$dir || exit 1 +fi + + diff --git a/egs/aspire/s5/local/segmentation/tuning/train_stats_sad_music_snr_1k.sh b/egs/aspire/s5/local/segmentation/tuning/train_stats_sad_music_snr_1k.sh new file mode 100644 index 00000000000..48425e50386 --- /dev/null +++ b/egs/aspire/s5/local/segmentation/tuning/train_stats_sad_music_snr_1k.sh @@ -0,0 +1,317 @@ +#!/bin/bash + +# This is a script to train a TDNN-LSTM for speech activity detection (SAD) and +# music-id using LSTM for long-context information. +# This is same as 1h, but has more layers. + +set -o pipefail +set -e +set -u + +. cmd.sh + +# At this script level we don't support not running on GPU, as it would be painfully slow. +# If you want to run without GPU you'd have to call train_tdnn.sh with --gpu false, +# --num-threads 16 and --minibatch-size 128. + +stage=0 +train_stage=-10 +get_egs_stage=-10 +egs_opts= # Directly passed to get_egs_multiple_targets.py + +chunk_width=20 +num_chunk_per_minibatch=64 + +extra_left_context=79 +extra_right_context=11 + +relu_dim=256 +cell_dim=256 +projection_dim=64 + +# training options +num_epochs=2 +initial_effective_lrate=0.0003 +final_effective_lrate=0.00003 +num_jobs_initial=3 +num_jobs_final=8 +remove_egs=false +max_param_change=0.2 # Small max-param change for small network +extra_egs_copy_cmd= # Used if you want to do some weird stuff to egs + # such as removing one of the targets + +extra_egs_copy_cmd="nnet3-copy-egs --keep-outputs=output-speech,output-music,output-speech_music ark:- ark:- |" + +egs_dir= +nj=40 +feat_type=raw +config_dir= + +dir= +affix=1k + +. cmd.sh +. ./path.sh +. ./utils/parse_options.sh + +if [ $stage -le -1 ]; then + utils/combine_data.sh --extra-files "deriv_weights.scp speech_labels.scp irm_targets.scp deriv_weights_for_irm_targets.scp" \ + data/train_tztec_whole_corrupted_spr_hires_bp data/fisher_train_100k_whole_900_corrupted_spr_hires_bp/ \ + data/babel_{turkish,zulu,cantonese,tamil}_train_whole_corrupted_spr_hires_bp/ + + cp data/train_tztec_whole_corrupted_spr_hires_bp/{speech_labels.scp,speech_music_labels.scp} + + utils/combine_data.sh --extra-files "deriv_weights.scp speech_labels.scp music_labels.scp speech_music_labels.scp" \ + data/train_tztec_whole_music_corrupted_spr_hires_bp data/fisher_train_100k_whole_900_music_corrupted_spr_hires_bp/ \ + data/babel_{turkish,zulu,cantonese,tamil}_train_whole_music_corrupted_spr_hires_bp/ +fi + +sad_data_dir=data/train_tztec_whole_corrupted_spr_hires_bp +music_data_dir=data/train_tztec_whole_music_corrupted_spr_hires_bp + +num_utts=`cat $sad_data_dir/utt2spk $music_data_dir/utt2spk | wc -l` +num_utts_subset_valid=`perl -e '$n=int($ARGV[0] * 0.005); print ($n > 4000 ? 4000 : $n)' $num_utts` +num_utts_subset_train=`perl -e '$n=int($ARGV[0] * 0.005); print ($n > 4000 ? 4000 : $n)' $num_utts` + +if [ -z "$dir" ]; then + dir=exp/nnet3_stats_sad_music_snr/nnet_tdnn_stats +fi + +dir=$dir${affix:+_$affix} + +if ! cuda-compiled; then + cat < $dir/scales +fi + +if [ $stage -le 2 ]; then + echo "$0: creating neural net configs using the xconfig parser"; + + scales=`cat $dir/scales` + + speech_scale=`echo $scales | awk '{print $1}'` + music_scale=`echo $scales | awk '{print $2}'` + speech_music_scale=`echo $scales | awk '{print $3}'` + snr_scale=`echo $scales | awk '{print $4}'` + + num_snr_bins=`feat-to-dim scp:$sad_data_dir/irm_targets.scp -` + snr_scale=`perl -e "print $snr_scale / $num_snr_bins"` + + mkdir -p $dir/configs + cat < $dir/configs/network.xconfig + input dim=`feat-to-dim scp:$sad_data_dir/feats.scp -` name=input + + relu-renorm-layer name=tdnn1 input=Append(-2,-1,0,1,2) dim=$relu_dim add-log-stddev=true + relu-renorm-layer name=tdnn2 input=Append(-1,0,1,2) dim=$relu_dim add-log-stddev=true + relu-renorm-layer name=tdnn3 input=Append(-3,0,3,6) dim=$relu_dim add-log-stddev=true + stats-layer name=tdnn3_stats config=mean+stddev+count(-99:3:9:99) + relu-renorm-layer name=tdnn4 input=Append(tdnn3@-6,tdnn3@0,tdnn3@6,tdnn3@12,tdnn3_stats) add-log-stddev=true dim=$relu_dim + stats-layer name=tdnn4_stats config=mean+stddev+count(-108:6:18:108) + relu-renorm-layer name=tdnn5 input=Append(tdnn4@-12,tdnn4@0,tdnn4@12,tdnn4@24,tdnn4_stats) dim=$relu_dim + relu-renorm-layer name=tdnn5-snr input=Append(tdnn3@-6,tdnn3@0,tdnn3@6,tdnn3@12,tdnn5) dim=$relu_dim + + output-layer name=output-speech include-log-softmax=true dim=2 presoftmax-scale-file=$dir/presoftmax_prior_scale_output-speech.txt learning-rate-factor=0.1 objective-scale=$speech_scale input=tdnn5 + output-layer name=output-music include-log-softmax=true dim=2 presoftmax-scale-file=$dir/presoftmax_prior_scale_output-music.txt learning-rate-factor=0.1 objective-scale=$music_scale input=tdnn5 + output-layer name=output-speech_music include-log-softmax=true dim=4 presoftmax-scale-file=$dir/presoftmax_prior_scale_output-speech_music.txt learning-rate-factor=0.1 objective-scale=$speech_music_scale input=tdnn5 + output-layer name=output-snr include-log-softmax=false dim=$num_snr_bins objective-type=quadratic learning-rate-factor=0.1 objective-scale=$snr_scale input=tdnn5-snr + + output name=output-temp input=Append(input@-3,input@-2,input@-1,input,input@1,input@2, input@3) +EOF + steps/nnet3/xconfig_to_configs.py --xconfig-file $dir/configs/network.xconfig \ + --config-dir $dir/configs/ \ + --nnet-edits="rename-node old-name=output-speech new-name=output" + + cat <> $dir/configs/vars +add_lda=false +EOF +fi + +samples_per_iter=`perl -e "print int(400000 / $chunk_width)"` + +if [ -z "$egs_dir" ]; then + egs_dir=$dir/egs_multi + if [ $stage -le 3 ]; then + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $dir/egs_speech/storage ]; then + utils/create_split_dir.pl \ + /export/b{03,04,05,06}/$USER/kaldi-data/egs/aspire-$(date +'%m_%d_%H_%M')/s5/$dir/egs_speech/storage $dir/egs_speech/storage + fi + + . $dir/configs/vars + + steps/nnet3/get_egs_multiple_targets.py --cmd="$decode_cmd" \ + $egs_opts \ + --feat.dir="$sad_data_dir" \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --frames-per-eg=$chunk_width \ + --left-context=$[model_left_context + extra_left_context] \ + --right-context=$[model_right_context + extra_right_context] \ + --num-utts-subset-train=$num_utts_subset_train \ + --num-utts-subset-valid=$num_utts_subset_valid \ + --samples-per-iter=$samples_per_iter \ + --stage=$get_egs_stage \ + --targets-parameters="--output-name=output-speech --target-type=sparse --dim=2 --targets-scp=$sad_data_dir/speech_labels.scp --deriv-weights-scp=$sad_data_dir/deriv_weights.scp --scp2ark-cmd=\"ali-to-post scp:- ark:- |\" --compress=true" \ + --targets-parameters="--output-name=output-speech_music --target-type=sparse --dim=4 --targets-scp=$sad_data_dir/speech_music_labels.scp --deriv-weights-scp=$sad_data_dir/deriv_weights.scp --scp2ark-cmd=\"ali-to-post scp:- ark:- |\" --compress=true" \ + --targets-parameters="--output-name=output-snr --target-type=dense --targets-scp=$sad_data_dir/irm_targets.scp --deriv-weights-scp=$sad_data_dir/deriv_weights_for_irm_targets.scp" \ + --generate-egs-scp=true \ + --dir=$dir/egs_speech + fi + + if [ $stage -le 4 ]; then + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $dir/egs_music/storage ]; then + utils/create_split_dir.pl \ + /export/b{03,04,05,06}/$USER/kaldi-data/egs/aspire-$(date +'%m_%d_%H_%M')/s5/$dir/egs_music/storage $dir/egs_music/storage + fi + + . $dir/configs/vars + + steps/nnet3/get_egs_multiple_targets.py --cmd="$decode_cmd" \ + $egs_opts \ + --feat.dir="$music_data_dir" \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --frames-per-eg=$chunk_width \ + --left-context=$[model_left_context + extra_left_context] \ + --right-context=$[model_right_context + extra_right_context] \ + --num-utts-subset-train=$num_utts_subset_train \ + --num-utts-subset-valid=$num_utts_subset_valid \ + --samples-per-iter=$samples_per_iter \ + --stage=$get_egs_stage \ + --targets-parameters="--output-name=output-music --target-type=sparse --dim=2 --targets-scp=$music_data_dir/music_labels.scp --scp2ark-cmd=\"ali-to-post scp:- ark:- |\" --compress=true" \ + --targets-parameters="--output-name=output-speech_music --target-type=sparse --dim=4 --targets-scp=$music_data_dir/speech_music_labels.scp --deriv-weights-scp=$music_data_dir/deriv_weights.scp --scp2ark-cmd=\"ali-to-post scp:- ark:- |\" --compress=true" \ + --targets-parameters="--output-name=output-speech --target-type=sparse --dim=2 --targets-scp=$music_data_dir/speech_labels.scp --deriv-weights-scp=$music_data_dir/deriv_weights.scp --scp2ark-cmd=\"ali-to-post scp:- ark:- |\" --compress=true" \ + --generate-egs-scp=true \ + --dir=$dir/egs_music + fi + + if [ $stage -le 5 ]; then + # num_chunk_per_minibatch is multiplied by 4 to allow a buffer to use + # the same egs with a different num_chunk_per_minibatch + steps/nnet3/multilingual/get_egs.sh \ + --cmd "$train_cmd" \ + --minibatch-size $num_chunk_per_minibatch \ + --samples-per-iter $samples_per_iter \ + 2 $dir/egs_speech $dir/egs_music $dir/egs_multi + fi +fi + +if [ $stage -le 6 ]; then + steps/nnet3/train_raw_rnn.py --stage=$train_stage \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --egs.chunk-width=$chunk_width \ + --egs.dir="$egs_dir" --egs.stage=$get_egs_stage \ + --egs.chunk-left-context=$extra_left_context \ + --egs.chunk-right-context=$extra_right_context \ + --egs.use-multitask-egs=true --egs.rename-multitask-outputs=false \ + ${extra_egs_copy_cmd:+--egs.extra-copy-cmd="$extra_egs_copy_cmd"} \ + --trainer.num-epochs=$num_epochs \ + --trainer.samples-per-iter=20000 \ + --trainer.optimization.num-jobs-initial=$num_jobs_initial \ + --trainer.optimization.num-jobs-final=$num_jobs_final \ + --trainer.optimization.initial-effective-lrate=$initial_effective_lrate \ + --trainer.optimization.final-effective-lrate=$final_effective_lrate \ + --trainer.optimization.shrink-value=1.0 \ + --trainer.rnn.num-chunk-per-minibatch=$num_chunk_per_minibatch \ + --trainer.deriv-truncate-margin=8 \ + --trainer.max-param-change=$max_param_change \ + --trainer.compute-per-dim-accuracy=true \ + --cmd="$decode_cmd" --nj 40 \ + --cleanup=true \ + --cleanup.remove-egs=$remove_egs \ + --cleanup.preserve-model-interval=10 \ + --use-gpu=true \ + --use-dense-targets=false \ + --feat-dir=$sad_data_dir \ + --targets-scp="$sad_data_dir/speech_labels.scp" \ + --dir=$dir || exit 1 +fi + + + diff --git a/egs/aspire/s5/local/segmentation/tuning/train_stats_sad_music_snr_1l.sh b/egs/aspire/s5/local/segmentation/tuning/train_stats_sad_music_snr_1l.sh new file mode 100644 index 00000000000..689c31e623a --- /dev/null +++ b/egs/aspire/s5/local/segmentation/tuning/train_stats_sad_music_snr_1l.sh @@ -0,0 +1,318 @@ +#!/bin/bash + +# This is a script to train a TDNN-LSTM for speech activity detection (SAD) and +# music-id using LSTM for long-context information. +# This is same as 1h, but has more layers. + +set -o pipefail +set -e +set -u + +. cmd.sh + +# At this script level we don't support not running on GPU, as it would be painfully slow. +# If you want to run without GPU you'd have to call train_tdnn.sh with --gpu false, +# --num-threads 16 and --minibatch-size 128. + +stage=0 +train_stage=-10 +get_egs_stage=-10 +egs_opts= # Directly passed to get_egs_multiple_targets.py + +chunk_width=20 +num_chunk_per_minibatch=64 + +extra_left_context=79 +extra_right_context=11 + +relu_dim=256 +cell_dim=256 +projection_dim=64 + +# training options +num_epochs=2 +initial_effective_lrate=0.0003 +final_effective_lrate=0.00003 +num_jobs_initial=3 +num_jobs_final=8 +remove_egs=false +max_param_change=0.2 # Small max-param change for small network +extra_egs_copy_cmd= # Used if you want to do some weird stuff to egs + # such as removing one of the targets + +extra_egs_copy_cmd="nnet3-copy-egs --keep-outputs=output-speech,output-music,output-speech_music,output-snr ark:- ark:- |" + +egs_dir= +nj=40 +feat_type=raw +config_dir= + +dir= +affix=1l + +. cmd.sh +. ./path.sh +. ./utils/parse_options.sh + +if [ $stage -le -1 ]; then + utils/combine_data.sh --extra-files "deriv_weights.scp speech_labels.scp irm_targets.scp deriv_weights_for_irm_targets.scp" \ + data/train_tztec_whole_corrupted_spr_hires_bp data/fisher_train_100k_whole_900_corrupted_spr_hires_bp/ \ + data/babel_{turkish,zulu,cantonese,tamil}_train_whole_corrupted_spr_hires_bp/ + + cp data/train_tztec_whole_corrupted_spr_hires_bp/{speech_labels.scp,speech_music_labels.scp} + + utils/combine_data.sh --extra-files "deriv_weights.scp speech_labels.scp music_labels.scp speech_music_labels.scp" \ + data/train_tztec_whole_music_corrupted_spr_hires_bp data/fisher_train_100k_whole_900_music_corrupted_spr_hires_bp/ \ + data/babel_{turkish,zulu,cantonese,tamil}_train_whole_music_corrupted_spr_hires_bp/ +fi + +sad_data_dir=data/train_tztec_whole_corrupted_spr_hires_bp +music_data_dir=data/train_tztec_whole_music_corrupted_spr_hires_bp + +num_utts=`cat $sad_data_dir/utt2spk $music_data_dir/utt2spk | wc -l` +num_utts_subset_valid=`perl -e '$n=int($ARGV[0] * 0.005); print ($n > 4000 ? 4000 : $n)' $num_utts` +num_utts_subset_train=`perl -e '$n=int($ARGV[0] * 0.005); print ($n > 4000 ? 4000 : $n)' $num_utts` + +if [ -z "$dir" ]; then + dir=exp/nnet3_stats_sad_music_snr/nnet_tdnn_stats +fi + +dir=$dir${affix:+_$affix} + +if ! cuda-compiled; then + cat < $dir/scales +fi + +if [ $stage -le 2 ]; then + echo "$0: creating neural net configs using the xconfig parser"; + + scales=`cat $dir/scales` + + speech_scale=`echo $scales | awk '{print $1}'` + music_scale=`echo $scales | awk '{print $2}'` + speech_music_scale=`echo $scales | awk '{print $3}'` + snr_scale=`echo $scales | awk '{print $4}'` + + num_snr_bins=`feat-to-dim scp:$sad_data_dir/irm_targets.scp -` + snr_scale=`perl -e "print $snr_scale / $num_snr_bins"` + + mkdir -p $dir/configs + cat < $dir/configs/network.xconfig + input dim=`feat-to-dim scp:$sad_data_dir/feats.scp -` name=input + + relu-renorm-layer name=tdnn1 input=Append(-2,-1,0,1,2) dim=$relu_dim add-log-stddev=true + relu-renorm-layer name=tdnn2 input=Append(-1,0,1,2) dim=$relu_dim add-log-stddev=true + relu-renorm-layer name=tdnn3 input=Append(-3,0,3,6) dim=$relu_dim add-log-stddev=true + stats-layer name=tdnn3_stats config=mean+count(-99:3:9:99) + relu-renorm-layer name=tdnn4 input=Append(tdnn3@-6,tdnn3@0,tdnn3@6,tdnn3@12,tdnn3_stats) add-log-stddev=true dim=$relu_dim + stats-layer name=tdnn4_stats config=mean+count(-108:6:18:108) + relu-renorm-layer name=tdnn5 input=Append(tdnn4@-12,tdnn4@0,tdnn4@12,tdnn4@24,tdnn4_stats) dim=$relu_dim + relu-renorm-layer name=tdnn5-snr input=Append(tdnn3@-6,tdnn3@0,tdnn3@6,tdnn3@12,tdnn5) dim=$relu_dim + + output-layer name=output-speech include-log-softmax=true dim=2 presoftmax-scale-file=$dir/presoftmax_prior_scale_output-speech.txt learning-rate-factor=0.1 objective-scale=$speech_scale input=tdnn5 + output-layer name=output-music include-log-softmax=true dim=2 presoftmax-scale-file=$dir/presoftmax_prior_scale_output-music.txt learning-rate-factor=0.1 objective-scale=$music_scale input=tdnn5 + output-layer name=output-speech_music include-log-softmax=true dim=4 presoftmax-scale-file=$dir/presoftmax_prior_scale_output-speech_music.txt learning-rate-factor=0.1 objective-scale=$speech_music_scale input=tdnn5 + output-layer name=output-snr include-log-softmax=false dim=$num_snr_bins objective-type=quadratic learning-rate-factor=0.1 objective-scale=$snr_scale input=tdnn5-snr + + output name=output-temp input=Append(input@-3,input@-2,input@-1,input,input@1,input@2, input@3) +EOF + steps/nnet3/xconfig_to_configs.py --xconfig-file $dir/configs/network.xconfig \ + --config-dir $dir/configs/ \ + --nnet-edits="rename-node old-name=output-speech new-name=output" + + cat <> $dir/configs/vars +add_lda=false +EOF +fi + +samples_per_iter=`perl -e "print int(400000 / $chunk_width)"` + +if [ -z "$egs_dir" ]; then + egs_dir=$dir/egs_multi + if [ $stage -le 3 ]; then + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $dir/egs_speech/storage ]; then + utils/create_split_dir.pl \ + /export/b{03,04,05,06}/$USER/kaldi-data/egs/aspire-$(date +'%m_%d_%H_%M')/s5/$dir/egs_speech/storage $dir/egs_speech/storage + fi + + . $dir/configs/vars + + steps/nnet3/get_egs_multiple_targets.py --cmd="$decode_cmd" \ + $egs_opts \ + --feat.dir="$sad_data_dir" \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --frames-per-eg=$chunk_width \ + --left-context=$[model_left_context + extra_left_context] \ + --right-context=$[model_right_context + extra_right_context] \ + --num-utts-subset-train=$num_utts_subset_train \ + --num-utts-subset-valid=$num_utts_subset_valid \ + --samples-per-iter=$samples_per_iter \ + --stage=$get_egs_stage \ + --targets-parameters="--output-name=output-speech --target-type=sparse --dim=2 --targets-scp=$sad_data_dir/speech_labels.scp --deriv-weights-scp=$sad_data_dir/deriv_weights.scp --scp2ark-cmd=\"ali-to-post scp:- ark:- |\" --compress=true" \ + --targets-parameters="--output-name=output-speech_music --target-type=sparse --dim=4 --targets-scp=$sad_data_dir/speech_music_labels.scp --deriv-weights-scp=$sad_data_dir/deriv_weights.scp --scp2ark-cmd=\"ali-to-post scp:- ark:- |\" --compress=true" \ + --targets-parameters="--output-name=output-snr --target-type=dense --targets-scp=$sad_data_dir/irm_targets.scp --deriv-weights-scp=$sad_data_dir/deriv_weights_for_irm_targets.scp" \ + --generate-egs-scp=true \ + --dir=$dir/egs_speech + fi + + if [ $stage -le 4 ]; then + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $dir/egs_music/storage ]; then + utils/create_split_dir.pl \ + /export/b{03,04,05,06}/$USER/kaldi-data/egs/aspire-$(date +'%m_%d_%H_%M')/s5/$dir/egs_music/storage $dir/egs_music/storage + fi + + . $dir/configs/vars + + steps/nnet3/get_egs_multiple_targets.py --cmd="$decode_cmd" \ + $egs_opts \ + --feat.dir="$music_data_dir" \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --frames-per-eg=$chunk_width \ + --left-context=$[model_left_context + extra_left_context] \ + --right-context=$[model_right_context + extra_right_context] \ + --num-utts-subset-train=$num_utts_subset_train \ + --num-utts-subset-valid=$num_utts_subset_valid \ + --samples-per-iter=$samples_per_iter \ + --stage=$get_egs_stage \ + --targets-parameters="--output-name=output-music --target-type=sparse --dim=2 --targets-scp=$music_data_dir/music_labels.scp --scp2ark-cmd=\"ali-to-post scp:- ark:- |\" --compress=true" \ + --targets-parameters="--output-name=output-speech_music --target-type=sparse --dim=4 --targets-scp=$music_data_dir/speech_music_labels.scp --deriv-weights-scp=$music_data_dir/deriv_weights.scp --scp2ark-cmd=\"ali-to-post scp:- ark:- |\" --compress=true" \ + --targets-parameters="--output-name=output-speech --target-type=sparse --dim=2 --targets-scp=$music_data_dir/speech_labels.scp --deriv-weights-scp=$music_data_dir/deriv_weights.scp --scp2ark-cmd=\"ali-to-post scp:- ark:- |\" --compress=true" \ + --generate-egs-scp=true \ + --dir=$dir/egs_music + fi + + if [ $stage -le 5 ]; then + # num_chunk_per_minibatch is multiplied by 4 to allow a buffer to use + # the same egs with a different num_chunk_per_minibatch + steps/nnet3/multilingual/get_egs.sh \ + --cmd "$train_cmd" \ + --minibatch-size $num_chunk_per_minibatch \ + --samples-per-iter $samples_per_iter \ + 2 $dir/egs_speech $dir/egs_music $dir/egs_multi + fi +fi + +if [ $stage -le 6 ]; then + steps/nnet3/train_raw_rnn.py --stage=$train_stage \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --egs.chunk-width=$chunk_width \ + --egs.dir="$egs_dir" --egs.stage=$get_egs_stage \ + --egs.chunk-left-context=$extra_left_context \ + --egs.chunk-right-context=$extra_right_context \ + --egs.use-multitask-egs=true --egs.rename-multitask-outputs=false \ + ${extra_egs_copy_cmd:+--egs.extra-copy-cmd="$extra_egs_copy_cmd"} \ + --trainer.num-epochs=$num_epochs \ + --trainer.samples-per-iter=20000 \ + --trainer.optimization.num-jobs-initial=$num_jobs_initial \ + --trainer.optimization.num-jobs-final=$num_jobs_final \ + --trainer.optimization.initial-effective-lrate=$initial_effective_lrate \ + --trainer.optimization.final-effective-lrate=$final_effective_lrate \ + --trainer.optimization.shrink-value=1.0 \ + --trainer.rnn.num-chunk-per-minibatch=$num_chunk_per_minibatch \ + --trainer.deriv-truncate-margin=8 \ + --trainer.max-param-change=$max_param_change \ + --trainer.compute-per-dim-accuracy=true \ + --cmd="$decode_cmd" --nj 40 \ + --cleanup=true \ + --cleanup.remove-egs=$remove_egs \ + --cleanup.preserve-model-interval=10 \ + --use-gpu=true \ + --use-dense-targets=false \ + --feat-dir=$sad_data_dir \ + --targets-scp="$sad_data_dir/speech_labels.scp" \ + --dir=$dir || exit 1 +fi + + + +