Skip to content

Commit

Permalink
Correctly align jit function inputs.
Browse files Browse the repository at this point in the history
We were assuming that 'align std::vector<uint8_t>::data() >= align arbitrary xls type'. This was true for a while but with the change to S128 having a 16 byte alignment in a recent llvm update it is no longer always true.

Fixes: #1205
PiperOrigin-RevId: 587185658
  • Loading branch information
allight authored and copybara-github committed Dec 2, 2023
1 parent da93f3b commit 5afca97
Show file tree
Hide file tree
Showing 15 changed files with 275 additions and 61 deletions.
1 change: 1 addition & 0 deletions xls/jit/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -521,6 +521,7 @@ cc_library(
":orc_jit",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/base",
"@com_google_absl//absl/base:log_severity",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/status",
Expand Down
85 changes: 61 additions & 24 deletions xls/jit/block_jit.cc
Original file line number Diff line number Diff line change
Expand Up @@ -63,27 +63,49 @@ absl::StatusOr<std::unique_ptr<BlockJit>> BlockJit::Create(
std::unique_ptr<BlockJitContinuation> BlockJit::NewContinuation() {
return std::unique_ptr<BlockJitContinuation>(new BlockJitContinuation(
block_, this, runtime_, function_.temp_buffer_size,
/*register_sizes=*/
absl::MakeSpan(function_.input_buffer_sizes)
.subspan(block_->GetInputPorts().size()),
/*register_alignments=*/
absl::MakeSpan(function_.input_buffer_prefered_alignments)
.subspan(block_->GetInputPorts().size()),
/*output_port_sizes=*/
absl::MakeSpan(function_.output_buffer_sizes)
.subspan(0, block_->GetOutputPorts().size()),
/*output_port_alignments=*/
absl::MakeSpan(function_.output_buffer_prefered_alignments)
.subspan(0, block_->GetOutputPorts().size()),
/*input_port_sizes=*/
absl::MakeSpan(function_.input_buffer_sizes)
.subspan(0, block_->GetInputPorts().size()),
/*input_port_alignments=*/
absl::MakeSpan(function_.input_buffer_prefered_alignments)
.subspan(0, block_->GetInputPorts().size())));
}

absl::Status BlockJit::RunOneCycle(BlockJitContinuation& continuation) {
function_.function(continuation.function_inputs().data(),
continuation.function_outputs().data(),
runtime_->AsStack(continuation.temp_buffer()).data(),
&continuation.GetEvents(), /*user_data=*/nullptr, runtime_,
/*continuation_point=*/0);
function_.RunJittedFunction(
continuation.function_inputs().data(),
continuation.function_outputs().data(),
runtime_->AsStack(continuation.temp_buffer()).data(),
&continuation.GetEvents(), /*user_data=*/nullptr, runtime_,
/*continuation_point=*/0);
continuation.SwapRegisters();
return absl::OkStatus();
}

namespace {
int64_t SumElements(absl::Span<int64_t const> v) {
return absl::c_accumulate(v, int64_t{0});
// Determine how much memory is needed to hold all the arguments of the given
// sizes. This is larger than just the sum to allow for space for all of them to
// be properly aligned.
int64_t FindFullBufferSize(JitRuntime* runtime, absl::Span<int64_t const> sizes,
absl::Span<int64_t const> alignments) {
XLS_CHECK_EQ(sizes.size(), alignments.size());
int64_t total = 0;
for (int64_t i = 0; i < sizes.size(); ++i) {
total += runtime->ShouldAllocateForAlignment(sizes[i], alignments[i]);
}
return total;
}

std::vector<uint8_t*> CombineLists(absl::Span<uint8_t* const> a,
Expand All @@ -101,14 +123,19 @@ std::vector<uint8_t*> CombineLists(absl::Span<uint8_t* const> a,

// Find the start-pointer of each argument in the argument arena. The size of
// each argument is given by the sizes span.
std::vector<uint8_t*> CalculatePointers(uint8_t* base_ptr,
absl::Span<const int64_t> sizes) {
size_t tot = 0;
std::vector<uint8_t*> CalculatePointers(JitRuntime* runtime,
absl::Span<uint8_t> base_buffer,
absl::Span<const int64_t> sizes,
absl::Span<const int64_t> alignments) {
XLS_CHECK_EQ(sizes.size(), alignments.size());
std::vector<uint8_t*> out;
out.reserve(sizes.size());
for (size_t s : sizes) {
out.push_back(base_ptr + tot);
tot += s;
for (int64_t i = 0; i < sizes.size(); ++i) {
auto aligned_span = runtime->AsAligned(base_buffer, alignments[i]);
XLS_CHECK_GE(aligned_span.size(), sizes[i])
<< "Buffer for " << i << " element not large enough!";
out.push_back(aligned_span.data());
base_buffer = aligned_span.subspan(sizes[i]);
}
return out;
}
Expand All @@ -117,23 +144,35 @@ std::vector<uint8_t*> CalculatePointers(uint8_t* base_ptr,
BlockJitContinuation::BlockJitContinuation(
Block* block, BlockJit* jit, JitRuntime* runtime, size_t temp_size,
absl::Span<const int64_t> register_sizes,
absl::Span<const int64_t> register_alignments,
absl::Span<const int64_t> output_port_sizes,
absl::Span<const int64_t> input_port_sizes)
absl::Span<const int64_t> output_port_alignments,
absl::Span<const int64_t> input_port_sizes,
absl::Span<const int64_t> input_port_alignments)
: block_(block),
block_jit_(jit),
runtime_(runtime),
register_arena_left_(SumElements(register_sizes), 0),
register_arena_left_(
FindFullBufferSize(runtime, register_sizes, register_alignments), 0),
register_arena_right_(register_arena_left_.size(), 0xff),
output_port_arena_(SumElements(output_port_sizes), 0xff),
input_port_arena_(SumElements(input_port_sizes), 0xff),
output_port_arena_(FindFullBufferSize(runtime, output_port_sizes,
output_port_alignments),
0xff),
input_port_arena_(
FindFullBufferSize(runtime, input_port_sizes, input_port_alignments),
0xff),
register_pointers_(BlockJitContinuation::IOSpace(
CalculatePointers(register_arena_left_.data(), register_sizes),
CalculatePointers(register_arena_right_.data(), register_sizes),
CalculatePointers(runtime, absl::MakeSpan(register_arena_left_),
register_sizes, register_alignments),
CalculatePointers(runtime, absl::MakeSpan(register_arena_right_),
register_sizes, register_alignments),
BlockJitContinuation::IOSpace::RegisterSpace::kLeft)),
output_port_pointers_(
CalculatePointers(output_port_arena_.data(), output_port_sizes)),
CalculatePointers(runtime, absl::MakeSpan(output_port_arena_),
output_port_sizes, output_port_alignments)),
input_port_pointers_(
CalculatePointers(input_port_arena_.data(), input_port_sizes)),
CalculatePointers(runtime, absl::MakeSpan(input_port_arena_),
input_port_sizes, input_port_alignments)),
full_input_pointer_set_(BlockJitContinuation::IOSpace(
CombineLists(input_port_pointers_, register_pointers_.left()),
CombineLists(input_port_pointers_, register_pointers_.right()),
Expand Down Expand Up @@ -461,9 +500,7 @@ class BlockContinuationJitWrapper final : public BlockContinuation {
}
return *temporary_regs_;
}
const InterpreterEvents& events() final {
return continuation_->GetEvents();
}
const InterpreterEvents& events() final { return continuation_->GetEvents(); }
absl::Status RunOneCycle(
const absl::flat_hash_map<std::string, Value>& inputs) final {
temporary_outputs_.reset();
Expand Down
18 changes: 12 additions & 6 deletions xls/jit/block_jit.h
Original file line number Diff line number Diff line change
Expand Up @@ -176,8 +176,11 @@ class BlockJitContinuation {
BlockJitContinuation(Block* block, BlockJit* jit, JitRuntime* runtime,
size_t temp_size,
absl::Span<const int64_t> register_sizes,
absl::Span<const int64_t> register_alignments,
absl::Span<const int64_t> output_port_sizes,
absl::Span<const int64_t> input_port_sizes);
absl::Span<const int64_t> output_port_alignments,
absl::Span<const int64_t> input_port_sizes,
absl::Span<const int64_t> input_port_alignments);

void SwapRegisters() {
register_pointers_.Swap();
Expand Down Expand Up @@ -206,21 +209,24 @@ class BlockJitContinuation {
// memory live for the pointers.
std::vector<uint8_t> input_port_arena_;

// The register pointer file.
// The register pointer file, aligned as required.
IOSpace register_pointers_;

// The output port pointers.
// The output port pointers, aligned as required.
const std::vector<uint8_t*> output_port_pointers_;

// The input port pointers.
// The input port pointers, aligned as required.
const std::vector<uint8_t*> input_port_pointers_;
// The input value pointers the register pointers.
// The input value pointers the register pointers, aligned as required.
IOSpace full_input_pointer_set_;
// The output value pointers followed by the register pointers.
// The output value pointers followed by the register pointers, aligned as
// required.
IOSpace full_output_pointer_set_;

// Data block to store temporary data in.
std::vector<uint8_t> temp_data_arena_;
// Data block to store temporary data in, aligned as required.
uint8_t* temp_data_ptr_;

InterpreterEvents events_;

Expand Down
48 changes: 48 additions & 0 deletions xls/jit/function_base_jit.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

#include <algorithm>
#include <cstdint>
#include <ios>
#include <iterator>
#include <limits>
#include <memory>
Expand Down Expand Up @@ -47,6 +48,7 @@
#include "xls/common/status/status_macros.h"
#include "xls/ir/block.h"
#include "xls/ir/call_graph.h"
#include "xls/ir/events.h"
#include "xls/ir/function_base.h"
#include "xls/ir/node.h"
#include "xls/ir/node_iterator.h"
Expand All @@ -56,6 +58,7 @@
#include "xls/ir/type.h"
#include "xls/jit/ir_builder_visitor.h"
#include "xls/jit/jit_channel_queue.h"
#include "xls/jit/jit_runtime.h"
#include "xls/jit/llvm_type_converter.h"
#include "xls/jit/orc_jit.h"

Expand Down Expand Up @@ -1266,12 +1269,24 @@ absl::StatusOr<JittedFunctionBase> BuildFunctionAndDependencies(
for (const Node* input : GetJittedFunctionInputs(xls_function)) {
jitted_function.input_buffer_sizes.push_back(
jit_context.type_converter().GetTypeByteSize(input->GetType()));
jitted_function.input_buffer_prefered_alignments.push_back(
jit_context.type_converter().GetTypePreferredAlignment(
input->GetType()));
jitted_function.input_buffer_abi_alignments.push_back(
jit_context.type_converter().GetTypeAbiAlignment(
input->GetType()));
jitted_function.packed_input_buffer_sizes.push_back(
jit_context.type_converter().GetPackedTypeByteSize(input->GetType()));
}
for (const Node* output : GetJittedFunctionOutputs(xls_function)) {
jitted_function.output_buffer_sizes.push_back(
jit_context.type_converter().GetTypeByteSize(OutputType(output)));
jitted_function.output_buffer_prefered_alignments.push_back(
jit_context.type_converter().GetTypePreferredAlignment(
OutputType(output)));
jitted_function.output_buffer_abi_alignments.push_back(
jit_context.type_converter().GetTypeAbiAlignment(
OutputType(output)));
jitted_function.packed_output_buffer_sizes.push_back(
jit_context.type_converter().GetPackedTypeByteSize(OutputType(output)));
}
Expand Down Expand Up @@ -1312,4 +1327,37 @@ absl::StatusOr<JittedFunctionBase> BuildBlockFunction(Block* block,
/*build_packed_wrapper=*/false);
}

namespace {
absl::Status VerifyOffsetAlignments(uint8_t const* const* const ptrs,
absl::Span<int64_t const> alignments) {
for (int64_t i = 0; i < alignments.size(); ++i) {
XLS_RET_CHECK_EQ(absl::bit_cast<uintptr_t>(ptrs[i]) % alignments[i], 0)
<< "value at index " << i << " is not aligned to " << alignments[i]
<< ". Value is 0x" << std::hex << absl::bit_cast<uintptr_t>(ptrs[i]);
}
return absl::OkStatus();
}
} // namespace

int64_t JittedFunctionBase::RunJittedFunction(
const uint8_t* const* inputs, uint8_t* const* outputs, void* temp_buffer,
InterpreterEvents* events, void* user_data, JitRuntime* jit_runtime,
int64_t continuation_point) const {
XLS_DCHECK_OK(VerifyOffsetAlignments(inputs, input_buffer_abi_alignments));
XLS_DCHECK_OK(VerifyOffsetAlignments(outputs, output_buffer_abi_alignments));
return function(inputs, outputs, temp_buffer, events, user_data, jit_runtime,
continuation_point);
}

std::optional<int64_t> JittedFunctionBase::RunPackedJittedFunction(
const uint8_t* const* inputs, uint8_t* const* outputs, void* temp_buffer,
InterpreterEvents* events, void* user_data, JitRuntime* jit_runtime,
int64_t continuation_point) const {
// TODO(allight): Do actual checks here.
if (packed_function) {
return (*packed_function)(inputs, outputs, temp_buffer, events, user_data,
jit_runtime, continuation_point);
}
return std::nullopt;
}
} // namespace xls
24 changes: 24 additions & 0 deletions xls/jit/function_base_jit.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,16 @@
#ifndef XLS_JIT_FUNCTION_BASE_JIT_H_
#define XLS_JIT_FUNCTION_BASE_JIT_H_

#include <cstdint>
#include <optional>
#include <string>
#include <vector>

#include "absl/container/flat_hash_map.h"
#include "absl/status/statusor.h"
#include "xls/ir/events.h"
#include "xls/ir/function.h"
#include "xls/ir/function_base.h"
#include "xls/ir/proc.h"
#include "xls/jit/jit_channel_queue.h"
#include "xls/jit/jit_runtime.h"
Expand Down Expand Up @@ -68,16 +71,37 @@ struct JittedFunctionBase {
std::string function_name;
JitFunctionType function;

// Execute the actual function (after verifying some invariants)
int64_t RunJittedFunction(const uint8_t* const* inputs,
uint8_t* const* outputs, void* temp_buffer,
InterpreterEvents* events, void* user_data,
JitRuntime* jit_runtime,
int64_t continuation_point) const;

// Name and function pointer for the jitted function which accepts/produces
// arguments/results in a packed format. Only exists for JITted
// xls::Functions, not procs.
std::optional<std::string> packed_function_name;
std::optional<JitFunctionType> packed_function;

// Execute the actual function (after verifying some invariants)
std::optional<int64_t> RunPackedJittedFunction(
const uint8_t* const* inputs, uint8_t* const* outputs, void* temp_buffer,
InterpreterEvents* events, void* user_data, JitRuntime* jit_runtime,
int64_t continuation_point) const;

// Sizes of the inputs/outputs in native LLVM format for `function_base`.
std::vector<int64_t> input_buffer_sizes;
std::vector<int64_t> output_buffer_sizes;

// alignment preferences of each input/output buffer.
std::vector<int64_t> input_buffer_prefered_alignments;
std::vector<int64_t> output_buffer_prefered_alignments;

// alignment ABI requirements of each input/output buffer.
std::vector<int64_t> input_buffer_abi_alignments;
std::vector<int64_t> output_buffer_abi_alignments;

// Sizes of the inputs/outputs in packed format for `function_base`.
std::vector<int64_t> packed_input_buffer_sizes;
std::vector<int64_t> packed_output_buffer_sizes;
Expand Down
34 changes: 24 additions & 10 deletions xls/jit/function_jit.cc
Original file line number Diff line number Diff line change
Expand Up @@ -73,17 +73,32 @@ absl::StatusOr<std::unique_ptr<FunctionJit>> FunctionJit::CreateInternal(
llvm::DataLayout data_layout,
OrcJit::CreateDataLayout(/*aot_specification=*/emit_object_code));
jit->jit_runtime_ = std::make_unique<JitRuntime>(data_layout);
JitRuntime& runtime = *jit->jit_runtime_;
XLS_ASSIGN_OR_RETURN(jit->jitted_function_base_,
BuildFunction(xls_function, *jit->orc_jit_));

// Pre-allocate argument, result, and temporary buffers.
for (int64_t i = 0; i < xls_function->params().size(); ++i) {
jit->arg_buffers_.push_back(std::vector<uint8_t>(jit->GetArgTypeSize(i)));
jit->arg_buffer_ptrs_.push_back(jit->arg_buffers_.back().data());
for (int i = 0; i < xls_function->params().size(); ++i) {
jit->arg_buffers_.push_back(
std::vector<uint8_t>(runtime.ShouldAllocateForAlignment(
jit->GetArgTypeSize(i), jit->GetArgTypeAlignment(i))));
jit->arg_buffer_ptrs_.push_back(
runtime
.AsAligned(absl::MakeSpan(jit->arg_buffers_.back()),
jit->GetArgTypeAlignment(i))
.data());
}
jit->result_buffer_.resize(jit->GetReturnTypeSize());
jit->result_buffer_.resize(runtime.ShouldAllocateForAlignment(
jit->GetReturnTypeSize(), jit->GetReturnTypeAlignment()));
jit->result_buffer_ptr_ =
runtime
.AsAligned(absl::MakeSpan(jit->result_buffer_),
jit->GetReturnTypeAlignment())
.data();
jit->temp_buffer_.resize(
jit->jit_runtime_->ShouldAllocateForStack(jit->GetTempBufferSize()));
runtime.ShouldAllocateForStack(jit->GetTempBufferSize()));
jit->temp_buffer_ptr_ =
runtime.AsStack(absl::MakeSpan(jit->temp_buffer_)).data();

return jit;
}
Expand Down Expand Up @@ -115,9 +130,9 @@ absl::StatusOr<InterpreterResult<Value>> FunctionJit::Run(
absl::MakeSpan(arg_buffer_ptrs_)));

InterpreterEvents events;
InvokeJitFunction(arg_buffer_ptrs_, result_buffer_.data(), &events);
InvokeJitFunction(arg_buffer_ptrs_, result_buffer_ptr_, &events);
Value result = jit_runtime_->UnpackBuffer(
result_buffer_.data(), xls_function_->return_value()->GetType());
result_buffer_ptr_, xls_function_->return_value()->GetType());

return InterpreterResult<Value>{std::move(result), std::move(events)};
}
Expand Down Expand Up @@ -153,9 +168,8 @@ void FunctionJit::InvokeJitFunction(
absl::Span<const uint8_t* const> arg_buffers, uint8_t* output_buffer,
InterpreterEvents* events) {
uint8_t* output_buffers[1] = {output_buffer};
jitted_function_base_.function(
arg_buffers.data(), output_buffers,
jit_runtime_->AsStack(absl::MakeSpan(temp_buffer_)).data(), events,
jitted_function_base_.RunJittedFunction(
arg_buffers.data(), output_buffers, temp_buffer_ptr_, events,
/*user_data=*/nullptr, runtime(), /*continuation_point=*/0);
}

Expand Down
Loading

0 comments on commit 5afca97

Please sign in to comment.