// SPDX-License-Identifier: MIT // OpenZeppelin Contracts (last updated v5.0.0) (utils/math/Math.sol) pragma solidity ^0.8.20; import {Panic} from "./Panic.sol"; import {SafeCast} from "./SafeCast.sol"; /** * @dev Standard math utilities missing in the Solidity language. */ library Math { enum Rounding { Floor, // Toward negative infinity Ceil, // Toward positive infinity Trunc, // Toward zero Expand // Away from zero } /** * @dev Returns the addition of two unsigned integers, with an success flag (no overflow). */ function tryAdd(uint256 a, uint256 b) internal pure returns (bool success, uint256 result) { unchecked { uint256 c = a + b; if (c < a) return (false, 0); return (true, c); } } /** * @dev Returns the subtraction of two unsigned integers, with an success flag (no overflow). */ function trySub(uint256 a, uint256 b) internal pure returns (bool success, uint256 result) { unchecked { if (b > a) return (false, 0); return (true, a - b); } } /** * @dev Returns the multiplication of two unsigned integers, with an success flag (no overflow). */ function tryMul(uint256 a, uint256 b) internal pure returns (bool success, uint256 result) { unchecked { // Gas optimization: this is cheaper than requiring 'a' not being zero, but the // benefit is lost if 'b' is also tested. // See: https://github.com/OpenZeppelin/openzeppelin-contracts/pull/522 if (a == 0) return (true, 0); uint256 c = a * b; if (c / a != b) return (false, 0); return (true, c); } } /** * @dev Returns the division of two unsigned integers, with a success flag (no division by zero). */ function tryDiv(uint256 a, uint256 b) internal pure returns (bool success, uint256 result) { unchecked { if (b == 0) return (false, 0); return (true, a / b); } } /** * @dev Returns the remainder of dividing two unsigned integers, with a success flag (no division by zero). */ function tryMod(uint256 a, uint256 b) internal pure returns (bool success, uint256 result) { unchecked { if (b == 0) return (false, 0); return (true, a % b); } } /** * @dev Returns the largest of two numbers. */ function max(uint256 a, uint256 b) internal pure returns (uint256) { return a > b ? a : b; } /** * @dev Returns the smallest of two numbers. */ function min(uint256 a, uint256 b) internal pure returns (uint256) { return a < b ? a : b; } /** * @dev Returns the average of two numbers. The result is rounded towards * zero. */ function average(uint256 a, uint256 b) internal pure returns (uint256) { // (a + b) / 2 can overflow. return (a & b) + (a ^ b) / 2; } /** * @dev Returns the ceiling of the division of two numbers. * * This differs from standard division with `/` in that it rounds towards infinity instead * of rounding towards zero. */ function ceilDiv(uint256 a, uint256 b) internal pure returns (uint256) { if (b == 0) { // Guarantee the same behavior as in a regular Solidity division. Panic.panic(Panic.DIVISION_BY_ZERO); } // The following calculation ensures accurate ceiling division without overflow. // Since a is non-zero, (a - 1) / b will not overflow. // The largest possible result occurs when (a - 1) / b is type(uint256).max, // but the largest value we can obtain is type(uint256).max - 1, which happens // when a = type(uint256).max and b = 1. unchecked { return a == 0 ? 0 : (a - 1) / b + 1; } } /** * @dev Calculates floor(x * y / denominator) with full precision. Throws if result overflows a uint256 or * denominator == 0. * * Original credit to Remco Bloemen under MIT license (https://xn--2-umb.com/21/muldiv) with further edits by * Uniswap Labs also under MIT license. */ function mulDiv(uint256 x, uint256 y, uint256 denominator) internal pure returns (uint256 result) { unchecked { // 512-bit multiply [prod1 prod0] = x * y. Compute the product mod 2²⁵⁶ and mod 2²⁵⁶ - 1, then use // use the Chinese Remainder Theorem to reconstruct the 512 bit result. The result is stored in two 256 // variables such that product = prod1 * 2²⁵⁶ + prod0. uint256 prod0 = x * y; // Least significant 256 bits of the product uint256 prod1; // Most significant 256 bits of the product assembly { let mm := mulmod(x, y, not(0)) prod1 := sub(sub(mm, prod0), lt(mm, prod0)) } // Handle non-overflow cases, 256 by 256 division. if (prod1 == 0) { // Solidity will revert if denominator == 0, unlike the div opcode on its own. // The surrounding unchecked block does not change this fact. // See https://docs.soliditylang.org/en/latest/control-structures.html#checked-or-unchecked-arithmetic. return prod0 / denominator; } // Make sure the result is less than 2²⁵⁶. Also prevents denominator == 0. if (denominator <= prod1) { Panic.panic(denominator == 0 ? Panic.DIVISION_BY_ZERO : Panic.UNDER_OVERFLOW); } /////////////////////////////////////////////// // 512 by 256 division. /////////////////////////////////////////////// // Make division exact by subtracting the remainder from [prod1 prod0]. uint256 remainder; assembly { // Compute remainder using mulmod. remainder := mulmod(x, y, denominator) // Subtract 256 bit number from 512 bit number. prod1 := sub(prod1, gt(remainder, prod0)) prod0 := sub(prod0, remainder) } // Factor powers of two out of denominator and compute largest power of two divisor of denominator. // Always >= 1. See https://cs.stackexchange.com/q/138556/92363. uint256 twos = denominator & (0 - denominator); assembly { // Divide denominator by twos. denominator := div(denominator, twos) // Divide [prod1 prod0] by twos. prod0 := div(prod0, twos) // Flip twos such that it is 2²⁵⁶ / twos. If twos is zero, then it becomes one. twos := add(div(sub(0, twos), twos), 1) } // Shift in bits from prod1 into prod0. prod0 |= prod1 * twos; // Invert denominator mod 2²⁵⁶. Now that denominator is an odd number, it has an inverse modulo 2²⁵⁶ such // that denominator * inv ≡ 1 mod 2²⁵⁶. Compute the inverse by starting with a seed that is correct for // four bits. That is, denominator * inv ≡ 1 mod 2⁴. uint256 inverse = (3 * denominator) ^ 2; // Use the Newton-Raphson iteration to improve the precision. Thanks to Hensel's lifting lemma, this also // works in modular arithmetic, doubling the correct bits in each step. inverse *= 2 - denominator * inverse; // inverse mod 2⁸ inverse *= 2 - denominator * inverse; // inverse mod 2¹⁶ inverse *= 2 - denominator * inverse; // inverse mod 2³² inverse *= 2 - denominator * inverse; // inverse mod 2⁶⁴ inverse *= 2 - denominator * inverse; // inverse mod 2¹²⁸ inverse *= 2 - denominator * inverse; // inverse mod 2²⁵⁶ // Because the division is now exact we can divide by multiplying with the modular inverse of denominator. // This will give us the correct result modulo 2²⁵⁶. Since the preconditions guarantee that the outcome is // less than 2²⁵⁶, this is the final result. We don't need to compute the high bits of the result and prod1 // is no longer required. result = prod0 * inverse; return result; } } /** * @dev Calculates x * y / denominator with full precision, following the selected rounding direction. */ function mulDiv(uint256 x, uint256 y, uint256 denominator, Rounding rounding) internal pure returns (uint256) { return mulDiv(x, y, denominator) + SafeCast.toUint(unsignedRoundsUp(rounding) && mulmod(x, y, denominator) > 0); } /** * @dev Calculate the modular multiplicative inverse of a number in Z/nZ. * * If n is a prime, then Z/nZ is a field. In that case all elements are inversible, expect 0. * If n is not a prime, then Z/nZ is not a field, and some elements might not be inversible. * * If the input value is not inversible, 0 is returned. * * NOTE: If you know for sure that n is (big) a prime, it may be cheaper to use Ferma's little theorem and get the * inverse using `Math.modExp(a, n - 2, n)`. */ function invMod(uint256 a, uint256 n) internal pure returns (uint256) { unchecked { if (n == 0) return 0; // The inverse modulo is calculated using the Extended Euclidean Algorithm (iterative version) // Used to compute integers x and y such that: ax + ny = gcd(a, n). // When the gcd is 1, then the inverse of a modulo n exists and it's x. // ax + ny = 1 // ax = 1 + (-y)n // ax ≡ 1 (mod n) # x is the inverse of a modulo n // If the remainder is 0 the gcd is n right away. uint256 remainder = a % n; uint256 gcd = n; // Therefore the initial coefficients are: // ax + ny = gcd(a, n) = n // 0a + 1n = n int256 x = 0; int256 y = 1; while (remainder != 0) { uint256 quotient = gcd / remainder; (gcd, remainder) = ( // The old remainder is the next gcd to try. remainder, // Compute the next remainder. // Can't overflow given that (a % gcd) * (gcd // (a % gcd)) <= gcd // where gcd is at most n (capped to type(uint256).max) gcd - remainder * quotient ); (x, y) = ( // Increment the coefficient of a. y, // Decrement the coefficient of n. // Can overflow, but the result is casted to uint256 so that the // next value of y is "wrapped around" to a value between 0 and n - 1. x - y * int256(quotient) ); } if (gcd != 1) return 0; // No inverse exists. return x < 0 ? (n - uint256(-x)) : uint256(x); // Wrap the result if it's negative. } } /** * @dev Returns the modular exponentiation of the specified base, exponent and modulus (b ** e % m) * * Requirements: * - modulus can't be zero * - underlying staticcall to precompile must succeed * * IMPORTANT: The result is only valid if the underlying call succeeds. When using this function, make * sure the chain you're using it on supports the precompiled contract for modular exponentiation * at address 0x05 as specified in https://eips.ethereum.org/EIPS/eip-198[EIP-198]. Otherwise, * the underlying function will succeed given the lack of a revert, but the result may be incorrectly * interpreted as 0. */ function modExp(uint256 b, uint256 e, uint256 m) internal view returns (uint256) { (bool success, uint256 result) = tryModExp(b, e, m); if (!success) { Panic.panic(Panic.DIVISION_BY_ZERO); } return result; } /** * @dev Returns the modular exponentiation of the specified base, exponent and modulus (b ** e % m). * It includes a success flag indicating if the operation succeeded. Operation will be marked has failed if trying * to operate modulo 0 or if the underlying precompile reverted. * * IMPORTANT: The result is only valid if the success flag is true. When using this function, make sure the chain * you're using it on supports the precompiled contract for modular exponentiation at address 0x05 as specified in * https://eips.ethereum.org/EIPS/eip-198[EIP-198]. Otherwise, the underlying function will succeed given the lack * of a revert, but the result may be incorrectly interpreted as 0. */ function tryModExp(uint256 b, uint256 e, uint256 m) internal view returns (bool success, uint256 result) { if (m == 0) return (false, 0); /// @solidity memory-safe-assembly assembly { let ptr := mload(0x40) // | Offset | Content | Content (Hex) | // |-----------|------------|--------------------------------------------------------------------| // | 0x00:0x1f | size of b | 0x0000000000000000000000000000000000000000000000000000000000000020 | // | 0x20:0x3f | size of e | 0x0000000000000000000000000000000000000000000000000000000000000020 | // | 0x40:0x5f | size of m | 0x0000000000000000000000000000000000000000000000000000000000000020 | // | 0x60:0x7f | value of b | 0x<...............................b> | // | 0x80:0x9f | value of e | 0x<...............................e> | // | 0xa0:0xbf | value of m | 0x<...............................m> | mstore(ptr, 0x20) mstore(add(ptr, 0x20), 0x20) mstore(add(ptr, 0x40), 0x20) mstore(add(ptr, 0x60), b) mstore(add(ptr, 0x80), e) mstore(add(ptr, 0xa0), m) // Given the result < m, it's guaranteed to fit in 32 bytes, // so we can use the memory scratch space located at offset 0. success := staticcall(gas(), 0x05, ptr, 0xc0, 0x00, 0x20) result := mload(0x00) } } /** * @dev Variant of {modExp} that supports inputs of arbitrary length. */ function modExp(bytes memory b, bytes memory e, bytes memory m) internal view returns (bytes memory) { (bool success, bytes memory result) = tryModExp(b, e, m); if (!success) { Panic.panic(Panic.DIVISION_BY_ZERO); } return result; } /** * @dev Variant of {tryModExp} that supports inputs of arbitrary length. */ function tryModExp( bytes memory b, bytes memory e, bytes memory m ) internal view returns (bool success, bytes memory result) { if (_zeroBytes(m)) return (false, new bytes(0)); uint256 mLen = m.length; // Encode call args in result and move the free memory pointer result = abi.encodePacked(b.length, e.length, mLen, b, e, m); /// @solidity memory-safe-assembly assembly { let dataPtr := add(result, 0x20) // Write result on top of args to avoid allocating extra memory. success := staticcall(gas(), 0x05, dataPtr, mload(result), dataPtr, mLen) // Overwrite the length. // result.length > returndatasize() is guaranteed because returndatasize() == m.length mstore(result, mLen) // Set the memory pointer after the returned data. mstore(0x40, add(dataPtr, mLen)) } } /** * @dev Returns whether the provided byte array is zero. */ function _zeroBytes(bytes memory byteArray) private pure returns (bool) { for (uint256 i = 0; i < byteArray.length; ++i) { if (byteArray[i] != 0) { return false; } } return true; } /** * @dev Returns the square root of a number. If the number is not a perfect square, the value is rounded * towards zero. * * This method is based on Newton's method for computing square roots; the algorithm is restricted to only * using integer operations. */ function sqrt(uint256 a) internal pure returns (uint256) { unchecked { // Take care of easy edge cases when a == 0 or a == 1 if (a <= 1) { return a; } // In this function, we use Newton's method to get a root of `f(x) := x² - a`. It involves building a // sequence x_n that converges toward sqrt(a). For each iteration x_n, we also define the error between // the current value as `ε_n = | x_n - sqrt(a) |`. // // For our first estimation, we consider `e` the smallest power of 2 which is bigger than the square root // of the target. (i.e. `2**(e-1) ≤ sqrt(a) < 2**e`). We know that `e ≤ 128` because `(2¹²⁸)² = 2²⁵⁶` is // bigger than any uint256. // // By noticing that // `2**(e-1) ≤ sqrt(a) < 2**e → (2**(e-1))² ≤ a < (2**e)² → 2**(2*e-2) ≤ a < 2**(2*e)` // we can deduce that `e - 1` is `log2(a) / 2`. We can thus compute `x_n = 2**(e-1)` using a method similar // to the msb function. uint256 aa = a; uint256 xn = 1; if (aa >= (1 << 128)) { aa >>= 128; xn <<= 64; } if (aa >= (1 << 64)) { aa >>= 64; xn <<= 32; } if (aa >= (1 << 32)) { aa >>= 32; xn <<= 16; } if (aa >= (1 << 16)) { aa >>= 16; xn <<= 8; } if (aa >= (1 << 8)) { aa >>= 8; xn <<= 4; } if (aa >= (1 << 4)) { aa >>= 4; xn <<= 2; } if (aa >= (1 << 2)) { xn <<= 1; } // We now have x_n such that `x_n = 2**(e-1) ≤ sqrt(a) < 2**e = 2 * x_n`. This implies ε_n ≤ 2**(e-1). // // We can refine our estimation by noticing that the middle of that interval minimizes the error. // If we move x_n to equal 2**(e-1) + 2**(e-2), then we reduce the error to ε_n ≤ 2**(e-2). // This is going to be our x_0 (and ε_0) xn = (3 * xn) >> 1; // ε_0 := | x_0 - sqrt(a) | ≤ 2**(e-2) // From here, Newton's method give us: // x_{n+1} = (x_n + a / x_n) / 2 // // One should note that: // x_{n+1}² - a = ((x_n + a / x_n) / 2)² - a // = ((x_n² + a) / (2 * x_n))² - a // = (x_n⁴ + 2 * a * x_n² + a²) / (4 * x_n²) - a // = (x_n⁴ + 2 * a * x_n² + a² - 4 * a * x_n²) / (4 * x_n²) // = (x_n⁴ - 2 * a * x_n² + a²) / (4 * x_n²) // = (x_n² - a)² / (2 * x_n)² // = ((x_n² - a) / (2 * x_n))² // ≥ 0 // Which proves that for all n ≥ 1, sqrt(a) ≤ x_n // // This gives us the proof of quadratic convergence of the sequence: // ε_{n+1} = | x_{n+1} - sqrt(a) | // = | (x_n + a / x_n) / 2 - sqrt(a) | // = | (x_n² + a - 2*x_n*sqrt(a)) / (2 * x_n) | // = | (x_n - sqrt(a))² / (2 * x_n) | // = | ε_n² / (2 * x_n) | // = ε_n² / | (2 * x_n) | // // For the first iteration, we have a special case where x_0 is known: // ε_1 = ε_0² / | (2 * x_0) | // ≤ (2**(e-2))² / (2 * (2**(e-1) + 2**(e-2))) // ≤ 2**(2*e-4) / (3 * 2**(e-1)) // ≤ 2**(e-3) / 3 // ≤ 2**(e-3-log2(3)) // ≤ 2**(e-4.5) // // For the following iterations, we use the fact that, 2**(e-1) ≤ sqrt(a) ≤ x_n: // ε_{n+1} = ε_n² / | (2 * x_n) | // ≤ (2**(e-k))² / (2 * 2**(e-1)) // ≤ 2**(2*e-2*k) / 2**e // ≤ 2**(e-2*k) xn = (xn + a / xn) >> 1; // ε_1 := | x_1 - sqrt(a) | ≤ 2**(e-4.5) -- special case, see above xn = (xn + a / xn) >> 1; // ε_2 := | x_2 - sqrt(a) | ≤ 2**(e-9) -- general case with k = 4.5 xn = (xn + a / xn) >> 1; // ε_3 := | x_3 - sqrt(a) | ≤ 2**(e-18) -- general case with k = 9 xn = (xn + a / xn) >> 1; // ε_4 := | x_4 - sqrt(a) | ≤ 2**(e-36) -- general case with k = 18 xn = (xn + a / xn) >> 1; // ε_5 := | x_5 - sqrt(a) | ≤ 2**(e-72) -- general case with k = 36 xn = (xn + a / xn) >> 1; // ε_6 := | x_6 - sqrt(a) | ≤ 2**(e-144) -- general case with k = 72 // Because e ≤ 128 (as discussed during the first estimation phase), we know have reached a precision // ε_6 ≤ 2**(e-144) < 1. Given we're operating on integers, then we can ensure that xn is now either // sqrt(a) or sqrt(a) + 1. return xn - SafeCast.toUint(xn > a / xn); } } /** * @dev Calculates sqrt(a), following the selected rounding direction. */ function sqrt(uint256 a, Rounding rounding) internal pure returns (uint256) { unchecked { uint256 result = sqrt(a); return result + SafeCast.toUint(unsignedRoundsUp(rounding) && result * result < a); } } /** * @dev Return the log in base 2 of a positive value rounded towards zero. * Returns 0 if given 0. */ function log2(uint256 value) internal pure returns (uint256) { uint256 result = 0; uint256 exp; unchecked { exp = 128 * SafeCast.toUint(value > (1 << 128) - 1); value >>= exp; result += exp; exp = 64 * SafeCast.toUint(value > (1 << 64) - 1); value >>= exp; result += exp; exp = 32 * SafeCast.toUint(value > (1 << 32) - 1); value >>= exp; result += exp; exp = 16 * SafeCast.toUint(value > (1 << 16) - 1); value >>= exp; result += exp; exp = 8 * SafeCast.toUint(value > (1 << 8) - 1); value >>= exp; result += exp; exp = 4 * SafeCast.toUint(value > (1 << 4) - 1); value >>= exp; result += exp; exp = 2 * SafeCast.toUint(value > (1 << 2) - 1); value >>= exp; result += exp; result += SafeCast.toUint(value > 1); } return result; } /** * @dev Return the log in base 2, following the selected rounding direction, of a positive value. * Returns 0 if given 0. */ function log2(uint256 value, Rounding rounding) internal pure returns (uint256) { unchecked { uint256 result = log2(value); return result + SafeCast.toUint(unsignedRoundsUp(rounding) && 1 << result < value); } } /** * @dev Return the log in base 10 of a positive value rounded towards zero. * Returns 0 if given 0. */ function log10(uint256 value) internal pure returns (uint256) { uint256 result = 0; unchecked { if (value >= 10 ** 64) { value /= 10 ** 64; result += 64; } if (value >= 10 ** 32) { value /= 10 ** 32; result += 32; } if (value >= 10 ** 16) { value /= 10 ** 16; result += 16; } if (value >= 10 ** 8) { value /= 10 ** 8; result += 8; } if (value >= 10 ** 4) { value /= 10 ** 4; result += 4; } if (value >= 10 ** 2) { value /= 10 ** 2; result += 2; } if (value >= 10 ** 1) { result += 1; } } return result; } /** * @dev Return the log in base 10, following the selected rounding direction, of a positive value. * Returns 0 if given 0. */ function log10(uint256 value, Rounding rounding) internal pure returns (uint256) { unchecked { uint256 result = log10(value); return result + SafeCast.toUint(unsignedRoundsUp(rounding) && 10 ** result < value); } } /** * @dev Return the log in base 256 of a positive value rounded towards zero. * Returns 0 if given 0. * * Adding one to the result gives the number of pairs of hex symbols needed to represent `value` as a hex string. */ function log256(uint256 value) internal pure returns (uint256) { uint256 result = 0; uint256 isGt; unchecked { isGt = SafeCast.toUint(value > (1 << 128) - 1); value >>= isGt * 128; result += isGt * 16; isGt = SafeCast.toUint(value > (1 << 64) - 1); value >>= isGt * 64; result += isGt * 8; isGt = SafeCast.toUint(value > (1 << 32) - 1); value >>= isGt * 32; result += isGt * 4; isGt = SafeCast.toUint(value > (1 << 16) - 1); value >>= isGt * 16; result += isGt * 2; result += SafeCast.toUint(value > (1 << 8) - 1); } return result; } /** * @dev Return the log in base 256, following the selected rounding direction, of a positive value. * Returns 0 if given 0. */ function log256(uint256 value, Rounding rounding) internal pure returns (uint256) { unchecked { uint256 result = log256(value); return result + SafeCast.toUint(unsignedRoundsUp(rounding) && 1 << (result << 3) < value); } } /** * @dev Returns whether a provided rounding mode is considered rounding up for unsigned integers. */ function unsignedRoundsUp(Rounding rounding) internal pure returns (bool) { return uint8(rounding) % 2 == 1; } }