diff --git a/src/auth/EnumerableRoles.sol b/src/auth/EnumerableRoles.sol index 59afefc73..3b6e236cd 100644 --- a/src/auth/EnumerableRoles.sol +++ b/src/auth/EnumerableRoles.sol @@ -244,17 +244,23 @@ abstract contract EnumerableRoles { } } - /// @dev Throws if the sender does not have `role`. + /// @dev Reverts if `msg.sender` does not have `role`. function _checkRole(uint256 role) internal view virtual { if (!_hasRole(msg.sender, role)) _revertEnumerableRolesUnauthorized(); } - /// @dev Throws if the sender does not have any roles in `encodedRoles`. + /// @dev Reverts if `msg.sender` does not have any role in `encodedRoles`. function _checkRoles(bytes memory encodedRoles) internal view virtual { if (!_hasAnyRoles(msg.sender, encodedRoles)) _revertEnumerableRolesUnauthorized(); } - /// @dev Throws if the sender does not have any roles in `encodedRoles`. + /// @dev Reverts if `msg.sender` is not the contract owner and does not have `role`. + function _checkOwnerOrRole(uint256 role) internal view virtual { + if (!_senderIsContractOwner()) _checkRole(role); + } + + /// @dev Reverts if `msg.sender` is not the contract owner and + /// does not have any role in `encodedRoles`. function _checkOwnerOrRoles(bytes memory encodedRoles) internal view virtual { if (!_senderIsContractOwner()) _checkRoles(encodedRoles); } @@ -276,6 +282,12 @@ abstract contract EnumerableRoles { _; } + /// @dev Marks a function as only callable by the owner or by an account with `role`. + modifier onlyOwnerOrRole(uint256 role) virtual { + _checkOwnerOrRole(role); + _; + } + /// @dev Marks a function as only callable by the owner or /// by an account with any role in `encodedRoles`. /// Checks for ownership first, then checks for roles. diff --git a/test/EnumerableRoles.t.sol b/test/EnumerableRoles.t.sol index e67783df5..ebc0c4eb3 100644 --- a/test/EnumerableRoles.t.sol +++ b/test/EnumerableRoles.t.sol @@ -210,6 +210,25 @@ contract EnumerableRolesTest is SoladyTest { } } + function testOnlyOwnerOrRole(uint256 allowedRole, uint256 holderRole) public { + address holder = _randomUniqueHashedAddress(); + assertEq(mockEnumerableRoles.owner(), address(this)); + if (holder == address(this)) return; + mockEnumerableRoles.setAllowedRole(allowedRole); + mockEnumerableRoles.setRoleDirect(holder, holderRole, true); + if (_randomChance(32)) { + mockEnumerableRoles.guardedByOnlyOwnerOrRole(); + } + if (holderRole != allowedRole) { + vm.prank(holder); + vm.expectRevert(EnumerableRoles.EnumerableRolesUnauthorized.selector); + mockEnumerableRoles.guardedByOnlyOwnerOrRole(); + } else { + vm.prank(holder); + mockEnumerableRoles.guardedByOnlyOwnerOrRole(); + } + } + function testSetAndGetRoles(bytes32) public { _TestTemps memory t; t.holders = _sampleUniqueAddresses(_randomUniform() & 7); diff --git a/test/utils/mocks/MockEnumerableRoles.sol b/test/utils/mocks/MockEnumerableRoles.sol index 717ae2c9b..03bd701fb 100644 --- a/test/utils/mocks/MockEnumerableRoles.sol +++ b/test/utils/mocks/MockEnumerableRoles.sol @@ -13,6 +13,7 @@ contract MockEnumerableRoles is EnumerableRoles, Brutalizer { address owner; bool ownerReverts; bytes allowedRolesEncoded; + uint256 allowedRole; } event Yo(); @@ -57,10 +58,18 @@ contract MockEnumerableRoles is EnumerableRoles, Brutalizer { $.allowedRolesEncoded = value; } + function setAllowedRole(uint256 role) public { + $.allowedRole = role; + } + function guardedByOnlyOwnerOrRoles() public onlyOwnerOrRoles($.allowedRolesEncoded) { emit Yo(); } + function guardedByOnlyOwnerOrRole() public onlyOwnerOrRole($.allowedRole) { + emit Yo(); + } + function guardedByOnlyRoles() public onlyRoles($.allowedRolesEncoded) { emit Yo(); }