diff --git a/xls/jit/BUILD b/xls/jit/BUILD index 2d468886ca..bd62a30fe8 100644 --- a/xls/jit/BUILD +++ b/xls/jit/BUILD @@ -521,6 +521,7 @@ cc_library( ":orc_jit", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base", + "@com_google_absl//absl/base:log_severity", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/status", diff --git a/xls/jit/block_jit.cc b/xls/jit/block_jit.cc index 7c8108c5ff..45797c68b7 100644 --- a/xls/jit/block_jit.cc +++ b/xls/jit/block_jit.cc @@ -63,27 +63,49 @@ absl::StatusOr> BlockJit::Create( std::unique_ptr BlockJit::NewContinuation() { return std::unique_ptr(new BlockJitContinuation( block_, this, runtime_, function_.temp_buffer_size, + /*register_sizes=*/ absl::MakeSpan(function_.input_buffer_sizes) .subspan(block_->GetInputPorts().size()), + /*register_alignments=*/ + absl::MakeSpan(function_.input_buffer_prefered_alignments) + .subspan(block_->GetInputPorts().size()), + /*output_port_sizes=*/ absl::MakeSpan(function_.output_buffer_sizes) .subspan(0, block_->GetOutputPorts().size()), + /*output_port_alignments=*/ + absl::MakeSpan(function_.output_buffer_prefered_alignments) + .subspan(0, block_->GetOutputPorts().size()), + /*input_port_sizes=*/ absl::MakeSpan(function_.input_buffer_sizes) + .subspan(0, block_->GetInputPorts().size()), + /*input_port_alignments=*/ + absl::MakeSpan(function_.input_buffer_prefered_alignments) .subspan(0, block_->GetInputPorts().size()))); } absl::Status BlockJit::RunOneCycle(BlockJitContinuation& continuation) { - function_.function(continuation.function_inputs().data(), - continuation.function_outputs().data(), - runtime_->AsStack(continuation.temp_buffer()).data(), - &continuation.GetEvents(), /*user_data=*/nullptr, runtime_, - /*continuation_point=*/0); + function_.RunJittedFunction( + continuation.function_inputs().data(), + continuation.function_outputs().data(), + runtime_->AsStack(continuation.temp_buffer()).data(), + &continuation.GetEvents(), /*user_data=*/nullptr, runtime_, + /*continuation_point=*/0); continuation.SwapRegisters(); return absl::OkStatus(); } namespace { -int64_t SumElements(absl::Span v) { - return absl::c_accumulate(v, int64_t{0}); +// Determine how much memory is needed to hold all the arguments of the given +// sizes. This is larger than just the sum to allow for space for all of them to +// be properly aligned. +int64_t FindFullBufferSize(JitRuntime* runtime, absl::Span sizes, + absl::Span alignments) { + XLS_CHECK_EQ(sizes.size(), alignments.size()); + int64_t total = 0; + for (int64_t i = 0; i < sizes.size(); ++i) { + total += runtime->ShouldAllocateForAlignment(sizes[i], alignments[i]); + } + return total; } std::vector CombineLists(absl::Span a, @@ -101,14 +123,19 @@ std::vector CombineLists(absl::Span a, // Find the start-pointer of each argument in the argument arena. The size of // each argument is given by the sizes span. -std::vector CalculatePointers(uint8_t* base_ptr, - absl::Span sizes) { - size_t tot = 0; +std::vector CalculatePointers(JitRuntime* runtime, + absl::Span base_buffer, + absl::Span sizes, + absl::Span alignments) { + XLS_CHECK_EQ(sizes.size(), alignments.size()); std::vector out; out.reserve(sizes.size()); - for (size_t s : sizes) { - out.push_back(base_ptr + tot); - tot += s; + for (int64_t i = 0; i < sizes.size(); ++i) { + auto aligned_span = runtime->AsAligned(base_buffer, alignments[i]); + XLS_CHECK_GE(aligned_span.size(), sizes[i]) + << "Buffer for " << i << " element not large enough!"; + out.push_back(aligned_span.data()); + base_buffer = aligned_span.subspan(sizes[i]); } return out; } @@ -117,23 +144,35 @@ std::vector CalculatePointers(uint8_t* base_ptr, BlockJitContinuation::BlockJitContinuation( Block* block, BlockJit* jit, JitRuntime* runtime, size_t temp_size, absl::Span register_sizes, + absl::Span register_alignments, absl::Span output_port_sizes, - absl::Span input_port_sizes) + absl::Span output_port_alignments, + absl::Span input_port_sizes, + absl::Span input_port_alignments) : block_(block), block_jit_(jit), runtime_(runtime), - register_arena_left_(SumElements(register_sizes), 0), + register_arena_left_( + FindFullBufferSize(runtime, register_sizes, register_alignments), 0), register_arena_right_(register_arena_left_.size(), 0xff), - output_port_arena_(SumElements(output_port_sizes), 0xff), - input_port_arena_(SumElements(input_port_sizes), 0xff), + output_port_arena_(FindFullBufferSize(runtime, output_port_sizes, + output_port_alignments), + 0xff), + input_port_arena_( + FindFullBufferSize(runtime, input_port_sizes, input_port_alignments), + 0xff), register_pointers_(BlockJitContinuation::IOSpace( - CalculatePointers(register_arena_left_.data(), register_sizes), - CalculatePointers(register_arena_right_.data(), register_sizes), + CalculatePointers(runtime, absl::MakeSpan(register_arena_left_), + register_sizes, register_alignments), + CalculatePointers(runtime, absl::MakeSpan(register_arena_right_), + register_sizes, register_alignments), BlockJitContinuation::IOSpace::RegisterSpace::kLeft)), output_port_pointers_( - CalculatePointers(output_port_arena_.data(), output_port_sizes)), + CalculatePointers(runtime, absl::MakeSpan(output_port_arena_), + output_port_sizes, output_port_alignments)), input_port_pointers_( - CalculatePointers(input_port_arena_.data(), input_port_sizes)), + CalculatePointers(runtime, absl::MakeSpan(input_port_arena_), + input_port_sizes, input_port_alignments)), full_input_pointer_set_(BlockJitContinuation::IOSpace( CombineLists(input_port_pointers_, register_pointers_.left()), CombineLists(input_port_pointers_, register_pointers_.right()), @@ -461,9 +500,7 @@ class BlockContinuationJitWrapper final : public BlockContinuation { } return *temporary_regs_; } - const InterpreterEvents& events() final { - return continuation_->GetEvents(); - } + const InterpreterEvents& events() final { return continuation_->GetEvents(); } absl::Status RunOneCycle( const absl::flat_hash_map& inputs) final { temporary_outputs_.reset(); diff --git a/xls/jit/block_jit.h b/xls/jit/block_jit.h index 174bde9726..d7ce08c5e7 100644 --- a/xls/jit/block_jit.h +++ b/xls/jit/block_jit.h @@ -176,8 +176,11 @@ class BlockJitContinuation { BlockJitContinuation(Block* block, BlockJit* jit, JitRuntime* runtime, size_t temp_size, absl::Span register_sizes, + absl::Span register_alignments, absl::Span output_port_sizes, - absl::Span input_port_sizes); + absl::Span output_port_alignments, + absl::Span input_port_sizes, + absl::Span input_port_alignments); void SwapRegisters() { register_pointers_.Swap(); @@ -206,21 +209,24 @@ class BlockJitContinuation { // memory live for the pointers. std::vector input_port_arena_; - // The register pointer file. + // The register pointer file, aligned as required. IOSpace register_pointers_; - // The output port pointers. + // The output port pointers, aligned as required. const std::vector output_port_pointers_; - // The input port pointers. + // The input port pointers, aligned as required. const std::vector input_port_pointers_; - // The input value pointers the register pointers. + // The input value pointers the register pointers, aligned as required. IOSpace full_input_pointer_set_; - // The output value pointers followed by the register pointers. + // The output value pointers followed by the register pointers, aligned as + // required. IOSpace full_output_pointer_set_; // Data block to store temporary data in. std::vector temp_data_arena_; + // Data block to store temporary data in, aligned as required. + uint8_t* temp_data_ptr_; InterpreterEvents events_; diff --git a/xls/jit/function_base_jit.cc b/xls/jit/function_base_jit.cc index 9fc9897949..b97197bba8 100644 --- a/xls/jit/function_base_jit.cc +++ b/xls/jit/function_base_jit.cc @@ -15,6 +15,7 @@ #include #include +#include #include #include #include @@ -47,6 +48,7 @@ #include "xls/common/status/status_macros.h" #include "xls/ir/block.h" #include "xls/ir/call_graph.h" +#include "xls/ir/events.h" #include "xls/ir/function_base.h" #include "xls/ir/node.h" #include "xls/ir/node_iterator.h" @@ -56,6 +58,7 @@ #include "xls/ir/type.h" #include "xls/jit/ir_builder_visitor.h" #include "xls/jit/jit_channel_queue.h" +#include "xls/jit/jit_runtime.h" #include "xls/jit/llvm_type_converter.h" #include "xls/jit/orc_jit.h" @@ -1266,12 +1269,24 @@ absl::StatusOr BuildFunctionAndDependencies( for (const Node* input : GetJittedFunctionInputs(xls_function)) { jitted_function.input_buffer_sizes.push_back( jit_context.type_converter().GetTypeByteSize(input->GetType())); + jitted_function.input_buffer_prefered_alignments.push_back( + jit_context.type_converter().GetTypePreferredAlignment( + input->GetType())); + jitted_function.input_buffer_abi_alignments.push_back( + jit_context.type_converter().GetTypeAbiAlignment( + input->GetType())); jitted_function.packed_input_buffer_sizes.push_back( jit_context.type_converter().GetPackedTypeByteSize(input->GetType())); } for (const Node* output : GetJittedFunctionOutputs(xls_function)) { jitted_function.output_buffer_sizes.push_back( jit_context.type_converter().GetTypeByteSize(OutputType(output))); + jitted_function.output_buffer_prefered_alignments.push_back( + jit_context.type_converter().GetTypePreferredAlignment( + OutputType(output))); + jitted_function.output_buffer_abi_alignments.push_back( + jit_context.type_converter().GetTypeAbiAlignment( + OutputType(output))); jitted_function.packed_output_buffer_sizes.push_back( jit_context.type_converter().GetPackedTypeByteSize(OutputType(output))); } @@ -1312,4 +1327,37 @@ absl::StatusOr BuildBlockFunction(Block* block, /*build_packed_wrapper=*/false); } +namespace { +absl::Status VerifyOffsetAlignments(uint8_t const* const* const ptrs, + absl::Span alignments) { + for (int64_t i = 0; i < alignments.size(); ++i) { + XLS_RET_CHECK_EQ(absl::bit_cast(ptrs[i]) % alignments[i], 0) + << "value at index " << i << " is not aligned to " << alignments[i] + << ". Value is 0x" << std::hex << absl::bit_cast(ptrs[i]); + } + return absl::OkStatus(); +} +} // namespace + +int64_t JittedFunctionBase::RunJittedFunction( + const uint8_t* const* inputs, uint8_t* const* outputs, void* temp_buffer, + InterpreterEvents* events, void* user_data, JitRuntime* jit_runtime, + int64_t continuation_point) const { + XLS_DCHECK_OK(VerifyOffsetAlignments(inputs, input_buffer_abi_alignments)); + XLS_DCHECK_OK(VerifyOffsetAlignments(outputs, output_buffer_abi_alignments)); + return function(inputs, outputs, temp_buffer, events, user_data, jit_runtime, + continuation_point); +} + +std::optional JittedFunctionBase::RunPackedJittedFunction( + const uint8_t* const* inputs, uint8_t* const* outputs, void* temp_buffer, + InterpreterEvents* events, void* user_data, JitRuntime* jit_runtime, + int64_t continuation_point) const { + // TODO(allight): Do actual checks here. + if (packed_function) { + return (*packed_function)(inputs, outputs, temp_buffer, events, user_data, + jit_runtime, continuation_point); + } + return std::nullopt; +} } // namespace xls diff --git a/xls/jit/function_base_jit.h b/xls/jit/function_base_jit.h index 3854618c4d..5a74f74d9f 100644 --- a/xls/jit/function_base_jit.h +++ b/xls/jit/function_base_jit.h @@ -14,13 +14,16 @@ #ifndef XLS_JIT_FUNCTION_BASE_JIT_H_ #define XLS_JIT_FUNCTION_BASE_JIT_H_ +#include #include #include #include #include "absl/container/flat_hash_map.h" +#include "absl/status/statusor.h" #include "xls/ir/events.h" #include "xls/ir/function.h" +#include "xls/ir/function_base.h" #include "xls/ir/proc.h" #include "xls/jit/jit_channel_queue.h" #include "xls/jit/jit_runtime.h" @@ -68,16 +71,37 @@ struct JittedFunctionBase { std::string function_name; JitFunctionType function; + // Execute the actual function (after verifying some invariants) + int64_t RunJittedFunction(const uint8_t* const* inputs, + uint8_t* const* outputs, void* temp_buffer, + InterpreterEvents* events, void* user_data, + JitRuntime* jit_runtime, + int64_t continuation_point) const; + // Name and function pointer for the jitted function which accepts/produces // arguments/results in a packed format. Only exists for JITted // xls::Functions, not procs. std::optional packed_function_name; std::optional packed_function; + // Execute the actual function (after verifying some invariants) + std::optional RunPackedJittedFunction( + const uint8_t* const* inputs, uint8_t* const* outputs, void* temp_buffer, + InterpreterEvents* events, void* user_data, JitRuntime* jit_runtime, + int64_t continuation_point) const; + // Sizes of the inputs/outputs in native LLVM format for `function_base`. std::vector input_buffer_sizes; std::vector output_buffer_sizes; + // alignment preferences of each input/output buffer. + std::vector input_buffer_prefered_alignments; + std::vector output_buffer_prefered_alignments; + + // alignment ABI requirements of each input/output buffer. + std::vector input_buffer_abi_alignments; + std::vector output_buffer_abi_alignments; + // Sizes of the inputs/outputs in packed format for `function_base`. std::vector packed_input_buffer_sizes; std::vector packed_output_buffer_sizes; diff --git a/xls/jit/function_jit.cc b/xls/jit/function_jit.cc index e6c45c89a8..2ebafc42d3 100644 --- a/xls/jit/function_jit.cc +++ b/xls/jit/function_jit.cc @@ -73,17 +73,32 @@ absl::StatusOr> FunctionJit::CreateInternal( llvm::DataLayout data_layout, OrcJit::CreateDataLayout(/*aot_specification=*/emit_object_code)); jit->jit_runtime_ = std::make_unique(data_layout); + JitRuntime& runtime = *jit->jit_runtime_; XLS_ASSIGN_OR_RETURN(jit->jitted_function_base_, BuildFunction(xls_function, *jit->orc_jit_)); // Pre-allocate argument, result, and temporary buffers. - for (int64_t i = 0; i < xls_function->params().size(); ++i) { - jit->arg_buffers_.push_back(std::vector(jit->GetArgTypeSize(i))); - jit->arg_buffer_ptrs_.push_back(jit->arg_buffers_.back().data()); + for (int i = 0; i < xls_function->params().size(); ++i) { + jit->arg_buffers_.push_back( + std::vector(runtime.ShouldAllocateForAlignment( + jit->GetArgTypeSize(i), jit->GetArgTypeAlignment(i)))); + jit->arg_buffer_ptrs_.push_back( + runtime + .AsAligned(absl::MakeSpan(jit->arg_buffers_.back()), + jit->GetArgTypeAlignment(i)) + .data()); } - jit->result_buffer_.resize(jit->GetReturnTypeSize()); + jit->result_buffer_.resize(runtime.ShouldAllocateForAlignment( + jit->GetReturnTypeSize(), jit->GetReturnTypeAlignment())); + jit->result_buffer_ptr_ = + runtime + .AsAligned(absl::MakeSpan(jit->result_buffer_), + jit->GetReturnTypeAlignment()) + .data(); jit->temp_buffer_.resize( - jit->jit_runtime_->ShouldAllocateForStack(jit->GetTempBufferSize())); + runtime.ShouldAllocateForStack(jit->GetTempBufferSize())); + jit->temp_buffer_ptr_ = + runtime.AsStack(absl::MakeSpan(jit->temp_buffer_)).data(); return jit; } @@ -115,9 +130,9 @@ absl::StatusOr> FunctionJit::Run( absl::MakeSpan(arg_buffer_ptrs_))); InterpreterEvents events; - InvokeJitFunction(arg_buffer_ptrs_, result_buffer_.data(), &events); + InvokeJitFunction(arg_buffer_ptrs_, result_buffer_ptr_, &events); Value result = jit_runtime_->UnpackBuffer( - result_buffer_.data(), xls_function_->return_value()->GetType()); + result_buffer_ptr_, xls_function_->return_value()->GetType()); return InterpreterResult{std::move(result), std::move(events)}; } @@ -153,9 +168,8 @@ void FunctionJit::InvokeJitFunction( absl::Span arg_buffers, uint8_t* output_buffer, InterpreterEvents* events) { uint8_t* output_buffers[1] = {output_buffer}; - jitted_function_base_.function( - arg_buffers.data(), output_buffers, - jit_runtime_->AsStack(absl::MakeSpan(temp_buffer_)).data(), events, + jitted_function_base_.RunJittedFunction( + arg_buffers.data(), output_buffers, temp_buffer_ptr_, events, /*user_data=*/nullptr, runtime(), /*continuation_point=*/0); } diff --git a/xls/jit/function_jit.h b/xls/jit/function_jit.h index 5fa518a047..ae13e5bd24 100644 --- a/xls/jit/function_jit.h +++ b/xls/jit/function_jit.h @@ -15,6 +15,7 @@ #ifndef XLS_JIT_FUNCTION_JIT_H_ #define XLS_JIT_FUNCTION_JIT_H_ +#include #include #include #include @@ -117,7 +118,7 @@ class FunctionJit { InterpreterEvents events; uint8_t* output_buffers[1] = {result_buffer}; - jitted_function_base_.packed_function.value()( + jitted_function_base_.RunPackedJittedFunction( arg_buffers, output_buffers, temp_buffer_.data(), &events, /*user_data=*/nullptr, runtime(), /*continuation_point=*/0); @@ -146,9 +147,15 @@ class FunctionJit { int64_t GetArgTypeSize(int arg_index) const { return jitted_function_base_.input_buffer_sizes.at(arg_index); } + int64_t GetArgTypeAlignment(int arg_index) const { + return jitted_function_base_.input_buffer_prefered_alignments.at(arg_index); + } int64_t GetReturnTypeSize() const { return jitted_function_base_.output_buffer_sizes[0]; } + int64_t GetReturnTypeAlignment() const { + return jitted_function_base_.output_buffer_prefered_alignments[0]; + } // Gets the size of the compiled function's arguments (or return value) in the // packed layout. @@ -227,12 +234,19 @@ class FunctionJit { // Buffers to hold the arguments, result, temporary storage. This is allocated // once and then re-used with each invocation of Run. Not thread-safe. + // These buffer pointers cannot be used directly as they might not be + // correctly aligned. the '_ptr[s]_' version must be used instead. std::vector> arg_buffers_; std::vector result_buffer_; std::vector temp_buffer_; - // Raw pointers to the buffers held in `arg_buffers_`. + // Raw pointers to the buffers held in `arg_buffers_` with each pointer having + // correct alignment. std::vector arg_buffer_ptrs_; + // Raw pointer to the buffer 'result_buffer_' with correct alignment. + uint8_t* result_buffer_ptr_ = nullptr; + // Raw pointer to the buffer 'temp_buffer_' with correct alignment. + uint8_t* temp_buffer_ptr_ = nullptr; JittedFunctionBase jitted_function_base_; std::unique_ptr jit_runtime_; diff --git a/xls/jit/function_jit_test.cc b/xls/jit/function_jit_test.cc index 840e009b4a..e0c0513a1e 100644 --- a/xls/jit/function_jit_test.cc +++ b/xls/jit/function_jit_test.cc @@ -249,7 +249,10 @@ TEST(FunctionJitTest, PackedAndUnpackedSmokeWide) { XLS_ASSERT_OK_AND_ASSIGN(auto jit, FunctionJit::Create(function)); - uint8_t input_data[] = {0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x8, 0x9, 0xa}; + // 80-bit data is represented as an i128 with 16 byte alignment. + // TODO(allight): 2023-11-30: The fact this is needed is unfortunate. + alignas(16) + uint8_t input_data[] = {0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x8, 0x9, 0xa}; { uint8_t output_data[10]; PackedBitsView<80> input(input_data, 0); @@ -260,7 +263,9 @@ TEST(FunctionJitTest, PackedAndUnpackedSmokeWide) { } { - uint8_t output_data[10]; + // 80-bit data is represented as an i128 with 16 byte alignment. + // TODO(allight): 2023-11-30: The fact this is needed is unfortunate. + alignas(16) uint8_t output_data[10]; BitsView<80> input(input_data); MutableBitsView<80> output(output_data); XLS_ASSERT_OK(jit->RunWithUnpackedViews(input, output)); diff --git a/xls/jit/jit_runtime.cc b/xls/jit/jit_runtime.cc index 22172495f4..a44a2ede0c 100644 --- a/xls/jit/jit_runtime.cc +++ b/xls/jit/jit_runtime.cc @@ -147,10 +147,10 @@ void JitRuntime::BlitValueToBuffer(const Value& value, const Type* type, BlitValueToBufferInternal(value, type, buffer); } -absl::Span JitRuntime::AsStack(absl::Span buffer) { - return buffer.subspan( - llvm::offsetToAlignment(reinterpret_cast(buffer.data()), - data_layout_.getStackAlignment())); +absl::Span JitRuntime::AsAligned(absl::Span buffer, + int64_t alignment) const { + return buffer.subspan(llvm::offsetToAlignment( + reinterpret_cast(buffer.data()), llvm::Align(alignment))); } void JitRuntime::BlitValueToBufferInternal(const Value& value, const Type* type, diff --git a/xls/jit/jit_runtime.h b/xls/jit/jit_runtime.h index ef305a1479..7c61dd1851 100644 --- a/xls/jit/jit_runtime.h +++ b/xls/jit/jit_runtime.h @@ -57,6 +57,16 @@ class JitRuntime { const llvm::DataLayout& data_layout() { return data_layout_; } + // Returns the number of bytes that should be allocated for a native LLVM + // value storing `size` bytes with `alignment` alignment. + // + // Note: A user cannot just allocate `size` bytes in every scenario, as there + // may be alignment constraints. This method tells us how much to + // overallocate. + size_t ShouldAllocateForAlignment(size_t size, int64_t alignment) const { + return size + alignment - 1; + } + // Returns the number of bytes that should be allocated for a native LLVM // stack storing `size` bytes. // @@ -64,22 +74,39 @@ class JitRuntime { // may be stack alignment constraints. This method tells us how much to // overallocate. size_t ShouldAllocateForStack(size_t size) { - return size + data_layout_.getStackAlignment().value() - 1; + return ShouldAllocateForAlignment(size, + data_layout_.getStackAlignment().value()); } + // Converts the provided buffer into a native LLVM buffer by aligning to the + // given alignment. + // + // May reduce the size of the buffer; if the buffer was allocated to hold at + // least `ShouldAllocateForAlignment(size, alignment)` bytes, then the result + // is guaranteed to hold at least `size` bytes. + absl::Span AsAligned(absl::Span buffer, + int64_t alignment) const; + // Converts the provided buffer into a native LLVM stack by aligning to the // memory model's requirements. // // May reduce the size of the buffer; if the buffer was allocated to hold at // least `ShouldAllocateForStack(size)` bytes, then the result is guaranteed // to hold at least `size` bytes. - absl::Span AsStack(absl::Span buffer); + absl::Span AsStack(absl::Span buffer) { + return AsAligned(buffer, data_layout_.getStackAlignment().value()); + } int64_t GetTypeByteSize(Type* xls_type) { absl::MutexLock lock(&mutex_); return type_converter_->GetTypeByteSize(xls_type); } + int64_t GetTypeAlignment(Type* xls_type) { + absl::MutexLock lock(&mutex_); + return type_converter_->GetTypePreferredAlignment(xls_type); + } + private: Value UnpackBufferInternal(const uint8_t* buffer, const Type* result_type, bool unpoison) ABSL_SHARED_LOCKS_REQUIRED(mutex_); diff --git a/xls/jit/llvm_type_converter.cc b/xls/jit/llvm_type_converter.cc index 21513a1e5b..5ae41ca06d 100644 --- a/xls/jit/llvm_type_converter.cc +++ b/xls/jit/llvm_type_converter.cc @@ -169,6 +169,19 @@ int64_t LlvmTypeConverter::GetTypeByteSize(const Type* type) const { return data_layout_.getTypeAllocSize(ConvertToLlvmType(type)).getFixedValue(); } +int64_t LlvmTypeConverter::GetTypeAbiAlignment(const Type* type) const { + return data_layout_.getABITypeAlign(ConvertToLlvmType(type)).value(); +} +int64_t LlvmTypeConverter::GetTypePreferredAlignment(const Type* type) const { + // NB We ask for stack alignment since we often memcpy things around which is + // slightly faster at higher alignments. + llvm::Align alignment = + data_layout_.getPrefTypeAlign(ConvertToLlvmType(type)); + if (data_layout_.exceedsNaturalStackAlignment(alignment)) { + return alignment.value(); + } + return data_layout_.getStackAlignment().value(); +} int64_t LlvmTypeConverter::AlignFor(const Type* type, int64_t offset) const { llvm::Align alignment = data_layout_.getPrefTypeAlign(ConvertToLlvmType(type)); diff --git a/xls/jit/llvm_type_converter.h b/xls/jit/llvm_type_converter.h index c52d07a01f..648aa1bca8 100644 --- a/xls/jit/llvm_type_converter.h +++ b/xls/jit/llvm_type_converter.h @@ -69,6 +69,12 @@ class LlvmTypeConverter { // DataLayout object can handle ~all of the work for us. int64_t GetTypeByteSize(const Type* type) const; + // Returns the preferred alignment for the given type. + int64_t GetTypePreferredAlignment(const Type* type) const; + + // Returns the alignment requirement for the given type. + int64_t GetTypeAbiAlignment(const Type* type) const; + // Returns the next position (starting from offset) where LLVM would consider // an object of the given type to have ended; specifically, the next position // that matches the greater of the stack alignment and the type's preferred diff --git a/xls/jit/proc_jit.cc b/xls/jit/proc_jit.cc index e3016d70a6..1e7b436aa6 100644 --- a/xls/jit/proc_jit.cc +++ b/xls/jit/proc_jit.cc @@ -46,13 +46,19 @@ ProcJitContinuation::ProcJitContinuation(Proc* proc, int64_t temp_buffer_size, // Pre-allocate input, output, and temporary buffers. for (Param* param : proc->params()) { int64_t param_size = jit_runtime_->GetTypeByteSize(param->GetType()); - int64_t buffer_size = jit_runtime_->ShouldAllocateForStack(param_size); + int64_t param_align = jit_runtime->GetTypeAlignment(param->GetType()); + int64_t buffer_size = + jit_runtime_->ShouldAllocateForAlignment(param_size, param_align); input_buffers_.push_back(std::vector(buffer_size)); output_buffers_.push_back(std::vector(buffer_size)); input_ptrs_.push_back( - jit_runtime_->AsStack(absl::MakeSpan(input_buffers_.back())).data()); + jit_runtime_ + ->AsAligned(absl::MakeSpan(input_buffers_.back()), param_align) + .data()); output_ptrs_.push_back( - jit_runtime_->AsStack(absl::MakeSpan(output_buffers_.back())).data()); + jit_runtime_ + ->AsAligned(absl::MakeSpan(output_buffers_.back()), param_align) + .data()); } // Write initial state value to the input_buffer. @@ -112,7 +118,7 @@ absl::StatusOr ProcJit::Tick(ProcContinuation& continuation) const { // The jitted function returns the early exit point at which execution // halted. A return value of zero indicates that the tick completed. - int64_t next_continuation_point = jitted_function_base_.function( + int64_t next_continuation_point = jitted_function_base_.RunJittedFunction( cont->GetInputBuffers().data(), cont->GetOutputBuffers().data(), cont->GetTempBuffer().data(), &cont->GetEvents(), /*user_data=*/nullptr, runtime(), cont->GetContinuationPoint()); diff --git a/xls/jit/proc_jit.h b/xls/jit/proc_jit.h index f421ca11aa..6ed9cbf8bc 100644 --- a/xls/jit/proc_jit.h +++ b/xls/jit/proc_jit.h @@ -85,11 +85,14 @@ class ProcJitContinuation : public ProcContinuation { InterpreterEvents events_; // Buffers to hold inputs, outputs, and temporary storage. This is allocated - // once and then re-used with each invocation of Run. Not thread-safe. + // once and then re-used with each invocation of Run. Not thread-safe. These + // cannot be used directly because the pointers might not be aligned. Use the + // '_ptr_' versions instead. std::vector> input_buffers_; std::vector> output_buffers_; - // Raw pointers to the buffers held in `input_buffers_` and `output_buffers_`. + // Raw pointers to the buffers held in `input_buffers_` and `output_buffers_`, + // aligned as required. std::vector input_ptrs_; std::vector output_ptrs_; std::vector temp_buffer_; diff --git a/xls/tools/benchmark_main.cc b/xls/tools/benchmark_main.cc index a649ec46f6..bbe2fc5836 100644 --- a/xls/tools/benchmark_main.cc +++ b/xls/tools/benchmark_main.cc @@ -553,8 +553,15 @@ absl::StatusOr ConvertToJitArguments( pointers.reserve(args.size()); for (int64_t i = 0; i < args.size(); ++i) { buffers.push_back( - std::vector(runtime->GetTypeByteSize(params[i]), 0)); - pointers.push_back(buffers.back().data()); + std::vector(runtime->ShouldAllocateForAlignment( + runtime->GetTypeByteSize(params[i]), + runtime->GetTypeAlignment(params[i])), + 0)); + pointers.push_back(runtime + ->AsAligned(absl::MakeSpan(buffers.back()), + runtime->GetTypeAlignment(params[i])) + .data()); + XLS_CHECK_NE(pointers.back(), nullptr); } arg_buffers.push_back(std::move(buffers)); arg_pointers.push_back(pointers); @@ -593,15 +600,18 @@ absl::Status RunFunctionInterpreterAndJit(Function* function, // The JIT is much faster so run many times. InterpreterEvents events; - std::vector result_buffer(jit->GetReturnTypeSize()); + std::vector result_buffer(jit->runtime()->ShouldAllocateForAlignment( + jit->GetReturnTypeSize(), jit->GetReturnTypeAlignment())); + absl::Span result_aligned = jit->runtime()->AsAligned( + absl::MakeSpan(result_buffer), jit->GetReturnTypeAlignment()); XLS_ASSIGN_OR_RETURN( float jit_run_rate, CountRate( [&]() -> absl::Status { for (int64_t i = 0; i < kJitRunMultiplier; ++i) { for (const std::vector& pointers : jit_arg_pointers) { - XLS_CHECK_OK(jit->RunWithViews( - pointers, absl::MakeSpan(result_buffer), &events)); + XLS_CHECK_OK( + jit->RunWithViews(pointers, result_aligned, &events)); } } return absl::OkStatus();