Skip to content

Commit

Permalink
Merge pull request #4659 from dmitrygx/topic/uct/tcp_unknown_conn_1_7
Browse files Browse the repository at this point in the history
UCT/TCP/GTEST: Protect against connection from non-UCX sock-based app
  • Loading branch information
shamisp authored Jan 12, 2020
2 parents b2201ce + d286f8d commit ee2252b
Show file tree
Hide file tree
Showing 8 changed files with 436 additions and 63 deletions.
19 changes: 19 additions & 0 deletions src/ucs/sys/sock.c
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,25 @@ ucs_status_t ucs_socket_setopt(int fd, int level, int optname,
return UCS_OK;
}

const char *ucs_socket_getname_str(int fd, char *str, size_t max_size)
{
struct sockaddr_storage sock_addr = {0}; /* Suppress Clang false-positive */
socklen_t addr_size;
int ret;

addr_size = sizeof(sock_addr);
ret = getsockname(fd, (struct sockaddr*)&sock_addr,
&addr_size);
if (ret < 0) {
ucs_debug("getsockname(fd=%d) failed: %m", fd);
ucs_strncpy_safe(str, "-", max_size);
return str;
}

return ucs_sockaddr_str((const struct sockaddr*)&sock_addr,
str, max_size);
}

ucs_status_t ucs_socket_connect(int fd, const struct sockaddr *dest_addr)
{
char dest_str[UCS_SOCKADDR_STRING_LEN];
Expand Down
13 changes: 13 additions & 0 deletions src/ucs/sys/sock.h
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,19 @@ const char* ucs_sockaddr_str(const struct sockaddr *sock_addr,
char *str, size_t max_size);


/**
* Extract the IP address from a given socket fd and return it as a string.
*
* @param [in] fd Socket fd.
* @param [out] str A string filled with the IP address.
* @param [in] max_size Size of a string (considering '\0'-terminated symbol)
*
* @return ip_str if the sock_addr has a valid IP address or 'Invalid address'
* otherwise.
*/
const char *ucs_socket_getname_str(int fd, char *str, size_t max_size);


/**
* Return a value indicating the relationships between passed sockaddr structures.
*
Expand Down
18 changes: 9 additions & 9 deletions src/uct/tcp/tcp.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@

#define UCT_TCP_CONFIG_PREFIX "TCP_"

/* Magic number that is used by TCP to identify its peers */
#define UCT_TCP_MAGIC_NUMBER 0xCAFEBABE12345678lu

/* Maximum number of events to wait on event set */
#define UCT_TCP_MAX_EVENTS 16

Expand Down Expand Up @@ -71,9 +74,13 @@ typedef enum uct_tcp_ep_conn_state {
* After it is done, it sends `UCT_TCP_CM_CONN_REQ` to the peer.
* All AM operations return `UCS_ERR_NO_RESOURCE` error to a caller. */
UCT_TCP_EP_CONN_STATE_CONNECTING,
/* EP is receiving the magic number in order to verify a peer. EP is moved
* to this state after accept() completed. */
UCT_TCP_EP_CONN_STATE_RECV_MAGIC_NUMBER,
/* EP is accepting connection from a peer, i.e. accept() returns socket fd
* on which a connection was accepted, this EP was created using this socket
* and now it is waiting for `UCT_TCP_CM_CONN_REQ` from a peer. */
* fd and the magic number was received and verified by EP and now it is
* waiting for `UCT_TCP_CM_CONN_REQ` from a peer. */
UCT_TCP_EP_CONN_STATE_ACCEPTING,
/* EP is waiting for `UCT_TCP_CM_CONN_ACK` message from a peer after sending
* `UCT_TCP_CM_CONN_REQ`.
Expand Down Expand Up @@ -129,6 +136,7 @@ KHASH_INIT(uct_tcp_cm_eps, struct sockaddr_in, ucs_list_link_t*,
typedef struct uct_tcp_cm_state {
const char *name; /* CM state name */
uct_tcp_ep_progress_t tx_progress; /* TX progress function */
uct_tcp_ep_progress_t rx_progress; /* RX progress function */
} uct_tcp_cm_state_t;


Expand Down Expand Up @@ -353,8 +361,6 @@ void uct_tcp_ep_remove(uct_tcp_iface_t *iface, uct_tcp_ep_t *ep);

void uct_tcp_ep_add(uct_tcp_iface_t *iface, uct_tcp_ep_t *ep);

unsigned uct_tcp_ep_progress_rx(uct_tcp_ep_t *ep);

void uct_tcp_ep_mod_events(uct_tcp_ep_t *ep, int add, int remove);

void uct_tcp_ep_pending_queue_dispatch(uct_tcp_ep_t *ep);
Expand Down Expand Up @@ -409,12 +415,6 @@ ucs_status_t uct_tcp_cm_handle_incoming_conn(uct_tcp_iface_t *iface,

ucs_status_t uct_tcp_cm_conn_start(uct_tcp_ep_t *ep);

static inline unsigned uct_tcp_ep_progress_tx(uct_tcp_ep_t *ep)
{
return uct_tcp_ep_cm_state[ep->conn_state].tx_progress(ep);
}


/**
* Query for active network devices under /sys/class/net, as determined by
* ucs_netif_is_active(). 'md' parameter is not used, and is added for
Expand Down
83 changes: 49 additions & 34 deletions src/uct/tcp/tcp_cm.c
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,6 @@ void uct_tcp_cm_change_conn_state(uct_tcp_ep_t *ep,

switch(ep->conn_state) {
case UCT_TCP_EP_CONN_STATE_CONNECTING:
ucs_assertv(iface->config.conn_nb, "ep=%p", ep);
/* Fall through */
case UCT_TCP_EP_CONN_STATE_WAITING_ACK:
if (old_conn_state == UCT_TCP_EP_CONN_STATE_CLOSED) {
uct_tcp_iface_outstanding_inc(iface);
Expand Down Expand Up @@ -61,13 +59,15 @@ void uct_tcp_cm_change_conn_state(uct_tcp_ep_t *ep,
(old_conn_state == UCT_TCP_EP_CONN_STATE_WAITING_ACK) ||
(old_conn_state == UCT_TCP_EP_CONN_STATE_WAITING_REQ)) {
uct_tcp_iface_outstanding_dec(iface);
} else if (old_conn_state == UCT_TCP_EP_CONN_STATE_ACCEPTING) {
} else if ((old_conn_state == UCT_TCP_EP_CONN_STATE_ACCEPTING) ||
(old_conn_state == UCT_TCP_EP_CONN_STATE_RECV_MAGIC_NUMBER)) {
/* Since ep::peer_addr is 0'ed, we have to print w/o peer's address */
full_log = 0;
}
break;
default:
ucs_assert(ep->conn_state == UCT_TCP_EP_CONN_STATE_ACCEPTING);
ucs_assert((ep->conn_state == UCT_TCP_EP_CONN_STATE_ACCEPTING) ||
(ep->conn_state == UCT_TCP_EP_CONN_STATE_RECV_MAGIC_NUMBER));
/* Since ep::peer_addr is 0'ed and client's <address:port>
* has already been logged, print w/o peer's address */
full_log = 0;
Expand Down Expand Up @@ -143,8 +143,9 @@ static void uct_tcp_cm_trace_conn_pkt(const uct_tcp_ep_t *ep,

ucs_status_t uct_tcp_cm_send_event(uct_tcp_ep_t *ep, uct_tcp_cm_conn_event_t event)
{
uct_tcp_iface_t *iface = ucs_derived_of(ep->super.super.iface,
uct_tcp_iface_t);
uct_tcp_iface_t *iface = ucs_derived_of(ep->super.super.iface,
uct_tcp_iface_t);
size_t magic_number_length = 0;
void *pkt_buf;
size_t pkt_length, cm_pkt_length;
uct_tcp_cm_conn_req_pkt_t *conn_pkt;
Expand All @@ -157,20 +158,29 @@ ucs_status_t uct_tcp_cm_send_event(uct_tcp_ep_t *ep, uct_tcp_cm_conn_event_t eve
UCT_TCP_CM_CONN_WAIT_REQ)),
"ep=%p", ep);

pkt_length = sizeof(*pkt_hdr);
pkt_length = sizeof(*pkt_hdr);
if (event == UCT_TCP_CM_CONN_REQ) {
cm_pkt_length = sizeof(*conn_pkt);
cm_pkt_length = sizeof(*conn_pkt);
if (ep->conn_state == UCT_TCP_EP_CONN_STATE_CONNECTING) {
magic_number_length = sizeof(uint64_t);
}
} else {
cm_pkt_length = sizeof(event);
cm_pkt_length = sizeof(event);
}
pkt_length += cm_pkt_length;
pkt_buf = ucs_alloca(pkt_length);

pkt_hdr = (uct_tcp_am_hdr_t*)pkt_buf;
pkt_length += cm_pkt_length + magic_number_length;
pkt_buf = ucs_alloca(pkt_length);
pkt_hdr = (uct_tcp_am_hdr_t*)(UCS_PTR_BYTE_OFFSET(pkt_buf,
magic_number_length));
pkt_hdr->am_id = UCT_AM_ID_MAX;
pkt_hdr->length = cm_pkt_length;

if (event == UCT_TCP_CM_CONN_REQ) {
if (ep->conn_state == UCT_TCP_EP_CONN_STATE_CONNECTING) {
ucs_assert(magic_number_length == sizeof(uint64_t));
*(uint64_t*)pkt_buf = UCT_TCP_MAGIC_NUMBER;
}

conn_pkt = (uct_tcp_cm_conn_req_pkt_t*)(pkt_hdr + 1);
conn_pkt->event = UCT_TCP_CM_CONN_REQ;
conn_pkt->iface_addr = iface->config.ifaddr;
Expand Down Expand Up @@ -508,29 +518,42 @@ unsigned uct_tcp_cm_handle_conn_pkt(uct_tcp_ep_t **ep, void *pkt, uint32_t lengt
return 0;
}

unsigned uct_tcp_cm_conn_progress(uct_tcp_ep_t *ep)
static ucs_status_t uct_tcp_cm_conn_complete(uct_tcp_ep_t *ep,
unsigned *progress_count_p)
{
ucs_status_t status;

if (!ucs_socket_is_connected(ep->fd)) {
ucs_error("tcp_ep %p: connection establishment for "
"socket fd %d was unsuccessful", ep, ep->fd);
goto err;
}

status = uct_tcp_cm_send_event(ep, UCT_TCP_CM_CONN_REQ);
if (status != UCS_OK) {
return 0;
goto out;
}

uct_tcp_cm_change_conn_state(ep, UCT_TCP_EP_CONN_STATE_WAITING_ACK);
uct_tcp_ep_mod_events(ep, UCS_EVENT_SET_EVREAD, 0);

ucs_assertv((ep->tx.length == 0) && (ep->tx.offset == 0) &&
(ep->tx.buf == NULL), "ep=%p", ep);
return 1;
out:
if (progress_count_p != NULL) {
*progress_count_p = (status == UCS_OK);
}
return status;
}

err:
unsigned uct_tcp_cm_conn_progress(uct_tcp_ep_t *ep)
{
unsigned progress_count;

if (!ucs_socket_is_connected(ep->fd)) {
ucs_error("tcp_ep %p: connection establishment for "
"socket fd %d was unsuccessful", ep, ep->fd);
goto err;
}

uct_tcp_cm_conn_complete(ep, &progress_count);
return progress_count;

err:
uct_tcp_ep_set_failed(ep);
return 0;
}
Expand All @@ -547,13 +570,13 @@ ucs_status_t uct_tcp_cm_conn_start(uct_tcp_ep_t *ep)
return UCS_ERR_TIMED_OUT;
}

uct_tcp_cm_change_conn_state(ep, UCT_TCP_EP_CONN_STATE_CONNECTING);

status = ucs_socket_connect(ep->fd, (const struct sockaddr*)&ep->peer_addr);
if (UCS_STATUS_IS_ERR(status)) {
return status;
} else if (status == UCS_INPROGRESS) {
uct_tcp_cm_change_conn_state(ep, UCT_TCP_EP_CONN_STATE_CONNECTING);
uct_tcp_ep_mod_events(ep, UCS_EVENT_SET_EVWRITE, 0);

return UCS_OK;
}

Expand All @@ -566,15 +589,7 @@ ucs_status_t uct_tcp_cm_conn_start(uct_tcp_ep_t *ep)
}
}

status = uct_tcp_cm_send_event(ep, UCT_TCP_CM_CONN_REQ);
if (status != UCS_OK) {
return status;
}

uct_tcp_cm_change_conn_state(ep, UCT_TCP_EP_CONN_STATE_WAITING_ACK);
uct_tcp_ep_mod_events(ep, UCS_EVENT_SET_EVREAD, 0);

return UCS_OK;
return uct_tcp_cm_conn_complete(ep, NULL);
}

/* This function is called from async thread */
Expand All @@ -592,7 +607,7 @@ ucs_status_t uct_tcp_cm_handle_incoming_conn(uct_tcp_iface_t *iface,
return status;
}

uct_tcp_cm_change_conn_state(ep, UCT_TCP_EP_CONN_STATE_ACCEPTING);
uct_tcp_cm_change_conn_state(ep, UCT_TCP_EP_CONN_STATE_RECV_MAGIC_NUMBER);
uct_tcp_ep_mod_events(ep, UCS_EVENT_SET_EVREAD, 0);

ucs_debug("tcp_iface %p: accepted connection from "
Expand Down
Loading

0 comments on commit ee2252b

Please sign in to comment.