diff --git a/.buildkite/pipeline.yml b/.buildkite/pipeline.yml
index 4e60aab..14be783 100644
--- a/.buildkite/pipeline.yml
+++ b/.buildkite/pipeline.yml
@@ -1,16 +1,31 @@
steps:
- - label: "Julia v1"
- plugins:
- - JuliaCI/julia#v1:
- version: "1"
- - JuliaCI/julia-test#v1: ~
- - JuliaCI/julia-coverage#v1:
- codecov: true
- agents:
- queue: "juliagpu"
- cuda: "*"
- if: build.message !~ /\[skip tests\]/
- timeout_in_minutes: 60
+ - group: ":julia: Julia"
+ key: "julia"
+ steps:
+ - label: "Julia {{matrix.julia}}"
+ plugins:
+ - JuliaCI/julia#v1:
+ version: "{{matrix.julia}}"
+ - JuliaCI/julia-test#v1:
+ test_args: "--quickfail"
+ - JuliaCI/julia-coverage#v1:
+ dirs:
+ - src
+ agents:
+ queue: "juliagpu"
+ cuda: "*"
+ timeout_in_minutes: 120
+ matrix:
+ setup:
+ julia:
+ - "1.8"
+ - "1.9"
+ - "1.10"
+ adjustments:
+ - with:
+ julia: "nightly"
+ soft_fail: true
+
env:
- SECRET_CODECOV_TOKEN: "AFBHqF1xnrD/W69t402L6WLKqP1pBRph9mzRKUGd+V7m+uPZRCsZE7bKLmWmWGeGsBn94P8aGPepa4s9rpmJTffrkR9v9yn8S0IdbT4ETOkgFVqyB+OYtcu1zrsK/MujKYDDHg0GQl5DqYWoCqen6Xyaty+oQQUpwE943SzMDy4S0ezBFlC2o4UxHCQF9PIcz0/DQuaGQVWnLknUrNL18GKb7RsieJYWRBSSUZLRtfzAALZsdNRTTGfW2tWiXZtW2lBin3MJ7OBxYju4SiIN2FUvKk+raLwl0fNceh/eEni5hgfGvqbkv1ugULanPILYRhNijRVDsGWMamcZP0T3/w==;U2FsdGVkX1+pFc8npLm1aKY00zEDGz75z8BzjOtEp0Kca4XPmxw8T0sPunb29IO6da0FJmOjZXaYmWzI8jAwDA=="
+ SECRET_CODECOV_TOKEN: "AFBHqF1xnrD/W69t402L6WLKqP1pBRph9mzRKUGd+V7m+uPZRCsZE7bKLmWmWGeGsBn94P8aGPepa4s9rpmJTffrkR9v9yn8S0IdbT4ETOkgFVqyB+OYtcu1zrsK/MujKYDDHg0GQl5DqYWoCqen6Xyaty+oQQUpwE943SzMDy4S0ezBFlC2o4UxHCQF9PIcz0/DQuaGQVWnLknUrNL18GKb7RsieJYWRBSSUZLRtfzAALZsdNRTTGfW2tWiXZtW2lBin3MJ7OBxYju4SiIN2FUvKk+raLwl0fNceh/eEni5hgfGvqbkv1ugULanPILYRhNijRVDsGWMamcZP0T3/w==;U2FsdGVkX1+pFc8npLm1aKY00zEDGz75z8BzjOtEp0Kca4XPmxw8T0sPunb29IO6da0FJmOjZXaYmWzI8jAwDA=="
\ No newline at end of file
diff --git a/Project.toml b/Project.toml
index 8c0997f..62ab6ee 100644
--- a/Project.toml
+++ b/Project.toml
@@ -6,11 +6,12 @@ version = "0.1.0"
[deps]
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
-Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
+TropicalGemmC_jll = "4f4992fb-2984-5eba-87b8-475305d0f5fc"
TropicalNumbers = "b3a74e9c-7526-4576-a4eb-79c0d4c32334"
[compat]
CUDA = "5"
+TropicalGemmC_jll = "0.1"
TropicalNumbers = "0.6.2"
julia = "1"
diff --git a/README.md b/README.md
index 93441fe..367f7e3 100644
--- a/README.md
+++ b/README.md
@@ -1,7 +1,7 @@
# CuTropicalGEMM
[![Build status](https://badge.buildkite.com/06c24dc7b1a9d7c38897acd21575ffd678ee03de190c0b8d81.svg)](https://buildkite.com/julialang/cutropicalgemm-dot-jl)
-[![Coverage](https://codecov.io/gh/ArrogantGao/CuTropicalGEMM.jl/branch/main/graph/badge.svg)](https://codecov.io/gh/ArrogantGao/CuTropicalGEMM.jl)
+[![Coverage](https://codecov.io/gh/TensorBFS/CuTropicalGEMM.jl/branch/main/graph/badge.svg)](https://codecov.io/gh/TensorBFS/CuTropicalGEMM.jl)
CuTropicalGEMM is an open source
@@ -87,6 +87,8 @@ The performance of `Cublas` on normal GEMM is used as a reference.
Please open an [issue](https://github.com/TensorBFS/CuTropicalGEMM.jl/issues)
if you encounter any problems, or have any feature requests.
+If you want to have a check of the `C-CUDA` code, please check the repo [TropicalGemm_Cuda](https://github.com/ArrogantGao/TropicalGemm_Cuda).
+
It is also welcomed for any suggestions about the issues marked as `enhancement`, please let us know if you have any idea about them.
## Acknowalgement
diff --git a/deps/CMakeLists.txt b/deps/CMakeLists.txt
deleted file mode 100644
index c378724..0000000
--- a/deps/CMakeLists.txt
+++ /dev/null
@@ -1,38 +0,0 @@
-cmake_minimum_required(VERSION 3.18) # 3.18 or later is required for CUDA_ARCHITECTURES support
-project(TropicalGEMM CXX CUDA)
-
-set(CUDA_FILES tropicalgemm_kernels.cu)
-set(LIBRARY_NAME kernels.so)
-
-set(MACRO_COMBINATIONS
- PlusMul_FP32
- PlusMul_FP64
- PlusMul_INT32
- PlusMul_INT64
- TropicalAndOr_Bool
- TropicalMaxMul_FP32
- TropicalMaxMul_FP64
- TropicalMaxMul_INT32
- TropicalMaxMul_INT64
- TropicalMaxPlus_FP32
- TropicalMaxPlus_FP64
- TropicalMinPlus_FP32
- TropicalMinPlus_FP64
-)
-
-# Set the desired GPU architectures, e.g., 60 (Pascal), 70 (Volta), 75 (Turing), 80 (Ampere)
-set(CMAKE_CUDA_ARCHITECTURES 60 70 75 80)
-
-# Set the output directory for the shared libraries
-set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/lib)
-
-foreach(MACRO IN LISTS MACRO_COMBINATIONS)
- string(REPLACE "_" ";" MACRO_SPLIT ${MACRO})
- list(GET MACRO_SPLIT 0 MACRO_1)
- list(GET MACRO_SPLIT 1 MACRO_2)
-
- set(TARGET_NAME "_${MACRO}")
- add_library(${TARGET_NAME} SHARED ${CUDA_FILES})
- target_compile_definitions(${TARGET_NAME} PRIVATE ${MACRO_1} ${MACRO_2})
- set_target_properties(${TARGET_NAME} PROPERTIES POSITION_INDEPENDENT_CODE ON)
-endforeach()
\ No newline at end of file
diff --git a/deps/build.jl b/deps/build.jl
deleted file mode 100644
index e0dbb0c..0000000
--- a/deps/build.jl
+++ /dev/null
@@ -1,11 +0,0 @@
-const script_dir = dirname(@__FILE__)
-const src_dir = joinpath(script_dir)
-const build_dir = joinpath(script_dir, "build")
-
-mkpath(build_dir)
-
-cd(build_dir) do
- run(`cmake $src_dir`)
- run(`make clean`)
- run(`make`)
-end
\ No newline at end of file
diff --git a/deps/tropicalgemm_kernels.cu b/deps/tropicalgemm_kernels.cu
deleted file mode 100644
index 8d3ac9b..0000000
--- a/deps/tropicalgemm_kernels.cu
+++ /dev/null
@@ -1,775 +0,0 @@
-#include
-#include
-
-#include
-#include
-#include
-#include
-#include
-
-// CUDA runtime
-#include
-#include
-
-#define CONCATENATE_(x, y) x##y
-#define CONCATENATETHREE_(x, y, z) x##y##z
-
-#define CONCATENATE(x, y) CONCATENATE_(x, y)
-#define CONCATENATETHREE(x, y, z) CONCATENATETHREE_(x, y, z)
-
-// The macro
-#define OFFSET_row(row, col, ld) ((row) * (ld) + (col))
-#define OFFSET_col(row, col, ld) ((col) * (ld) + (row))
-
-// The Tropical algebras
-#ifdef PlusMul
-#define OPERATOR_ADD(a, b) (a + b)
-#define OPERATOR_MUL(a, b) (a * b)
-#define PADDING 0
-#define FUNCNAME _plusmul
-#endif
-
-#ifdef TropicalAndOr
-#define OPERATOR_ADD(a, b) (a || b)
-#define OPERATOR_MUL(a, b) (a && b)
-#define PADDING false
-#define FUNCNAME _andor
-#endif
-
-#ifdef TropicalMaxMul
-#define OPERATOR_ADD(a, b) max(a, b)
-#define OPERATOR_MUL(a, b) (a * b)
-#define PADDING 0
-#define FUNCNAME _maxmul
-#endif
-
-#ifdef TropicalMaxPlus
-#define OPERATOR_ADD(a, b) max(a, b)
-#define OPERATOR_MUL(a, b) (a + b)
-#define PADDING -INFINITY
-#define FUNCNAME _maxplus
-#endif
-
-#ifdef TropicalMinPlus
-#define OPERATOR_ADD(a, b) min(a, b)
-#define OPERATOR_MUL(a, b) (a + b)
-#define PADDING INFINITY
-#define FUNCNAME _minplus
-#endif
-
-// Types
-
-#ifdef Bool
-#define TYPE bool
-#define TYPENAME BOOL
-#endif
-
-#ifdef FP32
-#define TYPE float
-#define TYPENAME FLOAT
-#endif
-
-#ifdef FP64
-#define TYPE double
-#define TYPENAME DOUBLE
-#endif
-
-#ifdef INT32
-#define TYPE int
-#define TYPENAME INT
-#endif
-
-#ifdef INT64
-#define TYPE long
-#define TYPENAME LONG
-#endif
-
-
-#define TT _TT
-#define TN _TN
-#define NT _NT
-#define NN _NN
-
-template <
- const int BLOCK_SIZE_M, // width of block of C that each thread block calculate
- const int BLOCK_SIZE_K, // height of block of A that each thread block load into shared memory
- const int BLOCK_SIZE_N, // height of block of C that each thread block calculate
- const int THREAD_SIZE_M, // height of block of C that each thread calculate
- const int THREAD_SIZE_N
- >
-__global__ void CONCATENATETHREE(TYPENAME, FUNCNAME, TT)(
- TYPE * __restrict__ A,
- TYPE * __restrict__ B,
- TYPE * __restrict__ C,
- TYPE alpha,
- TYPE beta,
- int M,
- int N,
- int K,
- int DIM_GRID_X,
- int DIM_GRID_Y
- ) {
-
- // size of thread block
- const int bszx = BLOCK_SIZE_N / THREAD_SIZE_N;
- const int bszy = BLOCK_SIZE_M / THREAD_SIZE_M;
- const int THREAD_NUM_PER_BLOCK = bszy * bszx;
-
- // thread id
- const int tid = threadIdx.y * bszx + threadIdx.x;
- int BLOCK_IDX = blockIdx.x % DIM_GRID_X;
- int BLOCK_IDY = blockIdx.x / DIM_GRID_X;
-
- // shared memory
-
- __shared__ TYPE As[BLOCK_SIZE_M * BLOCK_SIZE_K]; // avoid bank conflict
- __shared__ TYPE Bs[BLOCK_SIZE_K * BLOCK_SIZE_N];
- // registers for C
- TYPE accum[THREAD_SIZE_M * THREAD_SIZE_N] = {0};
- TYPE regs_a[THREAD_SIZE_M] = {0};
- TYPE regs_b[THREAD_SIZE_N] = {0};
-
- // init the accum as tropical zero
- #pragma unroll
- for (int thread_y = 0; thread_y < THREAD_SIZE_M; ++thread_y) {
- #pragma unroll
- for (int thread_x = 0; thread_x < THREAD_SIZE_N; ++thread_x) {
- accum[OFFSET_row(thread_y, thread_x, THREAD_SIZE_N)] = PADDING;
- }
- }
-
- // row number and col number that needs to be loaded blockIdx.y this thread
- const int A_TILE_ROW = tid / BLOCK_SIZE_K;
- const int A_TILE_COL = tid % BLOCK_SIZE_K;
-
- const int B_TILE_ROW = tid / BLOCK_SIZE_N;
- const int B_TILE_COL = tid % BLOCK_SIZE_N;
-
- // row stride that thread uses to load multiple rows of a tile
- const int A_TILE_ROW_STRIDE = THREAD_NUM_PER_BLOCK / BLOCK_SIZE_K;
- const int B_TILE_ROW_STRIDE = THREAD_NUM_PER_BLOCK / BLOCK_SIZE_N;
-
- // const int A_S = BLOCK_SIZE_M / THREAD_SIZE_M;
- // const int B_S = BLOCK_SIZE_N / THREAD_SIZE_N;
-
- // can not unroll since K can not be determined at this point
- for (int tile_idx = 0 ; tile_idx < K ; tile_idx += BLOCK_SIZE_K) {
-
- #pragma unroll
- for ( int i = 0 ; i < BLOCK_SIZE_M ; i += A_TILE_ROW_STRIDE) {
- const int row = BLOCK_SIZE_M * BLOCK_IDY + i + A_TILE_ROW ;
- const int col = A_TILE_COL + tile_idx;
- if (tile_idx > K - BLOCK_SIZE_K || BLOCK_IDY == DIM_GRID_Y - 1) {
- As[OFFSET_row(i + A_TILE_ROW, A_TILE_COL, BLOCK_SIZE_K)] = row < M && col < K ? A[OFFSET_row(
- row, // row
- col, // col
- K )] : PADDING;
- } else {
- As[OFFSET_row(i + A_TILE_ROW, A_TILE_COL, BLOCK_SIZE_K)] = A[OFFSET_row(
- row, // row
- col, // col
- K )];
- }
- }
-
- // load B from global memory to shared memory
- #pragma unroll
- for ( int i = 0 ; i < BLOCK_SIZE_K; i += B_TILE_ROW_STRIDE) {
- const int row = tile_idx + i + B_TILE_ROW;
- const int col = B_TILE_COL + BLOCK_SIZE_N * BLOCK_IDX;
- if (BLOCK_IDX == DIM_GRID_X -1 || tile_idx > K - BLOCK_SIZE_K) {
- Bs[OFFSET_row(i + B_TILE_ROW, B_TILE_COL, BLOCK_SIZE_N)] = row < K && col < N ? B[OFFSET_row(
- row, // row
- col, // col
- N )] : PADDING;
- } else {
- Bs[OFFSET_row(i + B_TILE_ROW, B_TILE_COL, BLOCK_SIZE_N)] = B[OFFSET_row(
- row, // row
- col, // col
- N )];
- }
- }
-
- __syncthreads();
-
- // compute c
- #pragma unroll
- for (int k = 0; k < BLOCK_SIZE_K; ++ k) {
-
- #pragma unroll
- for (int thread_y = 0; thread_y < THREAD_SIZE_M; ++thread_y) {
- regs_a[thread_y] = As[OFFSET_row(thread_y + THREAD_SIZE_M * threadIdx.y, k, BLOCK_SIZE_K)];
- }
-
- #pragma unroll
- for (int thread_x = 0; thread_x < THREAD_SIZE_N; ++thread_x) {
- regs_b[thread_x] = Bs[OFFSET_row(k, thread_x + THREAD_SIZE_N * threadIdx.x, BLOCK_SIZE_N)];
- }
-
- #pragma unroll
- for (int thread_y = 0; thread_y < THREAD_SIZE_M; ++thread_y) {
- #pragma unroll
- for (int thread_x = 0; thread_x < THREAD_SIZE_N; ++thread_x) {
- accum[OFFSET_row(thread_y, thread_x, THREAD_SIZE_N)] = OPERATOR_ADD(OPERATOR_MUL(regs_a[thread_y], regs_b[thread_x]), accum[OFFSET_row(thread_y, thread_x, THREAD_SIZE_N)]);
- }
- }
-
- }
- __syncthreads();
- }
-
- // store back to C
- #pragma unroll
- for (int thread_y = 0; thread_y < THREAD_SIZE_M; ++thread_y) {
- #pragma unroll
- for (int thread_x = 0; thread_x < THREAD_SIZE_N; ++thread_x) {
- const int row = BLOCK_SIZE_M * BLOCK_IDY + THREAD_SIZE_M * threadIdx.y + thread_y;
- const int col = BLOCK_SIZE_N * BLOCK_IDX + THREAD_SIZE_N * threadIdx.x + thread_x;
- if (BLOCK_IDX == DIM_GRID_X -1 || BLOCK_IDY == DIM_GRID_Y - 1) {
- if (row < M && col < N) {
- C[OFFSET_col(row, col, M)] = OPERATOR_ADD(
- OPERATOR_MUL(C[OFFSET_col(row, col, M)], beta),
- OPERATOR_MUL(accum[OFFSET_row(thread_y, thread_x, THREAD_SIZE_N)], alpha)
- );
- }
- } else {
- C[OFFSET_col(row, col, M)] = OPERATOR_ADD(
- OPERATOR_MUL(C[OFFSET_col(row, col, M)], beta),
- OPERATOR_MUL(accum[OFFSET_row(thread_y, thread_x, THREAD_SIZE_N)], alpha)
- );
- }
- }
- }
- __syncthreads();
-}
-
-template <
- const int BLOCK_SIZE_M, // width of block of C that each thread block calculate
- const int BLOCK_SIZE_K, // height of block of A that each thread block load into shared memory
- const int BLOCK_SIZE_N, // height of block of C that each thread block calculate
- const int THREAD_SIZE_M, // height of block of C that each thread calculate
- const int THREAD_SIZE_N
- >
-__global__ void CONCATENATETHREE(TYPENAME, FUNCNAME, TN)(
- TYPE * __restrict__ A,
- TYPE * __restrict__ B,
- TYPE * __restrict__ C,
- TYPE alpha,
- TYPE beta,
- int M,
- int N,
- int K,
- int DIM_GRID_X,
- int DIM_GRID_Y
- ) {
-
- // size of thread block
- const int bszx = BLOCK_SIZE_N / THREAD_SIZE_N;
- const int bszy = BLOCK_SIZE_M / THREAD_SIZE_M;
- const int THREAD_NUM_PER_BLOCK = bszy * bszx;
-
- // thread id
- const int tid_A = threadIdx.y * bszx + threadIdx.x;
- const int tid_B = threadIdx.y + threadIdx.x * bszy;
-
- int BLOCK_IDX = blockIdx.x % DIM_GRID_X;
- int BLOCK_IDY = blockIdx.x / DIM_GRID_X;
-
- // shared memory
-
- __shared__ TYPE As[BLOCK_SIZE_M * BLOCK_SIZE_K]; // avoid bank conflict
- __shared__ TYPE Bs[BLOCK_SIZE_K * BLOCK_SIZE_N];
- // registers for C
- TYPE accum[THREAD_SIZE_M * THREAD_SIZE_N] = {0};
- TYPE regs_a[THREAD_SIZE_M] = {0};
- TYPE regs_b[THREAD_SIZE_N] = {0};
-
- #pragma unroll
- for (int thread_y = 0; thread_y < THREAD_SIZE_M; ++thread_y) {
- #pragma unroll
- for (int thread_x = 0; thread_x < THREAD_SIZE_N; ++thread_x) {
- accum[OFFSET_row(thread_y, thread_x, THREAD_SIZE_N)] = PADDING;
- }
- }
-
- // row number and col number that needs to be loaded blockIdx.y this thread
- const int A_TILE_ROW = tid_A / BLOCK_SIZE_K;
- const int A_TILE_COL = tid_A % BLOCK_SIZE_K;
-
- const int B_TILE_ROW = tid_B % BLOCK_SIZE_K;
- const int B_TILE_COL = tid_B / BLOCK_SIZE_K;
-
- // row stride that thread uses to load multiple rows of a tile
- const int A_TILE_ROW_STRIDE = THREAD_NUM_PER_BLOCK / BLOCK_SIZE_K;
- const int B_TILE_COL_STRIDE = THREAD_NUM_PER_BLOCK / BLOCK_SIZE_K;
-
- // const int A_S = BLOCK_SIZE_M / THREAD_SIZE_M;
- // const int B_S = BLOCK_SIZE_N / THREAD_SIZE_N;
-
- // can not unroll since K can not be determined at this point
- for (int tile_idx = 0 ; tile_idx < K ; tile_idx += BLOCK_SIZE_K) {
-
- #pragma unroll
- for ( int i = 0 ; i < BLOCK_SIZE_M ; i += A_TILE_ROW_STRIDE) {
- const int row = BLOCK_SIZE_M * BLOCK_IDY + i + A_TILE_ROW ;
- const int col = A_TILE_COL + tile_idx;
- if (tile_idx > K - BLOCK_SIZE_K || BLOCK_IDY == DIM_GRID_Y - 1) {
- As[OFFSET_row(i + A_TILE_ROW, A_TILE_COL, BLOCK_SIZE_K)] = row < M && col < K ? A[OFFSET_row(
- row, // row
- col, // col
- K )] : PADDING;
- } else {
- As[OFFSET_row(i + A_TILE_ROW, A_TILE_COL, BLOCK_SIZE_K)] = A[OFFSET_row(
- row, // row
- col, // col
- K )];
- }
- }
-
- // load B from global memory to shared memory
- #pragma unroll
- for ( int i = 0 ; i < BLOCK_SIZE_N; i += B_TILE_COL_STRIDE) {
- const int row = tile_idx + B_TILE_ROW;
- const int col = B_TILE_COL + i + BLOCK_SIZE_N * BLOCK_IDX;
- if (BLOCK_IDX == DIM_GRID_X -1 || tile_idx > K - BLOCK_SIZE_K) {
- Bs[OFFSET_row(B_TILE_ROW, i + B_TILE_COL, BLOCK_SIZE_N)] = row < K && col < N ? B[OFFSET_col(row, col, K)] : PADDING;
- } else {
- Bs[OFFSET_row(B_TILE_ROW, i + B_TILE_COL, BLOCK_SIZE_N)] = B[OFFSET_col(row, col, K)];
- }
- }
-
- __syncthreads();
-
- // compute c
- #pragma unroll
- for (int k = 0; k < BLOCK_SIZE_K; ++ k) {
-
- #pragma unroll
- for (int thread_y = 0; thread_y < THREAD_SIZE_M; ++thread_y) {
- regs_a[thread_y] = As[OFFSET_row(thread_y + THREAD_SIZE_M * threadIdx.y, k, BLOCK_SIZE_K)];
- }
-
- #pragma unroll
- for (int thread_x = 0; thread_x < THREAD_SIZE_N; ++thread_x) {
- regs_b[thread_x] = Bs[OFFSET_row(k, thread_x + THREAD_SIZE_N * threadIdx.x, BLOCK_SIZE_N)];
- }
-
- #pragma unroll
- for (int thread_y = 0; thread_y < THREAD_SIZE_M; ++thread_y) {
- #pragma unroll
- for (int thread_x = 0; thread_x < THREAD_SIZE_N; ++thread_x) {
- accum[OFFSET_row(thread_y, thread_x, THREAD_SIZE_N)] = OPERATOR_ADD(OPERATOR_MUL(regs_a[thread_y], regs_b[thread_x]), accum[OFFSET_row(thread_y, thread_x, THREAD_SIZE_N)]);
- }
- }
-
- }
- __syncthreads();
- }
-
- // store back to C
- #pragma unroll
- for (int thread_y = 0; thread_y < THREAD_SIZE_M; ++thread_y) {
- #pragma unroll
- for (int thread_x = 0; thread_x < THREAD_SIZE_N; ++thread_x) {
- const int row = BLOCK_SIZE_M * BLOCK_IDY + THREAD_SIZE_M * threadIdx.y + thread_y;
- const int col = BLOCK_SIZE_N * BLOCK_IDX + THREAD_SIZE_N * threadIdx.x + thread_x;
- if (BLOCK_IDX == DIM_GRID_X -1 || BLOCK_IDY == DIM_GRID_Y - 1) {
- if (row < M && col < N) {
- C[OFFSET_col(row, col, M)] = OPERATOR_ADD(
- OPERATOR_MUL(C[OFFSET_col(row, col, M)], beta),
- OPERATOR_MUL(accum[OFFSET_row(thread_y, thread_x, THREAD_SIZE_N)], alpha)
- );
- }
- } else {
- C[OFFSET_col(row, col, M)] = OPERATOR_ADD(
- OPERATOR_MUL(C[OFFSET_col(row, col, M)], beta),
- OPERATOR_MUL(accum[OFFSET_row(thread_y, thread_x, THREAD_SIZE_N)], alpha)
- );
- }
- }
- }
- __syncthreads();
-}
-
-template <
- const int BLOCK_SIZE_M, // width of block of C that each thread block calculate
- const int BLOCK_SIZE_K, // height of block of A that each thread block load into shared memory
- const int BLOCK_SIZE_N, // height of block of C that each thread block calculate
- const int THREAD_SIZE_M, // height of block of C that each thread calculate
- const int THREAD_SIZE_N // width of block of C that each thread calculate
- >
-__global__ void CONCATENATETHREE(TYPENAME, FUNCNAME, NT)(
- TYPE * __restrict__ A,
- TYPE * __restrict__ B,
- TYPE * __restrict__ C,
- TYPE alpha,
- TYPE beta,
- int M,
- int N,
- int K,
- int DIM_GRID_X,
- int DIM_GRID_Y
- ) {
-
- // size of thread block
- const int bszm = BLOCK_SIZE_M / THREAD_SIZE_M;
- const int bszn = BLOCK_SIZE_N / THREAD_SIZE_N;
- const int THREAD_NUM_PER_BLOCK = bszm * bszn;
-
- const int BLOCK_SIZE_MK = BLOCK_SIZE_M * BLOCK_SIZE_K;
- const int BLOCK_SIZE_KN = BLOCK_SIZE_K * BLOCK_SIZE_N;
- const int THREAD_SIZE_MN = THREAD_SIZE_M * THREAD_SIZE_N;
-
- int BLOCK_IDX = blockIdx.x % DIM_GRID_X;
- int BLOCK_IDY = blockIdx.x / DIM_GRID_X;
-
- // thread id
- const int tid = threadIdx.y * bszm + threadIdx.x;
-
- // shared memory
- // directly use 1d shared memory to avoid the conflict of col-major and row-major
- __shared__ TYPE As[BLOCK_SIZE_MK]; // avoid bank conflict
- __shared__ TYPE Bs[BLOCK_SIZE_KN];
-
- // registers for C
- TYPE accum[THREAD_SIZE_MN] = {0};
- TYPE regs_a[THREAD_SIZE_M] = {0};
- TYPE regs_b[THREAD_SIZE_N] = {0};
-
- #pragma unroll
- for (int thread_m = 0; thread_m < THREAD_SIZE_M; ++thread_m) {
- #pragma unroll
- for (int thread_n = 0; thread_n < THREAD_SIZE_N; ++thread_n) {
- accum[OFFSET_col(thread_m, thread_n, THREAD_SIZE_M)] = PADDING;
- }
- }
-
- // row number and col number that needs to be loaded blockIdx.y this thread
- const int A_TILE_COL = tid / BLOCK_SIZE_M;
- const int A_TILE_ROW = tid % BLOCK_SIZE_M;
-
- const int B_TILE_ROW = tid / BLOCK_SIZE_N;
- const int B_TILE_COL = tid % BLOCK_SIZE_N;
-
- // col stride that thread uses to load multiple rows of a tile
- // how many cols that the threads load in one iteration
- const int A_TILE_COL_STRIDE = THREAD_NUM_PER_BLOCK / BLOCK_SIZE_M;
- // const int B_TILE_COL_STRIDE = THREAD_NUM_PER_BLOCK / BLOCK_SIZE_K;
- const int B_TILE_ROW_STRIDE = THREAD_NUM_PER_BLOCK / BLOCK_SIZE_N;
-
- for (int tile_idx = 0 ; tile_idx < K ;) {
-
- // load A from global memory to shared memory
- #pragma unroll
- for ( int i = 0 ; i < BLOCK_SIZE_K ; i += A_TILE_COL_STRIDE) {
- const int row = BLOCK_SIZE_M * BLOCK_IDX + A_TILE_ROW ;
- const int col = A_TILE_COL + i + tile_idx;
-
- if (BLOCK_IDX == DIM_GRID_X -1 || tile_idx >= K - BLOCK_SIZE_K) {
- As[OFFSET_col(A_TILE_ROW, i + A_TILE_COL, BLOCK_SIZE_M)] = row < M && col < K ? A[OFFSET_col(row, col, M)] : PADDING;
- } else {
- As[OFFSET_col(A_TILE_ROW, i + A_TILE_COL, BLOCK_SIZE_M)] = A[OFFSET_col(row, col, M)];
- }
- }
-
- // load B from global memory to shared memory
- #pragma unroll
- for ( int i = 0 ; i < BLOCK_SIZE_K; i += B_TILE_ROW_STRIDE) {
- const int row = tile_idx + i + B_TILE_ROW;
- const int col = B_TILE_COL + BLOCK_SIZE_N * BLOCK_IDY;
-
- if (BLOCK_IDY == DIM_GRID_Y -1 || tile_idx > K - BLOCK_SIZE_K) {
- Bs[OFFSET_row(i + B_TILE_ROW, B_TILE_COL, BLOCK_SIZE_N)] = row < K && col < N ? B[OFFSET_row(row, col, N)] : PADDING;
- } else {
- Bs[OFFSET_row(i + B_TILE_ROW, B_TILE_COL, BLOCK_SIZE_N)] = B[OFFSET_row(row, col, N)];
- }
- }
-
- __syncthreads();
-
- // compute c
- #pragma unroll
- for (int k = 0; k < BLOCK_SIZE_K; k += 1) {
-
- // load A and B from shared memory to registers
- #pragma unroll
- for (int thread_m = 0; thread_m < THREAD_SIZE_M; ++thread_m) {
- regs_a[thread_m] = As[OFFSET_col(threadIdx.x * THREAD_SIZE_M + thread_m, k, BLOCK_SIZE_M)];
- }
-
- #pragma unroll
- for (int thread_n = 0; thread_n < THREAD_SIZE_N; ++thread_n) {
- regs_b[thread_n] = Bs[OFFSET_row(k, threadIdx.y * THREAD_SIZE_N + thread_n, BLOCK_SIZE_N)];
- }
-
- #pragma unroll
- for (int thread_m = 0; thread_m < THREAD_SIZE_M; ++thread_m) {
- #pragma unroll
- for (int thread_n = 0; thread_n < THREAD_SIZE_N; ++thread_n) {
- accum[OFFSET_col(thread_m, thread_n, THREAD_SIZE_M)] = OPERATOR_ADD(OPERATOR_MUL(regs_a[thread_m], regs_b[thread_n]), accum[OFFSET_col(thread_m, thread_n, THREAD_SIZE_M)]);
- }
- }
-
- }
- __syncthreads();
- tile_idx += BLOCK_SIZE_K;
- }
-
- // store back to C
- #pragma unroll
- for (int thread_m = 0; thread_m < THREAD_SIZE_M; ++thread_m) {
- #pragma unroll
- for (int thread_n = 0; thread_n < THREAD_SIZE_N; ++thread_n) {
- const int col = BLOCK_SIZE_N * BLOCK_IDY + THREAD_SIZE_N * threadIdx.y + thread_n;
- const int row = BLOCK_SIZE_M * BLOCK_IDX + THREAD_SIZE_M * threadIdx.x + thread_m;
- if (BLOCK_IDX == DIM_GRID_X -1 || BLOCK_IDY == DIM_GRID_Y - 1) {
- if (row < M && col < N) {
- C[OFFSET_col(row, col, M)] = OPERATOR_ADD(
- OPERATOR_MUL(accum[OFFSET_col(thread_m, thread_n, THREAD_SIZE_M)], alpha),
- OPERATOR_MUL(C[OFFSET_col(row, col, M)], beta)
- );
- }
- } else {
- C[OFFSET_col(row, col, M)] = OPERATOR_ADD(
- OPERATOR_MUL(accum[OFFSET_col(thread_m, thread_n, THREAD_SIZE_M)], alpha),
- OPERATOR_MUL(C[OFFSET_col(row, col, M)], beta)
- );
- }
- }
- }
- __syncthreads();
-}
-
-template <
- const int BLOCK_SIZE_M, // width of block of C that each thread block calculate
- const int BLOCK_SIZE_K, // height of block of A that each thread block load into shared memory
- const int BLOCK_SIZE_N, // height of block of C that each thread block calculate
- const int THREAD_SIZE_M, // height of block of C that each thread calculate
- const int THREAD_SIZE_N // width of block of C that each thread calculate
- >
-__global__ void CONCATENATETHREE(TYPENAME, FUNCNAME, NN)(
- TYPE * __restrict__ A,
- TYPE * __restrict__ B,
- TYPE * __restrict__ C,
- TYPE alpha,
- TYPE beta,
- int M,
- int N,
- int K,
- int DIM_GRID_X,
- int DIM_GRID_Y
- ) {
-
- // size of thread block
- const int bszm = BLOCK_SIZE_M / THREAD_SIZE_M;
- const int bszn = BLOCK_SIZE_N / THREAD_SIZE_N;
- const int THREAD_NUM_PER_BLOCK = bszm * bszn;
-
- const int BLOCK_SIZE_MK = BLOCK_SIZE_M * BLOCK_SIZE_K;
- const int BLOCK_SIZE_KN = BLOCK_SIZE_K * BLOCK_SIZE_N;
- const int THREAD_SIZE_MN = THREAD_SIZE_M * THREAD_SIZE_N;
-
- int BLOCK_IDX = blockIdx.x % DIM_GRID_X;
- int BLOCK_IDY = blockIdx.x / DIM_GRID_X;
-
- // thread id
- const int tid = threadIdx.y * bszm + threadIdx.x;
-
- // shared memory
- // directly use 1d shared memory to avoid the conflict of col-major and row-major
- __shared__ TYPE As[BLOCK_SIZE_MK]; // avoid bank conflict
- __shared__ TYPE Bs[BLOCK_SIZE_KN];
-
- // registers for C
- TYPE accum[THREAD_SIZE_MN] = {0};
- TYPE regs_a[THREAD_SIZE_M] = {0};
- TYPE regs_b[THREAD_SIZE_N] = {0};
-
- #pragma unroll
- for (int thread_m = 0; thread_m < THREAD_SIZE_M; ++thread_m) {
- #pragma unroll
- for (int thread_n = 0; thread_n < THREAD_SIZE_N; ++thread_n) {
- accum[OFFSET_col(thread_m, thread_n, THREAD_SIZE_M)] = PADDING;
- }
- }
-
- // row number and col number that needs to be loaded blockIdx.y this thread
- const int A_TILE_COL = tid / BLOCK_SIZE_M;
- const int A_TILE_ROW = tid % BLOCK_SIZE_M;
-
- const int B_TILE_COL = tid / BLOCK_SIZE_K;
- const int B_TILE_ROW = tid % BLOCK_SIZE_K;
-
- // col stride that thread uses to load multiple rows of a tile
- // how many cols that the threads load in one iteration
- const int A_TILE_COL_STRIDE = THREAD_NUM_PER_BLOCK / BLOCK_SIZE_M;
- const int B_TILE_COL_STRIDE = THREAD_NUM_PER_BLOCK / BLOCK_SIZE_K;
-
- // number of threads in M and N direction (used when calculating C)
- // const int A_S = BLOCK_SIZE_M / THREAD_SIZE_M;
- // const int B_S = BLOCK_SIZE_N / THREAD_SIZE_N;
-
- for (int tile_idx = 0 ; tile_idx < K ;) {
-
- // load A from global memory to shared memory
- #pragma unroll
- for ( int i = 0 ; i < BLOCK_SIZE_K ; i += A_TILE_COL_STRIDE) {
- const int row = BLOCK_SIZE_M * BLOCK_IDX + A_TILE_ROW ;
- const int col = A_TILE_COL + i + tile_idx;
-
- if (BLOCK_IDX == DIM_GRID_X -1 || tile_idx >= K - BLOCK_SIZE_K) {
- As[OFFSET_col(A_TILE_ROW, i + A_TILE_COL, BLOCK_SIZE_M)] = row < M && col < K ? A[OFFSET_col(row, col, M)] : PADDING;
- } else {
- As[OFFSET_col(A_TILE_ROW, i + A_TILE_COL, BLOCK_SIZE_M)] = A[OFFSET_col(row, col, M)];
- }
- }
-
- // load B from global memory to shared memory
- #pragma unroll
- for ( int i = 0 ; i < BLOCK_SIZE_N; i += B_TILE_COL_STRIDE) {
- const int row = tile_idx + B_TILE_ROW;
- const int col = BLOCK_SIZE_N * BLOCK_IDY + i + B_TILE_COL;
-
- if (tile_idx >= K - BLOCK_SIZE_K || BLOCK_IDY == DIM_GRID_Y - 1) {
- Bs[OFFSET_col(B_TILE_ROW, i + B_TILE_COL, BLOCK_SIZE_K)] = row < K && col < N ? B[OFFSET_col(row, col, K)] : PADDING;
- } else {
- Bs[OFFSET_col(B_TILE_ROW, i + B_TILE_COL, BLOCK_SIZE_K)] = B[OFFSET_col(row, col, K)];
- }
- }
-
- __syncthreads();
-
- // compute c
- #pragma unroll
- for (int k = 0; k < BLOCK_SIZE_K; k += 1) {
-
- // load A and B from shared memory to registers
- #pragma unroll
- for (int thread_m = 0; thread_m < THREAD_SIZE_M; ++thread_m) {
- regs_a[thread_m] = As[OFFSET_col(threadIdx.x * THREAD_SIZE_M + thread_m, k, BLOCK_SIZE_M)];
- }
-
- #pragma unroll
- for (int thread_n = 0; thread_n < THREAD_SIZE_N; ++thread_n) {
- regs_b[thread_n] = Bs[OFFSET_col(k, threadIdx.y * THREAD_SIZE_N + thread_n, BLOCK_SIZE_K)];
- }
-
- #pragma unroll
- for (int thread_m = 0; thread_m < THREAD_SIZE_M; ++thread_m) {
- #pragma unroll
- for (int thread_n = 0; thread_n < THREAD_SIZE_N; ++thread_n) {
- accum[OFFSET_col(thread_m, thread_n, THREAD_SIZE_M)] = OPERATOR_ADD(OPERATOR_MUL(regs_a[thread_m], regs_b[thread_n]), accum[OFFSET_col(thread_m, thread_n, THREAD_SIZE_M)]);
- }
- }
-
- }
- __syncthreads();
- tile_idx += BLOCK_SIZE_K;
- }
-
- // store back to C
- #pragma unroll
- for (int thread_m = 0; thread_m < THREAD_SIZE_M; ++thread_m) {
- #pragma unroll
- for (int thread_n = 0; thread_n < THREAD_SIZE_N; ++thread_n) {
- const int col = BLOCK_SIZE_N * BLOCK_IDY + THREAD_SIZE_N * threadIdx.y + thread_n;
- const int row = BLOCK_SIZE_M * BLOCK_IDX + THREAD_SIZE_M * threadIdx.x + thread_m;
-
- if (BLOCK_IDX == DIM_GRID_X -1 || BLOCK_IDY == DIM_GRID_Y - 1) {
- if (row < M && col < N) {
- C[OFFSET_col(row, col, M)] = OPERATOR_ADD(
- OPERATOR_MUL(accum[OFFSET_col(thread_m, thread_n, THREAD_SIZE_M)], alpha),
- OPERATOR_MUL(C[OFFSET_col(row, col, M)], beta)
- );
- }
- } else {
- C[OFFSET_col(row, col, M)] = OPERATOR_ADD(
- OPERATOR_MUL(accum[OFFSET_col(thread_m, thread_n, THREAD_SIZE_M)], alpha),
- OPERATOR_MUL(C[OFFSET_col(row, col, M)], beta)
- );
- }
- }
- }
- __syncthreads();
-}
-
-extern "C"{
-void CONCATENATE(TYPENAME, FUNCNAME)(const int m, const int n, const int k, TYPE *d_A, TYPE *d_B, TYPE *d_C, TYPE alpha, TYPE beta, const char TA, const char TB){
- // TA and TB are 'T' or 'N'
-
- const char T = 'T';
- const char N = 'N';
-
- const int BLOCK_SIZE_M = 64;
- const int BLOCK_SIZE_K = 32;
- const int BLOCK_SIZE_N = 64;
- const int THREAD_SIZE_M = 4;
- const int THREAD_SIZE_N = 4;
-
-
- if (TA == T && TB == T) {
- dim3 dimBlock(BLOCK_SIZE_N / THREAD_SIZE_N, BLOCK_SIZE_M / THREAD_SIZE_M);
-
- int DIM_GRID_X = n / BLOCK_SIZE_N;
- int DIM_GRID_Y = m / BLOCK_SIZE_M;
- if (n % BLOCK_SIZE_N != 0)
- DIM_GRID_X++;
- if (m % BLOCK_SIZE_M != 0)
- DIM_GRID_Y++;
-
- dim3 dimGrid(DIM_GRID_X * DIM_GRID_Y);
-
- CONCATENATETHREE(TYPENAME, FUNCNAME, TT)
- <<< dimGrid, dimBlock >>>(d_A, d_B, d_C, alpha, beta, m, n, k, DIM_GRID_X, DIM_GRID_Y);
- }
-
- if (TA == T && TB == N) {
- dim3 dimBlock(BLOCK_SIZE_N / THREAD_SIZE_N, BLOCK_SIZE_M / THREAD_SIZE_M);
-
- int DIM_GRID_X = n / BLOCK_SIZE_N;
- int DIM_GRID_Y = m / BLOCK_SIZE_M;
- if (n % BLOCK_SIZE_N != 0)
- DIM_GRID_X++;
- if (m % BLOCK_SIZE_M != 0)
- DIM_GRID_Y++;
-
- dim3 dimGrid(DIM_GRID_X * DIM_GRID_Y);
-
- CONCATENATETHREE(TYPENAME, FUNCNAME, TN)
- <<< dimGrid, dimBlock >>>(d_A, d_B, d_C, alpha, beta, m, n, k, DIM_GRID_X, DIM_GRID_Y);
- }
-
- if (TA == N && TB == T) {
- dim3 dimBlock(BLOCK_SIZE_M / THREAD_SIZE_M, BLOCK_SIZE_N / THREAD_SIZE_N);
-
- int DIM_GRID_X = m / BLOCK_SIZE_M;
- int DIM_GRID_Y = n / BLOCK_SIZE_N;
- if (m % BLOCK_SIZE_M != 0)
- DIM_GRID_X++;
- if (n % BLOCK_SIZE_N != 0)
- DIM_GRID_Y++;
-
- dim3 dimGrid(DIM_GRID_X * DIM_GRID_Y);
-
- CONCATENATETHREE(TYPENAME, FUNCNAME, NT)
- <<< dimGrid, dimBlock >>>(d_A, d_B, d_C, alpha, beta, m, n, k, DIM_GRID_X, DIM_GRID_Y);
- }
-
- if (TA == N && TB == N) {
- dim3 dimBlock(BLOCK_SIZE_M / THREAD_SIZE_M, BLOCK_SIZE_N / THREAD_SIZE_N);
-
- int DIM_GRID_X = m / BLOCK_SIZE_M;
- int DIM_GRID_Y = n / BLOCK_SIZE_N;
- if (m % BLOCK_SIZE_M != 0)
- DIM_GRID_X++;
- if (n % BLOCK_SIZE_N != 0)
- DIM_GRID_Y++;
-
- dim3 dimGrid(DIM_GRID_X * DIM_GRID_Y);
-
- CONCATENATETHREE(TYPENAME, FUNCNAME, NN)
- <<< dimGrid, dimBlock >>>(d_A, d_B, d_C, alpha, beta, m, n, k, DIM_GRID_X, DIM_GRID_Y);
- }
-
-}
-}
\ No newline at end of file
diff --git a/src/CuTropicalGEMM.jl b/src/CuTropicalGEMM.jl
index d622ec4..5e74fb8 100644
--- a/src/CuTropicalGEMM.jl
+++ b/src/CuTropicalGEMM.jl
@@ -1,24 +1,16 @@
module CuTropicalGEMM
-using CUDA, TropicalNumbers, LinearAlgebra
+using CUDA, TropicalNumbers, LinearAlgebra, TropicalGemmC_jll
export matmul!
-const path = @__DIR__
const Symbol_FP32 = (:FP32, "FP32")
const Symbol_FP64 = (:FP64, "FP64")
const Symbol_INT32 = (:INT32, "INT32")
const Symbol_INT64 = (:INT64, "INT64")
const Symbol_Bool = (:Bool, "Bool")
-for (Algebra, SAlgebra, Symbol_types) in [(:PlusMul, "PlusMul", [Symbol_FP32, Symbol_FP64, Symbol_INT32, Symbol_INT64]), (:TropicalAndOr, "TropicalAndOr", [Symbol_Bool]), (:TropicalMaxMul, "TropicalMaxMul", [Symbol_FP32, Symbol_FP64, Symbol_INT32, Symbol_INT64]), (:TropicalMaxPlus, "TropicalMaxPlus", [Symbol_FP32, Symbol_FP64]), (:TropicalMinPlus, "TropicalMinPlus", [Symbol_FP32, Symbol_FP64])]
- for Symbol_type in Symbol_types
- t, st = Symbol_type
- @eval const $(Symbol("lib_$(Algebra)_$(t)")) = joinpath(path, "../deps/lib", "lib_" * $SAlgebra * "_" * $st * ".so")
- end
-end
-
const CTranspose{T} = Transpose{T, <:CuVecOrMat{T}}
include("tropical_gemms.jl")
-end
+end
\ No newline at end of file
diff --git a/src/tropical_gemms.jl b/src/tropical_gemms.jl
index c3af854..03fcaf5 100644
--- a/src/tropical_gemms.jl
+++ b/src/tropical_gemms.jl
@@ -15,9 +15,9 @@ for (TA, tA) in [(:CuVecOrMat, 'N'), (:CTranspose, 'T')]
for (TB, tB) in [(:CuVecOrMat, 'N'), (:CTranspose, 'T')]
for (TT, CT, funcname, lib) in [
(:Float32, :Cfloat, :FLOAT_plusmul, :lib_PlusMul_FP32), (:Float64, :Cdouble, :DOUBLE_plusmul, :lib_PlusMul_FP64), (:Int32, :Cint, :INT_plusmul, :lib_PlusMul_INT32), (:Int64, :Clong, :LONG_plusmul, :lib_PlusMul_INT64),
- (:TropicalAndOr, :Bool, :BOOL_andor, :lib_TropicalAndOr_Bool),
- (:TropicalMaxPlusF32, :Cfloat, :FLOAT_maxplus, :lib_TropicalMaxPlus_FP32), (:TropicalMaxPlusF64, :Cdouble, :DOUBLE_maxplus, :lib_TropicalMaxPlus_FP64),
- (:TropicalMinPlusF32, :Cfloat, :FLOAT_minplus, :lib_TropicalMinPlus_FP32), (:TropicalMinPlusF64, :Cdouble, :DOUBLE_minplus, :lib_TropicalMinPlus_FP64),
+ (:TropicalAndOr, :Bool, :BOOL_andor, :TropicalAndOr_Bool),
+ (:TropicalMaxPlusF32, :Cfloat, :FLOAT_maxplus, :TropicalMaxPlus_FP32), (:TropicalMaxPlusF64, :Cdouble, :DOUBLE_maxplus, :TropicalMaxPlus_FP64),
+ (:TropicalMinPlusF32, :Cfloat, :FLOAT_minplus, :TropicalMinPlus_FP32), (:TropicalMinPlusF64, :Cdouble, :DOUBLE_minplus, :TropicalMinPlus_FP64),
(:TropicalMaxMulF32, :Cfloat, :FLOAT_maxmul, :lib_TropicalMaxMul_FP32), (:TropicalMaxMulF64, :Cdouble, :DOUBLE_maxmul, :lib_TropicalMaxMul_FP64), (:TropicalMaxMulI32, :Cint, :INT_maxmul, :lib_TropicalMaxMul_INT32), (:TropicalMaxMulI64, :Clong, :LONG_maxmul, :lib_TropicalMaxMul_INT64)
]
@eval function matmul!(C::CuVecOrMat{T}, A::$TA{T}, B::$TB{T}, α::T, β::T) where {T<:$TT}
diff --git a/test/Project.toml b/test/Project.toml
index 416b83f..f25053b 100644
--- a/test/Project.toml
+++ b/test/Project.toml
@@ -1,5 +1,6 @@
[deps]
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
-TropicalNumbers = "b3a74e9c-7526-4576-a4eb-79c0d4c32334"
-Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
\ No newline at end of file
+Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
+TropicalGemmC_jll = "4f4992fb-2984-5eba-87b8-475305d0f5fc"
+TropicalNumbers = "b3a74e9c-7526-4576-a4eb-79c0d4c32334"
\ No newline at end of file
diff --git a/test/tropical_gemms.jl b/test/tropical_gemms.jl
index c86f295..961472d 100644
--- a/test/tropical_gemms.jl
+++ b/test/tropical_gemms.jl
@@ -1,7 +1,7 @@
@testset "Testing the gemms" begin
for (MT, DT) in [(Real, [Float32, Float64, Int32, Int64]), (TropicalAndOr, [Bool]), (TropicalMaxPlus, [Float32, Float64]), (TropicalMinPlus, [Float32, Float64]), (TropicalMaxMul, [Float32, Float64]), (TropicalMaxMul, [Int32, Int64])]
for T in DT
- for (M, N, K) in [(0, 0, 0), (2, 0, 0), (2, 2, 0), (5, 6, 7), (101, 102, 103)]
+ for (M, N, K) in [(0, 0, 0), (2, 0, 0), (2, 2, 0), (5, 6, 7), (66, 67, 33)]
if MT == Real
TT = T
elseif MT == TropicalAndOr