Skip to content

Commit

Permalink
Faster Multinomial Sampling code
Browse files Browse the repository at this point in the history
  • Loading branch information
Amartya Sanyal committed Jun 29, 2017
1 parent 18d390c commit 6946f0e
Show file tree
Hide file tree
Showing 3 changed files with 78 additions and 19 deletions.
13 changes: 6 additions & 7 deletions lib/THC/THCTensorRandom.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -228,17 +228,15 @@ sampleMultinomialOnce(long* dest,

template <typename T>
__global__ void
sampleMultinomialWithReplacement(curandStateMtgp32* state,
sampleMultinomialWithReplacement(T* uniform_idx,
int totalSamples,
long* dest,
long distributions,
int categories,
T* normDistPrefixSum) {
// At the moment, each warp computes one sample value in the binary
// search due to divergence. It seems possible to compute multiple
// values and limit divergence though later on. However, no matter
// what, all block threads must participate in the curand_uniform
// call to update the generator state.
// values and limit divergence though later on.

// The block determines the distribution for which we generate a point
for (long curDist = blockIdx.x;
Expand All @@ -250,7 +248,7 @@ sampleMultinomialWithReplacement(curandStateMtgp32* state,
int sample = sampleBase + threadIdx.y;

// All threads participate in this
T r = ScalarConvert<float, T>::to(curand_uniform(&state[blockIdx.x]));
T r = uniform_idx[sample];

if (threadIdx.x == 0 && sample < totalSamples) {
// Find the bucket that a uniform sample lies in
Expand Down Expand Up @@ -284,6 +282,7 @@ sampleMultinomialWithoutReplacement(curandStateMtgp32* state,

// The block and warp determines the distribution for which we
// generate a point
T zero = ScalarConvert<int, T>::to(0);
for (long curDistBase = blockIdx.x * blockDim.y;
curDistBase < distributions;
curDistBase += gridDim.x * blockDim.y) {
Expand All @@ -292,7 +291,7 @@ sampleMultinomialWithoutReplacement(curandStateMtgp32* state,

// All threads must participate in this
T r = ScalarConvert<float, T>::to(curand_uniform(&state[blockIdx.x]));

if (threadIdx.x == 0 && curDist < distributions) {
// Find the bucket that a uniform sample lies in
int choice = binarySearchForMultinomial<T>(
Expand All @@ -305,7 +304,7 @@ sampleMultinomialWithoutReplacement(curandStateMtgp32* state,

// Without replacement, so update the original probability so it
// is not considered a second time
origDist[curDist * categories + choice] = ScalarConvert<int, T>::to(0);
origDist[curDist * categories + choice] = zero;
}
}
}
Expand Down
32 changes: 20 additions & 12 deletions lib/THC/generic/THCTensorRandom.cu
Original file line number Diff line number Diff line change
Expand Up @@ -202,13 +202,21 @@ THC_API void THCTensor_(multinomial)(struct THCState *state,
// distribution concurrently.
dim3 grid(numDist < MAX_NUM_BLOCKS ? numDist : MAX_NUM_BLOCKS);
//Create the matrix of uniformly sampled numbers
THCTensor *uniform_idx = THCTensor_(newWithSize1d)(state, n_sample);
THCTensor_(uniform)(state, uniform_idx, 0, 1);
sampleMultinomialWithReplacement
<<<grid, block, 0, THCState_getCurrentStream(state)>>>(
gen->gen_states,
n_sample,
THCudaLongTensor_data(state, self),
numDist, numCategories,
THCTensor_(data)(state, prefixSum));
THCTensor_(data)(state, uniform_idx),
n_sample,
THCudaLongTensor_data(state, self),
numDist, numCategories,
THCTensor_(data)(state, prefixSum));
THCTensor_(free)(state, uniform_idx);
} else {
// Sample without replacement
Expand Down Expand Up @@ -237,13 +245,13 @@ THC_API void THCTensor_(multinomial)(struct THCState *state,
// recalculate our distribution
sampleMultinomialWithoutReplacement
<<<grid, block, 0, THCState_getCurrentStream(state)>>>(
gen->gen_states,
n_sample,
sample,
THCudaLongTensor_data(state, self),
numDist, numCategories,
THCTensor_(data)(state, origDist),
THCTensor_(data)(state, prefixSum));
gen->gen_states,
n_sample,
sample,
THCudaLongTensor_data(state, self),
numDist, numCategories,
THCTensor_(data)(state, origDist),
THCTensor_(data)(state, prefixSum));
}
}
Expand Down
52 changes: 52 additions & 0 deletions test/multinomial.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
local tester = torch.Tester()

cmd = torch.CmdLine()
cmd:text()
cmd:text()
cmd:text('Testing alias multinomial on cuda')
cmd:text()
cmd:text('Options')
cmd:option('--compare',false,'compare with cutorch multinomial')
cmd:text()

-- parse input params
params = cmd:parse(arg)

require 'cutorch'
local function checkMultinomial()
local n_class = {10, 100, 1000}
local n_sample = {10, 100, 1000, 10000}
local n_dist = 100
for _, curr_n_class in pairs(n_class) do
for _, curr_n_sample in pairs(n_sample) do
print("")
print("Benchmarking multinomial with "..curr_n_class.." classes and "..curr_n_sample.." samples")
torch.seed()
local probs = torch.CudaDoubleTensor(n_dist, curr_n_class):uniform(0,1)
local a = torch.Timer()
local cold_time = a:time().real
a:reset()
cutorch.synchronize()
a:reset()
for i = 1,10 do
torch.multinomial(probs, curr_n_sample, true)
cutorch.synchronize()
end
print("[CUDA] : torch.multinomial draw: "..(a:time().real/10).." seconds (hot)")
end
torch.seed()
local probs = torch.CudaDoubleTensor(3, curr_n_class):uniform(0,1)
for i =1,3 do
probs[i]:div(probs[i]:sum())
end
local output = torch.multinomial(probs, 5000000, true)
local counts = torch.Tensor(3, curr_n_class):zero()
for i=1,3 do
output[i]:long():apply(function(x) counts[{i, x}] = counts[{i, x}] + 1 end)
counts[i]:div(counts[i]:sum())
end
tester:eq(probs:double(), counts, 0.01, "probs and counts should be approximately equal for n_class = "..curr_n_class)
end
end
tester:add(checkMultinomial)
tester:run()

0 comments on commit 6946f0e

Please sign in to comment.