Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

UCM/BISTRO: Translate relative jumps #7866

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
170 changes: 122 additions & 48 deletions src/ucm/bistro/bistro_x86_64.c
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,19 @@ typedef struct {
char code[];
} ucm_bistro_orig_func_t;

/* Patch by jumping to absolute address loaded from register */
typedef struct ucm_bistro_jmp_rax_patch {
uint8_t mov_rax[2]; /* mov %rax, addr */
void *ptr;
uint8_t jmp_rax[2]; /* jmp rax */
} UCS_S_PACKED ucm_bistro_jmp_rax_patch_t;

/* Patch by jumping to relative address by immediate displacement */
typedef struct ucm_bistro_jmp_near_patch {
uint8_t jmp_rel; /* opcode: JMP rel32 */
int32_t disp; /* operand: jump displacement */
} UCS_S_PACKED ucm_bistro_jmp_near_patch_t;

typedef struct {
uint8_t opcode; /* 0xff */
uint8_t modrm; /* 0x25 */
Expand All @@ -43,7 +56,24 @@ typedef struct {
uint8_t cmp_dptr_rax[2];
uint32_t cmp_value;
uint8_t pop_rax;
} UCS_S_PACKED ucm_bistro_compare_xlt_t;
} UCS_S_PACKED ucm_bistro_cmp_xlt_t;

typedef struct {
uint8_t jmp_rel[2];
uint8_t jmp_out[2];
struct {
uint8_t push_imm;
uint32_t value;
} UCS_S_PACKED hi, lo;
uint8_t ret;
} UCS_S_PACKED ucm_bistro_jcc_xlt_t;

typedef struct {
const void *src_p; /* Pointer to current source instruction */
const void *src_end; /* Upper limit for source instructions */
void *dst_p; /* Pointer to current destination instruction */
void *dst_end; /* Upper limit for destination instructions */
} ucm_bistro_relocate_context_t;


/* REX prefix */
Expand Down Expand Up @@ -96,28 +126,40 @@ typedef struct {
/* ModR/M encoding for CMP [RIP+x], Imm32 */
#define UCM_BISTRO_X86_MODRM_CMP_RIP 0x3D /* 11 111 101 */

/* Jcc (conditional jump) opcodes range */
#define UCM_BISTRO_X86_JCC_FIRST 0x70
#define UCM_BISTRO_X86_JCC_LAST 0x7F


static ucs_status_t
ucm_bistro_relocate_one(void *dst, const void *src, size_t max_dst_length,
size_t *dst_length, size_t *src_length)
ucm_bistro_relocate_one(ucm_bistro_relocate_context_t *ctx)
{
const void *src_p = src;
ucm_bistro_compare_xlt_t cmp_xlt = {
const void *copy_src = ctx->src_p;
ucm_bistro_cmp_xlt_t cmp = {
.push_rax = 0x50,
.movabs_rax = {0x48, 0xb8},
.cmp_dptr_rax = {0x81, 0x38},
.pop_rax = 0x58
};
ucm_bistro_jcc_xlt_t jcc = {
.jmp_rel = {0x00, 0x02},
.jmp_out = {0xeb, 0x0b},
.hi = {0x68, 0},
.lo = {0x68, 0},
.ret = 0xc3
};
uint8_t rex, opcode, modrm, mod;
const void *copy_src;
size_t dst_length;
uint64_t jmpdest;
int32_t disp32;
uint32_t imm32;
int8_t disp8;

/* Check opcode and REX prefix */
opcode = *ucs_serialize_next(&src_p, const uint8_t);
opcode = *ucs_serialize_next(&ctx->src_p, const uint8_t);
if ((opcode & UCM_BISTRO_X86_REX_MASK) == UCM_BISTRO_X86_REX) {
rex = opcode;
opcode = *ucs_serialize_next(&src_p, const uint8_t);
opcode = *ucs_serialize_next(&ctx->src_p, const uint8_t);
} else {
rex = 0;
}
Expand All @@ -128,15 +170,15 @@ ucm_bistro_relocate_one(void *dst, const void *src, size_t max_dst_length,
goto out_copy_src;
} else if ((rex == UCM_BISTRO_X86_REX_W) &&
(opcode == UCM_BISTRO_X86_IMM_GRP1_EV_IZ)) {
modrm = *ucs_serialize_next(&src_p, const uint8_t);
modrm = *ucs_serialize_next(&ctx->src_p, const uint8_t);
if (modrm == UCM_BISTRO_X86_MODRM_SUB_SP) {
/* sub $imm32, %rsp */
ucs_serialize_next(&src_p, const uint32_t);
ucs_serialize_next(&ctx->src_p, const uint32_t);
goto out_copy_src;
}
} else if ((rex == UCM_BISTRO_X86_REX_W) &&
(opcode == UCM_BISTRO_X86_MOV_EV_GV)) {
modrm = *ucs_serialize_next(&src_p, const uint8_t);
modrm = *ucs_serialize_next(&ctx->src_p, const uint8_t);
mod = modrm >> UCM_BISTRO_X86_MODRM_MOD_SHIFT;
if (modrm == UCM_BISTRO_X86_MODRM_BP_SP) {
/* mov %rsp, %rbp */
Expand All @@ -147,22 +189,22 @@ ucm_bistro_relocate_one(void *dst, const void *src, size_t max_dst_length,
((modrm & UCS_MASK(UCM_BISTRO_X86_MODRM_RM_BITS)) ==
UCM_BISTRO_X86_MODRM_RM_SIB)) {
/* r/m = 0b100, mod = 0b00/0b01/0b10 */
ucs_serialize_next(&src_p, const uint8_t); /* skip SIB */
ucs_serialize_next(&ctx->src_p, const uint8_t); /* skip SIB */
if (mod == UCM_BISTRO_X86_MODRM_MOD_DISP8) {
ucs_serialize_next(&src_p, const uint8_t); /* skip disp8 */
ucs_serialize_next(&ctx->src_p, const uint8_t); /* skip disp8 */
goto out_copy_src;
} else if (mod == UCM_BISTRO_X86_MODRM_MOD_DISP32) {
ucs_serialize_next(&src_p, const uint32_t); /* skip disp32 */
ucs_serialize_next(&ctx->src_p, const uint32_t); /* skip disp32 */
goto out_copy_src;
}
}
} else if ((rex == 0) && ((opcode & UCM_BISTRO_X86_MOV_IR_MASK) ==
UCM_BISTRO_X86_MOV_IR)) {
/* mov $imm32, %reg */
ucs_serialize_next(&src_p, const uint32_t);
ucs_serialize_next(&ctx->src_p, const uint32_t);
goto out_copy_src;
} else if ((rex == 0) && (opcode == UCM_BISTRO_X86_IMM_GRP1_EV_IZ)) {
modrm = *ucs_serialize_next(&src_p, const uint8_t);
modrm = *ucs_serialize_next(&ctx->src_p, const uint8_t);
if (modrm == UCM_BISTRO_X86_MODRM_CMP_RIP) {
/*
* Since we can't assume the new code will be within 32-bit
Expand All @@ -175,29 +217,53 @@ ucm_bistro_relocate_one(void *dst, const void *src, size_t max_dst_length,
* cmpl $imm32, (%rax)
* pop %rax
*/
disp32 = *ucs_serialize_next(&src_p, const uint32_t);
imm32 = *ucs_serialize_next(&src_p, const uint32_t);
cmp_xlt.rax_value = (uintptr_t)UCS_PTR_BYTE_OFFSET(src_p, disp32);
cmp_xlt.cmp_value = imm32;
copy_src = &cmp_xlt;
*dst_length = sizeof(cmp_xlt);
disp32 = *ucs_serialize_next(&ctx->src_p, const int32_t);
imm32 = *ucs_serialize_next(&ctx->src_p, const uint32_t);
cmp.rax_value = (uintptr_t)UCS_PTR_BYTE_OFFSET(ctx->src_p, disp32);
cmp.cmp_value = imm32;
copy_src = &cmp;
dst_length = sizeof(cmp);
goto out_copy;
}
} else if ((rex == 0) && (opcode >= UCM_BISTRO_X86_JCC_FIRST) &&
(opcode <= UCM_BISTRO_X86_JCC_LAST)) {
/*
* Since we can't assume the new code will be within 32-bit range of the
* jump destination, we need to translate the code from:
* jCC $disp8
* to:
* jCC L1
* L1: jmp L2 ; condition 'CC' did not hold
* push $addrhi
* push $addrlo
* ret ; 64-bit jump to destination
* L2: ; continue execution
*/
disp8 = *ucs_serialize_next(&ctx->src_p, const int8_t);
jmpdest = (uintptr_t)UCS_PTR_BYTE_OFFSET(ctx->src_p, disp8);
jcc.jmp_rel[0] = opcode; /* keep original jump condition */
jcc.hi.value = jmpdest >> 32;
jcc.lo.value = jmpdest & UCS_MASK(32);
copy_src = &jcc;
dst_length = sizeof(jcc);
/* Prevent patching past jump target */
ctx->src_end = ucs_min(ctx->src_end, (void*)jmpdest);
goto out_copy;
}

/* Could not recognize the instruction */
return UCS_ERR_UNSUPPORTED;

out_copy_src:
copy_src = src;
*dst_length = UCS_PTR_BYTE_DIFF(src, src_p);
dst_length = UCS_PTR_BYTE_DIFF(copy_src, ctx->src_p);
out_copy:
if (*dst_length > max_dst_length) {
if (UCS_PTR_BYTE_OFFSET(ctx->dst_p, dst_length) > ctx->dst_end) {
return UCS_ERR_BUFFER_TOO_SMALL;
}

*src_length = UCS_PTR_BYTE_DIFF(src, src_p);
memcpy(dst, copy_src, *dst_length);
/* Copy 'dst_length' bytes to ctx->dst_p and advance it */
memcpy(ucs_serialize_next_raw(&ctx->dst_p, void, dst_length), copy_src,
dst_length);
return UCS_OK;
}

Expand All @@ -209,28 +275,30 @@ ucm_bistro_relocate_one(void *dst, const void *src, size_t max_dst_length,
*/
static ucs_status_t
ucm_bistro_relocate_code(void *dst, const void *src, size_t min_src_length,
size_t max_dst_length, size_t *dst_length,
size_t *src_length)
size_t max_dst_length, size_t *dst_length_p,
size_t *src_length_p)
{
size_t src_length_one, dst_length_one;
ucm_bistro_relocate_context_t ctx = {
.src_p = src,
.dst_p = dst,
.dst_end = UCS_PTR_BYTE_OFFSET(dst, max_dst_length),
.src_end = (void*)UINTPTR_MAX
};
ucs_status_t status;

*src_length = 0;
*dst_length = 0;
while (*src_length < min_src_length) {
status = ucm_bistro_relocate_one(UCS_PTR_BYTE_OFFSET(dst, *dst_length),
UCS_PTR_BYTE_OFFSET(src, *src_length),
max_dst_length - *dst_length,
&dst_length_one, &src_length_one);
while (ctx.src_p < UCS_PTR_BYTE_OFFSET(src, min_src_length)) {
status = ucm_bistro_relocate_one(&ctx);
if (status != UCS_OK) {
return status;
}

*dst_length += dst_length_one;
*src_length += src_length_one;
if (ctx.src_p > ctx.src_end) {
return UCS_ERR_UNSUPPORTED;
}
}

ucm_assert(*dst_length <= max_dst_length);
*src_length_p = UCS_PTR_BYTE_DIFF(src, ctx.src_p);
*dst_length_p = UCS_PTR_BYTE_DIFF(dst, ctx.dst_p);
return UCS_OK;
}

Expand Down Expand Up @@ -259,9 +327,13 @@ ucm_bistro_construct_orig_func(const void *func_ptr, size_t patch_len,
ucm_bistro_orig_func_t *orig_func;
ucs_status_t status;
char code_buf[64];
int dladdr_ret;
Dl_info dli;

/* Allocate executable page */
max_code_len = patch_len + sizeof(ucm_bistro_compare_xlt_t);
max_code_len = ucs_max(patch_len + sizeof(ucm_bistro_cmp_xlt_t) +
sizeof(ucm_bistro_jcc_xlt_t),
64);
orig_func = ucm_bistro_allocate_code(sizeof(*orig_func) + max_code_len +
sizeof(*jmp_back));
if (orig_func == NULL) {
Expand All @@ -274,7 +346,9 @@ ucm_bistro_construct_orig_func(const void *func_ptr, size_t patch_len,
status = ucm_bistro_relocate_code(orig_func->code, func_ptr, patch_len,
max_code_len, &code_len, &prefix_len);
if (status != UCS_OK) {
ucm_diag("'%s' could not patch by bistro, code:%s", symbol,
dladdr_ret = dladdr(func_ptr, &dli);
ucm_diag("failed to patch '%s' from %s length %zu code:%s", symbol,
(dladdr_ret != 0) ? dli.dli_fname : "(unknown)", patch_len,
ucm_bistro_dump_code(func_ptr, 16, code_buf,
sizeof(code_buf)));
return UCS_ERR_UNSUPPORTED;
Expand All @@ -298,9 +372,9 @@ ucs_status_t ucm_bistro_patch(void *func_ptr, void *hook, const char *symbol,
void **orig_func_p,
ucm_bistro_restore_point_t **rp)
{
ucm_bistro_jmp_r11_patch_t jmp_r11 = {
.mov_r11 = {0x49, 0xbb},
.jmp_r11 = {0x41, 0xff, 0xe3}
ucm_bistro_jmp_rax_patch_t jmp_rax = {
.mov_rax = {0x48, 0xb8},
.jmp_rax = {0xff, 0xe0}
};
ucm_bistro_jmp_near_patch_t jmp_near = {
.jmp_rel = 0xe9
Expand All @@ -320,9 +394,9 @@ ucs_status_t ucm_bistro_patch(void *func_ptr, void *hook, const char *symbol,
patch = &jmp_near;
patch_len = sizeof(jmp_near);
} else {
jmp_r11.ptr = hook;
patch = &jmp_r11;
patch_len = sizeof(jmp_r11);
jmp_rax.ptr = hook;
patch = &jmp_rax;
patch_len = sizeof(jmp_rax);
}

if (orig_func_p != NULL) {
Expand Down
15 changes: 0 additions & 15 deletions src/ucm/bistro/bistro_x86_64.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,25 +11,10 @@
#include <stdint.h>

#include <ucs/type/status.h>
#include <ucs/sys/compiler_def.h>

#define UCM_BISTRO_PROLOGUE
#define UCM_BISTRO_EPILOGUE

/* Patch by jumping to absolute address loaded from register */
typedef struct ucm_bistro_jmp_r11_patch {
uint8_t mov_r11[2]; /* mov %r11, addr */
void *ptr;
uint8_t jmp_r11[3]; /* jmp r11 */
} UCS_S_PACKED ucm_bistro_jmp_r11_patch_t;


/* Patch by jumping to relative address by immediate displacement */
typedef struct ucm_bistro_jmp_near_patch {
uint8_t jmp_rel; /* opcode: JMP rel32 */
int32_t disp; /* operand: jump displacement */
} UCS_S_PACKED ucm_bistro_jmp_near_patch_t;


/**
* Set library function call hook using Binary Instrumentation
Expand Down
1 change: 1 addition & 0 deletions src/ucm/util/log.c
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,7 @@ static void ucm_log_vsnprintf(char *buf, size_t max, const char *fmt, va_list ap
flags |= UCM_LOG_LTOA_PAD_LEFT;
break;
case 'l':
case 'z':
flags |= UCM_LOG_LTOA_FLAG_LONG;
break;
case '0':
Expand Down
16 changes: 16 additions & 0 deletions test/apps/test_cuda_hook.c
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include <ucp/api/ucp.h>
#include <ucm/api/ucm.h>
#include <cuda_runtime.h>
#include <sys/mman.h>
#include <getopt.h>
#include <cuda.h>

Expand Down Expand Up @@ -81,11 +82,13 @@ int main(int argc, char **argv)
{
static const ucm_event_type_t memtype_events = UCM_EVENT_MEM_TYPE_ALLOC |
UCM_EVENT_MEM_TYPE_FREE;
static const size_t dummy_va_size = 4 * (1ul << 30); /* 4 GB */
static const int num_expected_events = 2;
ucp_context_h context;
ucs_status_t status;
ucp_params_t params;
int use_driver_api;
void *dummy_ptr;
int num_events;
int c;

Expand All @@ -104,6 +107,18 @@ int main(int argc, char **argv)
}
}

/* In order to test long jumps in bistro hooks code, increase address space
* separation by allocaing a large VA space segment.
*/
dummy_ptr = mmap(NULL, dummy_va_size, PROT_READ | PROT_WRITE,
MAP_PRIVATE | MAP_ANONYMOUS, -1, 0);
if (dummy_ptr == MAP_FAILED) {
printf("failed to allocate dummy VA space: %m\n");
return -1;
}

printf("allocated dummy VA space at %p\n", dummy_ptr);

params.field_mask = UCP_PARAM_FIELD_FEATURES;
params.features = UCP_FEATURE_TAG | UCP_FEATURE_STREAM;
status = ucp_init(&params, NULL, &context);
Expand All @@ -126,5 +141,6 @@ int main(int argc, char **argv)

ucp_cleanup(context);

munmap(dummy_ptr, dummy_va_size);
return (num_events >= num_expected_events) ? 0 : -1;
}