Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

prov/tcp: Restrict which EPs can be opened per domain #9092

Merged
merged 1 commit into from
Jun 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions prov/tcp/src/xnet.h
Original file line number Diff line number Diff line change
Expand Up @@ -423,6 +423,7 @@ struct xnet_xfer_entry {
struct xnet_domain {
struct util_domain util_domain;
struct xnet_progress progress;
enum fi_ep_type ep_type;
};

static inline struct xnet_progress *xnet_ep2_progress(struct xnet_ep *ep)
Expand Down
14 changes: 11 additions & 3 deletions prov/tcp/src/xnet_domain.c
Original file line number Diff line number Diff line change
Expand Up @@ -128,14 +128,21 @@ xnet_mr_regattr(struct fid *fid, const struct fi_mr_attr *attr,
return ret;
}

static int xnet_open_ep(struct fid_domain *domain, struct fi_info *info,
static int xnet_open_ep(struct fid_domain *domain_fid, struct fi_info *info,
struct fid_ep **ep_fid, void *context)
{
struct xnet_domain *domain;

domain = container_of(domain_fid, struct xnet_domain,
util_domain.domain_fid);
if (domain->ep_type != info->ep_attr->type)
return -FI_EINVAL;

if (info->ep_attr->type == FI_EP_MSG)
return xnet_endpoint(domain, info, ep_fid, context);
return xnet_endpoint(domain_fid, info, ep_fid, context);

if (info->ep_attr->type == FI_EP_RDM)
return xnet_rdm_ep(domain, info, ep_fid, context);
return xnet_rdm_ep(domain_fid, info, ep_fid, context);

return -FI_EINVAL;
}
Expand Down Expand Up @@ -232,6 +239,7 @@ int xnet_domain_open(struct fid_fabric *fabric_fid, struct fi_info *info,
if (ret)
goto close;

domain->ep_type = info->ep_attr->type;
domain->util_domain.domain_fid.fid.ops = &xnet_domain_fi_ops;
domain->util_domain.domain_fid.ops = &xnet_domain_ops;
domain->util_domain.domain_fid.mr = &xnet_domain_fi_ops_mr;
Expand Down
1 change: 1 addition & 0 deletions prov/tcp/src/xnet_ep.c
Original file line number Diff line number Diff line change
Expand Up @@ -697,6 +697,7 @@ int xnet_endpoint(struct fid_domain *domain, struct fi_info *info,
if (ret)
goto err1;

assert(info->ep_attr->type == FI_EP_MSG);
ofi_bsock_init(&ep->bsock, &xnet_ep2_progress(ep)->sockapi,
xnet_staging_sbuf_size, xnet_prefetch_rbuf_size,
&ep->util_ep.ep_fid);
Expand Down
1 change: 1 addition & 0 deletions prov/tcp/src/xnet_rdm.c
Original file line number Diff line number Diff line change
Expand Up @@ -878,6 +878,7 @@ int xnet_rdm_ep(struct fid_domain *domain, struct fi_info *info,
if (ret)
goto err1;

assert(info->ep_attr->type == FI_EP_RDM);
ret = xnet_init_rdm(rdm, info);
if (ret)
goto err2;
Expand Down
4 changes: 2 additions & 2 deletions prov/tcp/src/xnet_rdm_cm.c
Original file line number Diff line number Diff line change
Expand Up @@ -171,8 +171,8 @@ static int xnet_open_conn(struct xnet_conn *conn, struct fi_info *info)
int ret;

assert(xnet_progress_locked(xnet_rdm2_progress(conn->rdm)));
ret = fi_endpoint(&conn->rdm->util_ep.domain->domain_fid, info,
&ep_fid, conn);
ret = xnet_endpoint(&conn->rdm->util_ep.domain->domain_fid, info,
&ep_fid, conn);
if (ret) {
XNET_WARN_ERR(FI_LOG_EP_CTRL, "fi_endpoint", ret);
return ret;
Expand Down