Skip to content
This repository has been archived by the owner on Nov 24, 2022. It is now read-only.

Implement & use safeFromIntegral #776

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 28 additions & 18 deletions asterius/src/Asterius/Backends/Binaryen.hs
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE MultiWayIf #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE StrictData #-}
{-# OPTIONS_GHC -Wno-overflowed-literals #-}
{-# OPTIONS_GHC -Wno-orphans -Wno-overflowed-literals #-}

-- |
-- Module : Asterius.Backends.Binaryen
Expand All @@ -28,6 +31,7 @@ import qualified Asterius.Internals.Arena as A
import Asterius.Internals.Barf
import Asterius.Internals.MagicNumber
import Asterius.Internals.Marshal
import Asterius.Internals.SafeFromIntegral
import Asterius.Types
import qualified Asterius.Types.SymbolMap as SM
import Asterius.TypesConv
Expand Down Expand Up @@ -62,6 +66,12 @@ import Foreign.C
import GHC.Exts
import Language.Haskell.GHC.Toolkit.Constants

deriving newtype instance Bounded Binaryen.Index
deriving newtype instance Enum Binaryen.Index
deriving newtype instance Integral Binaryen.Index
deriving newtype instance Ord Binaryen.Index
deriving newtype instance Real Binaryen.Index

newtype MarshalError
= UnsupportedExpression Expression
deriving (Show)
Expand All @@ -85,7 +95,7 @@ marshalValueTypes vts = do
a <- askArena
lift $ do
(vts', vtl) <- marshalV a $ map marshalValueType vts
Binaryen.Type.create vts' (fromIntegral vtl)
Binaryen.Type.create vts' (safeFromIntegral vtl)

marshalReturnTypes :: [ValueType] -> CodeGen Binaryen.Type
marshalReturnTypes vts = case vts of
Expand Down Expand Up @@ -292,7 +302,7 @@ marshalExpression e = case e of
lift $ do
(bsp, bl) <- marshalV a bs
np <- marshalBS a name
Binaryen.block m np bsp (fromIntegral bl) rts
Binaryen.block m np bsp (safeFromIntegral bl) rts
If {..} -> do
c <- marshalExpression condition
t <- marshalExpression ifTrue
Expand Down Expand Up @@ -321,7 +331,7 @@ marshalExpression e = case e of
ns <- forM names $ marshalBS a
(nsp, nl) <- marshalV a ns
dn <- marshalBS a defaultName
Binaryen.switch m nsp (fromIntegral nl) dn c (coerce nullPtr)
Binaryen.switch m nsp (safeFromIntegral nl) dn c (coerce nullPtr)
Call {..} -> do
verbose_err <- isVerboseErrOn
func_sym_map <- askFunctionsSymbolMap
Expand All @@ -343,7 +353,7 @@ marshalExpression e = case e of
lift $ do
(ops, osl) <- marshalV a os
tp <- marshalBS a (entityName target)
Binaryen.call m tp ops (fromIntegral osl) rts
Binaryen.call m tp ops (safeFromIntegral osl) rts
| verbose_err ->
marshalExpression $
barf (entityName target) callReturnTypes
Expand All @@ -358,7 +368,7 @@ marshalExpression e = case e of
lift $ do
(ops, osl) <- marshalV a os
tp <- marshalBS a target'
Binaryen.call m tp ops (fromIntegral osl) rts
Binaryen.call m tp ops (safeFromIntegral osl) rts
CallIndirect {..} -> do
t <- marshalExpression indirectTarget
os <- forM operands marshalExpression
Expand All @@ -367,7 +377,7 @@ marshalExpression e = case e of
a <- askArena
lift $ do
(ops, osl) <- marshalV a os
Binaryen.callIndirect m t ops (fromIntegral osl) pt rt
Binaryen.callIndirect m t ops (safeFromIntegral osl) pt rt
GetLocal {..} -> do
m <- askModuleRef
lift $ Binaryen.localGet m (coerce index) $ marshalValueType valueType
Expand Down Expand Up @@ -505,7 +515,7 @@ marshalExpression e = case e of
func_sym_map <- askFunctionsSymbolMap
m <- askModuleRef
if | Just x <- SM.lookup unresolvedSymbol ss_sym_map ->
lift $ Binaryen.constInt64 m $ x + fromIntegral symbolOffset
lift $ Binaryen.constInt64 m $ x + safeFromIntegral symbolOffset
| Just x <- SM.lookup unresolvedSymbol func_sym_map ->
let base =
GetGlobal
Expand All @@ -515,7 +525,7 @@ marshalExpression e = case e of
in marshalExpression $
addInt64
(extendUInt32 base)
(constI64 $ fromIntegral x + symbolOffset)
(constI64 $ safeFromIntegral x + symbolOffset)
| verbose_err ->
marshalExpression $ barf (entityName unresolvedSymbol) [I64]
| otherwise ->
Expand Down Expand Up @@ -548,7 +558,7 @@ marshalFunction k (pt, rt) Function {..} = do
lift $ do
(vtp, vtl) <- marshalV a $ map marshalValueType varTypes
np <- marshalBS a k
Binaryen.addFunction m np pt rt vtp (fromIntegral vtl) b
Binaryen.addFunction m np pt rt vtp (safeFromIntegral vtl) b

marshalFunctionImport ::
Binaryen.Module ->
Expand Down Expand Up @@ -581,10 +591,10 @@ marshalFunctionTable m tbl_slots FunctionTable {..} = do
(fnp, fnl) <- marshalV a func_name_ptrs
Binaryen.setFunctionTable
m
(fromIntegral tbl_slots)
(safeFromIntegral tbl_slots)
(-1)
fnp
(fromIntegral fnl)
(safeFromIntegral fnl)
o

marshalMemorySegments :: Int -> [DataSegment] -> CodeGen ()
Expand All @@ -603,17 +613,17 @@ marshalMemorySegments mbs segs = do
( \DataSegment {..} ->
flip runReaderT env $ marshalExpression $ ConstI32 offset
)
(seg_sizes, _) <- marshalV a $ map (fromIntegral . BS.length . content) segs
(seg_sizes, _) <- marshalV a $ map (safeFromIntegral . BS.length . content) segs
Binaryen.setMemory
m
(fromIntegral $ mbs * (mblock_size `quot` wasmPageSize))
(safeFromIntegral $ mbs * (mblock_size `quot` wasmPageSize))
(-1)
nullPtr
seg_bufs
seg_passives
seg_offsets
seg_sizes
(fromIntegral segs_len)
(safeFromIntegral segs_len)
0

marshalTableImport :: Binaryen.Module -> TableImport -> CodeGen ()
Expand Down Expand Up @@ -642,7 +652,7 @@ marshalGlobalImport m GlobalImport {..} = do
emp <- marshalBS a externalModuleName
ebp <- marshalBS a externalBaseName
let (ty, mut) = marshalGlobalType globalType
Binaryen.addGlobalImport m inp emp ebp ty (fromIntegral mut)
Binaryen.addGlobalImport m inp emp ebp ty (safeFromIntegral mut)

marshalGlobalExport ::
Binaryen.Module -> GlobalExport -> CodeGen Binaryen.Export
Expand Down Expand Up @@ -738,7 +748,7 @@ relooperAddBranch bm k ab = case ab of
(bm M.! k)
(bm M.! to)
(coerce idp)
(fromIntegral idn)
(safeFromIntegral idn)
(coerce nullPtr)

relooperRun :: RelooperRun -> CodeGen Binaryen.Expression
Expand All @@ -760,7 +770,7 @@ serializeModule m = alloca $ \(buf_p :: Ptr (Ptr ())) ->
Binaryen.allocateAndWriteMut m nullPtr buf_p len_p src_map_p
buf <- peek buf_p
len <- peek len_p
BS.unsafePackMallocCStringLen (castPtr buf, fromIntegral len)
BS.unsafePackMallocCStringLen (castPtr buf, safeFromIntegral len)

serializeModuleSExpr :: Binaryen.Module -> IO BS.ByteString
serializeModuleSExpr m =
Expand Down
11 changes: 6 additions & 5 deletions asterius/src/Asterius/Backends/WasmToolkit.hs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import Asterius.Builtins
import Asterius.EDSL (addInt64, constI64, extendUInt32)
import Asterius.Internals.Barf
import Asterius.Internals.MagicNumber
import Asterius.Internals.SafeFromIntegral
import Asterius.Passes.Relooper
import Asterius.TypeInfer
import Asterius.Types
Expand Down Expand Up @@ -155,7 +156,7 @@ makeImportSection Module {..} ModuleSymbolTable {..} = pure Wasm.ImportSection
importName = coerce $ SBS.toShort externalBaseName,
importDescription = Wasm.ImportMemory $ Wasm.MemoryType $ Wasm.Limits
{ minLimit =
fromIntegral $
safeFromIntegral $
memoryMBlocks
* (mblock_size `quot` wasmPageSize),
maxLimit = Nothing
Expand All @@ -167,7 +168,7 @@ makeImportSection Module {..} ModuleSymbolTable {..} = pure Wasm.ImportSection
{ moduleName = coerce $ SBS.toShort externalModuleName,
importName = coerce $ SBS.toShort externalBaseName,
importDescription = Wasm.ImportTable $ Wasm.TableType Wasm.AnyFunc $ Wasm.Limits
{ minLimit = fromIntegral tableSlots,
{ minLimit = safeFromIntegral tableSlots,
maxLimit = Nothing
}
}
Expand Down Expand Up @@ -347,7 +348,7 @@ makeLocalContext Module {} Function {..} =
$ sort
$ zip varTypes [arity ..]
where
arity = fromIntegral $ length $ paramTypes functionType
arity = safeFromIntegral $ length $ paramTypes functionType

lookupLocalContext :: LocalContext -> BinaryenIndex -> Wasm.LocalIndex
lookupLocalContext LocalContext {..} i = coerce $ case Map.lookup i localMap of
Expand Down Expand Up @@ -785,7 +786,7 @@ makeInstructions expr =
func_sym_map <- askFunctionsSymbolMap
if | Just x <- SM.lookup unresolvedSymbol ss_sym_map ->
pure $ unitBag Wasm.I64Const
{ i64ConstValue = x + fromIntegral symbolOffset
{ i64ConstValue = x + safeFromIntegral symbolOffset
}
| Just x <- SM.lookup unresolvedSymbol func_sym_map ->
let base =
Expand All @@ -796,7 +797,7 @@ makeInstructions expr =
in makeInstructions $
addInt64
(extendUInt32 base)
(constI64 $ fromIntegral x + symbolOffset)
(constI64 $ safeFromIntegral x + symbolOffset)
| verbose_err ->
makeInstructions $ barf (entityName unresolvedSymbol) [I64]
| otherwise ->
Expand Down
5 changes: 3 additions & 2 deletions asterius/src/Asterius/Builtins.hs
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ import Asterius.Builtins.Time
import Asterius.EDSL
import Asterius.Internals
import Asterius.Internals.MagicNumber
import Asterius.Internals.SafeFromIntegral
import Asterius.Types
import qualified Asterius.Types.SymbolMap as SM
import qualified Data.ByteString as BS
Expand Down Expand Up @@ -1557,7 +1558,7 @@ genWrap ti b x = Block
},
Load
{ signed = False,
bytes = fromIntegral b,
bytes = safeFromIntegral b,
offset = 0,
valueType = I32,
ptr = wrapInt64 (symbol "__asterius_i64_slot")
Expand Down Expand Up @@ -1599,7 +1600,7 @@ genExtend b to sext x = Block
},
Load
{ signed = sext == Sext,
bytes = fromIntegral b,
bytes = safeFromIntegral b,
offset = 0,
valueType = to,
ptr = wrapInt64 (symbol "__asterius_i64_slot")
Expand Down
25 changes: 13 additions & 12 deletions asterius/src/Asterius/Builtins/Posix.hs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ module Asterius.Builtins.Posix
where

import Asterius.EDSL
import Asterius.Internals.SafeFromIntegral
import Asterius.Types
import qualified Asterius.Types.SymbolMap as SM
import qualified Data.ByteString as BS
Expand Down Expand Up @@ -286,17 +287,17 @@ posixConstants =
}
)
| (k, v) <-
[ ("__hscore_sizeof_stat", fromIntegral sizeof_stat),
("__hscore_o_rdonly", fromIntegral o_RDONLY),
("__hscore_o_wronly", fromIntegral o_WRONLY),
("__hscore_o_rdwr", fromIntegral o_RDWR),
("__hscore_o_append", fromIntegral o_APPEND),
("__hscore_o_creat", fromIntegral o_CREAT),
("__hscore_o_excl", fromIntegral o_EXCL),
("__hscore_o_trunc", fromIntegral o_TRUNC),
("__hscore_o_noctty", fromIntegral o_NOCTTY),
("__hscore_o_nonblock", fromIntegral o_NONBLOCK),
("__hscore_o_binary", fromIntegral o_BINARY)
[ ("__hscore_sizeof_stat", safeFromIntegral sizeof_stat),
("__hscore_o_rdonly", safeFromIntegral o_RDONLY),
("__hscore_o_wronly", safeFromIntegral o_WRONLY),
("__hscore_o_rdwr", safeFromIntegral o_RDWR),
("__hscore_o_append", safeFromIntegral o_APPEND),
("__hscore_o_creat", safeFromIntegral o_CREAT),
("__hscore_o_excl", safeFromIntegral o_EXCL),
("__hscore_o_trunc", safeFromIntegral o_TRUNC),
("__hscore_o_noctty", safeFromIntegral o_NOCTTY),
("__hscore_o_nonblock", safeFromIntegral o_NONBLOCK),
("__hscore_o_binary", safeFromIntegral o_BINARY)
]
]
}
Expand All @@ -311,7 +312,7 @@ offset_stat_mtime,
unsafePerformIO $
allocaBytes sizeof_stat $ \p -> do
forM_ [0 .. sizeof_stat - 1] $
\i -> pokeByteOff p i (fromIntegral i :: Word8)
\i -> pokeByteOff p i (safeFromIntegral i :: Word8)
_mtime <- (.&. 0xFF) . fromEnum <$> st_mtime p
_size <- (.&. 0xFF) . fromEnum <$> st_size p
_mode <- (.&. 0xFF) . fromEnum <$> st_mode p
Expand Down
29 changes: 15 additions & 14 deletions asterius/src/Asterius/CodeGen.hs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import Asterius.CodeGen.Droppable
import Asterius.EDSL
import Asterius.Internals
import Asterius.Internals.Name
import Asterius.Internals.SafeFromIntegral
import Asterius.Passes.All
import Asterius.Passes.GlobalRegs
import Asterius.Resolve
Expand Down Expand Up @@ -124,20 +125,20 @@ marshalCmmStatic st = case st of
<$> dispatchAllCmmWidth
w
( if x < 0
then encodeStorable (fromIntegral x :: Int8)
else encodeStorable (fromIntegral x :: Word8)
then encodeStorable (safeFromIntegral x :: Int8)
else encodeStorable (safeFromIntegral x :: Word8)
)
( if x < 0
then encodeStorable (fromIntegral x :: Int16)
else encodeStorable (fromIntegral x :: Word16)
then encodeStorable (safeFromIntegral x :: Int16)
else encodeStorable (safeFromIntegral x :: Word16)
)
( if x < 0
then encodeStorable (fromIntegral x :: Int32)
else encodeStorable (fromIntegral x :: Word32)
then encodeStorable (safeFromIntegral x :: Int32)
else encodeStorable (safeFromIntegral x :: Word32)
)
( if x < 0
then encodeStorable (fromIntegral x :: Int64)
else encodeStorable (fromIntegral x :: Word64)
then encodeStorable (safeFromIntegral x :: Int64)
else encodeStorable (safeFromIntegral x :: Word64)
)
GHC.CmmFloat x w ->
Serialized
Expand Down Expand Up @@ -209,8 +210,8 @@ marshalCmmLit lit = case lit of
GHC.CmmInt x w ->
dispatchCmmWidth
w
(ConstI32 $ fromIntegral x, I32)
(ConstI64 $ fromIntegral x, I64)
(ConstI32 $ safeFromIntegral x, I32)
(ConstI64 $ safeFromIntegral x, I64)
GHC.CmmFloat x w ->
dispatchCmmWidth
w
Expand Down Expand Up @@ -297,7 +298,7 @@ marshalCmmRegOff r o = do
( Binary
{ binaryOp = AddInt32,
operand0 = re,
operand1 = ConstI32 $ fromIntegral o
operand1 = ConstI32 $ safeFromIntegral o
},
vt
)
Expand All @@ -306,7 +307,7 @@ marshalCmmRegOff r o = do
( Binary
{ binaryOp = AddInt64,
operand0 = re,
operand1 = ConstI64 $ fromIntegral o
operand1 = ConstI64 $ safeFromIntegral o
},
vt
)
Expand Down Expand Up @@ -1582,7 +1583,7 @@ marshalCmmBlockBranch instr = case instr of
a <- marshalAndCastCmmExpr cml_arg I64
brs <- for (GHC.switchTargetsCases st) $ \(idx, lbl) -> do
dest <- marshalLabel lbl
pure (dest, [fromIntegral $ idx - fst (GHC.switchTargetsRange st)])
pure (dest, [safeFromIntegral $ idx - fst (GHC.switchTargetsRange st)])
(needs_unreachable, dest_def) <- case GHC.switchTargetsDefault st of
Just lbl -> do
klbl <- marshalLabel lbl
Expand All @@ -1597,7 +1598,7 @@ marshalCmmBlockBranch instr = case instr of
(l, _) -> Binary
{ binaryOp = SubInt64,
operand0 = a,
operand1 = ConstI64 $ fromIntegral l
operand1 = ConstI64 $ safeFromIntegral l
}
},
[ AddBranchForSwitch {to = dest, indexes = tags}
Expand Down
Loading