Skip to content
This repository has been archived by the owner on Jan 15, 2024. It is now read-only.

Commit

Permalink
[bugfix] Backport int32 fix for sampled block and bump version (#1109)
Browse files Browse the repository at this point in the history
* [BugFix] support int32 for sampled blocks (#1106)

* support int32 for sampled blocks

* Fix lint

* Update version
  • Loading branch information
eric-haibin-lin authored Jan 14, 2020
1 parent 434187e commit 75b3c12
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 17 deletions.
2 changes: 1 addition & 1 deletion src/gluonnlp/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from . import initializer
from .vocab import Vocab

__version__ = '0.8.2'
__version__ = '0.8.3'

__all__ = ['data',
'model',
Expand Down
16 changes: 8 additions & 8 deletions src/gluonnlp/model/sampled_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,15 +70,15 @@ def hybrid_forward(self, F, x, sampled_values, label, w_all, b_all):

# remove accidental hits
if self._remove_accidental_hits:
label_vec = F.reshape(label, (-1, 1))
sample_vec = F.reshape(sampled_candidates, (1, -1))
mask = F.broadcast_equal(label_vec, sample_vec) * -1e37
label_vec = F.reshape(label, (-1, 1)).astype('int32')
sample_vec = F.reshape(sampled_candidates, (1, -1)).astype('int32')
mask = F.broadcast_equal(label_vec, sample_vec).astype('float32') * -1e37
pred_sampled = pred_sampled + mask

# subtract log(q)
expected_count_sampled = F.reshape(expected_count_sampled,
shape=(1, self._num_sampled))
expected_count_true = expected_count_true.reshape((-1,))
expected_count_sampled = expected_count_sampled.astype('float32')
expected_count_sampled = expected_count_sampled.reshape(shape=(1, self._num_sampled))
expected_count_true = expected_count_true.astype('float32').reshape((-1,))
pred_true = pred_true - F.log(expected_count_true)
pred_true = pred_true.reshape((-1, 1))
pred_sampled = F.broadcast_sub(pred_sampled, F.log(expected_count_sampled))
Expand Down Expand Up @@ -174,7 +174,7 @@ def hybrid_forward(self, F, x, sampled_values, label, weight, bias):
# (batch_size,)
label = F.reshape(label, shape=(-1,))
# (num_sampled+batch_size,)
ids = F.concat(sampled_candidates, label, dim=0)
ids = F.concat(sampled_candidates.astype('int32'), label.astype('int32'), dim=0)
# lookup weights and biases
# (num_sampled+batch_size, dim)
w_all = F.Embedding(data=ids, weight=weight,
Expand Down Expand Up @@ -477,7 +477,7 @@ def forward(self, x, sampled_values, label): # pylint: disable=arguments-differ
# (batch_size,)
label = label.reshape(shape=(-1,))
# (num_sampled+batch_size,)
ids = nd.concat(sampled_candidates, label, dim=0)
ids = nd.concat(sampled_candidates.astype('int32'), label.astype('int32'), dim=0)
# lookup weights and biases
weight = self.weight.row_sparse_data(ids)
bias = self.bias.data(ids.context)
Expand Down
20 changes: 12 additions & 8 deletions tests/unittest/test_sampled_logits.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,9 @@
import pytest

@pytest.mark.parametrize('f', [nlp.model.NCEDense, nlp.model.SparseNCEDense])
def test_nce_loss(f):
@pytest.mark.parametrize('cls_dtype', ['float32', 'int32'])
@pytest.mark.parametrize('count_dtype', ['float32', 'int32'])
def test_nce_loss(f, cls_dtype, count_dtype):
ctx = mx.cpu()
batch_size = 2
num_sampled = 3
Expand All @@ -40,9 +42,9 @@ def test_nce_loss(f):
trainer = mx.gluon.Trainer(model.collect_params(), 'sgd')
x = mx.nd.ones((batch_size, num_hidden))
y = mx.nd.ones((batch_size,))
sampled_cls = mx.nd.ones((num_sampled,))
sampled_cls_cnt = mx.nd.ones((num_sampled,))
true_cls_cnt = mx.nd.ones((batch_size,))
sampled_cls = mx.nd.ones((num_sampled,), dtype=cls_dtype)
sampled_cls_cnt = mx.nd.ones((num_sampled,), dtype=count_dtype)
true_cls_cnt = mx.nd.ones((batch_size,), dtype=count_dtype)
samples = (sampled_cls, sampled_cls_cnt, true_cls_cnt)
with mx.autograd.record():
pred, new_y = model(x, samples, y)
Expand All @@ -53,7 +55,9 @@ def test_nce_loss(f):
mx.nd.waitall()

@pytest.mark.parametrize('f', [nlp.model.ISDense, nlp.model.SparseISDense])
def test_is_softmax_loss(f):
@pytest.mark.parametrize('cls_dtype', ['float32', 'int32'])
@pytest.mark.parametrize('count_dtype', ['float32', 'int32'])
def test_is_softmax_loss(f, cls_dtype, count_dtype):
ctx = mx.cpu()
batch_size = 2
num_sampled = 3
Expand All @@ -66,9 +70,9 @@ def test_is_softmax_loss(f):
trainer = mx.gluon.Trainer(model.collect_params(), 'sgd')
x = mx.nd.ones((batch_size, num_hidden))
y = mx.nd.ones((batch_size,))
sampled_cls = mx.nd.ones((num_sampled,))
sampled_cls_cnt = mx.nd.ones((num_sampled,))
true_cls_cnt = mx.nd.ones((batch_size,))
sampled_cls = mx.nd.ones((num_sampled,), dtype=cls_dtype)
sampled_cls_cnt = mx.nd.ones((num_sampled,), dtype=count_dtype)
true_cls_cnt = mx.nd.ones((batch_size,), dtype=count_dtype)
samples = (sampled_cls, sampled_cls_cnt, true_cls_cnt)
with mx.autograd.record():
pred, new_y = model(x, samples, y)
Expand Down

0 comments on commit 75b3c12

Please sign in to comment.