Skip to content

Commit

Permalink
Squashed commits from Brad @ 171bf08 rebased onto main
Browse files Browse the repository at this point in the history
test(collectives): outline simpler test suite

chore: remove old collective implementations

chore: update interfaces of collectives

feat(co_broadcast): use wrapper to make contiguous

feat: re-implement co_sum

feat: create implementation of co_reduce

feat: create co_min implementation

feat: create implementation for co_max
  • Loading branch information
everythingfunctional authored and bonachea committed Jan 17, 2025
1 parent 98baa51 commit 6d6ac10
Show file tree
Hide file tree
Showing 13 changed files with 607 additions and 948 deletions.
38 changes: 1 addition & 37 deletions src/caffeine/caffeine.c
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ void caf_sync_all()
}

void caf_co_reduce(
CFI_cdesc_t* a_desc, int result_image, int num_elements, gex_Coll_ReduceFn_t user_op, void* client_data, gex_TM_t team
CFI_cdesc_t* a_desc, int result_image, size_t num_elements, gex_Coll_ReduceFn_t user_op, void* client_data, gex_TM_t team
)
{
char* a_address = (char*) a_desc->base_addr;
Expand Down Expand Up @@ -284,46 +284,10 @@ void caf_co_sum(CFI_cdesc_t* a_desc, int result_image, size_t num_elements, gex_
gex_Event_Wait(ev);
}

bool caf_same_cfi_type(CFI_cdesc_t* a_desc, CFI_cdesc_t* b_desc)
{
if (a_desc->type == b_desc->type) return true;
return false;
}

size_t caf_elem_len(CFI_cdesc_t* a_desc)
{
return a_desc->elem_len;
}

void caf_form_team(gex_TM_t current_team, gex_TM_t* new_team, int64_t team_number, int new_index)
{
// GASNet color argument is int (32-bit), check for value truncation:
assert((unsigned int)team_number == team_number);
gex_TM_Split(new_team, current_team, team_number, new_index, NULL, 0, GEX_FLAG_TM_NO_SCRATCH);
}

bool caf_numeric_type(CFI_cdesc_t* a_desc)
{
switch (a_desc->type)
{
case CFI_type_int32_t: return true;
case CFI_type_int64_t: return true;
case CFI_type_float: return true;
case CFI_type_double: return true;
case float_Complex_workaround: return true;
case double_Complex_workaround: return true;
default: return false;
}
}

#ifdef __GNUC__
bool caf_is_f_string(CFI_cdesc_t* a_desc){
if ( (a_desc->type - 5) % 256 == 0) return true;
return false;
}
#else // The code below is untested but believed to conform with the Fortran 2018 standard.
bool caf_is_f_string(CFI_cdesc_t* a_desc){
if (a_desc->type == CFI_type_char) return true;
return false;
}
#endif
12 changes: 11 additions & 1 deletion src/caffeine/collective_subroutines/co_broadcast_s.F90
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,20 @@
contains

module procedure prif_co_broadcast
call contiguous_co_broadcast(a, source_image, stat, errmsg, errmsg_alloc)
end procedure

subroutine contiguous_co_broadcast(a, source_image, stat, errmsg, errmsg_alloc)
type(*), intent(inout), target, contiguous :: a(..)
integer(c_int), intent(in) :: source_image
integer(c_int), intent(out), optional :: stat
character(len=*), intent(inout), optional :: errmsg
character(len=:), intent(inout), allocatable, optional :: errmsg_alloc

if (present(stat)) stat=0
call caf_co_broadcast(a, source_image, product(shape(a)), current_team%info%gex_team)
! With a compliant Fortran 2018 compiler, pass in c_sizeof(a) as the `Nelem` argument
! and eliminate the calculation of num_elements*sizeof(a) in caffeine.c.
end procedure
end subroutine

end submodule co_broadcast_s
68 changes: 46 additions & 22 deletions src/caffeine/collective_subroutines/co_max_s.F90
Original file line number Diff line number Diff line change
@@ -1,36 +1,60 @@
! Copyright (c), The Regents of the University of California
! Terms of use are as specified in LICENSE.txt

#include "assert_macros.h"

submodule(prif:prif_private_s) co_max_s
use iso_c_binding, only : c_funloc

use iso_c_binding, only: c_loc, c_f_pointer
implicit none

contains

module procedure prif_co_max
if (present(stat)) stat=0

if (caf_numeric_type(a)) then
call caf_co_max( &
a, optional_value(result_image), int(product(shape(a)), c_size_t), current_team%info%gex_team)
else if (caf_is_f_string(a)) then
call prif_co_reduce(a, c_funloc(reverse_alphabetize), optional_value(result_image), stat, errmsg, errmsg_alloc)
else
call prif_error_stop(.false._c_bool, stop_code_char="caf_co_max: unsupported type")
end if
call contiguous_co_max(a, result_image, stat, errmsg, errmsg_alloc)
end procedure

contains
subroutine contiguous_co_max(a, result_image, stat, errmsg, errmsg_alloc)
implicit none
type(*), intent(inout), target, contiguous :: a(..)
integer(c_int), intent(in), optional :: result_image
integer(c_int), intent(out), optional :: stat
character(len=*), intent(inout), optional :: errmsg
character(len=:), intent(inout), allocatable, optional :: errmsg_alloc

function reverse_alphabetize(lhs, rhs) result(last_alphabetically)
character(len=*), intent(in) :: lhs, rhs
character(len=len(lhs)) :: last_alphabetically
call_assert_diagnose(len(lhs)==len(rhs), "caf_co_max: LHS/RHS length match", lhs//" , "//rhs)
last_alphabetically = max(lhs,rhs)
end function
if (present(stat)) stat=0

call caf_co_max( &
a, &
optional_value(result_image), &
int(product(shape(a)), c_size_t), &
current_team%info%gex_team)
end subroutine

module procedure prif_co_max_character
call unimplemented("prif_co_max_character")
! integer(c_size_t), target :: char_len
! procedure(prif_operation_wrapper_interface), pointer :: op

! char_len = len(a)
! op => char_max_wrapper
! call prif_co_reduce(a, op, c_loc(char_len), result_image, stat, errmsg, errmsg_alloc)
end procedure

! subroutine char_max_wrapper(arg1, arg2_and_out, count, cdata) bind(C)
! type(c_ptr), intent(in), value :: arg1, arg2_and_out
! integer(c_size_t), intent(in), value :: count
! type(c_ptr), intent(in), value :: cdata

! integer(c_size_t), pointer :: char_len
! integer(c_size_t) :: i

! if (count == 0) return
! call c_f_pointer(cdata, char_len)
! block
! character(len=char_len,kind=c_char), pointer :: lhs(:), rhs_and_result(:)
! call c_f_pointer(arg1, lhs, [count])
! call c_f_pointer(arg2_and_out, rhs_and_result, [count])
! do i = 1, count
! if (lhs(i) <= rhs_and_result(i)) rhs_and_result(i) = lhs(i)
! end do
! end block
! end subroutine

end submodule co_max_s
68 changes: 46 additions & 22 deletions src/caffeine/collective_subroutines/co_min_s.F90
Original file line number Diff line number Diff line change
@@ -1,36 +1,60 @@
! Copyright (c), The Regents of the University of California
! Terms of use are as specified in LICENSE.txt

#include "assert_macros.h"

submodule(prif:prif_private_s) co_min_s
use iso_c_binding, only : c_funloc

use iso_c_binding, only: c_loc, c_f_pointer
implicit none

contains

module procedure prif_co_min
if (present(stat)) stat=0

if (caf_numeric_type(a)) then
call caf_co_min( &
a, optional_value(result_image), int(product(shape(a)), c_size_t), current_team%info%gex_team)
else if (caf_is_f_string(a)) then
call prif_co_reduce(a, c_funloc(alphabetize), optional_value(result_image), stat, errmsg, errmsg_alloc)
else
call prif_error_stop(.false._c_bool, stop_code_char="prif_co_min: unsupported type")
end if
call contiguous_co_min(a, result_image, stat, errmsg, errmsg_alloc)
end procedure

contains
subroutine contiguous_co_min(a, result_image, stat, errmsg, errmsg_alloc)
implicit none
type(*), intent(inout), target, contiguous :: a(..)
integer(c_int), intent(in), optional :: result_image
integer(c_int), intent(out), optional :: stat
character(len=*), intent(inout), optional :: errmsg
character(len=:), intent(inout), allocatable, optional :: errmsg_alloc

function alphabetize(lhs, rhs) result(first_alphabetically)
character(len=*), intent(in) :: lhs, rhs
character(len=len(lhs)) :: first_alphabetically
call_assert_diagnose(len(lhs)==len(rhs), "prif_co_min: LHS/RHS length match", lhs//" , "//rhs)
first_alphabetically = min(lhs,rhs)
end function
if (present(stat)) stat=0

call caf_co_min( &
a, &
optional_value(result_image), &
int(product(shape(a)), c_size_t), &
current_team%info%gex_team)
end subroutine

module procedure prif_co_min_character
call unimplemented("prif_co_min_character")
! integer(c_size_t), target :: char_len
! procedure(prif_operation_wrapper_interface), pointer :: op

! char_len = len(a)
! op => char_min_wrapper
! call prif_co_reduce(a, op, c_loc(char_len), result_image, stat, errmsg, errmsg_alloc)
end procedure

! subroutine char_min_wrapper(arg1, arg2_and_out, count, cdata) bind(C)
! type(c_ptr), intent(in), value :: arg1, arg2_and_out
! integer(c_size_t), intent(in), value :: count
! type(c_ptr), intent(in), value :: cdata

! integer(c_size_t), pointer :: char_len
! integer(c_size_t) :: i

! if (count == 0) return
! call c_f_pointer(cdata, char_len)
! block
! character(len=char_len,kind=c_char), pointer :: lhs(:), rhs_and_result(:)
! call c_f_pointer(arg1, lhs, [count])
! call c_f_pointer(arg2_and_out, rhs_and_result, [count])
! do i = 1, count
! if (lhs(i) <= rhs_and_result(i)) rhs_and_result(i) = lhs(i)
! end do
! end block
! end subroutine

end submodule co_min_s
Loading

0 comments on commit 6d6ac10

Please sign in to comment.