diff --git a/src/utils/LibBytes.sol b/src/utils/LibBytes.sol index 09b4a1ecc..2e2ad8ae9 100644 --- a/src/utils/LibBytes.sol +++ b/src/utils/LibBytes.sol @@ -627,6 +627,43 @@ library LibBytes { } } + /// @dev Directly returns `a` with minimal copying. + function directReturn(bytes[] memory a) internal pure { + assembly { + let n := mload(a) // `a.length`. + let o := add(a, 0x20) // Start of elements in `a`. + let u := a // Highest memory slot. + let w := not(0x1f) + for { let i := 0 } iszero(eq(i, n)) { i := add(i, 1) } { + let c := add(o, shl(5, i)) // Location of pointer to `a[i]`. + let s := mload(c) // `a[i]`. + let l := mload(s) // `a[i].length`. + let r := and(l, 0x1f) // `a[i].length % 32`. + let z := add(0x20, and(l, w)) // Offset of last word in `a[i]` from `s`. + // If `s` comes before `o`, or `s` is not zero right padded. + if iszero(lt(lt(s, o), or(iszero(r), iszero(shl(shl(3, r), mload(add(s, z))))))) { + let m := mload(0x40) + mstore(m, l) // Copy `a[i].length`. + for {} 1 {} { + mstore(add(m, z), mload(add(s, z))) // Copy `a[i]`, backwards. + z := add(z, w) // `sub(z, 0x20)`. + if iszero(z) { break } + } + let e := add(add(m, 0x20), l) + mstore(e, 0) // Zeroize the slot after the copied bytes. + mstore(0x40, add(e, 0x20)) // Allocate memory. + s := m + } + mstore(c, sub(s, o)) // Convert to calldata offset. + let t := add(l, add(s, 0x20)) + if iszero(lt(t, u)) { u := t } + } + let retStart := add(a, w) // Assumes `a` doesn't start from scratch space. + mstore(retStart, 0x20) // Store the return offset. + return(retStart, add(0x40, sub(u, retStart))) // End the transaction. + } + } + /// @dev Returns the word at `offset`, without any bounds checks. /// To load an address, you can use `address(bytes20(load(a, offset)))`. function load(bytes memory a, uint256 offset) internal pure returns (bytes32 result) { diff --git a/src/utils/g/LibBytes.sol b/src/utils/g/LibBytes.sol index 589718f82..3fa722203 100644 --- a/src/utils/g/LibBytes.sol +++ b/src/utils/g/LibBytes.sol @@ -631,6 +631,43 @@ library LibBytes { } } + /// @dev Directly returns `a` with minimal copying. + function directReturn(bytes[] memory a) internal pure { + assembly { + let n := mload(a) // `a.length`. + let o := add(a, 0x20) // Start of elements in `a`. + let u := a // Highest memory slot. + let w := not(0x1f) + for { let i := 0 } iszero(eq(i, n)) { i := add(i, 1) } { + let c := add(o, shl(5, i)) // Location of pointer to `a[i]`. + let s := mload(c) // `a[i]`. + let l := mload(s) // `a[i].length`. + let r := and(l, 0x1f) // `a[i].length % 32`. + let z := add(0x20, and(l, w)) // Offset of last word in `a[i]` from `s`. + // If `s` comes before `o`, or `s` is not zero right padded. + if iszero(lt(lt(s, o), or(iszero(r), iszero(shl(shl(3, r), mload(add(s, z))))))) { + let m := mload(0x40) + mstore(m, l) // Copy `a[i].length`. + for {} 1 {} { + mstore(add(m, z), mload(add(s, z))) // Copy `a[i]`, backwards. + z := add(z, w) // `sub(z, 0x20)`. + if iszero(z) { break } + } + let e := add(add(m, 0x20), l) + mstore(e, 0) // Zeroize the slot after the copied bytes. + mstore(0x40, add(e, 0x20)) // Allocate memory. + s := m + } + mstore(c, sub(s, o)) // Convert to calldata offset. + let t := add(l, add(s, 0x20)) + if iszero(lt(t, u)) { u := t } + } + let retStart := add(a, w) // Assumes `a` doesn't start from scratch space. + mstore(retStart, 0x20) // Store the return offset. + return(retStart, add(0x40, sub(u, retStart))) // End the transaction. + } + } + /// @dev Returns the word at `offset`, without any bounds checks. /// To load an address, you can use `address(bytes20(load(a, offset)))`. function load(bytes memory a, uint256 offset) internal pure returns (bytes32 result) { diff --git a/test/LibBytes.t.sol b/test/LibBytes.t.sol index 467fe3575..8bb25f51f 100644 --- a/test/LibBytes.t.sol +++ b/test/LibBytes.t.sol @@ -46,4 +46,95 @@ contract LibBytesTest is SoladyTest { function testEmptyCalldata() public { assertEq(LibBytes.emptyCalldata(), ""); } + + function testDirectReturn() public { + uint256 seed = 123; + bytes[] memory expected = _generateBytesArray(seed); + bytes[] memory computed = this.generateBytesArray(seed, false); + unchecked { + for (uint256 i; i != expected.length; ++i) { + _checkMemory(computed[i]); + assertEq(computed[i], expected[i]); + } + assertEq(computed.length, expected.length); + } + } + + function testDirectReturn(uint256 seed) public { + bytes[] memory expected = _generateBytesArray(seed); + (bool success, bytes memory encoded) = address(this).call( + abi.encodeWithSignature("generateBytesArray(uint256,bool)", seed, true) + ); + assertTrue(success); + bytes[] memory computed; + /// @solidity memory-safe-assembly + assembly { + let o := add(encoded, 0x20) + computed := add(o, mload(o)) + for { let i := 0 } lt(i, mload(computed)) { i := add(i, 1) } { + let c := add(add(0x20, computed), shl(5, i)) + mstore(c, add(add(0x20, computed), mload(c))) + } + } + unchecked { + for (uint256 i; i != expected.length; ++i) { + _checkMemory(computed[i]); + assertEq(computed[i], expected[i]); + } + assertEq(computed.length, expected.length); + } + if (seed & 0xf == 0) { + assertEq(abi.encode(expected), abi.encode(this.generateBytesArray(seed, true))); + } + } + + function generateBytesArray(uint256 seed, bool brutalized) + public + view + returns (bytes[] memory) + { + if (brutalized) { + _misalignFreeMemoryPointer(); + _brutalizeMemory(); + } + LibBytes.directReturn(_generateBytesArray(seed)); + } + + function _generateBytesArray(uint256 seed) internal pure returns (bytes[] memory a) { + bytes memory before = "hehe"; + /// @solidity memory-safe-assembly + assembly { + mstore(0x00, seed) + mstore(0x20, 0) + function _next() -> _r { + _r := keccak256(0x00, 0x40) + mstore(0x20, _r) + } + function _nextBytes() -> _b { + _b := mload(0x40) + let n_ := and(_next(), 0x7f) + mstore(_b, n_) + for { let i_ := 0 } lt(i_, n_) { i_ := add(i_, 0x20) } { + mstore(add(add(_b, 0x20), i_), _next()) + } + if and(1, _next()) { + mstore(0x40, add(n_, add(_b, 0x20))) + leave + } + mstore(add(n_, add(_b, 0x20)), 0) + mstore(0x40, add(n_, add(_b, 0x40))) + } + let n := and(_next(), 7) + a := mload(0x40) + mstore(a, n) + mstore(0x40, add(add(a, 0x20), shl(5, n))) + for { let i := 0 } lt(i, n) { i := add(1, i) } { + if iszero(and(7, _next())) { + mstore(add(add(a, 0x20), shl(5, i)), before) + continue + } + mstore(add(add(a, 0x20), shl(5, i)), _nextBytes()) + } + } + } }