From 82c2e0ba1825cc05fc08be91fbf41b7952984393 Mon Sep 17 00:00:00 2001 From: rayxpub Date: Sun, 25 Aug 2024 15:48:21 +0400 Subject: [PATCH] feat: implement solidity ast parsing --- README.md | 124 +++++++++++++++++++++------------ config/config.go | 1 + config/defaults.go | 1 + examples/RequestMeta.txt | 13 ---- main.go | 145 +++++++++++++++++++++++++++------------ parser/parsing.go | 133 +++++++++++++++++++++++++---------- script/Counter.s.sol | 19 ----- src/Counter.sol | 14 ---- src/Spack.sol | 23 +++++++ test/Counter.t.sol | 24 ------- 10 files changed, 304 insertions(+), 193 deletions(-) delete mode 100644 examples/RequestMeta.txt delete mode 100644 script/Counter.s.sol delete mode 100644 src/Counter.sol create mode 100644 src/Spack.sol delete mode 100644 test/Counter.t.sol diff --git a/README.md b/README.md index 9265b45..e8c3bf6 100644 --- a/README.md +++ b/README.md @@ -1,66 +1,106 @@ -## Foundry - -**Foundry is a blazing fast, portable and modular toolkit for Ethereum application development written in Rust.** - -Foundry consists of: +# Spack + +Spack parses Solidity structs and packs the fields efficiently to reduce the +number of storage slots they use. It also adds struct packing comments to clearly indicate +how the fields are packed. + +It can deal with comments and whitespace in the struct definition, and will +preserve them in the output. It handles unknown types by assuming they cannot be +packed, treating they as `bytes32`. + +## Disclaimer + +This code is a work in progress and can contain bugs. Use it at your own risk. +Feature request and bug reports are welcome. + +### Example + +input + +```solidity + struct RequestMeta { + uint64 completedRequests; + Custom.Datatype data; + address requestingContract; + uint72 adminFee; // in wei + address subscriptionOwner; + bytes32 flags; // 32 bytes of flags + uint96 availableBalance; // in wei. 0 if not specified. + uint64 subscriptionId; + uint64 initiatedRequests;// number of requests initiated by this contract + uint32 callbackGasLimit; + uint16 dataVersion; + } +``` -- **Forge**: Ethereum testing framework (like Truffle, Hardhat and DappTools). -- **Cast**: Swiss army knife for interacting with EVM smart contracts, sending transactions and getting chain data. -- **Anvil**: Local Ethereum node, akin to Ganache, Hardhat Network. -- **Chisel**: Fast, utilitarian, and verbose solidity REPL. +output + +```solidity + struct RequestMeta { + Custom.Datatype data; // + bytes32 flags; // + address requestingContract; // ──╮ + uint96 availableBalance; // ─────╯ + address subscriptionOwner; // ───╮ + uint64 completedRequests; // │ + uint32 callbackGasLimit; // ─────╯ + uint72 adminFee; // ─────────────╮ + uint64 subscriptionId; // │ + uint64 initiatedRequests; // │ + uint16 dataVersion; // ──────────╯ + } +``` -## Documentation +## Quickstart -https://book.getfoundry.sh/ +build -## Usage +```bash +go build +``` -### Build +Selecting all structs from a contract -```shell -$ forge build +```bash +./spack -c ``` -### Test +Selecting a specific struct from a contract -```shell -$ forge test +```bash +./spack -c -s ``` -### Format +Counting storage slots -```shell -$ forge fmt +```bash +./spack -c Spack count ``` -### Gas Snapshots +Printing structs without optimizations but with struct packing comments -```shell -$ forge snapshot +```bash +./spack -c Spack -u ``` -### Anvil +Printing the struct with optimizations and struct packing comments: -```shell -$ anvil -``` +## Commands and flags -### Deploy +### Commands -```shell -$ forge script script/Counter.s.sol:CounterScript --rpc-url --private-key -``` +- `pack` - packs the struct +- `count` - counts the number of storage slots the struct uses -### Cast +### Flags -```shell -$ cast -``` +- `-c` or `--contract` - load all structs from a contract +- `-s` or `--struct` - load a specific struct from a contract +- `-u` or `--unoptimized` - print structs without optimizations but with struct packing comments +- `--cfg` or `--config` - load the config from a file -### Help +## TODO -```shell -$ forge --help -$ anvil --help -$ cast --help -``` +- [ ] Add more flexible command line options +- [ ] Add tests +- [ ] Improve error handling diff --git a/config/config.go b/config/config.go index aecfdf9..2eed2fe 100644 --- a/config/config.go +++ b/config/config.go @@ -8,6 +8,7 @@ import ( type GlobalConfig struct { PrintingConfig PrintingConfig + OutDir string } type PrintingConfig struct { diff --git a/config/defaults.go b/config/defaults.go index 68fd485..bcb77f9 100644 --- a/config/defaults.go +++ b/config/defaults.go @@ -3,6 +3,7 @@ package config func GetDefaultConfig() GlobalConfig { return GlobalConfig{ PrintingConfig: GetDefaultPrintingConfig(), + OutDir: "./out/", } } diff --git a/examples/RequestMeta.txt b/examples/RequestMeta.txt deleted file mode 100644 index 296a097..0000000 --- a/examples/RequestMeta.txt +++ /dev/null @@ -1,13 +0,0 @@ -struct RequestMeta { - uint64 completedRequests; - Custom.Datatype data; - address requestingContract; - uint72 adminFee; // in wei - address subscriptionOwner; - bytes32 flags; // 32 bytes of flags - uint96 availableBalance; // in wei. 0 if not specified. - uint64 subscriptionId; - uint64 initiatedRequests;// number of requests initiated by this contract - uint32 callbackGasLimit; - uint16 dataVersion; -} \ No newline at end of file diff --git a/main.go b/main.go index 17a4208..d5cf6c4 100644 --- a/main.go +++ b/main.go @@ -1,9 +1,11 @@ package main import ( + "encoding/json" "fmt" "log" "os" + "os/exec" "sort" "github.com/pkg/errors" @@ -22,17 +24,23 @@ func main() { } func newSpackApp() *cli.App { - var configFile string - var readFromFile, unpacked bool + var contract, solidityStruct, configFile string + var unpacked bool app := &cli.App{ Name: "Spack", Usage: "pack Solidity structs", Flags: []cli.Flag{ - &cli.BoolFlag{ - Name: "file", - Aliases: []string{"f"}, - Usage: "loads a Solidity struct from a file", - Destination: &readFromFile, + &cli.StringFlag{ + Name: "contract", + Aliases: []string{"c"}, + Usage: "loads all Solidity structs from a contract", + Destination: &contract, + }, + &cli.StringFlag{ + Name: "struct", + Aliases: []string{"s"}, + Usage: "loads a single Solidity struct from a contract", + Destination: &solidityStruct, }, &cli.BoolFlag{ Name: "unpacked", @@ -42,7 +50,7 @@ func newSpackApp() *cli.App { }, &cli.StringFlag{ Name: "config", - Aliases: []string{"c"}, + Aliases: []string{"cfg"}, Usage: "location of the config file", Destination: &configFile, }, @@ -53,7 +61,7 @@ func newSpackApp() *cli.App { Aliases: []string{"p"}, Usage: "packs a Solidity struct", Action: func(c *cli.Context) error { - appConfig, err := NewAppSettings(configFile, readFromFile, unpacked, c.Args()) + appConfig, err := NewAppSettings(configFile, contract, solidityStruct, unpacked, c.Args()) if err != nil { return err } @@ -61,7 +69,11 @@ func newSpackApp() *cli.App { if err != nil { return err } - fmt.Println(result) + + for _, r := range result { + fmt.Println(r) + } + return nil }, }, @@ -70,7 +82,7 @@ func newSpackApp() *cli.App { Aliases: []string{"c"}, Usage: "count the slots of the given struct", Action: func(c *cli.Context) error { - appConfig, err := NewAppSettings(configFile, readFromFile, unpacked, c.Args()) + appConfig, err := NewAppSettings(configFile, contract, solidityStruct, unpacked, c.Args()) if err != nil { return err } @@ -92,13 +104,15 @@ func newSpackApp() *cli.App { } type AppSettings struct { - readFromFile bool + outDir string + contract string + solidityStruct string unpacked bool printer *printer.Printer args cli.Args } -func NewAppSettings(configFile string, readFromFile, unpacked bool, args cli.Args) (AppSettings, error) { +func NewAppSettings(configFile string, contract string, solidityStruct string, unpacked bool, args cli.Args) (AppSettings, error) { configuration := config.GetDefaultConfig() // If the user specified a config file, load it if configFile != "" { @@ -115,60 +129,101 @@ func NewAppSettings(configFile string, readFromFile, unpacked bool, args cli.Arg } return AppSettings{ + outDir: configuration.OutDir, printer: &newPrinter, args: args, - readFromFile: readFromFile, + contract: contract, + solidityStruct: solidityStruct, unpacked: unpacked, }, nil } -func pack(settings *AppSettings) (string, error) { - solidityStruct, err := getStruct(settings) +func pack(settings *AppSettings) ([]string, error) { + solidityStructs, err := getStructs(settings) if err != nil { - return "", errors.Wrap(err, "Error parsing struct") - } - if settings.unpacked { - solidityStruct.StorageSlots = packStructCurrentFieldOrder(solidityStruct.Fields) - return settings.printer.PrintSolidityStruct(solidityStruct), nil + return []string{}, errors.Wrap(err, "Error parsing struct") } - solidityStruct.StorageSlots = packStructOptimal(solidityStruct.Fields) + var results []string + + for _, solidityStruct := range solidityStructs { + if settings.unpacked { + solidityStruct.StorageSlots = packStructCurrentFieldOrder(solidityStruct.Fields) + results = append(results, settings.printer.PrintSolidityStruct(solidityStruct)) + } - return settings.printer.PrintSolidityStruct(solidityStruct), nil + solidityStruct.StorageSlots = packStructOptimal(solidityStruct.Fields) + + results = append(results, settings.printer.PrintSolidityStruct(solidityStruct)) + } + + return results, nil } -func count(settings *AppSettings) (int, error) { - structDef, err := getStruct(settings) +func count(settings *AppSettings) ([]int, error) { + structDef, err := getStructs(settings) if err != nil { - return 0, errors.Wrap(err, "Error parsing struct") + return []int{}, errors.Wrap(err, "Error parsing struct") } - if settings.unpacked { - structDef.StorageSlots = packStructCurrentFieldOrder(structDef.Fields) - return len(structDef.StorageSlots), nil + + var slots []int + + for _, structDef := range structDef { + if settings.unpacked { + structDef.StorageSlots = packStructCurrentFieldOrder(structDef.Fields) + slots = append(slots, len(structDef.StorageSlots)) + } + + structDef.StorageSlots = packStructOptimal(structDef.Fields) + slots = append(slots, len(structDef.StorageSlots)) } - structDef.StorageSlots = packStructOptimal(structDef.Fields) - return len(structDef.StorageSlots), nil + return slots, nil } -func getStruct(settings *AppSettings) (solidity.Struct, error) { - input := settings.args.Get(0) - if input == "" { - return solidity.Struct{}, errors.New("No input specified") +func getStructs(settings *AppSettings) ([]solidity.Struct, error) { + cmd := exec.Command("forge", "build", "--ast") + + _, err := cmd.Output() + + if err != nil { + return []solidity.Struct{}, errors.Wrap(err, "Failed to run forge compile") + } + + if settings.contract == "" { + return []solidity.Struct{}, errors.New("No input specified") } - structString := input - if settings.readFromFile { - fileByes, err := os.ReadFile(input) - if err != nil { - panic(err) - } - structString = string(fileByes) + rawData, err := os.ReadFile((settings.outDir + settings.contract + ".sol/" + settings.contract + ".json")) + + if err != nil { + return []solidity.Struct{}, errors.Wrap(err, "Failed to read file") } - structDef, err := parser.ParseStruct(structString) + // Create a map to extract the AST data + var data map[string]interface{} + if err := json.Unmarshal(rawData, &data); err != nil { + return []solidity.Struct{}, errors.New("Failed to parse AST") + } + + // Extract the AST data + astData, err := json.Marshal(data["ast"]) + if err != nil { + return []solidity.Struct{}, errors.New("Failed to extract AST data") + } + + // Cast astData to SolidityAST + var ast parser.SolidityAST + err = json.Unmarshal(astData, &ast) + if err != nil { - return solidity.Struct{}, errors.Wrap(err, "Error parsing struct") + return []solidity.Struct{}, errors.New("Failed to cast AST data") } - return structDef, nil + + s, err := ast.ParseStructs(settings.solidityStruct) + if err != nil { + return []solidity.Struct{}, errors.Wrap(err, "Failed to parse structs") + } + + return s, nil } diff --git a/parser/parsing.go b/parser/parsing.go index cc3feeb..8f3d3bd 100644 --- a/parser/parsing.go +++ b/parser/parsing.go @@ -1,54 +1,115 @@ package parser import ( - "regexp" + "strconv" + "strings" "github.com/pkg/errors" - "github.com/rensr/spack/solidity" ) -func ParseStruct(structDefString string) (solidity.Struct, error) { - structName, err := parseStructName(structDefString) - if err != nil { - return solidity.Struct{}, err - } +type SolidityAST struct { + Nodes []ASTNode `json:"nodes"` +} - return solidity.Struct{ - Name: structName, - Fields: parseStructFields(structDefString), - }, nil +type ASTNode struct { + NodeType string `json:"nodeType"` + Name string `json:"name"` + Nodes []ContractNode `json:"nodes"` } -func parseStructName(structString string) (string, error) { - structNameRegex, err := regexp.Compile(`struct ([a-zA-Z0-9]*) {`) - if err != nil { - return "", err - } +type ContractNode struct { + NodeType string `json:"nodeType"` + Name string `json:"name"` + Nodes []Member `json:"members"` +} - nameMatch := structNameRegex.FindStringSubmatch(structString) - if len(nameMatch) < 2 { - return "", errors.New("could not find struct name") - } +type Member struct { + NodeType string `json:"nodeType"` + TypeName TypeName `json:"typeName"` + Name string `json:"name"` +} - return nameMatch[1], nil +type TypeName struct { + NodeType string `json:"nodeType"` + TypeDescriptions TypeDescription `json:"typeDescriptions"` } -func parseStructFields(structString string) []solidity.DataDef { - componentRegex := regexp.MustCompile(`\s*([a-zA-Z0-9\][.]+)\s+([a-zA-Z0-9_]+)\s*;[ \t]*(?://)?(.*)\n`) - matches := componentRegex.FindAllStringSubmatch(structString, -1) - - var fields []solidity.DataDef - for _, match := range matches { - dataType := solidity.DataType(match[1]) - newField := solidity.DataDef{ - Name: match[2], - Type: dataType, - Comment: match[3], - Size: dataType.Size(), +type TypeDescription struct { + Type string `json:"typeString"` +} + +// Parses the Solidity AST object and returns a list of structs +func (s *SolidityAST) ParseStructs(solidityStruct string) ([]solidity.Struct, error) { + var structs []solidity.Struct + for _, node := range s.Nodes { + if node.NodeType == "ContractDefinition" { + singleLookup := solidityStruct != "" + for _, contractNode := range node.Nodes { + if contractNode.NodeType == "StructDefinition" { + if singleLookup && contractNode.Name != solidityStruct { + continue + } + ss := solidity.Struct{Name: contractNode.Name} + for _, member := range contractNode.Nodes { + size, err := member.ComputeSize() + if err != nil { + return nil, errors.Wrap(err, "Failed to compute size") + } + if strings.Contains(member.TypeName.TypeDescriptions.Type, "struct") { + member.TypeName.TypeDescriptions.Type = strings.Trim(member.TypeName.TypeDescriptions.Type, "struct " + node.Name + ".") + } + + ss.Fields = append(ss.Fields, solidity.DataDef{Name: member.Name, Comment: "", Type: solidity.DataType(member.TypeName.TypeDescriptions.Type), Size: size}) + } + structs = append(structs, ss) + } + } + } } + return structs, nil +} - fields = append(fields, newField) +// Computes the size of a member +// For dynamic types such as strings, bytes and arrays, it returns 32 +func (m *Member) ComputeSize() (uint8, error) { + if m.NodeType == "UserDefinedTypeName" { + return 32, nil } - return fields -} + + switch m.TypeName.TypeDescriptions.Type { + case "bool": return 1, nil + case "address": return 20, nil + case "string": return 32, nil + case "bytes": return 32, nil + default: + if strings.Contains(m.TypeName.TypeDescriptions.Type, "[]") { + return 32, nil + } + if strings.Contains(m.TypeName.TypeDescriptions.Type, "uint") { + sizeString := strings.Trim(m.TypeName.TypeDescriptions.Type, "uint") + size, err := strconv.ParseUint(sizeString, 10, 64) + if err != nil { + return 0, err + } + return uint8(size) / 8, nil + } + if strings.Contains(m.TypeName.TypeDescriptions.Type, "int") { + sizeString := strings.Trim(m.TypeName.TypeDescriptions.Type, "int") + size, err := strconv.ParseUint(sizeString, 10, 64) + if err != nil { + return 0, err + } + return uint8(size) / 8, nil + } + if strings.Contains(m.TypeName.TypeDescriptions.Type, "bytes") { + sizeString := strings.Trim(m.TypeName.TypeDescriptions.Type, "bytes") + size, err := strconv.ParseUint(sizeString, 10, 64) + if err != nil { + return 0, err + } + return uint8(size), nil + } + return 32, nil + } +} \ No newline at end of file diff --git a/script/Counter.s.sol b/script/Counter.s.sol deleted file mode 100644 index cdc1fe9..0000000 --- a/script/Counter.s.sol +++ /dev/null @@ -1,19 +0,0 @@ -// SPDX-License-Identifier: UNLICENSED -pragma solidity ^0.8.13; - -import {Script, console} from "forge-std/Script.sol"; -import {Counter} from "../src/Counter.sol"; - -contract CounterScript is Script { - Counter public counter; - - function setUp() public {} - - function run() public { - vm.startBroadcast(); - - counter = new Counter(); - - vm.stopBroadcast(); - } -} diff --git a/src/Counter.sol b/src/Counter.sol deleted file mode 100644 index aded799..0000000 --- a/src/Counter.sol +++ /dev/null @@ -1,14 +0,0 @@ -// SPDX-License-Identifier: UNLICENSED -pragma solidity ^0.8.13; - -contract Counter { - uint256 public number; - - function setNumber(uint256 newNumber) public { - number = newNumber; - } - - function increment() public { - number++; - } -} diff --git a/src/Spack.sol b/src/Spack.sol new file mode 100644 index 0000000..e92e35a --- /dev/null +++ b/src/Spack.sol @@ -0,0 +1,23 @@ +// SPDX-License-Identifier: UNLICENSED +pragma solidity 0.8.24; + +contract Spack { + struct RequestMeta { + uint64 completedRequests; + DataType data; + address requestingContract; + uint72 adminFee; // in wei + address subscriptionOwner; + bytes32 flags; // 32 bytes of flags + uint96 availableBalance; // in wei. 0 if not specified. + uint64 subscriptionId; + uint64 initiatedRequests; // number of requests initiated by this contract + uint32 callbackGasLimit; + uint16 dataVersion; + } + + struct DataType { + uint96 timestamp; + address sender; + } +} diff --git a/test/Counter.t.sol b/test/Counter.t.sol deleted file mode 100644 index 54b724f..0000000 --- a/test/Counter.t.sol +++ /dev/null @@ -1,24 +0,0 @@ -// SPDX-License-Identifier: UNLICENSED -pragma solidity ^0.8.13; - -import {Test, console} from "forge-std/Test.sol"; -import {Counter} from "../src/Counter.sol"; - -contract CounterTest is Test { - Counter public counter; - - function setUp() public { - counter = new Counter(); - counter.setNumber(0); - } - - function test_Increment() public { - counter.increment(); - assertEq(counter.number(), 1); - } - - function testFuzz_SetNumber(uint256 x) public { - counter.setNumber(x); - assertEq(counter.number(), x); - } -}