Skip to content

Commit

Permalink
Add NaN and Inf checking to MPFR functions
Browse files Browse the repository at this point in the history
  • Loading branch information
jmcardon committed Oct 16, 2024
1 parent 73492d5 commit 420a113
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 9 deletions.
1 change: 1 addition & 0 deletions pact-tests/pact-tests/ops.repl
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,7 @@
(expect "exp 16.0" 8886110.520507872104644775390625 (exp 16.0))
(expect "exp 15" 3269017.3724721106700599193572998046875 (exp 15))
(expect "exp 15.0" 3269017.3724721106700599193572998046875 (exp 15.0))
(expect-failure "exp produces +inf if the operand is too large" (exp (^ 420 420)))

"===== math.pow"
(expect "^ 0 0" 1 (^ 0 0))
Expand Down
50 changes: 41 additions & 9 deletions pact/Pact/Core/Trans/MPFR.hs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
-- |
-- |
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE CPP #-}
{-# LANGUAGE DeriveFunctor #-}
Expand Down Expand Up @@ -40,9 +40,9 @@ import System.IO.Unsafe (unsafePerformIO)

data TransResult a
= TransNumber !a
| TransNaN !a
| TransInf !a
| TransNegInf !a
| TransNaN
| TransInf
| TransNegInf
deriving (Show, Eq, Ord, Functor, Foldable, Traversable)

data MPZ = MPZ {
Expand Down Expand Up @@ -157,6 +157,17 @@ foreign import ccall "mpfr_exp"
foreign import ccall "mpfr_sqrt"
c'mpfr_sqrt :: Mpfr_t -> Mpfr_t -> CInt -> IO ()

foreign import ccall "mpfr_number_p"
c'mpfr_number_p :: Mpfr_t -> IO CInt

foreign import ccall "mpfr_inf_p"
c'mpfr_inf_p :: Mpfr_t -> IO CInt

foreign import ccall "mpfr_nan_p"
c'mpfr_nan_p :: Mpfr_t -> IO CInt

foreign import ccall "mpfr_sgn"
c'mpfr_sgn :: Mpfr_t -> IO CInt
{-------------------------------------------------------------------------
-- OPERATIONS
-------------------------------------------------------------------------}
Expand Down Expand Up @@ -214,9 +225,9 @@ mpfr2Dec m =
where
readResultNumber :: String -> TransResult Decimal
readResultNumber (' ':s) = readResultNumber s
readResultNumber "nan" = TransNaN 0
readResultNumber "inf" = TransInf 0
readResultNumber "-inf" = TransNegInf 0
readResultNumber "nan" = TransNaN
readResultNumber "inf" = TransInf
readResultNumber "-inf" = TransNegInf
readResultNumber "0" = TransNumber 0
readResultNumber n =
TransNumber (fromRational (read (trimZeroes n) % 1))
Expand All @@ -235,7 +246,28 @@ mpfr_arity1 f x = unsafePerformIO $
dec2Mpfr x $ \x' ->
withTemp $ \y' -> do
f y' x' rounding
mpfr2Dec y'
checkOutput y'

checkOutput :: Mpfr_t -> IO (TransResult Decimal)
checkOutput result = do
is_num <- c'mpfr_number_p result
if is_num == 1 then
mpfr2Dec result
else do
-- check for infinity
is_inf <- c'mpfr_inf_p result
if is_inf == 1 then do
-- check for sign
sgn <- c'mpfr_sgn result
if sgn > 0 then pure TransInf
else pure TransNegInf
else do
is_nan <- c'mpfr_nan_p result
if is_nan == 1 then pure $ TransNaN
-- The only remaining case is underflow, since overflows will go into +Inf
else pure $ TransNumber 0



mpfr_arity2
:: (Mpfr_t -> Mpfr_t -> Mpfr_t -> CInt -> IO ())
Expand All @@ -245,4 +277,4 @@ mpfr_arity2 f x y = unsafePerformIO $
dec2Mpfr y $ \y' ->
withTemp $ \z' -> do
f z' x' y' rounding
mpfr2Dec z'
checkOutput z'

0 comments on commit 420a113

Please sign in to comment.