Skip to content

Commit

Permalink
prov/shm: use owner-allocated srx
Browse files Browse the repository at this point in the history
The peer API has been updated to specify that the owner must allocate
the peer's fid_peer_srx. The shm implementation was allocating its
own internal fid_peer_srx.
This updates the shm implementation to assume it has a unique
fid_peer_srx and updates the imported fid_peer_srx peer_ops, saving
a pointer to the fid_peer_srx instead of the internal fid_ep which
required a wrapper function to get back to the fid_peer_srx

Signed-off-by: Alexia Ingerson <alexia.ingerson@intel.com>
  • Loading branch information
aingerson committed Sep 11, 2024
1 parent 5de922c commit f5a600f
Show file tree
Hide file tree
Showing 5 changed files with 47 additions and 57 deletions.
7 changes: 1 addition & 6 deletions prov/shm/src/smr.h
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ struct smr_ep {
const char *name;
uint64_t msg_id;
struct smr_region *volatile region;
struct fid_ep *srx;
struct fid_peer_srx *srx;
struct ofi_bufpool *cmd_ctx_pool;
struct ofi_bufpool *unexp_buf_pool;
struct ofi_bufpool *pend_buf_pool;
Expand All @@ -236,11 +236,6 @@ struct smr_ep {
void (*smr_progress_ipc_list)(struct smr_ep *ep);
};

static inline struct fid_peer_srx *smr_get_peer_srx(struct smr_ep *ep)
{
return container_of(ep->srx, struct fid_peer_srx, ep_fid);
}

#define smr_ep_rx_flags(smr_ep) ((smr_ep)->util_ep.rx_op_flags)
#define smr_ep_tx_flags(smr_ep) ((smr_ep)->util_ep.tx_op_flags)

Expand Down
5 changes: 2 additions & 3 deletions prov/shm/src/smr_av.c
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,6 @@ static int smr_av_insert(struct fid_av *av_fid, const void *addr, size_t count,
struct util_ep *util_ep;
struct smr_av *smr_av;
struct smr_ep *smr_ep;
struct fid_peer_srx *srx;
struct dlist_entry *av_entry;
fi_addr_t util_addr;
int64_t shm_id = -1;
Expand Down Expand Up @@ -173,8 +172,8 @@ static int smr_av_insert(struct fid_av *av_fid, const void *addr, size_t count,
smr_ep = container_of(util_ep, struct smr_ep, util_ep);
smr_ep->region->max_sar_buf_per_peer =
SMR_MAX_PEERS / smr_av->smr_map.num_peers;
srx = smr_get_peer_srx(smr_ep);
srx->owner_ops->foreach_unspec_addr(srx, &smr_get_addr);
smr_ep->srx->owner_ops->foreach_unspec_addr(smr_ep->srx,
&smr_get_addr);
}

}
Expand Down
39 changes: 17 additions & 22 deletions prov/shm/src/smr_ep.c
Original file line number Diff line number Diff line change
Expand Up @@ -119,8 +119,8 @@ int smr_ep_getopt(fid_t fid, int level, int optname, void *optval,
struct smr_ep *smr_ep =
container_of(fid, struct smr_ep, util_ep.ep_fid);

return smr_ep->srx->ops->getopt(&smr_ep->srx->fid, level, optname,
optval, optlen);
return smr_ep->srx->ep_fid.ops->getopt(&smr_ep->srx->ep_fid.fid, level,
optname, optval, optlen);
}

int smr_ep_setopt(fid_t fid, int level, int optname, const void *optval,
Expand All @@ -134,7 +134,7 @@ int smr_ep_setopt(fid_t fid, int level, int optname, const void *optval,
return -FI_ENOPROTOOPT;

if (optname == FI_OPT_MIN_MULTI_RECV) {
srx = util_get_peer_srx(smr_ep->srx)->ep_fid.fid.context;
srx = smr_ep->srx->ep_fid.fid.context;
srx->min_multi_recv_size = *(size_t *)optval;
return FI_SUCCESS;
}
Expand All @@ -159,7 +159,7 @@ static ssize_t smr_ep_cancel(fid_t ep_fid, void *context)
struct smr_ep *ep;

ep = container_of(ep_fid, struct smr_ep, util_ep.ep_fid);
return ep->srx->ops->cancel(&ep->srx->fid, context);
return ep->srx->ep_fid.ops->cancel(&ep->srx->ep_fid.fid, context);
}

static struct fi_ops_ep smr_ep_ops = {
Expand Down Expand Up @@ -808,9 +808,7 @@ static int smr_ep_close(struct fid *fid)
if (ep->srx) {
/* shm is an owner provider */
if (ep->util_ep.ep_fid.msg != &smr_no_recv_msg_ops)
(void) util_srx_close(&ep->srx->fid);
else /* shm is a peer provider */
free(ep->srx);
(void) util_srx_close(&ep->srx->ep_fid.fid);
}

ofi_endpoint_close(&ep->util_ep);
Expand Down Expand Up @@ -1073,6 +1071,7 @@ int smr_srx_context(struct fid_domain *domain, struct fi_rx_attr *attr,
if (attr->op_flags & FI_PEER) {
smr_domain->srx = ((struct fi_peer_srx_context *)
(context))->srx;
smr_domain->srx->peer_ops = &smr_srx_peer_ops;
return FI_SUCCESS;
}
FI_WARN(&smr_prov, FI_LOG_EP_CTRL,
Expand All @@ -1085,7 +1084,6 @@ static int smr_ep_bind(struct fid *ep_fid, struct fid *bfid, uint64_t flags)
struct smr_ep *ep;
struct util_av *av;
int ret = 0;
struct fid_peer_srx *srx, *srx_b;

ep = container_of(ep_fid, struct smr_ep, util_ep.ep_fid.fid);
switch (bfid->fclass) {
Expand All @@ -1109,16 +1107,10 @@ static int smr_ep_bind(struct fid *ep_fid, struct fid *bfid, uint64_t flags)
struct util_cntr, cntr_fid.fid), flags);
break;
case FI_CLASS_SRX_CTX:
srx = calloc(1, sizeof(*srx));
srx_b = container_of(bfid, struct fid_peer_srx, ep_fid.fid);
srx->peer_ops = &smr_srx_peer_ops;
srx->owner_ops = srx_b->owner_ops;
srx->ep_fid.fid.context = srx_b->ep_fid.fid.context;
ep->srx = &srx->ep_fid;
ep->srx = container_of(bfid, struct fid_peer_srx, ep_fid.fid);
break;
default:
FI_WARN(&smr_prov, FI_LOG_EP_CTRL,
"invalid fid class\n");
FI_WARN(&smr_prov, FI_LOG_EP_CTRL, "invalid fid class\n");
ret = -FI_EINVAL;
break;
}
Expand All @@ -1131,6 +1123,7 @@ static int smr_ep_ctrl(struct fid *fid, int command, void *arg)
struct smr_domain *domain;
struct smr_ep *ep;
struct smr_av *av;
struct fid_ep *srx;
int ret;

ep = container_of(fid, struct smr_ep, util_ep.ep_fid.fid);
Expand Down Expand Up @@ -1171,15 +1164,17 @@ static int smr_ep_ctrl(struct fid *fid, int command, void *arg)
ret = util_ep_srx_context(&domain->util_domain,
ep->rx_size, SMR_IOV_LIMIT,
SMR_INJECT_SIZE, &smr_update,
&ep->util_ep.lock, &ep->srx);
&ep->util_ep.lock, &srx);
if (ret)
return ret;

util_get_peer_srx(ep->srx)->peer_ops =
&smr_srx_peer_ops;
ret = util_srx_bind(&ep->srx->fid,
&ep->util_ep.rx_cq->cq_fid.fid,
FI_RECV);
ep->srx = container_of(srx, struct fid_peer_srx,
ep_fid.fid);
ep->srx->peer_ops = &smr_srx_peer_ops;

ret = util_srx_bind(&ep->srx->ep_fid.fid,
&ep->util_ep.rx_cq->cq_fid.fid,
FI_RECV);
if (ret)
return ret;
} else {
Expand Down
28 changes: 15 additions & 13 deletions prov/shm/src/smr_msg.c
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ static ssize_t smr_recvmsg(struct fid_ep *ep_fid, const struct fi_msg *msg,

ep = container_of(ep_fid, struct smr_ep, util_ep.ep_fid.fid);

return util_srx_generic_recv(ep->srx, msg->msg_iov, msg->desc,
return util_srx_generic_recv(&ep->srx->ep_fid, msg->msg_iov, msg->desc,
msg->iov_count, msg->addr, msg->context,
flags | ep->util_ep.rx_msg_flags);
}
Expand All @@ -58,8 +58,8 @@ static ssize_t smr_recvv(struct fid_ep *ep_fid, const struct iovec *iov,

ep = container_of(ep_fid, struct smr_ep, util_ep.ep_fid.fid);

return util_srx_generic_recv(ep->srx, iov, desc, count, src_addr,
context, smr_ep_rx_flags(ep));
return util_srx_generic_recv(&ep->srx->ep_fid, iov, desc, count,
src_addr, context, smr_ep_rx_flags(ep));
}

static ssize_t smr_recv(struct fid_ep *ep_fid, void *buf, size_t len,
Expand All @@ -73,8 +73,8 @@ static ssize_t smr_recv(struct fid_ep *ep_fid, void *buf, size_t len,
iov.iov_base = buf;
iov.iov_len = len;

return util_srx_generic_recv(ep->srx, &iov, &desc, 1, src_addr, context,
smr_ep_rx_flags(ep));
return util_srx_generic_recv(&ep->srx->ep_fid, &iov, &desc, 1, src_addr,
context, smr_ep_rx_flags(ep));
}

static ssize_t smr_generic_sendmsg(struct smr_ep *ep, const struct iovec *iov,
Expand Down Expand Up @@ -293,8 +293,9 @@ static ssize_t smr_trecv(struct fid_ep *ep_fid, void *buf, size_t len,
iov.iov_base = buf;
iov.iov_len = len;

return util_srx_generic_trecv(ep->srx, &iov, &desc, 1, src_addr, context,
tag, ignore, smr_ep_rx_flags(ep));
return util_srx_generic_trecv(&ep->srx->ep_fid, &iov, &desc, 1,
src_addr, context, tag, ignore,
smr_ep_rx_flags(ep));
}

static ssize_t smr_trecvv(struct fid_ep *ep_fid, const struct iovec *iov,
Expand All @@ -305,8 +306,9 @@ static ssize_t smr_trecvv(struct fid_ep *ep_fid, const struct iovec *iov,

ep = container_of(ep_fid, struct smr_ep, util_ep.ep_fid.fid);

return util_srx_generic_trecv(ep->srx, iov, desc, count, src_addr,
context, tag, ignore, smr_ep_rx_flags(ep));
return util_srx_generic_trecv(&ep->srx->ep_fid, iov, desc, count,
src_addr, context, tag, ignore,
smr_ep_rx_flags(ep));
}

static ssize_t smr_trecvmsg(struct fid_ep *ep_fid,
Expand All @@ -316,10 +318,10 @@ static ssize_t smr_trecvmsg(struct fid_ep *ep_fid,

ep = container_of(ep_fid, struct smr_ep, util_ep.ep_fid.fid);

return util_srx_generic_trecv(ep->srx, msg->msg_iov, msg->desc,
msg->iov_count, msg->addr, msg->context,
msg->tag, msg->ignore,
flags | ep->util_ep.rx_msg_flags);
return util_srx_generic_trecv(&ep->srx->ep_fid, msg->msg_iov, msg->desc,
msg->iov_count, msg->addr, msg->context,
msg->tag, msg->ignore,
flags | ep->util_ep.rx_msg_flags);
}

static ssize_t smr_tsend(struct fid_ep *ep_fid, const void *buf, size_t len,
Expand Down
25 changes: 12 additions & 13 deletions prov/shm/src/smr_progress.c
Original file line number Diff line number Diff line change
Expand Up @@ -781,7 +781,7 @@ static int smr_start_common(struct smr_ep *ep, struct smr_cmd *cmd,
FI_WARN(&smr_prov, FI_LOG_EP_CTRL,
"unable to process rx completion\n");
}
smr_get_peer_srx(ep)->owner_ops->free_entry(rx_entry);
ep->srx->owner_ops->free_entry(rx_entry);
}

return 0;
Expand Down Expand Up @@ -836,7 +836,7 @@ static int smr_copy_saved(struct smr_cmd_ctx *cmd_ctx,
"unable to process rx completion\n");
return ret;
}
smr_get_peer_srx(cmd_ctx->ep)->owner_ops->free_entry(rx_entry);
cmd_ctx->ep->srx->owner_ops->free_entry(rx_entry);

return FI_SUCCESS;
}
Expand Down Expand Up @@ -983,7 +983,6 @@ static int smr_alloc_cmd_ctx(struct smr_ep *ep,

static int smr_progress_cmd_msg(struct smr_ep *ep, struct smr_cmd *cmd)
{
struct fid_peer_srx *peer_srx = smr_get_peer_srx(ep);
struct fi_peer_match_attr attr;
struct fi_peer_rx_entry *rx_entry;
int ret;
Expand All @@ -992,33 +991,33 @@ static int smr_progress_cmd_msg(struct smr_ep *ep, struct smr_cmd *cmd)
attr.msg_size = cmd->msg.hdr.size;
attr.tag = cmd->msg.hdr.tag;
if (cmd->msg.hdr.op == ofi_op_tagged) {
ret = peer_srx->owner_ops->get_tag(peer_srx, &attr, &rx_entry);
ret = ep->srx->owner_ops->get_tag(ep->srx, &attr, &rx_entry);
if (ret == -FI_ENOENT) {
ret = smr_alloc_cmd_ctx(ep, rx_entry, cmd);
if (ret) {
peer_srx->owner_ops->free_entry(rx_entry);
ep->srx->owner_ops->free_entry(rx_entry);
return ret;
}

ret = peer_srx->owner_ops->queue_tag(rx_entry);
ret = ep->srx->owner_ops->queue_tag(rx_entry);
if (ret) {
peer_srx->owner_ops->free_entry(rx_entry);
ep->srx->owner_ops->free_entry(rx_entry);
return ret;
}
goto out;
}
} else {
ret = peer_srx->owner_ops->get_msg(peer_srx, &attr, &rx_entry);
ret = ep->srx->owner_ops->get_msg(ep->srx, &attr, &rx_entry);
if (ret == -FI_ENOENT) {
ret = smr_alloc_cmd_ctx(ep, rx_entry, cmd);
if (ret) {
peer_srx->owner_ops->free_entry(rx_entry);
ep->srx->owner_ops->free_entry(rx_entry);
return ret;
}

ret = peer_srx->owner_ops->queue_msg(rx_entry);
ret = ep->srx->owner_ops->queue_msg(rx_entry);
if (ret) {
peer_srx->owner_ops->free_entry(rx_entry);
ep->srx->owner_ops->free_entry(rx_entry);
return ret;
}
goto out;
Expand Down Expand Up @@ -1338,7 +1337,7 @@ void smr_progress_ipc_list(struct smr_ep *ep)
ipc_entry->async_event);
dlist_remove(&ipc_entry->entry);
if (ipc_entry->rx_entry)
smr_get_peer_srx(ep)->owner_ops->free_entry(ipc_entry->rx_entry);
ep->srx->owner_ops->free_entry(ipc_entry->rx_entry);
ofi_buf_free(ipc_entry);
}
}
Expand Down Expand Up @@ -1444,7 +1443,7 @@ static void smr_progress_sar_list(struct smr_ep *ep)
"unable to process rx completion\n");
}
if (sar_entry->rx_entry)
smr_get_peer_srx(ep)->owner_ops->free_entry(sar_entry->rx_entry);
ep->srx->owner_ops->free_entry(sar_entry->rx_entry);

dlist_remove(&sar_entry->entry);
ofi_buf_free(sar_entry);
Expand Down

0 comments on commit f5a600f

Please sign in to comment.