Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reduced type set option (on by default for Android) #435

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@ FILE(APPEND THCUNN_generic_h.lua "]]")

FILE(GLOB luasrc *.lua)

IF (ANDROID)
ADD_DEFINITIONS(-DTHC_MIN_MATH)
ENDIF()

ADD_SUBDIRECTORY(lib)

INSTALL(
Expand Down
11 changes: 8 additions & 3 deletions THCUNN.lua
Original file line number Diff line number Diff line change
Expand Up @@ -129,8 +129,10 @@ local function_names_generic = extract_function_names_generic(THCUNN_generic_h)
THNN.kernels['torch.CudaTensor'] = THNN.bind(THCUNN.C, function_names_generic, 'Cuda', THCUNN.getState)
torch.getmetatable('torch.CudaTensor').THNN = THNN.kernels['torch.CudaTensor']

THNN.kernels['torch.CudaDoubleTensor'] = THNN.bind(THCUNN.C, function_names_generic, 'CudaDouble', THCUNN.getState)
torch.getmetatable('torch.CudaDoubleTensor').THNN = THNN.kernels['torch.CudaDoubleTensor']
if not cutorch.minMath then
THNN.kernels['torch.CudaDoubleTensor'] = THNN.bind(THCUNN.C, function_names_generic, 'CudaDouble', THCUNN.getState)
torch.getmetatable('torch.CudaDoubleTensor').THNN = THNN.kernels['torch.CudaDoubleTensor']
end

if cutorch.hasHalf then
-- in order to call 'half' functions from lua, convert real arguments from
Expand Down Expand Up @@ -164,7 +166,10 @@ local function Module__converter(type)
end
end

rawset(torch.getmetatable('nn.Module'), 'cudaDouble', Module__converter('torch.CudaDoubleTensor'))
if not cutorch.minMath then
rawset(torch.getmetatable('nn.Module'), 'cudaDouble', Module__converter('torch.CudaDoubleTensor'))
end

if cutorch.hasHalf then
rawset(torch.getmetatable('nn.Module'), 'cudaHalf', Module__converter('torch.CudaHalfTensor'))
end
Expand Down
3 changes: 1 addition & 2 deletions init.lua
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,4 @@ require('cunn.DataParallelTable')

nn.Module._flattenTensorBuffer['torch.CudaTensor'] = torch.FloatTensor.new
nn.Module._flattenTensorBuffer['torch.CudaDoubleTensor'] = torch.DoubleTensor.new
-- FIXME: change this to torch.HalfTensor when available
nn.Module._flattenTensorBuffer['torch.CudaHalfTensor'] = torch.FloatTensor.new
nn.Module._flattenTensorBuffer['torch.CudaHalfTensor'] = torch.HalfTensor.new
5 changes: 2 additions & 3 deletions lib/THCUNN/SparseLinear.cu
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,5 @@ void THNN_CudaHalfSparseLinear_updateParameters(
#endif

#include "generic/SparseLinear.cu"
#include "THCGenerateFloatType.h"
#include "generic/SparseLinear.cu"
#include "THCGenerateDoubleType.h"
#include "THCGenerateFloatTypes.h"

35 changes: 35 additions & 0 deletions lib/THCUNN/THCGenerateFloatTypes.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
#ifndef THC_GENERIC_FILE
#error "You must define THC_GENERIC_FILE before including THGenerateFloatTypes.h"
#endif

#define THCGenerateFloatTypes

#define THCTypeIdxByte 1
#define THCTypeIdxChar 2
#define THCTypeIdxShort 3
#define THCTypeIdxInt 4
#define THCTypeIdxLong 5
#define THCTypeIdxFloat 6
#define THCTypeIdxDouble 7
#define THCTypeIdxHalf 8
#define THCTypeIdx_(T) TH_CONCAT_2(THCTypeIdx,T)

# ifndef THC_MIN_MATH
# include "THCGenerateHalfType.h"
# include "THCGenerateDoubleType.h"
# endif

#include "THCGenerateFloatType.h"

#undef THCTypeIdxByte
#undef THCTypeIdxChar
#undef THCTypeIdxShort
#undef THCTypeIdxInt
#undef THCTypeIdxLong
#undef THCTypeIdxFloat
#undef THCTypeIdxDouble
#undef THCTypeIdxHalf
#undef THCTypeIdx_

#undef THCGenerateFloatTypes
#undef THC_GENERIC_FILE
15 changes: 8 additions & 7 deletions test.lua
Original file line number Diff line number Diff line change
Expand Up @@ -11,20 +11,21 @@ local THC = ffi.os == 'Windows' and ffi.load('THC') or ffi.C
--e.g.: th -lcunn -e "nn.testcuda{'Sigmoid_forward'}"

local typenames = {
'torch.CudaTensor',
'torch.CudaDoubleTensor',
'torch.CudaTensor'
}

local t2cpu = {
['torch.CudaTensor'] = 'torch.FloatTensor',
['torch.CudaDoubleTensor'] = 'torch.DoubleTensor',

['torch.CudaTensor'] = 'torch.FloatTensor'
}

local function checkHalf()
if cutorch.hasHalf then
if not cutorch.minMath then
if cutorch.hasHalf then
table.insert(typenames, 'torch.CudaHalfTensor')
t2cpu['torch.CudaHalfTensor'] = 'torch.FloatTensor'
t2cpu['torch.CudaHalfTensor'] = 'torch.HalfTensor'
end
table.insert(typenames, 'torch.CudaDoubleTensor')
t2cpu['torch.CudaDoubleTensor'] = 'torch.DoubleTensor'
end
end

Expand Down