Skip to content

Commit

Permalink
feat: extract structured payload (#18)
Browse files Browse the repository at this point in the history
* add decode_raw

* add tests

* fix typo

* rename check

* specify vec

* import vec in tests

* Updated CHANGELOG.md

* use advance_unchecked
  • Loading branch information
Wollac committed Jun 28, 2024
1 parent d966d6e commit cd9a3ff
Show file tree
Hide file tree
Showing 2 changed files with 113 additions and 5 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## [Unreleased]

### Added

- New `decode_raw` static methods to `Header`

## [0.3.5] - 2024-05-22

### Changed
Expand Down
114 changes: 109 additions & 5 deletions crates/rlp/src/header.rs
Original file line number Diff line number Diff line change
Expand Up @@ -89,11 +89,7 @@ impl Header {
}

// SAFETY: this is already checked in `decode`
if buf.remaining() < payload_length {
unsafe { unreachable_unchecked() }
}
let bytes = unsafe { buf.get_unchecked(..payload_length) };
buf.advance(payload_length);
let bytes = unsafe { advance_unchecked(buf, payload_length) };
Ok(bytes)
}

Expand All @@ -108,6 +104,40 @@ impl Header {
core::str::from_utf8(bytes).map_err(|_| Error::Custom("invalid string"))
}

/// Extracts the next payload from the given buffer, advancing it.
///
/// # Errors
///
/// Returns an error if the buffer is too short, the header is invalid or one of the headers one
/// level deeper is invalid.
#[inline]
pub fn decode_raw<'a>(buf: &mut &'a [u8]) -> Result<PayloadView<'a>> {
let Self { list, payload_length } = Self::decode(buf)?;
// SAFETY: this is already checked in `decode`
let mut payload = unsafe { advance_unchecked(buf, payload_length) };

if !list {
return Ok(PayloadView::String(payload));
}

let mut items = alloc::vec::Vec::new();
while !payload.is_empty() {
// decode the next header without advancing in the payload
let Self { payload_length, .. } = Self::decode(&mut &payload[..])?;
// the length of the RLP encoding is the length of the header plus its payload length
// if payload length is 1 and the first byte is in [0x00, 0x7F], then there is no header
let rlp_length = if payload_length == 1 && payload[0] <= 0x7F {
1
} else {
payload_length + crate::length_of_length(payload_length)
};
items.push(&payload[..rlp_length]);
payload.advance(rlp_length);
}

return Ok(PayloadView::List(items));
}

/// Encodes the header into the `out` buffer.
#[inline]
pub fn encode(&self, out: &mut dyn BufMut) {
Expand All @@ -130,6 +160,12 @@ impl Header {
}
}

/// Structured representation of an RLP payload.
pub enum PayloadView<'a> {
String(&'a [u8]),
List(alloc::vec::Vec<&'a [u8]>),
}

/// Same as `buf.first().ok_or(Error::InputTooShort)`.
#[inline(always)]
fn get_next_byte(buf: &[u8]) -> Result<u8> {
Expand All @@ -139,3 +175,71 @@ fn get_next_byte(buf: &[u8]) -> Result<u8> {
// SAFETY: length checked above
Ok(*unsafe { buf.get_unchecked(0) })
}

/// Same as `let (bytes, rest) = buf.split_at(cnt); *buf = rest; bytes`.
#[inline(always)]
unsafe fn advance_unchecked<'a>(buf: &mut &'a [u8], cnt: usize) -> &'a [u8] {
if buf.remaining() < cnt {
unreachable_unchecked()
}
let bytes = &buf[..cnt];
buf.advance(cnt);
bytes
}

#[cfg(test)]
mod tests {
use super::*;
use crate::Encodable;
use alloc::vec::Vec;
use core::fmt::Debug;

fn check_decode_raw_list<T: Encodable + Debug>(input: Vec<T>) {
let encoded = crate::encode(&input);
let expected: Vec<_> = input.iter().map(crate::encode).collect();
let mut buf = encoded.as_slice();
assert!(
matches!(Header::decode_raw(&mut buf), Ok(PayloadView::List(v)) if v == expected),
"input: {:?}, expected list: {:?}",
input,
expected
);
assert!(buf.is_empty(), "buffer was not advanced");
}

fn check_decode_raw_string(input: &str) {
let encoded = crate::encode(input);
let expected = Header::decode_bytes(&mut &encoded[..], false).unwrap();
let mut buf = encoded.as_slice();
assert!(
matches!(Header::decode_raw(&mut buf), Ok(PayloadView::String(v)) if v == expected),
"input: {}, expected string: {:?}",
input,
expected
);
assert!(buf.is_empty(), "buffer was not advanced");
}

#[test]
fn decode_raw() {
// empty list
check_decode_raw_list(Vec::<u64>::new());
// list of an empty RLP list
check_decode_raw_list(vec![Vec::<u64>::new()]);
// list of an empty RLP string
check_decode_raw_list(vec![""]);
// list of two RLP strings
check_decode_raw_list(vec![0xBBCCB5_u64, 0xFFC0B5_u64]);
// list of three RLP lists of various lengths
check_decode_raw_list(vec![vec![0u64], vec![1u64, 2u64], vec![3u64, 4u64, 5u64]]);
// list of four empty RLP strings
check_decode_raw_list(vec![0u64; 4]);
// list of all one-byte strings, some will have an RLP header and some won't
check_decode_raw_list((0u64..0xFF).collect());

// strings of various lengths
check_decode_raw_string("");
check_decode_raw_string(" ");
check_decode_raw_string("test1234");
}
}

0 comments on commit cd9a3ff

Please sign in to comment.