Skip to content

Commit

Permalink
UCM/BISTRO: Translate relative jumps
Browse files Browse the repository at this point in the history
  • Loading branch information
yosefe committed Jan 23, 2022
1 parent ea16f70 commit 6c47694
Show file tree
Hide file tree
Showing 4 changed files with 139 additions and 63 deletions.
170 changes: 122 additions & 48 deletions src/ucm/bistro/bistro_x86_64.c
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,19 @@ typedef struct {
char code[];
} ucm_bistro_orig_func_t;

/* Patch by jumping to absolute address loaded from register */
typedef struct ucm_bistro_jmp_rax_patch {
uint8_t mov_rax[2]; /* mov %rax, addr */
void *ptr;
uint8_t jmp_rax[2]; /* jmp rax */
} UCS_S_PACKED ucm_bistro_jmp_rax_patch_t;

/* Patch by jumping to relative address by immediate displacement */
typedef struct ucm_bistro_jmp_near_patch {
uint8_t jmp_rel; /* opcode: JMP rel32 */
int32_t disp; /* operand: jump displacement */
} UCS_S_PACKED ucm_bistro_jmp_near_patch_t;

typedef struct {
uint8_t opcode; /* 0xff */
uint8_t modrm; /* 0x25 */
Expand All @@ -43,7 +56,24 @@ typedef struct {
uint8_t cmp_dptr_rax[2];
uint32_t cmp_value;
uint8_t pop_rax;
} UCS_S_PACKED ucm_bistro_compare_xlt_t;
} UCS_S_PACKED ucm_bistro_cmp_xlt_t;

typedef struct {
uint8_t jmp_rel[2];
uint8_t jmp_out[2];
struct {
uint8_t push_imm;
uint32_t value;
} UCS_S_PACKED hi, lo;
uint8_t ret;
} UCS_S_PACKED ucm_bistro_jcc_xlt_t;

typedef struct {
const void *src_p; /* Pointer to current source instruction */
const void *src_end; /* Upper limit for source instructions */
void *dst_p; /* Pointer to current destination instruction */
void *dst_end; /* Upper limit for destination instructions */
} ucm_bistro_relocate_context_t;


/* REX prefix */
Expand Down Expand Up @@ -96,28 +126,40 @@ typedef struct {
/* ModR/M encoding for CMP [RIP+x], Imm32 */
#define UCM_BISTRO_X86_MODRM_CMP_RIP 0x3D /* 11 111 101 */

/* Jcc (conditional jump) opcodes range */
#define UCM_BISTRO_X86_JCC_FIRST 0x70
#define UCM_BISTRO_X86_JCC_LAST 0x7F


static ucs_status_t
ucm_bistro_relocate_one(void *dst, const void *src, size_t max_dst_length,
size_t *dst_length, size_t *src_length)
ucm_bistro_relocate_one(ucm_bistro_relocate_context_t *ctx)
{
const void *src_p = src;
ucm_bistro_compare_xlt_t cmp_xlt = {
const void *copy_src = ctx->src_p;
ucm_bistro_cmp_xlt_t cmp = {
.push_rax = 0x50,
.movabs_rax = {0x48, 0xb8},
.cmp_dptr_rax = {0x81, 0x38},
.pop_rax = 0x58
};
ucm_bistro_jcc_xlt_t jcc = {
.jmp_rel = {0x00, 0x02},
.jmp_out = {0xeb, 0x0b},
.hi = {0x68, 0},
.lo = {0x68, 0},
.ret = 0xc3
};
uint8_t rex, opcode, modrm, mod;
const void *copy_src;
size_t dst_length;
uint64_t jmpdest;
int32_t disp32;
uint32_t imm32;
int8_t disp8;

/* Check opcode and REX prefix */
opcode = *ucs_serialize_next(&src_p, const uint8_t);
opcode = *ucs_serialize_next(&ctx->src_p, const uint8_t);
if ((opcode & UCM_BISTRO_X86_REX_MASK) == UCM_BISTRO_X86_REX) {
rex = opcode;
opcode = *ucs_serialize_next(&src_p, const uint8_t);
opcode = *ucs_serialize_next(&ctx->src_p, const uint8_t);
} else {
rex = 0;
}
Expand All @@ -128,15 +170,15 @@ ucm_bistro_relocate_one(void *dst, const void *src, size_t max_dst_length,
goto out_copy_src;
} else if ((rex == UCM_BISTRO_X86_REX_W) &&
(opcode == UCM_BISTRO_X86_IMM_GRP1_EV_IZ)) {
modrm = *ucs_serialize_next(&src_p, const uint8_t);
modrm = *ucs_serialize_next(&ctx->src_p, const uint8_t);
if (modrm == UCM_BISTRO_X86_MODRM_SUB_SP) {
/* sub $imm32, %rsp */
ucs_serialize_next(&src_p, const uint32_t);
ucs_serialize_next(&ctx->src_p, const uint32_t);
goto out_copy_src;
}
} else if ((rex == UCM_BISTRO_X86_REX_W) &&
(opcode == UCM_BISTRO_X86_MOV_EV_GV)) {
modrm = *ucs_serialize_next(&src_p, const uint8_t);
modrm = *ucs_serialize_next(&ctx->src_p, const uint8_t);
mod = modrm >> UCM_BISTRO_X86_MODRM_MOD_SHIFT;
if (modrm == UCM_BISTRO_X86_MODRM_BP_SP) {
/* mov %rsp, %rbp */
Expand All @@ -147,22 +189,22 @@ ucm_bistro_relocate_one(void *dst, const void *src, size_t max_dst_length,
((modrm & UCS_MASK(UCM_BISTRO_X86_MODRM_RM_BITS)) ==
UCM_BISTRO_X86_MODRM_RM_SIB)) {
/* r/m = 0b100, mod = 0b00/0b01/0b10 */
ucs_serialize_next(&src_p, const uint8_t); /* skip SIB */
ucs_serialize_next(&ctx->src_p, const uint8_t); /* skip SIB */
if (mod == UCM_BISTRO_X86_MODRM_MOD_DISP8) {
ucs_serialize_next(&src_p, const uint8_t); /* skip disp8 */
ucs_serialize_next(&ctx->src_p, const uint8_t); /* skip disp8 */
goto out_copy_src;
} else if (mod == UCM_BISTRO_X86_MODRM_MOD_DISP32) {
ucs_serialize_next(&src_p, const uint32_t); /* skip disp32 */
ucs_serialize_next(&ctx->src_p, const uint32_t); /* skip disp32 */
goto out_copy_src;
}
}
} else if ((rex == 0) && ((opcode & UCM_BISTRO_X86_MOV_IR_MASK) ==
UCM_BISTRO_X86_MOV_IR)) {
/* mov $imm32, %reg */
ucs_serialize_next(&src_p, const uint32_t);
ucs_serialize_next(&ctx->src_p, const uint32_t);
goto out_copy_src;
} else if ((rex == 0) && (opcode == UCM_BISTRO_X86_IMM_GRP1_EV_IZ)) {
modrm = *ucs_serialize_next(&src_p, const uint8_t);
modrm = *ucs_serialize_next(&ctx->src_p, const uint8_t);
if (modrm == UCM_BISTRO_X86_MODRM_CMP_RIP) {
/*
* Since we can't assume the new code will be within 32-bit
Expand All @@ -175,29 +217,53 @@ ucm_bistro_relocate_one(void *dst, const void *src, size_t max_dst_length,
* cmpl $imm32, (%rax)
* pop %rax
*/
disp32 = *ucs_serialize_next(&src_p, const uint32_t);
imm32 = *ucs_serialize_next(&src_p, const uint32_t);
cmp_xlt.rax_value = (uintptr_t)UCS_PTR_BYTE_OFFSET(src_p, disp32);
cmp_xlt.cmp_value = imm32;
copy_src = &cmp_xlt;
*dst_length = sizeof(cmp_xlt);
disp32 = *ucs_serialize_next(&ctx->src_p, const int32_t);
imm32 = *ucs_serialize_next(&ctx->src_p, const uint32_t);
cmp.rax_value = (uintptr_t)UCS_PTR_BYTE_OFFSET(ctx->src_p, disp32);
cmp.cmp_value = imm32;
copy_src = &cmp;
dst_length = sizeof(cmp);
goto out_copy;
}
} else if ((rex == 0) && (opcode >= UCM_BISTRO_X86_JCC_FIRST) &&
(opcode <= UCM_BISTRO_X86_JCC_LAST)) {
/*
* Since we can't assume the new code will be within 32-bit range of the
* jump destination, we need to translate the code from:
* jCC $disp8
* to:
* jCC L1
* L1: jmp L2 ; condition 'CC' did not hold
* push $addrhi
* push $addrlo
* ret ; 64-bit jump to destination
* L2: ; continue execution
*/
disp8 = *ucs_serialize_next(&ctx->src_p, const int8_t);
jmpdest = (uintptr_t)UCS_PTR_BYTE_OFFSET(ctx->src_p, disp8);
jcc.jmp_rel[0] = opcode; /* keep original jump condition */
jcc.hi.value = jmpdest >> 32;
jcc.lo.value = jmpdest & UCS_MASK(32);
copy_src = &jcc;
dst_length = sizeof(jcc);
/* Prevent patching past jump target */
ctx->src_end = ucs_min(ctx->src_end, (void*)jmpdest);
goto out_copy;
}

/* Could not recognize the instruction */
return UCS_ERR_UNSUPPORTED;

out_copy_src:
copy_src = src;
*dst_length = UCS_PTR_BYTE_DIFF(src, src_p);
dst_length = UCS_PTR_BYTE_DIFF(copy_src, ctx->src_p);
out_copy:
if (*dst_length > max_dst_length) {
if (UCS_PTR_BYTE_OFFSET(ctx->dst_p, dst_length) > ctx->dst_end) {
return UCS_ERR_BUFFER_TOO_SMALL;
}

*src_length = UCS_PTR_BYTE_DIFF(src, src_p);
memcpy(dst, copy_src, *dst_length);
/* Copy 'dst_length' bytes to ctx->dst_p and advance it */
memcpy(ucs_serialize_next_raw(&ctx->dst_p, void, dst_length), copy_src,
dst_length);
return UCS_OK;
}

Expand All @@ -209,28 +275,30 @@ ucm_bistro_relocate_one(void *dst, const void *src, size_t max_dst_length,
*/
static ucs_status_t
ucm_bistro_relocate_code(void *dst, const void *src, size_t min_src_length,
size_t max_dst_length, size_t *dst_length,
size_t *src_length)
size_t max_dst_length, size_t *dst_length_p,
size_t *src_length_p)
{
size_t src_length_one, dst_length_one;
ucm_bistro_relocate_context_t ctx = {
.src_p = src,
.dst_p = dst,
.dst_end = UCS_PTR_BYTE_OFFSET(dst, max_dst_length),
.src_end = (void*)UINTPTR_MAX
};
ucs_status_t status;

*src_length = 0;
*dst_length = 0;
while (*src_length < min_src_length) {
status = ucm_bistro_relocate_one(UCS_PTR_BYTE_OFFSET(dst, *dst_length),
UCS_PTR_BYTE_OFFSET(src, *src_length),
max_dst_length - *dst_length,
&dst_length_one, &src_length_one);
while (ctx.src_p < UCS_PTR_BYTE_OFFSET(src, min_src_length)) {
status = ucm_bistro_relocate_one(&ctx);
if (status != UCS_OK) {
return status;
}

*dst_length += dst_length_one;
*src_length += src_length_one;
if (ctx.src_p > ctx.src_end) {
return UCS_ERR_UNSUPPORTED;
}
}

ucm_assert(*dst_length <= max_dst_length);
*src_length_p = UCS_PTR_BYTE_DIFF(src, ctx.src_p);
*dst_length_p = UCS_PTR_BYTE_DIFF(dst, ctx.dst_p);
return UCS_OK;
}

Expand Down Expand Up @@ -259,9 +327,13 @@ ucm_bistro_construct_orig_func(const void *func_ptr, size_t patch_len,
ucm_bistro_orig_func_t *orig_func;
ucs_status_t status;
char code_buf[64];
int dladdr_ret;
Dl_info dli;

/* Allocate executable page */
max_code_len = patch_len + sizeof(ucm_bistro_compare_xlt_t);
max_code_len = ucs_max(patch_len + sizeof(ucm_bistro_cmp_xlt_t) +
sizeof(ucm_bistro_jcc_xlt_t),
64);
orig_func = ucm_bistro_allocate_code(sizeof(*orig_func) + max_code_len +
sizeof(*jmp_back));
if (orig_func == NULL) {
Expand All @@ -274,7 +346,9 @@ ucm_bistro_construct_orig_func(const void *func_ptr, size_t patch_len,
status = ucm_bistro_relocate_code(orig_func->code, func_ptr, patch_len,
max_code_len, &code_len, &prefix_len);
if (status != UCS_OK) {
ucm_diag("'%s' could not patch by bistro, code:%s", symbol,
dladdr_ret = dladdr(func_ptr, &dli);
ucm_diag("failed to patch '%s' from %s length %zu code:%s", symbol,
(dladdr_ret != 0) ? dli.dli_fname : "(unknown)", patch_len,
ucm_bistro_dump_code(func_ptr, 16, code_buf,
sizeof(code_buf)));
return UCS_ERR_UNSUPPORTED;
Expand All @@ -298,9 +372,9 @@ ucs_status_t ucm_bistro_patch(void *func_ptr, void *hook, const char *symbol,
void **orig_func_p,
ucm_bistro_restore_point_t **rp)
{
ucm_bistro_jmp_r11_patch_t jmp_r11 = {
.mov_r11 = {0x49, 0xbb},
.jmp_r11 = {0x41, 0xff, 0xe3}
ucm_bistro_jmp_rax_patch_t jmp_rax = {
.mov_rax = {0x48, 0xb8},
.jmp_rax = {0xff, 0xe0}
};
ucm_bistro_jmp_near_patch_t jmp_near = {
.jmp_rel = 0xe9
Expand All @@ -320,9 +394,9 @@ ucs_status_t ucm_bistro_patch(void *func_ptr, void *hook, const char *symbol,
patch = &jmp_near;
patch_len = sizeof(jmp_near);
} else {
jmp_r11.ptr = hook;
patch = &jmp_r11;
patch_len = sizeof(jmp_r11);
jmp_rax.ptr = hook;
patch = &jmp_rax;
patch_len = sizeof(jmp_rax);
}

if (orig_func_p != NULL) {
Expand Down
15 changes: 0 additions & 15 deletions src/ucm/bistro/bistro_x86_64.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,25 +11,10 @@
#include <stdint.h>

#include <ucs/type/status.h>
#include <ucs/sys/compiler_def.h>

#define UCM_BISTRO_PROLOGUE
#define UCM_BISTRO_EPILOGUE

/* Patch by jumping to absolute address loaded from register */
typedef struct ucm_bistro_jmp_r11_patch {
uint8_t mov_r11[2]; /* mov %r11, addr */
void *ptr;
uint8_t jmp_r11[3]; /* jmp r11 */
} UCS_S_PACKED ucm_bistro_jmp_r11_patch_t;


/* Patch by jumping to relative address by immediate displacement */
typedef struct ucm_bistro_jmp_near_patch {
uint8_t jmp_rel; /* opcode: JMP rel32 */
int32_t disp; /* operand: jump displacement */
} UCS_S_PACKED ucm_bistro_jmp_near_patch_t;


/**
* Set library function call hook using Binary Instrumentation
Expand Down
1 change: 1 addition & 0 deletions src/ucm/util/log.c
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,7 @@ static void ucm_log_vsnprintf(char *buf, size_t max, const char *fmt, va_list ap
flags |= UCM_LOG_LTOA_PAD_LEFT;
break;
case 'l':
case 'z':
flags |= UCM_LOG_LTOA_FLAG_LONG;
break;
case '0':
Expand Down
16 changes: 16 additions & 0 deletions test/apps/test_cuda_hook.c
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include <ucp/api/ucp.h>
#include <ucm/api/ucm.h>
#include <cuda_runtime.h>
#include <sys/mman.h>
#include <getopt.h>
#include <cuda.h>

Expand Down Expand Up @@ -81,11 +82,13 @@ int main(int argc, char **argv)
{
static const ucm_event_type_t memtype_events = UCM_EVENT_MEM_TYPE_ALLOC |
UCM_EVENT_MEM_TYPE_FREE;
static const size_t dummy_va_size = 4 * (1ul << 30); /* 4 GB */
static const int num_expected_events = 2;
ucp_context_h context;
ucs_status_t status;
ucp_params_t params;
int use_driver_api;
void *dummy_ptr;
int num_events;
int c;

Expand All @@ -104,6 +107,18 @@ int main(int argc, char **argv)
}
}

/* In order to test long jumps in bistro hooks code, increase address space
* separation by allocaing a large VA space segment.
*/
dummy_ptr = mmap(NULL, dummy_va_size, PROT_READ | PROT_WRITE,
MAP_PRIVATE | MAP_ANONYMOUS, -1, 0);
if (dummy_ptr == MAP_FAILED) {
printf("failed to allocate dummy VA space: %m\n");
return -1;
}

printf("allocated dummy VA space at %p\n", dummy_ptr);

params.field_mask = UCP_PARAM_FIELD_FEATURES;
params.features = UCP_FEATURE_TAG | UCP_FEATURE_STREAM;
status = ucp_init(&params, NULL, &context);
Expand All @@ -126,5 +141,6 @@ int main(int argc, char **argv)

ucp_cleanup(context);

munmap(dummy_ptr, dummy_va_size);
return (num_events >= num_expected_events) ? 0 : -1;
}

0 comments on commit 6c47694

Please sign in to comment.