Skip to content

Commit

Permalink
Add Support for Pre-approved IsValidSignature
Browse files Browse the repository at this point in the history
This commit demonstrates how `isValidSignature` can be modified to
support pre-approved signatures where the `caller` has an implicit
signature (similar to how `execTransaction` works).
  • Loading branch information
nlordell committed Dec 11, 2024
1 parent febab5e commit 8a9b056
Show file tree
Hide file tree
Showing 14 changed files with 156 additions and 88 deletions.
2 changes: 1 addition & 1 deletion certora/specs/NativeTokenRefund.spec
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
// This spec is a separate file because we summarize checkSignatures here

methods {
function checkSignatures(bytes32, bytes memory) internal => NONDET;
function checkSignatures(address, bytes32, bytes memory) internal => NONDET;

function getNativeTokenBalanceFor(address) external returns (uint256) envfree;
function getSafeGuard() external returns (address) envfree;
Expand Down
2 changes: 1 addition & 1 deletion certora/specs/Safe.spec
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ methods {
function execTransactionFromModule(address,uint256,bytes,Enum.Operation) external returns (bool);
function execTransaction(address,uint256,bytes,Enum.Operation,uint256,uint256,uint256,address,address,bytes) external returns (bool);

function checkSignatures(bytes32, bytes memory) internal => NONDET;
function checkSignatures(address, bytes32, bytes memory) internal => NONDET;
}

definition reachableOnly(method f) returns bool =
Expand Down
17 changes: 8 additions & 9 deletions certora/specs/Signatures.spec
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ methods {
function getThreshold() external returns (uint256) envfree;
function nonce() external returns (uint256) envfree;
function isOwner(address) external returns (bool) envfree;
function checkSignatures(address, bytes32, bytes) external envfree;
function checkNSignatures(address, bytes32, bytes, uint256) external envfree;

// harnessed
function signatureSplitPublic(bytes,uint256) external returns (uint8,bytes32,bytes32) envfree;
Expand All @@ -27,8 +29,7 @@ methods {
) internal returns (bytes32) => CONSTANT;

// optional
function checkSignatures(bytes32,bytes) external;
function execTransaction(address,uint256,bytes,Enum.Operation,uint256,uint256,uint256,address,address,bytes) external returns (bool);
function execTransaction(address, uint256, bytes, Enum.Operation, uint256, uint256, uint256, address, address, bytes) external returns (bool);
}

definition MAX_UINT256() returns uint256 = 0xffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff;
Expand All @@ -46,7 +47,6 @@ function signatureSplitGhost(bytes signatures, uint256 pos) returns (uint8,bytes
rule checkSignatures() {
bytes32 dataHash;
address executor;
env e;
bytes signaturesAB;
bytes signaturesA;
bytes signaturesB;
Expand All @@ -66,14 +66,13 @@ rule checkSignatures() {
require !isOwner(currentContract);
require getThreshold() == 2;
require getCurrentOwner(dataHash, vA, rA, sA) < getCurrentOwner(dataHash, vB, rB, sB);
require executor == e.msg.sender;

checkNSignatures@withrevert(e, executor, dataHash, signaturesA, 1);
checkNSignatures@withrevert(executor, dataHash, signaturesA, 1);
bool successA = !lastReverted;
checkNSignatures@withrevert(e, executor, dataHash, signaturesB, 1);
checkNSignatures@withrevert(executor, dataHash, signaturesB, 1);
bool successB = !lastReverted;

checkSignatures@withrevert(e, dataHash, signaturesAB);
checkSignatures@withrevert(executor, dataHash, signaturesAB);
bool successAB = !lastReverted;

assert (successA && successB) <=> successAB, "checkNSignatures called twice separately must be equivalent to checkSignatures";
Expand Down Expand Up @@ -107,8 +106,8 @@ rule ownerSignaturesAreProvidedForExecTransaction(
);

env e;
require e.msg.value == 0;
checkSignatures@withrevert(e, transactionHash, signatures);

checkSignatures@withrevert(e.msg.sender, transactionHash, signatures);
bool checkSignaturesOk = !lastReverted;

execTransaction(e, to, value, data, operation, safeTxGas, baseGas, gasPrice, gasToken, refundReceiver, signatures);
Expand Down
8 changes: 4 additions & 4 deletions contracts/Safe.sol
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ contract Safe is
// We use the post-increment here, so the current nonce value is used and incremented afterwards.
nonce++
);
checkSignatures(txHash, signatures);
checkSignatures(msg.sender, txHash, signatures);
}
address guard = getGuard();
{
Expand Down Expand Up @@ -267,12 +267,12 @@ contract Safe is
/**
* @inheritdoc ISafe
*/
function checkSignatures(bytes32 dataHash, bytes memory signatures) public view override {
function checkSignatures(address executor, bytes32 dataHash, bytes memory signatures) public view override {
// Load threshold to avoid multiple storage loads
uint256 _threshold = threshold;
// Check that a threshold is set
if (_threshold == 0) revertWithError("GS001");
checkNSignatures(msg.sender, dataHash, signatures, _threshold);
checkNSignatures(executor, dataHash, signatures, _threshold);
}

/**
Expand Down Expand Up @@ -343,7 +343,7 @@ contract Safe is
*/
function checkSignatures(bytes32 dataHash, bytes calldata data, bytes memory signatures) external view {
data;
checkSignatures(dataHash, signatures);
checkSignatures(msg.sender, dataHash, signatures);
}

/**
Expand Down
2 changes: 1 addition & 1 deletion contracts/handler/CompatibilityFallbackHandler.sol
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ contract CompatibilityFallbackHandler is TokenCallbackHandler, ISignatureValidat
if (_signature.length == 0) {
require(safe.signedMessages(messageHash) != 0, "Hash not approved");
} else {
safe.checkSignatures(messageHash, _signature);
safe.checkSignatures(_msgSender(), messageHash, _signature);
}
return EIP1271_MAGIC_VALUE;
}
Expand Down
18 changes: 9 additions & 9 deletions contracts/handler/extensible/ERC165Handler.sol
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@
pragma solidity >=0.7.0 <0.9.0;

import {IERC165} from "../../interfaces/IERC165.sol";
import {Safe, MarshalLib, ExtensibleBase} from "./ExtensibleBase.sol";
import {ISafe, MarshalLib, ExtensibleBase} from "./ExtensibleBase.sol";

interface IERC165Handler {
function safeInterfaces(Safe safe, bytes4 interfaceId) external view returns (bool);
function safeInterfaces(ISafe safe, bytes4 interfaceId) external view returns (bool);

function setSupportedInterface(bytes4 interfaceId, bool supported) external;

Expand All @@ -17,12 +17,12 @@ interface IERC165Handler {
abstract contract ERC165Handler is ExtensibleBase, IERC165Handler {
// --- events ---

event AddedInterface(Safe indexed safe, bytes4 interfaceId);
event RemovedInterface(Safe indexed safe, bytes4 interfaceId);
event AddedInterface(ISafe indexed safe, bytes4 interfaceId);
event RemovedInterface(ISafe indexed safe, bytes4 interfaceId);

// --- storage ---

mapping(Safe => mapping(bytes4 => bool)) public override safeInterfaces;
mapping(ISafe => mapping(bytes4 => bool)) public override safeInterfaces;

// --- setters ---

Expand All @@ -32,7 +32,7 @@ abstract contract ERC165Handler is ExtensibleBase, IERC165Handler {
* @param supported True if the interface is supported, false otherwise
*/
function setSupportedInterface(bytes4 interfaceId, bool supported) public override onlySelf {
Safe safe = Safe(payable(_manager()));
ISafe safe = ISafe(payable(_manager()));
// invalid interface id per ERC165 spec
require(interfaceId != 0xffffffff, "invalid interface id");
bool current = safeInterfaces[safe][interfaceId];
Expand All @@ -51,7 +51,7 @@ abstract contract ERC165Handler is ExtensibleBase, IERC165Handler {
* @param handlerWithSelectors The handlers encoded with the 4-byte selectors of the methods
*/
function addSupportedInterfaceBatch(bytes4 _interfaceId, bytes32[] calldata handlerWithSelectors) external override onlySelf {
Safe safe = Safe(payable(_msgSender()));
ISafe safe = ISafe(payable(_msgSender()));
bytes4 interfaceId;
for (uint256 i = 0; i < handlerWithSelectors.length; i++) {
(bool isStatic, bytes4 selector, address handlerAddress) = MarshalLib.decodeWithSelector(handlerWithSelectors[i]);
Expand All @@ -73,7 +73,7 @@ abstract contract ERC165Handler is ExtensibleBase, IERC165Handler {
* @param selectors The selectors of the methods to remove
*/
function removeSupportedInterfaceBatch(bytes4 _interfaceId, bytes4[] calldata selectors) external override onlySelf {
Safe safe = Safe(payable(_msgSender()));
ISafe safe = ISafe(payable(_msgSender()));
bytes4 interfaceId;
for (uint256 i = 0; i < selectors.length; i++) {
_setSafeMethod(safe, selectors[i], bytes32(0));
Expand All @@ -99,7 +99,7 @@ abstract contract ERC165Handler is ExtensibleBase, IERC165Handler {
interfaceId == type(IERC165).interfaceId ||
interfaceId == type(IERC165Handler).interfaceId ||
_supportsInterface(interfaceId) ||
safeInterfaces[Safe(payable(_manager()))][interfaceId];
safeInterfaces[ISafe(payable(_manager()))][interfaceId];
}

// --- internal ---
Expand Down
22 changes: 11 additions & 11 deletions contracts/handler/extensible/ExtensibleBase.sol
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
// SPDX-License-Identifier: LGPL-3.0-only
pragma solidity >=0.7.0 <0.9.0;

import {Safe} from "../../Safe.sol";
import {ISafe} from "../../interfaces/ISafe.sol";
import {HandlerContext} from "../HandlerContext.sol";
import {MarshalLib} from "./MarshalLib.sol";

interface IFallbackMethod {
function handle(Safe safe, address sender, uint256 value, bytes calldata data) external returns (bytes memory result);
function handle(ISafe safe, address sender, uint256 value, bytes calldata data) external returns (bytes memory result);
}

interface IStaticFallbackMethod {
function handle(Safe safe, address sender, uint256 value, bytes calldata data) external view returns (bytes memory result);
function handle(ISafe safe, address sender, uint256 value, bytes calldata data) external view returns (bytes memory result);
}

/**
Expand All @@ -20,9 +20,9 @@ interface IStaticFallbackMethod {
*/
abstract contract ExtensibleBase is HandlerContext {
// --- events ---
event AddedSafeMethod(Safe indexed safe, bytes4 selector, bytes32 method);
event ChangedSafeMethod(Safe indexed safe, bytes4 selector, bytes32 oldMethod, bytes32 newMethod);
event RemovedSafeMethod(Safe indexed safe, bytes4 selector);
event AddedSafeMethod(ISafe indexed safe, bytes4 selector, bytes32 method);
event ChangedSafeMethod(ISafe indexed safe, bytes4 selector, bytes32 oldMethod, bytes32 newMethod);
event RemovedSafeMethod(ISafe indexed safe, bytes4 selector);

// --- storage ---

Expand All @@ -31,7 +31,7 @@ abstract contract ExtensibleBase is HandlerContext {
// - The first byte is 0x00 if the method is static and 0x01 if the method is not static
// - The last 20 bytes are the address of the handler contract
// The method is encoded / decoded using the MarshalLib
mapping(Safe => mapping(bytes4 => bytes32)) public safeMethods;
mapping(ISafe => mapping(bytes4 => bytes32)) public safeMethods;

// --- modifiers ---
modifier onlySelf() {
Expand All @@ -44,7 +44,7 @@ abstract contract ExtensibleBase is HandlerContext {

// --- internal ---

function _setSafeMethod(Safe safe, bytes4 selector, bytes32 newMethod) internal {
function _setSafeMethod(ISafe safe, bytes4 selector, bytes32 newMethod) internal {
(, address newHandler) = MarshalLib.decode(newMethod);
bytes32 oldMethod = safeMethods[safe][selector];
(, address oldHandler) = MarshalLib.decode(oldMethod);
Expand All @@ -67,8 +67,8 @@ abstract contract ExtensibleBase is HandlerContext {
* @return safe The safe whose FallbackManager is making this call
* @return sender The original `msg.sender` (as received by the FallbackManager)
*/
function _getContext() internal view returns (Safe safe, address sender) {
safe = Safe(payable(_manager()));
function _getContext() internal view returns (ISafe safe, address sender) {
safe = ISafe(payable(_manager()));
sender = _msgSender();
}

Expand All @@ -79,7 +79,7 @@ abstract contract ExtensibleBase is HandlerContext {
* @return isStatic Whether the method is static (`view`) or not
* @return handler the address of the handler contract
*/
function _getContextAndHandler() internal view returns (Safe safe, address sender, bool isStatic, address handler) {
function _getContextAndHandler() internal view returns (ISafe safe, address sender, bool isStatic, address handler) {
(safe, sender) = _getContext();
(isStatic, handler) = MarshalLib.decode(safeMethods[safe][msg.sig]);
}
Expand Down
6 changes: 3 additions & 3 deletions contracts/handler/extensible/FallbackHandler.sol
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
// SPDX-License-Identifier: LGPL-3.0-only
pragma solidity >=0.7.0 <0.9.0;

import {Safe, IStaticFallbackMethod, IFallbackMethod, ExtensibleBase} from "./ExtensibleBase.sol";
import {ISafe, IStaticFallbackMethod, IFallbackMethod, ExtensibleBase} from "./ExtensibleBase.sol";

interface IFallbackHandler {
function setSafeMethod(bytes4 selector, bytes32 newMethod) external;
Expand All @@ -22,15 +22,15 @@ abstract contract FallbackHandler is ExtensibleBase, IFallbackHandler {
* @param newMethod A contract that implements the `IFallbackMethod` or `IStaticFallbackMethod` interface
*/
function setSafeMethod(bytes4 selector, bytes32 newMethod) public override onlySelf {
_setSafeMethod(Safe(payable(_msgSender())), selector, newMethod);
_setSafeMethod(ISafe(payable(_msgSender())), selector, newMethod);
}

// --- fallback ---

// solhint-disable-next-line
fallback(bytes calldata) external returns (bytes memory result) {
require(msg.data.length >= 24, "invalid method selector");
(Safe safe, address sender, bool isStatic, address handler) = _getContextAndHandler();
(ISafe safe, address sender, bool isStatic, address handler) = _getContextAndHandler();
require(handler != address(0), "method handler not set");

if (isStatic) {
Expand Down
29 changes: 17 additions & 12 deletions contracts/handler/extensible/SignatureVerifierMuxer.sol
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
// solhint-disable one-contract-per-file
pragma solidity >=0.7.0 <0.9.0;

import {Safe, ExtensibleBase} from "./ExtensibleBase.sol";
import {ISafe, ExtensibleBase} from "./ExtensibleBase.sol";

interface ERC1271 {
function isValidSignature(bytes32 hash, bytes calldata signature) external view returns (bytes4 magicValue);
Expand All @@ -28,7 +28,7 @@ interface ISafeSignatureVerifier {
* @return magic The magic value that should be returned if the signature is valid (0x1626ba7e)
*/
function isValidSafeSignature(
Safe safe,
ISafe safe,
address sender,
bytes32 _hash,
bytes32 domainSeparator,
Expand All @@ -39,7 +39,7 @@ interface ISafeSignatureVerifier {
}

interface ISignatureVerifierMuxer {
function domainVerifiers(Safe safe, bytes32 domainSeparator) external view returns (ISafeSignatureVerifier);
function domainVerifiers(ISafe safe, bytes32 domainSeparator) external view returns (ISafeSignatureVerifier);

function setDomainVerifier(bytes32 domainSeparator, ISafeSignatureVerifier verifier) external;
}
Expand All @@ -61,25 +61,25 @@ abstract contract SignatureVerifierMuxer is ExtensibleBase, ERC1271, ISignatureV
bytes4 private constant SAFE_SIGNATURE_MAGIC_VALUE = 0x5fd7e97d;

// --- storage ---
mapping(Safe => mapping(bytes32 => ISafeSignatureVerifier)) public override domainVerifiers;
mapping(ISafe => mapping(bytes32 => ISafeSignatureVerifier)) public override domainVerifiers;

// --- events ---
event AddedDomainVerifier(Safe indexed safe, bytes32 domainSeparator, ISafeSignatureVerifier verifier);
event AddedDomainVerifier(ISafe indexed safe, bytes32 domainSeparator, ISafeSignatureVerifier verifier);
event ChangedDomainVerifier(
Safe indexed safe,
ISafe indexed safe,
bytes32 domainSeparator,
ISafeSignatureVerifier oldVerifier,
ISafeSignatureVerifier newVerifier
);
event RemovedDomainVerifier(Safe indexed safe, bytes32 domainSeparator);
event RemovedDomainVerifier(ISafe indexed safe, bytes32 domainSeparator);

/**
* Setter for the signature muxer
* @param domainSeparator The domainSeparator authorised for the `ISafeSignatureVerifier`
* @param newVerifier A contract that implements `ISafeSignatureVerifier`
*/
function setDomainVerifier(bytes32 domainSeparator, ISafeSignatureVerifier newVerifier) public override onlySelf {
Safe safe = Safe(payable(_msgSender()));
ISafe safe = ISafe(payable(_msgSender()));
ISafeSignatureVerifier oldVerifier = domainVerifiers[safe][domainSeparator];
if (address(newVerifier) == address(0) && address(oldVerifier) != address(0)) {
delete domainVerifiers[safe][domainSeparator];
Expand All @@ -102,7 +102,7 @@ abstract contract SignatureVerifierMuxer is ExtensibleBase, ERC1271, ISignatureV
* @return magic Standardised ERC1271 return value
*/
function isValidSignature(bytes32 _hash, bytes calldata signature) external view override returns (bytes4 magic) {
(Safe safe, address sender) = _getContext();
(ISafe safe, address sender) = _getContext();

// Check if the signature is for an `ISafeSignatureVerifier` and if it is valid for the domain.
if (signature.length >= 4) {
Expand Down Expand Up @@ -144,7 +144,7 @@ abstract contract SignatureVerifierMuxer is ExtensibleBase, ERC1271, ISignatureV
}

// domainVerifier doesn't exist or the signature is invalid for the domain - fall back to the default
return defaultIsValidSignature(safe, _hash, signature);
return defaultIsValidSignature(safe, sender, _hash, signature);
}

/**
Expand All @@ -153,7 +153,12 @@ abstract contract SignatureVerifierMuxer is ExtensibleBase, ERC1271, ISignatureV
* @param _hash Hash of the data that is signed
* @param signature The signature to be verified
*/
function defaultIsValidSignature(Safe safe, bytes32 _hash, bytes memory signature) internal view returns (bytes4 magic) {
function defaultIsValidSignature(
ISafe safe,
address sender,
bytes32 _hash,
bytes memory signature
) internal view returns (bytes4 magic) {
bytes memory messageData = EIP712.encodeMessageData(
safe.domainSeparator(),
SAFE_MSG_TYPEHASH,
Expand All @@ -165,7 +170,7 @@ abstract contract SignatureVerifierMuxer is ExtensibleBase, ERC1271, ISignatureV
require(safe.signedMessages(messageHash) != 0, "Hash not approved");
} else {
// threshold signatures
safe.checkSignatures(messageHash, messageData, signature);
safe.checkSignatures(sender, messageHash, signature);
}
magic = ERC1271.isValidSignature.selector;
}
Expand Down
7 changes: 5 additions & 2 deletions contracts/interfaces/ISafe.sol
Original file line number Diff line number Diff line change
Expand Up @@ -76,12 +76,15 @@ interface ISafe is IModuleManager, IGuardManager, IOwnerManager, IFallbackManage
) external payable returns (bool success);

/**
* @notice Checks whether the signature provided is valid for the provided data and hash. Reverts otherwise.
* @notice Checks whether the signature provided is valid for the provided data and hash and executor. Reverts otherwise.
* @param executor Address that executes the transaction.
* ⚠️⚠️⚠️ Make sure that the executor address is a legitimate executor.
* Incorrectly passed the executor might reduce the threshold by 1 signature. ⚠️⚠️⚠️
* @param dataHash Hash of the data (could be either a message hash or transaction hash)
* @param signatures Signature data that should be verified.
* Can be packed ECDSA signature ({bytes32 r}{bytes32 s}{uint8 v}), contract signature (EIP-1271) or approved hash.
*/
function checkSignatures(bytes32 dataHash, bytes memory signatures) external view;
function checkSignatures(address executor, bytes32 dataHash, bytes memory signatures) external view;

/**
* @notice Checks whether the signature provided is valid for the provided data and hash. Reverts otherwise.
Expand Down
Loading

0 comments on commit 8a9b056

Please sign in to comment.