diff --git a/.dockerignore b/.dockerignore index e3760b68cb..b5efcaad57 100644 --- a/.dockerignore +++ b/.dockerignore @@ -1,3 +1,5 @@ build/ llvm-*/ +.github +.circleci diff --git a/.github/workflows/build-macos.yml b/.github/workflows/build-macos.yml index 135cedb853..3151ac5637 100644 --- a/.github/workflows/build-macos.yml +++ b/.github/workflows/build-macos.yml @@ -29,11 +29,11 @@ jobs: with: go-version: '1.20' cache: true - - name: Cache LLVM source - uses: actions/cache@v3 + - name: Restore LLVM source cache + uses: actions/cache/restore@v3 id: cache-llvm-source with: - key: llvm-source-15-macos-v2 + key: llvm-source-15-macos-v3 path: | llvm-project/clang/lib/Headers llvm-project/clang/include @@ -43,11 +43,22 @@ jobs: - name: Download LLVM source if: steps.cache-llvm-source.outputs.cache-hit != 'true' run: make llvm-source - - name: Cache LLVM build - uses: actions/cache@v3 + - name: Save LLVM source cache + uses: actions/cache/save@v3 + if: steps.cache-llvm-source.outputs.cache-hit != 'true' + with: + key: ${{ steps.cache-llvm-source.outputs.cache-primary-key }} + path: | + llvm-project/clang/lib/Headers + llvm-project/clang/include + llvm-project/compiler-rt + llvm-project/lld/include + llvm-project/llvm/include + - name: Restore LLVM build cache + uses: actions/cache/restore@v3 id: cache-llvm-build with: - key: llvm-build-15-macos-v3 + key: llvm-build-15-macos-v4 path: llvm-build - name: Build LLVM if: steps.cache-llvm-build.outputs.cache-hit != 'true' @@ -61,6 +72,12 @@ jobs: # build! make llvm-build find llvm-build -name CMakeFiles -prune -exec rm -r '{}' \; + - name: Save LLVM build cache + uses: actions/cache/save@v3 + if: steps.cache-llvm-build.outputs.cache-hit != 'true' + with: + key: ${{ steps.cache-llvm-build.outputs.cache-primary-key }} + path: llvm-build - name: Cache wasi-libc sysroot uses: actions/cache@v3 id: cache-wasi-libc @@ -70,9 +87,11 @@ jobs: - name: Build wasi-libc if: steps.cache-wasi-libc.outputs.cache-hit != 'true' run: make wasi-libc + - name: make gen-device + run: make -j3 gen-device - name: Test TinyGo shell: bash - run: make test GOTESTFLAGS="-v -short" + run: make test GOTESTFLAGS="-short" - name: Build TinyGo release tarball run: make release -j3 - name: Test stdlib packages diff --git a/.github/workflows/docker.yml b/.github/workflows/docker.yml index fe6a82adbd..8980436fd4 100644 --- a/.github/workflows/docker.yml +++ b/.github/workflows/docker.yml @@ -31,7 +31,7 @@ jobs: with: images: | tinygo/tinygo-dev - ghcr.io/${{ github.repository }}/tinygo-dev + ghcr.io/${{ github.repository_owner }}/tinygo-dev tags: | type=sha,format=long type=raw,value=latest @@ -53,6 +53,7 @@ jobs: push: true tags: ${{ steps.meta.outputs.tags }} labels: ${{ steps.meta.outputs.labels }} + build-contexts: tinygo-llvm-build=docker-image://tinygo/llvm-15 cache-from: type=gha cache-to: type=gha,mode=max - name: Trigger Drivers repo build on Github Actions @@ -69,21 +70,31 @@ jobs: -H "Accept: application/vnd.github.v3+json" \ https://api.github.com/repos/tinygo-org/bluetooth/actions/workflows/linux.yml/dispatches \ -d '{"ref": "dev"}' - - name: Trigger TinyFS repo build on CircleCI + - name: Trigger TinyFS repo build on Github Actions run: | - curl --location --request POST 'https://circleci.com/api/v2/project/github/tinygo-org/tinyfs/pipeline' \ - --header 'Content-Type: application/json' \ - -d '{"branch": "dev"}' \ - -u "${{ secrets.CIRCLECI_API_TOKEN }}" - - name: Trigger TinyFont repo build on CircleCI + curl -X POST \ + -H "Authorization: Bearer ${{secrets.GHA_ACCESS_TOKEN}}" \ + -H "Accept: application/vnd.github.v3+json" \ + https://api.github.com/repos/tinygo-org/tinyfs/actions/workflows/build.yml/dispatches \ + -d '{"ref": "dev"}' + - name: Trigger TinyFont repo build on Github Actions run: | - curl --location --request POST 'https://circleci.com/api/v2/project/github/tinygo-org/tinyfont/pipeline' \ - --header 'Content-Type: application/json' \ - -d '{"branch": "dev"}' \ - -u "${{ secrets.CIRCLECI_API_TOKEN }}" - - name: Trigger TinyDraw repo build on CircleCI + curl -X POST \ + -H "Authorization: Bearer ${{secrets.GHA_ACCESS_TOKEN}}" \ + -H "Accept: application/vnd.github.v3+json" \ + https://api.github.com/repos/tinygo-org/tinyfont/actions/workflows/build.yml/dispatches \ + -d '{"ref": "dev"}' + - name: Trigger TinyDraw repo build on Github Actions run: | - curl --location --request POST 'https://circleci.com/api/v2/project/github/tinygo-org/tinydraw/pipeline' \ - --header 'Content-Type: application/json' \ - -d '{"branch": "dev"}' \ - -u "${{ secrets.CIRCLECI_API_TOKEN }}" + curl -X POST \ + -H "Authorization: Bearer ${{secrets.GHA_ACCESS_TOKEN}}" \ + -H "Accept: application/vnd.github.v3+json" \ + https://api.github.com/repos/tinygo-org/tinydraw/actions/workflows/build.yml/dispatches \ + -d '{"ref": "dev"}' + - name: Trigger TinyTerm repo build on Github Actions + run: | + curl -X POST \ + -H "Authorization: Bearer ${{secrets.GHA_ACCESS_TOKEN}}" \ + -H "Accept: application/vnd.github.v3+json" \ + https://api.github.com/repos/tinygo-org/tinyterm/actions/workflows/build.yml/dispatches \ + -d '{"ref": "dev"}' diff --git a/.github/workflows/linux.yml b/.github/workflows/linux.yml index 413ffff967..bd020b77e0 100644 --- a/.github/workflows/linux.yml +++ b/.github/workflows/linux.yml @@ -39,11 +39,11 @@ jobs: path: | ~/.cache/go-build ~/go/pkg/mod - - name: Cache LLVM source - uses: actions/cache@v3 + - name: Restore LLVM source cache + uses: actions/cache/restore@v3 id: cache-llvm-source with: - key: llvm-source-15-linux-alpine-v2 + key: llvm-source-15-linux-alpine-v3 path: | llvm-project/clang/lib/Headers llvm-project/clang/include @@ -53,11 +53,22 @@ jobs: - name: Download LLVM source if: steps.cache-llvm-source.outputs.cache-hit != 'true' run: make llvm-source - - name: Cache LLVM build - uses: actions/cache@v3 + - name: Save LLVM source cache + uses: actions/cache/save@v3 + if: steps.cache-llvm-source.outputs.cache-hit != 'true' + with: + key: ${{ steps.cache-llvm-source.outputs.cache-primary-key }} + path: | + llvm-project/clang/lib/Headers + llvm-project/clang/include + llvm-project/compiler-rt + llvm-project/lld/include + llvm-project/llvm/include + - name: Restore LLVM build cache + uses: actions/cache/restore@v3 id: cache-llvm-build with: - key: llvm-build-15-linux-alpine-v3 + key: llvm-build-15-linux-alpine-v4 path: llvm-build - name: Build LLVM if: steps.cache-llvm-build.outputs.cache-hit != 'true' @@ -71,6 +82,12 @@ jobs: make llvm-build # Remove unnecessary object files (to reduce cache size). find llvm-build -name CMakeFiles -prune -exec rm -r '{}' \; + - name: Save LLVM build cache + uses: actions/cache/save@v3 + if: steps.cache-llvm-build.outputs.cache-hit != 'true' + with: + key: ${{ steps.cache-llvm-build.outputs.cache-primary-key }} + path: llvm-build - name: Cache Binaryen uses: actions/cache@v3 id: cache-binaryen @@ -122,7 +139,9 @@ jobs: cache: true - name: Install wasmtime run: | - curl https://wasmtime.dev/install.sh -sSf | bash + mkdir -p $HOME/.wasmtime $HOME/.wasmtime/bin + curl https://github.com/bytecodealliance/wasmtime/releases/download/v5.0.0/wasmtime-v5.0.0-x86_64-linux.tar.xz -o wasmtime-v5.0.0-x86_64-linux.tar.xz -SfL + tar -C $HOME/.wasmtime/bin --wildcards -xf wasmtime-v5.0.0-x86_64-linux.tar.xz --strip-components=1 wasmtime-v5.0.0-x86_64-linux/* echo "$HOME/.wasmtime/bin" >> $GITHUB_PATH - name: Download release artifact uses: actions/download-artifact@v3 @@ -166,13 +185,15 @@ jobs: node-version: '14' - name: Install wasmtime run: | - curl https://wasmtime.dev/install.sh -sSf | bash + mkdir -p $HOME/.wasmtime $HOME/.wasmtime/bin + curl -L https://github.com/bytecodealliance/wasmtime/releases/download/v5.0.0/wasmtime-v5.0.0-x86_64-linux.tar.xz -o wasmtime-v5.0.0-x86_64-linux.tar.xz -SfL + tar -C $HOME/.wasmtime/bin --wildcards -xf wasmtime-v5.0.0-x86_64-linux.tar.xz --strip-components=1 wasmtime-v5.0.0-x86_64-linux/* echo "$HOME/.wasmtime/bin" >> $GITHUB_PATH - - name: Cache LLVM source - uses: actions/cache@v3 + - name: Restore LLVM source cache + uses: actions/cache/restore@v3 id: cache-llvm-source with: - key: llvm-source-15-linux-asserts-v2 + key: llvm-source-15-linux-asserts-v3 path: | llvm-project/clang/lib/Headers llvm-project/clang/include @@ -182,11 +203,22 @@ jobs: - name: Download LLVM source if: steps.cache-llvm-source.outputs.cache-hit != 'true' run: make llvm-source - - name: Cache LLVM build - uses: actions/cache@v3 + - name: Save LLVM source cache + uses: actions/cache/save@v3 + if: steps.cache-llvm-source.outputs.cache-hit != 'true' + with: + key: ${{ steps.cache-llvm-source.outputs.cache-primary-key }} + path: | + llvm-project/clang/lib/Headers + llvm-project/clang/include + llvm-project/compiler-rt + llvm-project/lld/include + llvm-project/llvm/include + - name: Restore LLVM build cache + uses: actions/cache/restore@v3 id: cache-llvm-build with: - key: llvm-build-15-linux-asserts-v3 + key: llvm-build-15-linux-asserts-v4 path: llvm-build - name: Build LLVM if: steps.cache-llvm-build.outputs.cache-hit != 'true' @@ -198,6 +230,12 @@ jobs: make llvm-build ASSERT=1 # Remove unnecessary object files (to reduce cache size). find llvm-build -name CMakeFiles -prune -exec rm -r '{}' \; + - name: Save LLVM build cache + uses: actions/cache/save@v3 + if: steps.cache-llvm-build.outputs.cache-hit != 'true' + with: + key: ${{ steps.cache-llvm-build.outputs.cache-primary-key }} + path: llvm-build - name: Cache Binaryen uses: actions/cache@v3 id: cache-binaryen @@ -237,7 +275,7 @@ jobs: # in that process to avoid doing lots of duplicate work and to avoid # complications around precompiled libraries such as compiler-rt shipped as # part of the release tarball. - runs-on: ubuntu-18.04 + runs-on: ubuntu-20.04 needs: build-linux steps: - name: Checkout @@ -254,11 +292,11 @@ jobs: with: go-version: '1.20' cache: true - - name: Cache LLVM source - uses: actions/cache@v3 + - name: Restore LLVM source cache + uses: actions/cache/restore@v3 id: cache-llvm-source with: - key: llvm-source-15-linux-v2 + key: llvm-source-15-linux-v3 path: | llvm-project/clang/lib/Headers llvm-project/clang/include @@ -268,11 +306,22 @@ jobs: - name: Download LLVM source if: steps.cache-llvm-source.outputs.cache-hit != 'true' run: make llvm-source - - name: Cache LLVM build - uses: actions/cache@v3 + - name: Save LLVM source cache + uses: actions/cache/save@v3 + if: steps.cache-llvm-source.outputs.cache-hit != 'true' + with: + key: ${{ steps.cache-llvm-source.outputs.cache-primary-key }} + path: | + llvm-project/clang/lib/Headers + llvm-project/clang/include + llvm-project/compiler-rt + llvm-project/lld/include + llvm-project/llvm/include + - name: Restore LLVM build cache + uses: actions/cache/restore@v3 id: cache-llvm-build with: - key: llvm-build-15-linux-arm-v3 + key: llvm-build-15-linux-arm-v4 path: llvm-build - name: Build LLVM if: steps.cache-llvm-build.outputs.cache-hit != 'true' @@ -286,6 +335,12 @@ jobs: make llvm-build CROSS=arm-linux-gnueabihf # Remove unnecessary object files (to reduce cache size). find llvm-build -name CMakeFiles -prune -exec rm -r '{}' \; + - name: Save LLVM build cache + uses: actions/cache/save@v3 + if: steps.cache-llvm-build.outputs.cache-hit != 'true' + with: + key: ${{ steps.cache-llvm-build.outputs.cache-primary-key }} + path: llvm-build - name: Cache Binaryen uses: actions/cache@v3 id: cache-binaryen @@ -336,7 +391,7 @@ jobs: # in that process to avoid doing lots of duplicate work and to avoid # complications around precompiled libraries such as compiler-rt shipped as # part of the release tarball. - runs-on: ubuntu-18.04 + runs-on: ubuntu-20.04 needs: build-linux steps: - name: Checkout @@ -354,11 +409,11 @@ jobs: with: go-version: '1.20' cache: true - - name: Cache LLVM source - uses: actions/cache@v3 + - name: Restore LLVM source cache + uses: actions/cache/restore@v3 id: cache-llvm-source with: - key: llvm-source-15-linux-v2 + key: llvm-source-15-linux-v3 path: | llvm-project/clang/lib/Headers llvm-project/clang/include @@ -368,11 +423,22 @@ jobs: - name: Download LLVM source if: steps.cache-llvm-source.outputs.cache-hit != 'true' run: make llvm-source - - name: Cache LLVM build - uses: actions/cache@v3 + - name: Save LLVM source cache + uses: actions/cache/save@v3 + if: steps.cache-llvm-source.outputs.cache-hit != 'true' + with: + key: ${{ steps.cache-llvm-source.outputs.cache-primary-key }} + path: | + llvm-project/clang/lib/Headers + llvm-project/clang/include + llvm-project/compiler-rt + llvm-project/lld/include + llvm-project/llvm/include + - name: Restore LLVM build cache + uses: actions/cache/restore@v3 id: cache-llvm-build with: - key: llvm-build-15-linux-arm64-v3 + key: llvm-build-15-linux-arm64-v4 path: llvm-build - name: Build LLVM if: steps.cache-llvm-build.outputs.cache-hit != 'true' @@ -384,6 +450,12 @@ jobs: make llvm-build CROSS=aarch64-linux-gnu # Remove unnecessary object files (to reduce cache size). find llvm-build -name CMakeFiles -prune -exec rm -r '{}' \; + - name: Save LLVM build cache + uses: actions/cache/save@v3 + if: steps.cache-llvm-build.outputs.cache-hit != 'true' + with: + key: ${{ steps.cache-llvm-build.outputs.cache-primary-key }} + path: llvm-build - name: Cache Binaryen uses: actions/cache@v3 id: cache-binaryen diff --git a/.github/workflows/windows.yml b/.github/workflows/windows.yml index dd878fa16b..fd46e8d874 100644 --- a/.github/workflows/windows.yml +++ b/.github/workflows/windows.yml @@ -15,6 +15,12 @@ jobs: build-windows: runs-on: windows-2022 steps: + - name: Configure pagefile + uses: al-cheb/configure-pagefile-action@v1.3 + with: + minimum-size: 8GB + maximum-size: 24GB + disk-root: "C:" - uses: brechtm/setup-scoop@v2 with: scoop_update: 'false' @@ -31,11 +37,11 @@ jobs: with: go-version: '1.20' cache: true - - name: Cache LLVM source - uses: actions/cache@v3 + - name: Restore cached LLVM source + uses: actions/cache/restore@v3 id: cache-llvm-source with: - key: llvm-source-15-windows-v2 + key: llvm-source-15-windows-v4 path: | llvm-project/clang/lib/Headers llvm-project/clang/include @@ -45,11 +51,22 @@ jobs: - name: Download LLVM source if: steps.cache-llvm-source.outputs.cache-hit != 'true' run: make llvm-source - - name: Cache LLVM build - uses: actions/cache@v3 + - name: Save cached LLVM source + uses: actions/cache/save@v3 + if: steps.cache-llvm-source.outputs.cache-hit != 'true' + with: + key: ${{ steps.cache-llvm-source.outputs.cache-primary-key }} + path: | + llvm-project/clang/lib/Headers + llvm-project/clang/include + llvm-project/compiler-rt + llvm-project/lld/include + llvm-project/llvm/include + - name: Restore cached LLVM build + uses: actions/cache/restore@v3 id: cache-llvm-build with: - key: llvm-build-15-windows-v3 + key: llvm-build-15-windows-v6 path: llvm-build - name: Build LLVM if: steps.cache-llvm-build.outputs.cache-hit != 'true' @@ -62,6 +79,12 @@ jobs: make llvm-build CCACHE=OFF # Remove unnecessary object files (to reduce cache size). find llvm-build -name CMakeFiles -prune -exec rm -r '{}' \; + - name: Save cached LLVM build + uses: actions/cache/save@v3 + if: steps.cache-llvm-build.outputs.cache-hit != 'true' + with: + key: ${{ steps.cache-llvm-build.outputs.cache-primary-key }} + path: llvm-build - name: Cache wasi-libc sysroot uses: actions/cache@v3 id: cache-wasi-libc @@ -74,9 +97,11 @@ jobs: - name: Install wasmtime run: | scoop install wasmtime + - name: make gen-device + run: make -j3 gen-device - name: Test TinyGo shell: bash - run: make test GOTESTFLAGS="-v -short" + run: make test GOTESTFLAGS="-short" - name: Build TinyGo release tarball shell: bash run: make build/release -j4 @@ -100,6 +125,12 @@ jobs: runs-on: windows-2022 needs: build-windows steps: + - name: Configure pagefile + uses: al-cheb/configure-pagefile-action@v1.3 + with: + minimum-size: 8GB + maximum-size: 24GB + disk-root: "C:" - uses: brechtm/setup-scoop@v2 with: scoop_update: 'false' @@ -131,6 +162,12 @@ jobs: runs-on: windows-2022 needs: build-windows steps: + - name: Configure pagefile + uses: al-cheb/configure-pagefile-action@v1.3 + with: + minimum-size: 8GB + maximum-size: 24GB + disk-root: "C:" - name: Checkout uses: actions/checkout@v3 - name: Install Go @@ -154,6 +191,12 @@ jobs: runs-on: windows-2022 needs: build-windows steps: + - name: Configure pagefile + uses: al-cheb/configure-pagefile-action@v1.3 + with: + minimum-size: 8GB + maximum-size: 24GB + disk-root: "C:" - uses: brechtm/setup-scoop@v2 with: scoop_update: 'false' diff --git a/Makefile b/Makefile index 53cc98d2b8..f119b61c2c 100644 --- a/Makefile +++ b/Makefile @@ -30,7 +30,7 @@ GO ?= go export GOROOT = $(shell $(GO) env GOROOT) # Flags to pass to go test. -GOTESTFLAGS ?= -v +GOTESTFLAGS ?= # md5sum binary MD5SUM = md5sum @@ -120,10 +120,8 @@ ifeq ($(OS),Windows_NT) START_GROUP = -Wl,--start-group END_GROUP = -Wl,--end-group - # LLVM compiled using MinGW on Windows appears to have problems with threads. - # Without this flag, linking results in errors like these: - # libLLVMSupport.a(Threading.cpp.obj):Threading.cpp:(.text+0x55): undefined reference to `std::thread::hardware_concurrency()' - LLVM_OPTION += -DLLVM_ENABLE_THREADS=OFF -DLLVM_ENABLE_PIC=OFF + # PIC needs to be disabled for libclang to work. + LLVM_OPTION += -DLLVM_ENABLE_PIC=OFF CGO_CPPFLAGS += -DCINDEX_NO_EXPORTS CGO_LDFLAGS += -static -static-libgcc -static-libstdc++ @@ -310,6 +308,7 @@ TEST_PACKAGES_FAST = \ math \ math/cmplx \ net \ + net/http/internal \ net/http/internal/ascii \ net/mail \ os \ @@ -474,6 +473,8 @@ smoketest: @$(MD5SUM) test.hex $(TINYGO) build -size short -o test.hex -target=pca10040 examples/pininterrupt @$(MD5SUM) test.hex + $(TINYGO) build -size short -o test.hex -target=nano-rp2040 examples/rtcinterrupt + @$(MD5SUM) test.hex $(TINYGO) build -size short -o test.hex -target=pca10040 examples/serial @$(MD5SUM) test.hex $(TINYGO) build -size short -o test.hex -target=pca10040 examples/systick diff --git a/builder/build.go b/builder/build.go index 78a9dc6629..fcd7a2f073 100644 --- a/builder/build.go +++ b/builder/build.go @@ -169,6 +169,7 @@ func Build(pkgName, outpath, tmpdir string, config *compileopts.Config) (BuildRe CodeModel: config.CodeModel(), RelocationModel: config.RelocationModel(), SizeLevel: sizeLevel, + TinyGoVersion: goenv.Version, Scheduler: config.Scheduler(), AutomaticStackSize: config.AutomaticStackSize(), @@ -190,6 +191,9 @@ func Build(pkgName, outpath, tmpdir string, config *compileopts.Config) (BuildRe lprogram, err := loader.Load(config, pkgName, config.ClangHeaders, types.Config{ Sizes: compiler.Sizes(machine), }) + if err != nil { + return BuildResult{}, err + } result := BuildResult{ ModuleRoot: lprogram.MainPkg().Module.Dir, MainDir: lprogram.MainPkg().Dir, @@ -199,9 +203,6 @@ func Build(pkgName, outpath, tmpdir string, config *compileopts.Config) (BuildRe // If there is no module root, just the regular root. result.ModuleRoot = lprogram.MainPkg().Root } - if err != nil { // failed to load AST - return result, err - } err = lprogram.Parse() if err != nil { return result, err @@ -305,7 +306,6 @@ func Build(pkgName, outpath, tmpdir string, config *compileopts.Config) (BuildRe actionID := packageAction{ ImportPath: pkg.ImportPath, CompilerBuildID: string(compilerBuildID), - TinyGoVersion: goenv.Version, LLVMVersion: llvm.Version, Config: compilerConfig, CFlags: pkg.CFlags, @@ -594,12 +594,7 @@ func Build(pkgName, outpath, tmpdir string, config *compileopts.Config) (BuildRe defer llvmBuf.Dispose() return result, os.WriteFile(outpath, llvmBuf.Bytes(), 0666) case ".bc": - var buf llvm.MemoryBuffer - if config.UseThinLTO() { - buf = llvm.WriteThinLTOBitcodeToMemoryBuffer(mod) - } else { - buf = llvm.WriteBitcodeToMemoryBuffer(mod) - } + buf := llvm.WriteThinLTOBitcodeToMemoryBuffer(mod) defer buf.Dispose() return result, os.WriteFile(outpath, buf.Bytes(), 0666) case ".ll": @@ -621,16 +616,7 @@ func Build(pkgName, outpath, tmpdir string, config *compileopts.Config) (BuildRe dependencies: []*compileJob{programJob}, result: objfile, run: func(*compileJob) error { - var llvmBuf llvm.MemoryBuffer - if config.UseThinLTO() { - llvmBuf = llvm.WriteThinLTOBitcodeToMemoryBuffer(mod) - } else { - var err error - llvmBuf, err = machine.EmitToMemoryBuffer(mod, llvm.ObjectFile) - if err != nil { - return err - } - } + llvmBuf := llvm.WriteThinLTOBitcodeToMemoryBuffer(mod) defer llvmBuf.Dispose() return os.WriteFile(objfile, llvmBuf.Bytes(), 0666) }, @@ -664,7 +650,7 @@ func Build(pkgName, outpath, tmpdir string, config *compileopts.Config) (BuildRe job := &compileJob{ description: "compile extra file " + path, run: func(job *compileJob) error { - result, err := compileAndCacheCFile(abspath, tmpdir, config.CFlags(), config.UseThinLTO(), config.Options.PrintCommands) + result, err := compileAndCacheCFile(abspath, tmpdir, config.CFlags(), config.Options.PrintCommands) job.result = result return err }, @@ -682,7 +668,7 @@ func Build(pkgName, outpath, tmpdir string, config *compileopts.Config) (BuildRe job := &compileJob{ description: "compile CGo file " + abspath, run: func(job *compileJob) error { - result, err := compileAndCacheCFile(abspath, tmpdir, pkg.CFlags, config.UseThinLTO(), config.Options.PrintCommands) + result, err := compileAndCacheCFile(abspath, tmpdir, pkg.CFlags, config.Options.PrintCommands) job.result = result return err }, @@ -741,36 +727,34 @@ func Build(pkgName, outpath, tmpdir string, config *compileopts.Config) (BuildRe } ldflags = append(ldflags, dependency.result) } - if config.UseThinLTO() { - ldflags = append(ldflags, "-mllvm", "-mcpu="+config.CPU()) - if config.GOOS() == "windows" { - // Options for the MinGW wrapper for the lld COFF linker. - ldflags = append(ldflags, - "-Xlink=/opt:lldlto="+strconv.Itoa(optLevel), - "--thinlto-cache-dir="+filepath.Join(cacheDir, "thinlto")) - } else if config.GOOS() == "darwin" { - // Options for the ld64-compatible lld linker. - ldflags = append(ldflags, - "--lto-O"+strconv.Itoa(optLevel), - "-cache_path_lto", filepath.Join(cacheDir, "thinlto")) - } else { - // Options for the ELF linker. - ldflags = append(ldflags, - "--lto-O"+strconv.Itoa(optLevel), - "--thinlto-cache-dir="+filepath.Join(cacheDir, "thinlto"), - ) - } - if config.CodeModel() != "default" { - ldflags = append(ldflags, - "-mllvm", "-code-model="+config.CodeModel()) - } - if sizeLevel >= 2 { - // Workaround with roughly the same effect as - // https://reviews.llvm.org/D119342. - // Can hopefully be removed in LLVM 15. - ldflags = append(ldflags, - "-mllvm", "--rotation-max-header-size=0") - } + ldflags = append(ldflags, "-mllvm", "-mcpu="+config.CPU()) + if config.GOOS() == "windows" { + // Options for the MinGW wrapper for the lld COFF linker. + ldflags = append(ldflags, + "-Xlink=/opt:lldlto="+strconv.Itoa(optLevel), + "--thinlto-cache-dir="+filepath.Join(cacheDir, "thinlto")) + } else if config.GOOS() == "darwin" { + // Options for the ld64-compatible lld linker. + ldflags = append(ldflags, + "--lto-O"+strconv.Itoa(optLevel), + "-cache_path_lto", filepath.Join(cacheDir, "thinlto")) + } else { + // Options for the ELF linker. + ldflags = append(ldflags, + "--lto-O"+strconv.Itoa(optLevel), + "--thinlto-cache-dir="+filepath.Join(cacheDir, "thinlto"), + ) + } + if config.CodeModel() != "default" { + ldflags = append(ldflags, + "-mllvm", "-code-model="+config.CodeModel()) + } + if sizeLevel >= 2 { + // Workaround with roughly the same effect as + // https://reviews.llvm.org/D119342. + // Can hopefully be removed in LLVM 15. + ldflags = append(ldflags, + "-mllvm", "--rotation-max-header-size=0") } if config.Options.PrintCommands != nil { config.Options.PrintCommands(config.Target.Linker, ldflags...) @@ -1069,10 +1053,6 @@ func optimizeProgram(mod llvm.Module, config *compileopts.Config) error { } } - if config.GOOS() != "darwin" && !config.UseThinLTO() { - transform.ApplyFunctionSections(mod) // -ffunction-sections - } - // Insert values from -ldflags="-X ..." into the IR. err = setGlobalValues(mod, config.Options.GlobalValues) if err != nil { diff --git a/builder/cc.go b/builder/cc.go index 080ef2bff1..b2cc739e3b 100644 --- a/builder/cc.go +++ b/builder/cc.go @@ -56,7 +56,7 @@ import ( // depfile but without invalidating its name. For this reason, the depfile is // written on each new compilation (even when it seems unnecessary). However, it // could in rare cases lead to a stale file fetched from the cache. -func compileAndCacheCFile(abspath, tmpdir string, cflags []string, thinlto bool, printCommands func(string, ...string)) (string, error) { +func compileAndCacheCFile(abspath, tmpdir string, cflags []string, printCommands func(string, ...string)) (string, error) { // Hash input file. fileHash, err := hashFile(abspath) if err != nil { @@ -67,11 +67,6 @@ func compileAndCacheCFile(abspath, tmpdir string, cflags []string, thinlto bool, unlock := lock(filepath.Join(goenv.Get("GOCACHE"), fileHash+".c.lock")) defer unlock() - ext := ".o" - if thinlto { - ext = ".bc" - } - // Create cache key for the dependencies file. buf, err := json.Marshal(struct { Path string @@ -104,7 +99,7 @@ func compileAndCacheCFile(abspath, tmpdir string, cflags []string, thinlto bool, } // Obtain hashes of all the files listed as a dependency. - outpath, err := makeCFileCachePath(dependencies, depfileNameHash, ext) + outpath, err := makeCFileCachePath(dependencies, depfileNameHash) if err == nil { if _, err := os.Stat(outpath); err == nil { return outpath, nil @@ -117,7 +112,7 @@ func compileAndCacheCFile(abspath, tmpdir string, cflags []string, thinlto bool, return "", err } - objTmpFile, err := os.CreateTemp(goenv.Get("GOCACHE"), "tmp-*"+ext) + objTmpFile, err := os.CreateTemp(goenv.Get("GOCACHE"), "tmp-*.bc") if err != nil { return "", err } @@ -127,11 +122,8 @@ func compileAndCacheCFile(abspath, tmpdir string, cflags []string, thinlto bool, return "", err } depTmpFile.Close() - flags := append([]string{}, cflags...) // copy cflags - flags = append(flags, "-MD", "-MV", "-MTdeps", "-MF", depTmpFile.Name()) // autogenerate dependencies - if thinlto { - flags = append(flags, "-flto=thin") - } + flags := append([]string{}, cflags...) // copy cflags + flags = append(flags, "-MD", "-MV", "-MTdeps", "-MF", depTmpFile.Name(), "-flto=thin") // autogenerate dependencies flags = append(flags, "-c", "-o", objTmpFile.Name(), abspath) if strings.ToLower(filepath.Ext(abspath)) == ".s" { // If this is an assembly file (.s or .S, lowercase or uppercase), then @@ -189,7 +181,7 @@ func compileAndCacheCFile(abspath, tmpdir string, cflags []string, thinlto bool, } // Move temporary object file to final location. - outpath, err := makeCFileCachePath(dependencySlice, depfileNameHash, ext) + outpath, err := makeCFileCachePath(dependencySlice, depfileNameHash) if err != nil { return "", err } @@ -204,7 +196,7 @@ func compileAndCacheCFile(abspath, tmpdir string, cflags []string, thinlto bool, // Create a cache path (a path in GOCACHE) to store the output of a compiler // job. This path is based on the dep file name (which is a hash of metadata // including compiler flags) and the hash of all input files in the paths slice. -func makeCFileCachePath(paths []string, depfileNameHash, ext string) (string, error) { +func makeCFileCachePath(paths []string, depfileNameHash string) (string, error) { // Hash all input files. fileHashes := make(map[string]string, len(paths)) for _, path := range paths { @@ -229,7 +221,7 @@ func makeCFileCachePath(paths []string, depfileNameHash, ext string) (string, er outFileNameBuf := sha512.Sum512_224(buf) cacheKey := hex.EncodeToString(outFileNameBuf[:]) - outpath := filepath.Join(goenv.Get("GOCACHE"), "obj-"+cacheKey+ext) + outpath := filepath.Join(goenv.Get("GOCACHE"), "obj-"+cacheKey+".bc") return outpath, nil } diff --git a/builder/lld.cpp b/builder/lld.cpp index c7688f0c24..6cecbebe88 100644 --- a/builder/lld.cpp +++ b/builder/lld.cpp @@ -3,25 +3,39 @@ // This file provides C wrappers for liblld. #include +#include extern "C" { +static void configure() { +#if _WIN64 + // This is a hack to work around a hang in the LLD linker on Windows, with + // -DLLVM_ENABLE_THREADS=ON. It has a similar effect as the -threads=1 + // linker flag, but with support for the COFF linker. + llvm::parallel::strategy = llvm::hardware_concurrency(1); +#endif +} + bool tinygo_link_elf(int argc, char **argv) { + configure(); std::vector args(argv, argv + argc); return lld::elf::link(args, llvm::outs(), llvm::errs(), false, false); } bool tinygo_link_macho(int argc, char **argv) { + configure(); std::vector args(argv, argv + argc); return lld::macho::link(args, llvm::outs(), llvm::errs(), false, false); } bool tinygo_link_mingw(int argc, char **argv) { + configure(); std::vector args(argv, argv + argc); return lld::mingw::link(args, llvm::outs(), llvm::errs(), false, false); } bool tinygo_link_wasm(int argc, char **argv) { + configure(); std::vector args(argv, argv + argc); return lld::wasm::link(args, llvm::outs(), llvm::errs(), false, false); } diff --git a/builder/sizes.go b/builder/sizes.go index e55970f24b..caa3ca33f4 100644 --- a/builder/sizes.go +++ b/builder/sizes.go @@ -75,6 +75,7 @@ func (ps *packageSize) RAM() uint64 { type addressLine struct { Address uint64 Length uint64 // length of this chunk + Align uint64 // (maximum) alignment of this line File string // file path as stored in DWARF IsVariable bool // true if this is a variable (or constant), false if it is code } @@ -86,6 +87,7 @@ type memorySection struct { Type memoryType Address uint64 Size uint64 + Align uint64 } type memoryType int @@ -117,17 +119,13 @@ var ( // alloc: heap allocations during init interpretation // pack: data created when storing a constant in an interface for example // string: buffer behind strings - packageSymbolRegexp = regexp.MustCompile(`\$(alloc|embedfsfiles|embedfsslice|embedslice|pack|string)(\.[0-9]+)?$`) - - // Reflect sidetables. Created by the reflect lowering pass. - // See src/reflect/sidetables.go. - reflectDataRegexp = regexp.MustCompile(`^reflect\.[a-zA-Z]+Sidetable$`) + packageSymbolRegexp = regexp.MustCompile(`\$(alloc|pack|string)(\.[0-9]+)?$`) ) // readProgramSizeFromDWARF reads the source location for each line of code and // each variable in the program, as far as this is stored in the DWARF debug // information. -func readProgramSizeFromDWARF(data *dwarf.Data, codeOffset uint64, skipTombstone bool) ([]addressLine, error) { +func readProgramSizeFromDWARF(data *dwarf.Data, codeOffset, codeAlignment uint64, skipTombstone bool) ([]addressLine, error) { r := data.Reader() var lines []*dwarf.LineFile var addresses []addressLine @@ -199,6 +197,7 @@ func readProgramSizeFromDWARF(data *dwarf.Data, codeOffset uint64, skipTombstone line := addressLine{ Address: prevLineEntry.Address + codeOffset, Length: lineEntry.Address - prevLineEntry.Address, + Align: codeAlignment, File: prevLineEntry.File.Name, } if line.Length != 0 { @@ -223,20 +222,9 @@ func readProgramSizeFromDWARF(data *dwarf.Data, codeOffset uint64, skipTombstone // Try to parse the location. While this could in theory be a very // complex expression, usually it's just a DW_OP_addr opcode // followed by an address. - locationCode := location.Val.([]uint8) - if locationCode[0] != 3 { // DW_OP_addr - continue - } - var addr uint64 - switch len(locationCode) { - case 1 + 2: - addr = uint64(binary.LittleEndian.Uint16(locationCode[1:])) - case 1 + 4: - addr = uint64(binary.LittleEndian.Uint32(locationCode[1:])) - case 1 + 8: - addr = binary.LittleEndian.Uint64(locationCode[1:]) - default: - continue // unknown address + addr, err := readDWARFConstant(r.AddressSize(), location.Val.([]uint8)) + if err != nil { + continue // ignore the error, we don't know what to do with it } // Parse the type of the global variable, which (importantly) @@ -247,9 +235,16 @@ func readProgramSizeFromDWARF(data *dwarf.Data, codeOffset uint64, skipTombstone return nil, err } + // Read alignment, if it's stored as part of the debug information. + var alignment uint64 + if attr := e.AttrField(dwarf.AttrAlignment); attr != nil { + alignment = uint64(attr.Val.(int64)) + } + addresses = append(addresses, addressLine{ Address: addr, Length: uint64(typ.Size()), + Align: alignment, File: lines[file.Val.(int64)].Name, IsVariable: true, }) @@ -260,6 +255,52 @@ func readProgramSizeFromDWARF(data *dwarf.Data, codeOffset uint64, skipTombstone return addresses, nil } +// Parse a DWARF constant. For addresses, this is usually a very simple +// expression. +func readDWARFConstant(addressSize int, bytecode []byte) (uint64, error) { + var addr uint64 + for len(bytecode) != 0 { + op := bytecode[0] + bytecode = bytecode[1:] + switch op { + case 0x03: // DW_OP_addr + switch addressSize { + case 2: + addr = uint64(binary.LittleEndian.Uint16(bytecode)) + case 4: + addr = uint64(binary.LittleEndian.Uint32(bytecode)) + case 8: + addr = binary.LittleEndian.Uint64(bytecode) + default: + panic("unexpected address size") + } + bytecode = bytecode[addressSize:] + case 0x23: // DW_OP_plus_uconst + offset, n := readULEB128(bytecode) + addr += offset + bytecode = bytecode[n:] + default: + return 0, fmt.Errorf("unknown DWARF opcode: 0x%x", op) + } + } + return addr, nil +} + +// Source: https://en.wikipedia.org/wiki/LEB128#Decode_unsigned_integer +func readULEB128(buf []byte) (result uint64, n int) { + var shift uint8 + for { + b := buf[n] + n++ + result |= uint64(b&0x7f) << shift + if b&0x80 == 0 { + break + } + shift += 7 + } + return +} + // Read a MachO object file and return a line table. // Also return an index from symbol name to start address in the line table. func readMachOSymbolAddresses(path string) (map[string]int, []addressLine, error) { @@ -281,7 +322,7 @@ func readMachOSymbolAddresses(path string) (map[string]int, []addressLine, error if err != nil { return nil, nil, err } - lines, err := readProgramSizeFromDWARF(dwarf, 0, false) + lines, err := readProgramSizeFromDWARF(dwarf, 0, 0, false) if err != nil { return nil, nil, err } @@ -338,10 +379,15 @@ func loadProgramSize(path string, packagePathMap map[string]string) (*programSiz // Load the binary file, which could be in a number of file formats. var sections []memorySection if file, err := elf.NewFile(f); err == nil { + var codeAlignment uint64 + switch file.Machine { + case elf.EM_ARM: + codeAlignment = 4 // usually 2, but can be 4 + } // Read DWARF information. The error is intentionally ignored. data, _ := file.DWARF() if data != nil { - addresses, err = readProgramSizeFromDWARF(data, 0, true) + addresses, err = readProgramSizeFromDWARF(data, 0, codeAlignment, true) if err != nil { // However, _do_ report an error here. Something must have gone // wrong while trying to parse DWARF data. @@ -375,7 +421,7 @@ func loadProgramSize(path string, packagePathMap map[string]string) (*programSiz if section.Flags&elf.SHF_ALLOC == 0 { continue } - if packageSymbolRegexp.MatchString(symbol.Name) || reflectDataRegexp.MatchString(symbol.Name) { + if packageSymbolRegexp.MatchString(symbol.Name) || symbol.Name == "__isr_vector" { addresses = append(addresses, addressLine{ Address: symbol.Value, Length: symbol.Size, @@ -399,6 +445,7 @@ func loadProgramSize(path string, packagePathMap map[string]string) (*programSiz sections = append(sections, memorySection{ Address: section.Addr, Size: section.Size, + Align: section.Addralign, Type: memoryStack, }) } else { @@ -406,6 +453,7 @@ func loadProgramSize(path string, packagePathMap map[string]string) (*programSiz sections = append(sections, memorySection{ Address: section.Addr, Size: section.Size, + Align: section.Addralign, Type: memoryBSS, }) } @@ -414,6 +462,7 @@ func loadProgramSize(path string, packagePathMap map[string]string) (*programSiz sections = append(sections, memorySection{ Address: section.Addr, Size: section.Size, + Align: section.Addralign, Type: memoryCode, }) } else if section.Type == elf.SHT_PROGBITS && section.Flags&elf.SHF_WRITE != 0 { @@ -421,6 +470,7 @@ func loadProgramSize(path string, packagePathMap map[string]string) (*programSiz sections = append(sections, memorySection{ Address: section.Addr, Size: section.Size, + Align: section.Addralign, Type: memoryData, }) } else if section.Type == elf.SHT_PROGBITS { @@ -428,6 +478,7 @@ func loadProgramSize(path string, packagePathMap map[string]string) (*programSiz sections = append(sections, memorySection{ Address: section.Addr, Size: section.Size, + Align: section.Addralign, Type: memoryROData, }) } @@ -454,6 +505,7 @@ func loadProgramSize(path string, packagePathMap map[string]string) (*programSiz sections = append(sections, memorySection{ Address: section.Addr, Size: uint64(section.Size), + Align: uint64(section.Align), Type: memoryCode, }) } else if sectionType == 1 { // S_ZEROFILL @@ -461,6 +513,7 @@ func loadProgramSize(path string, packagePathMap map[string]string) (*programSiz sections = append(sections, memorySection{ Address: section.Addr, Size: uint64(section.Size), + Align: uint64(section.Align), Type: memoryBSS, }) } else if segment.Maxprot&0b011 == 0b001 { // --r (read-only data) @@ -468,6 +521,7 @@ func loadProgramSize(path string, packagePathMap map[string]string) (*programSiz sections = append(sections, memorySection{ Address: section.Addr, Size: uint64(section.Size), + Align: uint64(section.Align), Type: memoryROData, }) } else { @@ -475,6 +529,7 @@ func loadProgramSize(path string, packagePathMap map[string]string) (*programSiz sections = append(sections, memorySection{ Address: section.Addr, Size: uint64(section.Size), + Align: uint64(section.Align), Type: memoryData, }) } @@ -558,7 +613,7 @@ func loadProgramSize(path string, packagePathMap map[string]string) (*programSiz // Read DWARF information. The error is intentionally ignored. data, _ := file.DWARF() if data != nil { - addresses, err = readProgramSizeFromDWARF(data, 0, true) + addresses, err = readProgramSizeFromDWARF(data, 0, 0, true) if err != nil { // However, _do_ report an error here. Something must have gone // wrong while trying to parse DWARF data. @@ -630,7 +685,7 @@ func loadProgramSize(path string, packagePathMap map[string]string) (*programSiz // Read DWARF information. The error is intentionally ignored. data, _ := file.DWARF() if data != nil { - addresses, err = readProgramSizeFromDWARF(data, codeOffset, true) + addresses, err = readProgramSizeFromDWARF(data, codeOffset, 0, true) if err != nil { // However, _do_ report an error here. Something must have gone // wrong while trying to parse DWARF data. @@ -790,10 +845,18 @@ func readSection(section memorySection, addresses []addressLine, addSize func(st if addr < line.Address { // There is a gap: there is a space between the current and the // previous line entry. - addSize("(unknown)", line.Address-addr, false) - if sizesDebug { - fmt.Printf("%08x..%08x %5d: unknown (gap)\n", addr, line.Address, line.Address-addr) + // Check whether this is caused by alignment requirements. + addrAligned := (addr + line.Align - 1) &^ (line.Align - 1) + if line.Align > 1 && addrAligned >= line.Address { + // It is, assume that's what causes the gap. + addSize("(padding)", line.Address-addr, true) + } else { + addSize("(unknown)", line.Address-addr, false) + if sizesDebug { + fmt.Printf("%08x..%08x %5d: unknown (gap), alignment=%d\n", addr, line.Address, line.Address-addr, line.Align) + } } + addr = line.Address } if addr > line.Address+line.Length { // The current line is already covered by a previous line entry. @@ -815,9 +878,16 @@ func readSection(section memorySection, addresses []addressLine, addSize func(st } if addr < sectionEnd { // There is a gap at the end of the section. - addSize("(unknown)", sectionEnd-addr, false) - if sizesDebug { - fmt.Printf("%08x..%08x %5d: unknown (end)\n", addr, sectionEnd, sectionEnd-addr) + addrAligned := (addr + section.Align - 1) &^ (section.Align - 1) + if section.Align > 1 && addrAligned >= sectionEnd { + // The gap is caused by the section alignment. + // For example, if a .rodata section ends with a non-aligned string. + addSize("(padding)", sectionEnd-addr, true) + } else { + addSize("(unknown)", sectionEnd-addr, false) + if sizesDebug { + fmt.Printf("%08x..%08x %5d: unknown (end), alignment=%d\n", addr, sectionEnd, sectionEnd-addr, section.Align) + } } } } @@ -833,12 +903,15 @@ func findPackagePath(path string, packagePathMap map[string]string) string { // package, with a "C" prefix. For example: "C compiler-rt" for the // compiler runtime library from LLVM. packagePath = "C " + strings.Split(strings.TrimPrefix(path, filepath.Join(goenv.Get("TINYGOROOT"), "lib")), string(os.PathSeparator))[1] + } else if strings.HasPrefix(path, filepath.Join(goenv.Get("TINYGOROOT"), "llvm-project")) { + packagePath = "C compiler-rt" } else if packageSymbolRegexp.MatchString(path) { // Parse symbol names like main$alloc or runtime$string. packagePath = path[:strings.LastIndex(path, "$")] - } else if reflectDataRegexp.MatchString(path) { - // Parse symbol names like reflect.structTypesSidetable. - packagePath = "Go reflect data" + } else if path == "__isr_vector" { + packagePath = "C interrupt vector" + } else if path == "" { + packagePath = "Go types" } else if path == "" { // Interface type assert, generated by the interface lowering pass. packagePath = "Go interface assert" diff --git a/builder/sizes_test.go b/builder/sizes_test.go new file mode 100644 index 0000000000..e935b6ec45 --- /dev/null +++ b/builder/sizes_test.go @@ -0,0 +1,92 @@ +package builder + +import ( + "runtime" + "testing" + "time" + + "github.com/tinygo-org/tinygo/compileopts" +) + +var sema = make(chan struct{}, runtime.NumCPU()) + +type sizeTest struct { + target string + path string + codeSize uint64 + rodataSize uint64 + dataSize uint64 + bssSize uint64 +} + +// Test whether code and data size is as expected for the given targets. +// This tests both the logic of loadProgramSize and checks that code size +// doesn't change unintentionally. +// +// If you find that code or data size is reduced, then great! You can reduce the +// number in this test. +// If you find that the code or data size is increased, take a look as to why +// this is. It could be due to an update (LLVM version, Go version, etc) which +// is fine, but it could also mean that a recent change introduced this size +// increase. If so, please consider whether this new feature is indeed worth the +// size increase for all users. +func TestBinarySize(t *testing.T) { + if runtime.GOOS == "linux" && !hasBuiltinTools { + // Debian LLVM packages are modified a bit and tend to produce + // different machine code. Ideally we'd fix this (with some attributes + // or something?), but for now skip it. + t.Skip("Skip: using external LLVM version so binary size might differ") + } + + // This is a small number of very diverse targets that we want to test. + tests := []sizeTest{ + // microcontrollers + {"hifive1b", "examples/echo", 4556, 272, 0, 2252}, + {"microbit", "examples/serial", 2680, 380, 8, 2256}, + {"wioterminal", "examples/pininterrupt", 6109, 1471, 116, 6816}, + + // TODO: also check wasm. Right now this is difficult, because + // wasm binaries are run through wasm-opt and therefore the + // output varies by binaryen version. + } + for _, tc := range tests { + tc := tc + t.Run(tc.target+"/"+tc.path, func(t *testing.T) { + t.Parallel() + + // Build the binary. + options := compileopts.Options{ + Target: tc.target, + Opt: "z", + Semaphore: sema, + InterpTimeout: 60 * time.Second, + Debug: true, + VerifyIR: true, + } + target, err := compileopts.LoadTarget(&options) + if err != nil { + t.Fatal("could not load target:", err) + } + config := &compileopts.Config{ + Options: &options, + Target: target, + } + result, err := Build(tc.path, "", t.TempDir(), config) + if err != nil { + t.Fatal("could not build:", err) + } + + // Check whether the size of the binary matches the expected size. + sizes, err := loadProgramSize(result.Executable, nil) + if err != nil { + t.Fatal("could not read program size:", err) + } + if sizes.Code != tc.codeSize || sizes.ROData != tc.rodataSize || sizes.Data != tc.dataSize || sizes.BSS != tc.bssSize { + t.Errorf("Unexpected code size when compiling: -target=%s %s", tc.target, tc.path) + t.Errorf(" code rodata data bss") + t.Errorf("expected: %6d %6d %6d %6d", tc.codeSize, tc.rodataSize, tc.dataSize, tc.bssSize) + t.Errorf("actual: %6d %6d %6d %6d", sizes.Code, sizes.ROData, sizes.Data, sizes.BSS) + } + }) + } +} diff --git a/compileopts/config.go b/compileopts/config.go index 4a9670005f..bfb02f1b7a 100644 --- a/compileopts/config.go +++ b/compileopts/config.go @@ -191,14 +191,6 @@ func (c *Config) StackSize() uint64 { return c.Target.DefaultStackSize } -// UseThinLTO returns whether ThinLTO should be used for the given target. -func (c *Config) UseThinLTO() bool { - // All architectures support ThinLTO now. However, this code is kept for the - // time being in case there are regressions. The non-ThinLTO code support - // should be removed when it is proven to work reliably. - return true -} - // RP2040BootPatch returns whether the RP2040 boot patch should be applied that // calculates and patches in the checksum for the 2nd stage bootloader. func (c *Config) RP2040BootPatch() bool { diff --git a/compileopts/target.go b/compileopts/target.go index 92e9315ea2..30573a863f 100644 --- a/compileopts/target.go +++ b/compileopts/target.go @@ -326,9 +326,9 @@ func defaultTarget(goos, goarch, triple string) (*TargetSpec, error) { } if goarch != "wasm" { suffix := "" - if goos == "windows" { - // Windows uses a different calling convention from other operating - // systems so we need separate assembly files. + if goos == "windows" && goarch == "amd64" { + // Windows uses a different calling convention on amd64 from other + // operating systems so we need separate assembly files. suffix = "_windows" } spec.ExtraFiles = append(spec.ExtraFiles, "src/runtime/asm_"+goarch+suffix+".S") diff --git a/compiler/atomic.go b/compiler/atomic.go index 73761be47d..48a9fb2d28 100644 --- a/compiler/atomic.go +++ b/compiler/atomic.go @@ -13,8 +13,8 @@ import ( func (b *builder) createAtomicOp(name string) llvm.Value { switch name { case "AddInt32", "AddInt64", "AddUint32", "AddUint64", "AddUintptr": - ptr := b.getValue(b.fn.Params[0]) - val := b.getValue(b.fn.Params[1]) + ptr := b.getValue(b.fn.Params[0], getPos(b.fn)) + val := b.getValue(b.fn.Params[1], getPos(b.fn)) if strings.HasPrefix(b.Triple, "avr") { // AtomicRMW does not work on AVR as intended: // - There are some register allocation issues (fixed by https://reviews.llvm.org/D97127 which is not yet in a usable LLVM release) @@ -33,8 +33,8 @@ func (b *builder) createAtomicOp(name string) llvm.Value { // Return the new value, not the original value returned by atomicrmw. return b.CreateAdd(oldVal, val, "") case "SwapInt32", "SwapInt64", "SwapUint32", "SwapUint64", "SwapUintptr", "SwapPointer": - ptr := b.getValue(b.fn.Params[0]) - val := b.getValue(b.fn.Params[1]) + ptr := b.getValue(b.fn.Params[0], getPos(b.fn)) + val := b.getValue(b.fn.Params[1], getPos(b.fn)) isPointer := val.Type().TypeKind() == llvm.PointerTypeKind if isPointer { // atomicrmw only supports integers, so cast to an integer. @@ -48,21 +48,21 @@ func (b *builder) createAtomicOp(name string) llvm.Value { } return oldVal case "CompareAndSwapInt32", "CompareAndSwapInt64", "CompareAndSwapUint32", "CompareAndSwapUint64", "CompareAndSwapUintptr", "CompareAndSwapPointer": - ptr := b.getValue(b.fn.Params[0]) - old := b.getValue(b.fn.Params[1]) - newVal := b.getValue(b.fn.Params[2]) + ptr := b.getValue(b.fn.Params[0], getPos(b.fn)) + old := b.getValue(b.fn.Params[1], getPos(b.fn)) + newVal := b.getValue(b.fn.Params[2], getPos(b.fn)) tuple := b.CreateAtomicCmpXchg(ptr, old, newVal, llvm.AtomicOrderingSequentiallyConsistent, llvm.AtomicOrderingSequentiallyConsistent, true) swapped := b.CreateExtractValue(tuple, 1, "") return swapped case "LoadInt32", "LoadInt64", "LoadUint32", "LoadUint64", "LoadUintptr", "LoadPointer": - ptr := b.getValue(b.fn.Params[0]) + ptr := b.getValue(b.fn.Params[0], getPos(b.fn)) val := b.CreateLoad(b.getLLVMType(b.fn.Signature.Results().At(0).Type()), ptr, "") val.SetOrdering(llvm.AtomicOrderingSequentiallyConsistent) val.SetAlignment(b.targetData.PrefTypeAlignment(val.Type())) // required return val case "StoreInt32", "StoreInt64", "StoreUint32", "StoreUint64", "StoreUintptr", "StorePointer": - ptr := b.getValue(b.fn.Params[0]) - val := b.getValue(b.fn.Params[1]) + ptr := b.getValue(b.fn.Params[0], getPos(b.fn)) + val := b.getValue(b.fn.Params[1], getPos(b.fn)) if strings.HasPrefix(b.Triple, "avr") { // SelectionDAGBuilder is currently missing the "are unaligned atomics allowed" check for stores. vType := val.Type() diff --git a/compiler/channel.go b/compiler/channel.go index 0ce9aa66cd..c8c10fe0b6 100644 --- a/compiler/channel.go +++ b/compiler/channel.go @@ -14,7 +14,7 @@ import ( func (b *builder) createMakeChan(expr *ssa.MakeChan) llvm.Value { elementSize := b.targetData.TypeAllocSize(b.getLLVMType(expr.Type().Underlying().(*types.Chan).Elem())) elementSizeValue := llvm.ConstInt(b.uintptrType, elementSize, false) - bufSize := b.getValue(expr.Size) + bufSize := b.getValue(expr.Size, getPos(expr)) b.createChanBoundsCheck(elementSize, bufSize, expr.Size.Type().Underlying().(*types.Basic), expr.Pos()) if bufSize.Type().IntTypeWidth() < b.uintptrType.IntTypeWidth() { bufSize = b.CreateZExt(bufSize, b.uintptrType, "") @@ -27,8 +27,8 @@ func (b *builder) createMakeChan(expr *ssa.MakeChan) llvm.Value { // createChanSend emits a pseudo chan send operation. It is lowered to the // actual channel send operation during goroutine lowering. func (b *builder) createChanSend(instr *ssa.Send) { - ch := b.getValue(instr.Chan) - chanValue := b.getValue(instr.X) + ch := b.getValue(instr.Chan, getPos(instr)) + chanValue := b.getValue(instr.X, getPos(instr)) // store value-to-send valueType := b.getLLVMType(instr.X.Type()) @@ -62,7 +62,7 @@ func (b *builder) createChanSend(instr *ssa.Send) { // actual channel receive operation during goroutine lowering. func (b *builder) createChanRecv(unop *ssa.UnOp) llvm.Value { valueType := b.getLLVMType(unop.X.Type().Underlying().(*types.Chan).Elem()) - ch := b.getValue(unop.X) + ch := b.getValue(unop.X, getPos(unop)) // Allocate memory to receive into. isZeroSize := b.targetData.TypeAllocSize(valueType) == 0 @@ -140,7 +140,7 @@ func (b *builder) createSelect(expr *ssa.Select) llvm.Value { var selectStates []llvm.Value chanSelectStateType := b.getLLVMRuntimeType("chanSelectState") for _, state := range expr.States { - ch := b.getValue(state.Chan) + ch := b.getValue(state.Chan, state.Pos) selectState := llvm.ConstNull(chanSelectStateType) selectState = b.CreateInsertValue(selectState, ch, 0, "") switch state.Dir { @@ -156,7 +156,7 @@ func (b *builder) createSelect(expr *ssa.Select) llvm.Value { case types.SendOnly: // Store this value in an alloca and put a pointer to this alloca // in the send state. - sendValue := b.getValue(state.Send) + sendValue := b.getValue(state.Send, state.Pos) alloca := llvmutil.CreateEntryBlockAlloca(b.Builder, sendValue.Type(), "select.send.value") b.CreateStore(sendValue, alloca) ptr := b.CreateBitCast(alloca, b.i8ptrType, "") @@ -247,7 +247,7 @@ func (b *builder) createSelect(expr *ssa.Select) llvm.Value { func (b *builder) getChanSelectResult(expr *ssa.Extract) llvm.Value { if expr.Index == 0 { // index - value := b.getValue(expr.Tuple) + value := b.getValue(expr.Tuple, getPos(expr)) index := b.CreateExtractValue(value, expr.Index, "") if index.Type().IntTypeWidth() < b.intType.IntTypeWidth() { index = b.CreateSExt(index, b.intType, "") @@ -255,7 +255,7 @@ func (b *builder) getChanSelectResult(expr *ssa.Extract) llvm.Value { return index } else if expr.Index == 1 { // comma-ok - value := b.getValue(expr.Tuple) + value := b.getValue(expr.Tuple, getPos(expr)) return b.CreateExtractValue(value, expr.Index, "") } else { // Select statements are (index, ok, ...) where ... is a number of diff --git a/compiler/compiler.go b/compiler/compiler.go index 92e95a05a0..9e05ca5242 100644 --- a/compiler/compiler.go +++ b/compiler/compiler.go @@ -47,6 +47,7 @@ type Config struct { CodeModel string RelocationModel string SizeLevel int + TinyGoVersion string // for llvm.ident // Various compiler options that determine how code is generated. Scheduler string @@ -321,6 +322,14 @@ func CompilePackage(moduleName string, pkg *loader.Package, ssaPkg *ssa.Package, llvm.ConstInt(c.ctx.Int32Type(), 4, false).ConstantAsMetadata(), }), ) + if c.TinyGoVersion != "" { + // It is necessary to set llvm.ident, otherwise debugging on MacOS + // won't work. + c.mod.AddNamedMetadataOperand("llvm.ident", + c.ctx.MDNode(([]llvm.Metadata{ + c.ctx.MDString("TinyGo version " + c.TinyGoVersion), + }))) + } c.dibuilder.Finalize() c.dibuilder.Destroy() } @@ -340,12 +349,15 @@ func CompilePackage(moduleName string, pkg *loader.Package, ssaPkg *ssa.Package, return c.mod, c.diagnostics } +func (c *compilerContext) getRuntimeType(name string) types.Type { + return c.runtimePkg.Scope().Lookup(name).(*types.TypeName).Type() +} + // getLLVMRuntimeType obtains a named type from the runtime package and returns // it as a LLVM type, creating it if necessary. It is a shorthand for // getLLVMType(getRuntimeType(name)). func (c *compilerContext) getLLVMRuntimeType(name string) llvm.Type { - typ := c.runtimePkg.Scope().Lookup(name).(*types.TypeName).Type() - return c.getLLVMType(typ) + return c.getLLVMType(c.getRuntimeType(name)) } // getLLVMType returns a LLVM type for a Go type. It doesn't recreate already @@ -557,10 +569,19 @@ func (c *compilerContext) createDIType(typ types.Type) llvm.Metadata { case *types.Map: return c.getDIType(types.NewPointer(c.program.ImportedPackage("runtime").Members["hashmap"].(*ssa.Type).Type())) case *types.Named: - return c.dibuilder.CreateTypedef(llvm.DITypedef{ + // Placeholder metadata node, to be replaced afterwards. + temporaryMDNode := c.dibuilder.CreateReplaceableCompositeType(llvm.Metadata{}, llvm.DIReplaceableCompositeType{ + Tag: dwarf.TagTypedef, + SizeInBits: sizeInBytes * 8, + AlignInBits: uint32(c.targetData.ABITypeAlignment(llvmType)) * 8, + }) + c.ditypes[typ] = temporaryMDNode + md := c.dibuilder.CreateTypedef(llvm.DITypedef{ Type: c.getDIType(typ.Underlying()), Name: typ.String(), }) + temporaryMDNode.ReplaceAllUsesWith(md) + return md case *types.Pointer: return c.dibuilder.CreatePointerType(llvm.DIPointerType{ Pointee: c.getDIType(typ.Elem()), @@ -622,13 +643,6 @@ func (c *compilerContext) createDIType(typ types.Type) llvm.Metadata { }, }) case *types.Struct: - // Placeholder metadata node, to be replaced afterwards. - temporaryMDNode := c.dibuilder.CreateReplaceableCompositeType(llvm.Metadata{}, llvm.DIReplaceableCompositeType{ - Tag: dwarf.TagStructType, - SizeInBits: sizeInBytes * 8, - AlignInBits: uint32(c.targetData.ABITypeAlignment(llvmType)) * 8, - }) - c.ditypes[typ] = temporaryMDNode elements := make([]llvm.Metadata, typ.NumFields()) for i := range elements { field := typ.Field(i) @@ -647,7 +661,6 @@ func (c *compilerContext) createDIType(typ types.Type) llvm.Metadata { AlignInBits: uint32(c.targetData.ABITypeAlignment(llvmType)) * 8, Elements: elements, }) - temporaryMDNode.ReplaceAllUsesWith(md) return md case *types.TypeParam: return c.getDIType(typ.Underlying()) @@ -857,14 +870,14 @@ func (c *compilerContext) createPackage(irbuilder llvm.Builder, pkg *ssa.Package if fn == nil { continue // probably a generic method } - if fn.Blocks == nil { - continue // external function - } if member.Type().String() != member.String() { // This is a member on a type alias. Do not build such a // function. continue } + if fn.Blocks == nil { + continue // external function + } if fn.Synthetic != "" && fn.Synthetic != "package initializer" { // This function is a kind of wrapper function (created by // the ssa package, not appearing in the source code) that @@ -958,6 +971,19 @@ func (c *compilerContext) createEmbedGlobal(member *ssa.Global, global llvm.Valu global.SetInitializer(sliceObj) global.SetVisibility(llvm.HiddenVisibility) + if c.Debug { + // Add debug info to the slice backing array. + position := c.program.Fset.Position(member.Pos()) + diglobal := c.dibuilder.CreateGlobalVariableExpression(llvm.Metadata{}, llvm.DIGlobalVariableExpression{ + File: c.getDIFile(position.Filename), + Line: position.Line, + Type: c.getDIType(types.NewArray(types.Typ[types.Byte], int64(len(file.Data)))), + LocalToUnit: true, + Expr: c.dibuilder.CreateExpression(nil), + }) + bufferGlobal.AddMetadata(0, diglobal) + } + case *types.Struct: // Assume this is an embed.FS struct: // https://cs.opensource.google/go/go/+/refs/tags/go1.18.2:src/embed/embed.go;l=148 @@ -998,11 +1024,12 @@ func (c *compilerContext) createEmbedGlobal(member *ssa.Global, global llvm.Valu }) // Make the backing array for the []files slice. This is a LLVM global. - embedFileStructType := c.getLLVMType(typ.Field(0).Type().(*types.Pointer).Elem().(*types.Slice).Elem()) + embedFileStructType := typ.Field(0).Type().(*types.Pointer).Elem().(*types.Slice).Elem() + llvmEmbedFileStructType := c.getLLVMType(embedFileStructType) var fileStructs []llvm.Value for _, file := range allFiles { - fileStruct := llvm.ConstNull(embedFileStructType) - name := c.createConst(ssa.NewConst(constant.MakeString(file.Name), types.Typ[types.String])) + fileStruct := llvm.ConstNull(llvmEmbedFileStructType) + name := c.createConst(ssa.NewConst(constant.MakeString(file.Name), types.Typ[types.String]), getPos(member)) fileStruct = c.builder.CreateInsertValue(fileStruct, name, 0, "") // "name" field if file.Hash != "" { data := c.getEmbedFileString(file) @@ -1010,13 +1037,25 @@ func (c *compilerContext) createEmbedGlobal(member *ssa.Global, global llvm.Valu } fileStructs = append(fileStructs, fileStruct) } - sliceDataInitializer := llvm.ConstArray(embedFileStructType, fileStructs) + sliceDataInitializer := llvm.ConstArray(llvmEmbedFileStructType, fileStructs) sliceDataGlobal := llvm.AddGlobal(c.mod, sliceDataInitializer.Type(), c.pkg.Path()+"$embedfsfiles") sliceDataGlobal.SetInitializer(sliceDataInitializer) sliceDataGlobal.SetLinkage(llvm.InternalLinkage) sliceDataGlobal.SetGlobalConstant(true) sliceDataGlobal.SetUnnamedAddr(true) sliceDataGlobal.SetAlignment(c.targetData.ABITypeAlignment(sliceDataInitializer.Type())) + if c.Debug { + // Add debug information for code size attribution (among others). + position := c.program.Fset.Position(member.Pos()) + diglobal := c.dibuilder.CreateGlobalVariableExpression(llvm.Metadata{}, llvm.DIGlobalVariableExpression{ + File: c.getDIFile(position.Filename), + Line: position.Line, + Type: c.getDIType(types.NewArray(embedFileStructType, int64(len(allFiles)))), + LocalToUnit: true, + Expr: c.dibuilder.CreateExpression(nil), + }) + sliceDataGlobal.AddMetadata(0, diglobal) + } // Create the slice object itself. // Because embed.FS refers to it as *[]embed.file instead of a plain @@ -1033,6 +1072,17 @@ func (c *compilerContext) createEmbedGlobal(member *ssa.Global, global llvm.Valu sliceGlobal.SetGlobalConstant(true) sliceGlobal.SetUnnamedAddr(true) sliceGlobal.SetAlignment(c.targetData.ABITypeAlignment(sliceInitializer.Type())) + if c.Debug { + position := c.program.Fset.Position(member.Pos()) + diglobal := c.dibuilder.CreateGlobalVariableExpression(llvm.Metadata{}, llvm.DIGlobalVariableExpression{ + File: c.getDIFile(position.Filename), + Line: position.Line, + Type: c.getDIType(types.NewSlice(embedFileStructType)), + LocalToUnit: true, + Expr: c.dibuilder.CreateExpression(nil), + }) + sliceGlobal.AddMetadata(0, diglobal) + } // Define the embed.FS struct. It has only one field: the files (as a // *[]embed.file). @@ -1272,7 +1322,7 @@ func (b *builder) createFunction() { } dbgVar := b.getLocalVariable(variable) pos := b.program.Fset.Position(instr.Pos()) - b.dibuilder.InsertValueAtEnd(b.getValue(instr.X), dbgVar, b.dibuilder.CreateExpression(nil), llvm.DebugLoc{ + b.dibuilder.InsertValueAtEnd(b.getValue(instr.X, getPos(instr)), dbgVar, b.dibuilder.CreateExpression(nil), llvm.DebugLoc{ Line: uint(pos.Line), Col: uint(pos.Column), Scope: b.difunc, @@ -1303,7 +1353,7 @@ func (b *builder) createFunction() { for _, phi := range b.phis { block := phi.ssa.Block() for i, edge := range phi.ssa.Edges { - llvmVal := b.getValue(edge) + llvmVal := b.getValue(edge, getPos(phi.ssa)) llvmBlock := b.blockExits[block.Preds[i]] phi.llvm.AddIncoming([]llvm.Value{llvmVal}, []llvm.BasicBlock{llvmBlock}) } @@ -1411,7 +1461,7 @@ func (b *builder) createInstruction(instr ssa.Instruction) { // Start a new goroutine. b.createGo(instr) case *ssa.If: - cond := b.getValue(instr.Cond) + cond := b.getValue(instr.Cond, getPos(instr)) block := instr.Block() blockThen := b.blockEntries[block.Succs[0]] blockElse := b.blockEntries[block.Succs[1]] @@ -1420,13 +1470,13 @@ func (b *builder) createInstruction(instr ssa.Instruction) { blockJump := b.blockEntries[instr.Block().Succs[0]] b.CreateBr(blockJump) case *ssa.MapUpdate: - m := b.getValue(instr.Map) - key := b.getValue(instr.Key) - value := b.getValue(instr.Value) + m := b.getValue(instr.Map, getPos(instr)) + key := b.getValue(instr.Key, getPos(instr)) + value := b.getValue(instr.Value, getPos(instr)) mapType := instr.Map.Type().Underlying().(*types.Map) b.createMapUpdate(mapType.Key(), m, key, value, instr.Pos()) case *ssa.Panic: - value := b.getValue(instr.X) + value := b.getValue(instr.X, getPos(instr)) b.createRuntimeInvoke("_panic", []llvm.Value{value}, "") b.CreateUnreachable() case *ssa.Return: @@ -1436,12 +1486,12 @@ func (b *builder) createInstruction(instr ssa.Instruction) { if len(instr.Results) == 0 { b.CreateRetVoid() } else if len(instr.Results) == 1 { - b.CreateRet(b.getValue(instr.Results[0])) + b.CreateRet(b.getValue(instr.Results[0], getPos(instr))) } else { // Multiple return values. Put them all in a struct. retVal := llvm.ConstNull(b.llvmFn.GlobalValueType().ReturnType()) for i, result := range instr.Results { - val := b.getValue(result) + val := b.getValue(result, getPos(instr)) retVal = b.CreateInsertValue(retVal, val, i, "") } b.CreateRet(retVal) @@ -1451,8 +1501,8 @@ func (b *builder) createInstruction(instr ssa.Instruction) { case *ssa.Send: b.createChanSend(instr) case *ssa.Store: - llvmAddr := b.getValue(instr.Addr) - llvmVal := b.getValue(instr.Val) + llvmAddr := b.getValue(instr.Addr, getPos(instr)) + llvmVal := b.getValue(instr.Val, getPos(instr)) b.createNilCheck(instr.Addr, llvmAddr, "store") if b.targetData.TypeAllocSize(llvmVal.Type()) == 0 { // nothing to store @@ -1711,7 +1761,7 @@ func (b *builder) createBuiltin(argTypes []types.Type, argValues []llvm.Value, c func (b *builder) createFunctionCall(instr *ssa.CallCommon) (llvm.Value, error) { var params []llvm.Value for _, param := range instr.Args { - params = append(params, b.getValue(param)) + params = append(params, b.getValue(param, getPos(instr))) } // Try to call the function directly for trivially static calls. @@ -1728,9 +1778,9 @@ func (b *builder) createFunctionCall(instr *ssa.CallCommon) (llvm.Value, error) case name == "device.AsmFull" || name == "device/arm.AsmFull" || name == "device/arm64.AsmFull" || name == "device/avr.AsmFull" || name == "device/riscv.AsmFull": return b.createInlineAsmFull(instr) case strings.HasPrefix(name, "device/arm.SVCall"): - return b.emitSVCall(instr.Args) + return b.emitSVCall(instr.Args, getPos(instr)) case strings.HasPrefix(name, "device/arm64.SVCall"): - return b.emitSV64Call(instr.Args) + return b.emitSV64Call(instr.Args, getPos(instr)) case strings.HasPrefix(name, "(device/riscv.CSR)."): return b.emitCSROperation(instr) case strings.HasPrefix(name, "syscall.Syscall") || strings.HasPrefix(name, "syscall.RawSyscall"): @@ -1767,7 +1817,7 @@ func (b *builder) createFunctionCall(instr *ssa.CallCommon) (llvm.Value, error) case *ssa.MakeClosure: // A call on a func value, but the callee is trivial to find. For // example: immediately applied functions. - funcValue := b.getValue(value) + funcValue := b.getValue(value, getPos(value)) context = b.extractFuncContext(funcValue) default: panic("StaticCallee returned an unexpected value") @@ -1782,7 +1832,7 @@ func (b *builder) createFunctionCall(instr *ssa.CallCommon) (llvm.Value, error) return b.createBuiltin(argTypes, params, call.Name(), instr.Pos()) } else if instr.IsInvoke() { // Interface method call (aka invoke call). - itf := b.getValue(instr.Value) // interface value (runtime._interface) + itf := b.getValue(instr.Value, getPos(instr)) // interface value (runtime._interface) typecode := b.CreateExtractValue(itf, 0, "invoke.func.typecode") value := b.CreateExtractValue(itf, 1, "invoke.func.value") // receiver // Prefix the params with receiver value and suffix with typecode. @@ -1793,7 +1843,7 @@ func (b *builder) createFunctionCall(instr *ssa.CallCommon) (llvm.Value, error) context = llvm.Undef(b.i8ptrType) } else { // Function pointer. - value := b.getValue(instr.Value) + value := b.getValue(instr.Value, getPos(instr)) // This is a func value, which cannot be called directly. We have to // extract the function pointer and context first from the func value. calleeType, callee, context = b.decodeFuncValue(value, instr.Value.Type().Underlying().(*types.Signature)) @@ -1811,10 +1861,18 @@ func (b *builder) createFunctionCall(instr *ssa.CallCommon) (llvm.Value, error) // getValue returns the LLVM value of a constant, function value, global, or // already processed SSA expression. -func (b *builder) getValue(expr ssa.Value) llvm.Value { +func (b *builder) getValue(expr ssa.Value, pos token.Pos) llvm.Value { switch expr := expr.(type) { case *ssa.Const: - return b.createConst(expr) + if pos == token.NoPos { + // If the position isn't known, at least try to find in which file + // it is defined. + file := b.program.Fset.File(b.fn.Pos()) + if file != nil { + pos = file.Pos(0) + } + } + return b.createConst(expr, pos) case *ssa.Function: if b.getFunctionInfo(expr).exported { b.addError(expr.Pos(), "cannot use an exported function as value: "+expr.String()) @@ -1899,8 +1957,8 @@ func (b *builder) createExpr(expr ssa.Value) (llvm.Value, error) { return buf, nil } case *ssa.BinOp: - x := b.getValue(expr.X) - y := b.getValue(expr.Y) + x := b.getValue(expr.X, getPos(expr)) + y := b.getValue(expr.Y, getPos(expr)) return b.createBinOp(expr.Op, expr.X.Type(), expr.Y.Type(), x, y, expr.Pos()) case *ssa.Call: return b.createFunctionCall(expr.Common()) @@ -1911,12 +1969,12 @@ func (b *builder) createExpr(expr ssa.Value) (llvm.Value, error) { // This is different from how the official Go compiler works, because of // heap allocation and because it's easier to implement, see: // https://research.swtch.com/interfaces - return b.getValue(expr.X), nil + return b.getValue(expr.X, getPos(expr)), nil case *ssa.ChangeType: // This instruction changes the type, but the underlying value remains // the same. This is often a no-op, but sometimes we have to change the // LLVM type as well. - x := b.getValue(expr.X) + x := b.getValue(expr.X, getPos(expr)) llvmType := b.getLLVMType(expr.Type()) if x.Type() == llvmType { // Different Go type but same LLVM type (for example, named int). @@ -1945,20 +2003,20 @@ func (b *builder) createExpr(expr ssa.Value) (llvm.Value, error) { case *ssa.Const: panic("const is not an expression") case *ssa.Convert: - x := b.getValue(expr.X) + x := b.getValue(expr.X, getPos(expr)) return b.createConvert(expr.X.Type(), expr.Type(), x, expr.Pos()) case *ssa.Extract: if _, ok := expr.Tuple.(*ssa.Select); ok { return b.getChanSelectResult(expr), nil } - value := b.getValue(expr.Tuple) + value := b.getValue(expr.Tuple, getPos(expr)) return b.CreateExtractValue(value, expr.Index, ""), nil case *ssa.Field: - value := b.getValue(expr.X) + value := b.getValue(expr.X, getPos(expr)) result := b.CreateExtractValue(value, expr.Field, "") return result, nil case *ssa.FieldAddr: - val := b.getValue(expr.X) + val := b.getValue(expr.X, getPos(expr)) // Check for nil pointer before calculating the address, from the spec: // > For an operand x of type T, the address operation &x generates a // > pointer of type *T to x. [...] If the evaluation of x would cause a @@ -1976,8 +2034,8 @@ func (b *builder) createExpr(expr ssa.Value) (llvm.Value, error) { case *ssa.Global: panic("global is not an expression") case *ssa.Index: - collection := b.getValue(expr.X) - index := b.getValue(expr.Index) + collection := b.getValue(expr.X, getPos(expr)) + index := b.getValue(expr.Index, getPos(expr)) switch xType := expr.X.Type().Underlying().(type) { case *types.Basic: // extract byte from string @@ -2026,8 +2084,8 @@ func (b *builder) createExpr(expr ssa.Value) (llvm.Value, error) { panic("unknown *ssa.Index type") } case *ssa.IndexAddr: - val := b.getValue(expr.X) - index := b.getValue(expr.Index) + val := b.getValue(expr.X, getPos(expr)) + index := b.getValue(expr.Index, getPos(expr)) // Get buffer pointer and length var bufptr, buflen llvm.Value @@ -2078,8 +2136,8 @@ func (b *builder) createExpr(expr ssa.Value) (llvm.Value, error) { panic("unreachable") } case *ssa.Lookup: // map lookup - value := b.getValue(expr.X) - index := b.getValue(expr.Index) + value := b.getValue(expr.X, getPos(expr)) + index := b.getValue(expr.Index, getPos(expr)) valueType := expr.Type() if expr.CommaOk { valueType = valueType.(*types.Tuple).At(0).Type() @@ -2090,13 +2148,13 @@ func (b *builder) createExpr(expr ssa.Value) (llvm.Value, error) { case *ssa.MakeClosure: return b.parseMakeClosure(expr) case *ssa.MakeInterface: - val := b.getValue(expr.X) + val := b.getValue(expr.X, getPos(expr)) return b.createMakeInterface(val, expr.X.Type(), expr.Pos()), nil case *ssa.MakeMap: return b.createMakeMap(expr) case *ssa.MakeSlice: - sliceLen := b.getValue(expr.Len) - sliceCap := b.getValue(expr.Cap) + sliceLen := b.getValue(expr.Len, getPos(expr)) + sliceCap := b.getValue(expr.Cap, getPos(expr)) sliceType := expr.Type().Underlying().(*types.Slice) llvmElemType := b.getLLVMType(sliceType.Elem()) elemSize := b.targetData.TypeAllocSize(llvmElemType) @@ -2148,8 +2206,8 @@ func (b *builder) createExpr(expr ssa.Value) (llvm.Value, error) { return slice, nil case *ssa.Next: rangeVal := expr.Iter.(*ssa.Range).X - llvmRangeVal := b.getValue(rangeVal) - it := b.getValue(expr.Iter) + llvmRangeVal := b.getValue(rangeVal, getPos(expr)) + it := b.getValue(expr.Iter, getPos(expr)) if expr.IsString { return b.createRuntimeCall("stringNext", []llvm.Value{llvmRangeVal, it}, "range.next"), nil } else { // map @@ -2175,14 +2233,14 @@ func (b *builder) createExpr(expr ssa.Value) (llvm.Value, error) { case *ssa.Select: return b.createSelect(expr), nil case *ssa.Slice: - value := b.getValue(expr.X) + value := b.getValue(expr.X, getPos(expr)) var lowType, highType, maxType *types.Basic var low, high, max llvm.Value if expr.Low != nil { lowType = expr.Low.Type().Underlying().(*types.Basic) - low = b.getValue(expr.Low) + low = b.getValue(expr.Low, getPos(expr)) low = b.extendInteger(low, lowType, b.uintptrType) } else { lowType = types.Typ[types.Uintptr] @@ -2191,7 +2249,7 @@ func (b *builder) createExpr(expr ssa.Value) (llvm.Value, error) { if expr.High != nil { highType = expr.High.Type().Underlying().(*types.Basic) - high = b.getValue(expr.High) + high = b.getValue(expr.High, getPos(expr)) high = b.extendInteger(high, highType, b.uintptrType) } else { highType = types.Typ[types.Uintptr] @@ -2199,7 +2257,7 @@ func (b *builder) createExpr(expr ssa.Value) (llvm.Value, error) { if expr.Max != nil { maxType = expr.Max.Type().Underlying().(*types.Basic) - max = b.getValue(expr.Max) + max = b.getValue(expr.Max, getPos(expr)) max = b.extendInteger(max, maxType, b.uintptrType) } else { maxType = types.Typ[types.Uintptr] @@ -2332,7 +2390,7 @@ func (b *builder) createExpr(expr ssa.Value) (llvm.Value, error) { // Conversion from a slice to an array pointer, as the name clearly // says. This requires a runtime check to make sure the slice is at // least as big as the array. - slice := b.getValue(expr.X) + slice := b.getValue(expr.X, getPos(expr)) sliceLen := b.CreateExtractValue(slice, 1, "") arrayLen := expr.Type().Underlying().(*types.Pointer).Elem().Underlying().(*types.Array).Len() b.createSliceToArrayPointerCheck(sliceLen, arrayLen) @@ -2765,7 +2823,7 @@ func (b *builder) createBinOp(op token.Token, typ, ytyp types.Type, x, y llvm.Va } // createConst creates a LLVM constant value from a Go constant. -func (c *compilerContext) createConst(expr *ssa.Const) llvm.Value { +func (c *compilerContext) createConst(expr *ssa.Const, pos token.Pos) llvm.Value { switch typ := expr.Type().Underlying().(type) { case *types.Basic: llvmType := c.getLLVMType(typ) @@ -2788,6 +2846,18 @@ func (c *compilerContext) createConst(expr *ssa.Const) llvm.Value { global.SetGlobalConstant(true) global.SetUnnamedAddr(true) global.SetAlignment(1) + if c.Debug { + // Unfortunately, expr.Pos() is always token.NoPos. + position := c.program.Fset.Position(pos) + diglobal := c.dibuilder.CreateGlobalVariableExpression(llvm.Metadata{}, llvm.DIGlobalVariableExpression{ + File: c.getDIFile(position.Filename), + Line: position.Line, + Type: c.getDIType(types.NewArray(types.Typ[types.Byte], int64(len(str)))), + LocalToUnit: true, + Expr: c.dibuilder.CreateExpression(nil), + }) + global.AddMetadata(0, diglobal) + } zero := llvm.ConstInt(c.ctx.Int32Type(), 0, false) strPtr = llvm.ConstInBoundsGEP(globalType, global, []llvm.Value{zero, zero}) } else { @@ -2811,15 +2881,15 @@ func (c *compilerContext) createConst(expr *ssa.Const) llvm.Value { n, _ := constant.Float64Val(expr.Value) return llvm.ConstFloat(llvmType, n) } else if typ.Kind() == types.Complex64 { - r := c.createConst(ssa.NewConst(constant.Real(expr.Value), types.Typ[types.Float32])) - i := c.createConst(ssa.NewConst(constant.Imag(expr.Value), types.Typ[types.Float32])) + r := c.createConst(ssa.NewConst(constant.Real(expr.Value), types.Typ[types.Float32]), pos) + i := c.createConst(ssa.NewConst(constant.Imag(expr.Value), types.Typ[types.Float32]), pos) cplx := llvm.Undef(c.ctx.StructType([]llvm.Type{c.ctx.FloatType(), c.ctx.FloatType()}, false)) cplx = c.builder.CreateInsertValue(cplx, r, 0, "") cplx = c.builder.CreateInsertValue(cplx, i, 1, "") return cplx } else if typ.Kind() == types.Complex128 { - r := c.createConst(ssa.NewConst(constant.Real(expr.Value), types.Typ[types.Float64])) - i := c.createConst(ssa.NewConst(constant.Imag(expr.Value), types.Typ[types.Float64])) + r := c.createConst(ssa.NewConst(constant.Real(expr.Value), types.Typ[types.Float64]), pos) + i := c.createConst(ssa.NewConst(constant.Imag(expr.Value), types.Typ[types.Float64]), pos) cplx := llvm.Undef(c.ctx.StructType([]llvm.Type{c.ctx.DoubleType(), c.ctx.DoubleType()}, false)) cplx = c.builder.CreateInsertValue(cplx, r, 0, "") cplx = c.builder.CreateInsertValue(cplx, i, 1, "") @@ -2888,31 +2958,6 @@ func (b *builder) createConvert(typeFrom, typeTo types.Type, value llvm.Value, p if isPtrFrom && !isPtrTo { return b.CreatePtrToInt(value, llvmTypeTo, ""), nil } else if !isPtrFrom && isPtrTo { - if !value.IsABinaryOperator().IsNil() && value.InstructionOpcode() == llvm.Add { - // This is probably a pattern like the following: - // unsafe.Pointer(uintptr(ptr) + index) - // Used in functions like memmove etc. for lack of pointer - // arithmetic. Convert it to real pointer arithmatic here. - ptr := value.Operand(0) - index := value.Operand(1) - if !index.IsAPtrToIntInst().IsNil() { - // Swap if necessary, if ptr and index are reversed. - ptr, index = index, ptr - } - if !ptr.IsAPtrToIntInst().IsNil() { - origptr := ptr.Operand(0) - if origptr.Type() == b.i8ptrType { - // This pointer can be calculated from the original - // ptrtoint instruction with a GEP. The leftover inttoptr - // instruction is trivial to optimize away. - // Making it an in bounds GEP even though it's easy to - // create a GEP that is not in bounds. However, we're - // talking about unsafe code here so the programmer has to - // be careful anyway. - return b.CreateInBoundsGEP(b.ctx.Int8Type(), origptr, []llvm.Value{index}, ""), nil - } - } - } return b.CreateIntToPtr(value, llvmTypeTo, ""), nil } @@ -3115,7 +3160,7 @@ func (b *builder) createConvert(typeFrom, typeTo types.Type, value llvm.Value, p // which can all be directly lowered to IR. However, there is also the channel // receive operator which is handled in the runtime directly. func (b *builder) createUnOp(unop *ssa.UnOp) (llvm.Value, error) { - x := b.getValue(unop.X) + x := b.getValue(unop.X, getPos(unop)) switch unop.Op { case token.NOT: // !x return b.CreateNot(x, ""), nil diff --git a/compiler/compiler_test.go b/compiler/compiler_test.go index f8221a08e7..74e213dc61 100644 --- a/compiler/compiler_test.go +++ b/compiler/compiler_test.go @@ -49,6 +49,7 @@ func TestCompiler(t *testing.T) { {"goroutine.go", "cortex-m-qemu", "tasks"}, {"channel.go", "", ""}, {"gc.go", "", ""}, + {"zeromap.go", "", ""}, } if goMinor >= 20 { tests = append(tests, testCase{"go1.20.go", "", ""}) diff --git a/compiler/defer.go b/compiler/defer.go index 5ec3ef7e25..a7739c9db1 100644 --- a/compiler/defer.go +++ b/compiler/defer.go @@ -267,13 +267,13 @@ func (b *builder) createDefer(instr *ssa.Defer) { // Collect all values to be put in the struct (starting with // runtime._defer fields, followed by the call parameters). - itf := b.getValue(instr.Call.Value) // interface + itf := b.getValue(instr.Call.Value, getPos(instr)) // interface typecode := b.CreateExtractValue(itf, 0, "invoke.func.typecode") receiverValue := b.CreateExtractValue(itf, 1, "invoke.func.receiver") values = []llvm.Value{callback, next, typecode, receiverValue} - valueTypes = append(valueTypes, b.uintptrType, b.i8ptrType) + valueTypes = append(valueTypes, b.i8ptrType, b.i8ptrType) for _, arg := range instr.Call.Args { - val := b.getValue(arg) + val := b.getValue(arg, getPos(instr)) values = append(values, val) valueTypes = append(valueTypes, val.Type()) } @@ -290,7 +290,7 @@ func (b *builder) createDefer(instr *ssa.Defer) { // runtime._defer fields). values = []llvm.Value{callback, next} for _, param := range instr.Call.Args { - llvmParam := b.getValue(param) + llvmParam := b.getValue(param, getPos(instr)) values = append(values, llvmParam) valueTypes = append(valueTypes, llvmParam.Type()) } @@ -302,7 +302,7 @@ func (b *builder) createDefer(instr *ssa.Defer) { // pointer. // TODO: ignore this closure entirely and put pointers to the free // variables directly in the defer struct, avoiding a memory allocation. - closure := b.getValue(instr.Call.Value) + closure := b.getValue(instr.Call.Value, getPos(instr)) context := b.CreateExtractValue(closure, 0, "") // Get the callback number. @@ -318,7 +318,7 @@ func (b *builder) createDefer(instr *ssa.Defer) { // context pointer). values = []llvm.Value{callback, next} for _, param := range instr.Call.Args { - llvmParam := b.getValue(param) + llvmParam := b.getValue(param, getPos(instr)) values = append(values, llvmParam) valueTypes = append(valueTypes, llvmParam.Type()) } @@ -330,7 +330,7 @@ func (b *builder) createDefer(instr *ssa.Defer) { var argValues []llvm.Value for _, arg := range instr.Call.Args { argTypes = append(argTypes, arg.Type()) - argValues = append(argValues, b.getValue(arg)) + argValues = append(argValues, b.getValue(arg, getPos(instr))) } if _, ok := b.deferBuiltinFuncs[instr.Call.Value]; !ok { @@ -353,7 +353,7 @@ func (b *builder) createDefer(instr *ssa.Defer) { } } else { - funcValue := b.getValue(instr.Call.Value) + funcValue := b.getValue(instr.Call.Value, getPos(instr)) if _, ok := b.deferExprFuncs[instr.Call.Value]; !ok { b.deferExprFuncs[instr.Call.Value] = len(b.allDeferFuncs) @@ -368,7 +368,7 @@ func (b *builder) createDefer(instr *ssa.Defer) { values = []llvm.Value{callback, next, funcValue} valueTypes = append(valueTypes, funcValue.Type()) for _, param := range instr.Call.Args { - llvmParam := b.getValue(param) + llvmParam := b.getValue(param, getPos(instr)) values = append(values, llvmParam) valueTypes = append(valueTypes, llvmParam.Type()) } @@ -476,7 +476,7 @@ func (b *builder) createRunDefers() { valueTypes = append(valueTypes, b.getFuncType(callback.Signature())) } else { //Expect typecode - valueTypes = append(valueTypes, b.uintptrType, b.i8ptrType) + valueTypes = append(valueTypes, b.i8ptrType, b.i8ptrType) } for _, arg := range callback.Args { diff --git a/compiler/func.go b/compiler/func.go index 404c731630..743a4f0837 100644 --- a/compiler/func.go +++ b/compiler/func.go @@ -134,7 +134,7 @@ func (b *builder) parseMakeClosure(expr *ssa.MakeClosure) (llvm.Value, error) { boundVars := make([]llvm.Value, len(expr.Bindings)) for i, binding := range expr.Bindings { // The context stores the bound variables. - llvmBoundVar := b.getValue(binding) + llvmBoundVar := b.getValue(binding, getPos(expr)) boundVars[i] = llvmBoundVar } diff --git a/compiler/goroutine.go b/compiler/goroutine.go index da4e32aed0..8feb5e799c 100644 --- a/compiler/goroutine.go +++ b/compiler/goroutine.go @@ -16,7 +16,7 @@ func (b *builder) createGo(instr *ssa.Go) { // Get all function parameters to pass to the goroutine. var params []llvm.Value for _, param := range instr.Call.Args { - params = append(params, b.getValue(param)) + params = append(params, b.getValue(param, getPos(instr))) } var prefix string @@ -33,7 +33,7 @@ func (b *builder) createGo(instr *ssa.Go) { case *ssa.MakeClosure: // A goroutine call on a func value, but the callee is trivial to find. For // example: immediately applied functions. - funcValue := b.getValue(value) + funcValue := b.getValue(value, getPos(instr)) context = b.extractFuncContext(funcValue) default: panic("StaticCallee returned an unexpected value") @@ -70,13 +70,13 @@ func (b *builder) createGo(instr *ssa.Go) { var argValues []llvm.Value for _, arg := range instr.Call.Args { argTypes = append(argTypes, arg.Type()) - argValues = append(argValues, b.getValue(arg)) + argValues = append(argValues, b.getValue(arg, getPos(instr))) } b.createBuiltin(argTypes, argValues, builtin.Name(), instr.Pos()) return } else if instr.Call.IsInvoke() { // This is a method call on an interface value. - itf := b.getValue(instr.Call.Value) + itf := b.getValue(instr.Call.Value, getPos(instr)) itfTypeCode := b.CreateExtractValue(itf, 0, "") itfValue := b.CreateExtractValue(itf, 1, "") funcPtr = b.getInvokeFunction(&instr.Call) @@ -90,7 +90,7 @@ func (b *builder) createGo(instr *ssa.Go) { // * The function context, for closures. // * The function pointer (for tasks). var context llvm.Value - funcPtrType, funcPtr, context = b.decodeFuncValue(b.getValue(instr.Call.Value), instr.Call.Value.Type().Underlying().(*types.Signature)) + funcPtrType, funcPtr, context = b.decodeFuncValue(b.getValue(instr.Call.Value, getPos(instr)), instr.Call.Value.Type().Underlying().(*types.Signature)) params = append(params, context, funcPtr) hasContext = true prefix = b.fn.RelString(nil) diff --git a/compiler/inlineasm.go b/compiler/inlineasm.go index 2afb0a161d..72dd68cf3e 100644 --- a/compiler/inlineasm.go +++ b/compiler/inlineasm.go @@ -5,6 +5,7 @@ package compiler import ( "fmt" "go/constant" + "go/token" "regexp" "strconv" "strings" @@ -55,7 +56,7 @@ func (b *builder) createInlineAsmFull(instr *ssa.CallCommon) (llvm.Value, error) return llvm.Value{}, b.makeError(instr.Pos(), "register value map must be created in the same basic block") } key := constant.StringVal(r.Key.(*ssa.Const).Value) - registers[key] = b.getValue(r.Value.(*ssa.MakeInterface).X) + registers[key] = b.getValue(r.Value.(*ssa.MakeInterface).X, getPos(instr)) case *ssa.Call: if r.Common() == instr { break @@ -140,7 +141,7 @@ func (b *builder) createInlineAsmFull(instr *ssa.CallCommon) (llvm.Value, error) // // The num parameter must be a constant. All other parameters may be any scalar // value supported by LLVM inline assembly. -func (b *builder) emitSVCall(args []ssa.Value) (llvm.Value, error) { +func (b *builder) emitSVCall(args []ssa.Value, pos token.Pos) (llvm.Value, error) { num, _ := constant.Uint64Val(args[0].(*ssa.Const).Value) llvmArgs := []llvm.Value{} argTypes := []llvm.Type{} @@ -153,7 +154,7 @@ func (b *builder) emitSVCall(args []ssa.Value) (llvm.Value, error) { } else { constraints += ",{r" + strconv.Itoa(i) + "}" } - llvmValue := b.getValue(arg) + llvmValue := b.getValue(arg, pos) llvmArgs = append(llvmArgs, llvmValue) argTypes = append(argTypes, llvmValue.Type()) } @@ -178,7 +179,7 @@ func (b *builder) emitSVCall(args []ssa.Value) (llvm.Value, error) { // The num parameter must be a constant. All other parameters may be any scalar // value supported by LLVM inline assembly. // Same as emitSVCall but for AArch64 -func (b *builder) emitSV64Call(args []ssa.Value) (llvm.Value, error) { +func (b *builder) emitSV64Call(args []ssa.Value, pos token.Pos) (llvm.Value, error) { num, _ := constant.Uint64Val(args[0].(*ssa.Const).Value) llvmArgs := []llvm.Value{} argTypes := []llvm.Type{} @@ -191,7 +192,7 @@ func (b *builder) emitSV64Call(args []ssa.Value) (llvm.Value, error) { } else { constraints += ",{x" + strconv.Itoa(i) + "}" } - llvmValue := b.getValue(arg) + llvmValue := b.getValue(arg, pos) llvmArgs = append(llvmArgs, llvmValue) argTypes = append(argTypes, llvmValue.Type()) } @@ -231,19 +232,19 @@ func (b *builder) emitCSROperation(call *ssa.CallCommon) (llvm.Value, error) { fnType := llvm.FunctionType(b.ctx.VoidType(), []llvm.Type{b.uintptrType}, false) asm := fmt.Sprintf("csrw %d, $0", csr) target := llvm.InlineAsm(fnType, asm, "r", true, false, 0, false) - return b.CreateCall(fnType, target, []llvm.Value{b.getValue(call.Args[1])}, ""), nil + return b.CreateCall(fnType, target, []llvm.Value{b.getValue(call.Args[1], getPos(call))}, ""), nil case "SetBits": // Note: it may be possible to optimize this to csrrsi in many cases. fnType := llvm.FunctionType(b.uintptrType, []llvm.Type{b.uintptrType}, false) asm := fmt.Sprintf("csrrs $0, %d, $1", csr) target := llvm.InlineAsm(fnType, asm, "=r,r", true, false, 0, false) - return b.CreateCall(fnType, target, []llvm.Value{b.getValue(call.Args[1])}, ""), nil + return b.CreateCall(fnType, target, []llvm.Value{b.getValue(call.Args[1], getPos(call))}, ""), nil case "ClearBits": // Note: it may be possible to optimize this to csrrci in many cases. fnType := llvm.FunctionType(b.uintptrType, []llvm.Type{b.uintptrType}, false) asm := fmt.Sprintf("csrrc $0, %d, $1", csr) target := llvm.InlineAsm(fnType, asm, "=r,r", true, false, 0, false) - return b.CreateCall(fnType, target, []llvm.Value{b.getValue(call.Args[1])}, ""), nil + return b.CreateCall(fnType, target, []llvm.Value{b.getValue(call.Args[1], getPos(call))}, ""), nil default: return llvm.Value{}, b.makeError(call.Pos(), "unknown CSR operation: "+name) } diff --git a/compiler/interface.go b/compiler/interface.go index 2007b7d7c6..2d6c4a7f7c 100644 --- a/compiler/interface.go +++ b/compiler/interface.go @@ -6,6 +6,7 @@ package compiler // interface-lowering.go for more details. import ( + "fmt" "go/token" "go/types" "strconv" @@ -15,6 +16,50 @@ import ( "tinygo.org/x/go-llvm" ) +// Type kinds for basic types. +// They must match the constants for the Kind type in src/reflect/type.go. +var basicTypes = [...]uint8{ + types.Bool: 1, + types.Int: 2, + types.Int8: 3, + types.Int16: 4, + types.Int32: 5, + types.Int64: 6, + types.Uint: 7, + types.Uint8: 8, + types.Uint16: 9, + types.Uint32: 10, + types.Uint64: 11, + types.Uintptr: 12, + types.Float32: 13, + types.Float64: 14, + types.Complex64: 15, + types.Complex128: 16, + types.String: 17, + types.UnsafePointer: 18, +} + +// These must also match the constants for the Kind type in src/reflect/type.go. +const ( + typeKindChan = 19 + typeKindInterface = 20 + typeKindPointer = 21 + typeKindSlice = 22 + typeKindArray = 23 + typeKindSignature = 24 + typeKindMap = 25 + typeKindStruct = 26 +) + +// Flags stored in the first byte of the struct field byte array. Must be kept +// up to date with src/reflect/type.go. +const ( + structFieldFlagAnonymous = 1 << iota + structFieldFlagHasTag + structFieldFlagIsExported + structFieldFlagIsEmbedded +) + // createMakeInterface emits the LLVM IR for the *ssa.MakeInterface instruction. // It tries to put the type in the interface value, but if that's not possible, // it will do an allocation of the right size and put that in the interface @@ -23,10 +68,9 @@ import ( // An interface value is a {typecode, value} tuple named runtime._interface. func (b *builder) createMakeInterface(val llvm.Value, typ types.Type, pos token.Pos) llvm.Value { itfValue := b.emitPointerPack([]llvm.Value{val}) - itfTypeCodeGlobal := b.getTypeCode(typ) - itfTypeCode := b.CreatePtrToInt(itfTypeCodeGlobal, b.uintptrType, "") + itfType := b.getTypeCode(typ) itf := llvm.Undef(b.getLLVMRuntimeType("_interface")) - itf = b.CreateInsertValue(itf, itfTypeCode, 0, "") + itf = b.CreateInsertValue(itf, itfType, 0, "") itf = b.CreateInsertValue(itf, itfValue, 1, "") return itf } @@ -40,119 +84,310 @@ func (b *builder) extractValueFromInterface(itf llvm.Value, llvmType llvm.Type) return b.emitPointerUnpack(valuePtr, []llvm.Type{llvmType})[0] } +func (c *compilerContext) pkgPathPtr(pkgpath string) llvm.Value { + pkgpathName := "reflect/types.type.pkgpath.empty" + if pkgpath != "" { + pkgpathName = "reflect/types.type.pkgpath:" + pkgpath + } + + pkgpathGlobal := c.mod.NamedGlobal(pkgpathName) + if pkgpathGlobal.IsNil() { + pkgpathInitializer := c.ctx.ConstString(pkgpath+"\x00", false) + pkgpathGlobal = llvm.AddGlobal(c.mod, pkgpathInitializer.Type(), pkgpathName) + pkgpathGlobal.SetInitializer(pkgpathInitializer) + pkgpathGlobal.SetAlignment(1) + pkgpathGlobal.SetUnnamedAddr(true) + pkgpathGlobal.SetLinkage(llvm.LinkOnceODRLinkage) + pkgpathGlobal.SetGlobalConstant(true) + } + pkgPathPtr := llvm.ConstGEP(pkgpathGlobal.GlobalValueType(), pkgpathGlobal, []llvm.Value{ + llvm.ConstInt(c.ctx.Int32Type(), 0, false), + llvm.ConstInt(c.ctx.Int32Type(), 0, false), + }) + + return pkgPathPtr +} + // getTypeCode returns a reference to a type code. -// It returns a pointer to an external global which should be replaced with the -// real type in the interface lowering pass. +// A type code is a pointer to a constant global that describes the type. +// This function returns a pointer to the 'kind' field (which might not be the +// first field in the struct). func (c *compilerContext) getTypeCode(typ types.Type) llvm.Value { + ms := c.program.MethodSets.MethodSet(typ) + hasMethodSet := ms.Len() != 0 + if _, ok := typ.Underlying().(*types.Interface); ok { + hasMethodSet = false + } globalName := "reflect/types.type:" + getTypeCodeName(typ) global := c.mod.NamedGlobal(globalName) if global.IsNil() { - // Create a new typecode global. - global = llvm.AddGlobal(c.mod, c.getLLVMRuntimeType("typecodeID"), globalName) - // Some type classes contain more information for underlying types or - // element types. Store it directly in the typecode global to make - // reflect lowering simpler. - var references llvm.Value - var length int64 - var methodSet llvm.Value - var ptrTo llvm.Value - var typeAssert llvm.Value + var typeFields []llvm.Value + // Define the type fields. These must match the structs in + // src/reflect/type.go (ptrType, arrayType, etc). See the comment at the + // top of src/reflect/type.go for more information on the layout of these structs. + typeFieldTypes := []*types.Var{ + types.NewVar(token.NoPos, nil, "kind", types.Typ[types.Int8]), + } switch typ := typ.(type) { + case *types.Basic: + typeFieldTypes = append(typeFieldTypes, + types.NewVar(token.NoPos, nil, "ptrTo", types.Typ[types.UnsafePointer]), + ) case *types.Named: - references = c.getTypeCode(typ.Underlying()) - case *types.Chan: - references = c.getTypeCode(typ.Elem()) + name := typ.Obj().Name() + var pkgname string + if pkg := typ.Obj().Pkg(); pkg != nil { + pkgname = pkg.Name() + } + typeFieldTypes = append(typeFieldTypes, + types.NewVar(token.NoPos, nil, "numMethods", types.Typ[types.Uint16]), + types.NewVar(token.NoPos, nil, "ptrTo", types.Typ[types.UnsafePointer]), + types.NewVar(token.NoPos, nil, "underlying", types.Typ[types.UnsafePointer]), + types.NewVar(token.NoPos, nil, "pkgpath", types.Typ[types.UnsafePointer]), + types.NewVar(token.NoPos, nil, "name", types.NewArray(types.Typ[types.Int8], int64(len(pkgname)+1+len(name)+1))), + ) + case *types.Chan, *types.Slice: + typeFieldTypes = append(typeFieldTypes, + types.NewVar(token.NoPos, nil, "numMethods", types.Typ[types.Uint16]), + types.NewVar(token.NoPos, nil, "ptrTo", types.Typ[types.UnsafePointer]), + types.NewVar(token.NoPos, nil, "elementType", types.Typ[types.UnsafePointer]), + ) case *types.Pointer: - references = c.getTypeCode(typ.Elem()) - case *types.Slice: - references = c.getTypeCode(typ.Elem()) + typeFieldTypes = append(typeFieldTypes, + types.NewVar(token.NoPos, nil, "numMethods", types.Typ[types.Uint16]), + types.NewVar(token.NoPos, nil, "elementType", types.Typ[types.UnsafePointer]), + ) case *types.Array: - references = c.getTypeCode(typ.Elem()) - length = typ.Len() + typeFieldTypes = append(typeFieldTypes, + types.NewVar(token.NoPos, nil, "numMethods", types.Typ[types.Uint16]), + types.NewVar(token.NoPos, nil, "ptrTo", types.Typ[types.UnsafePointer]), + types.NewVar(token.NoPos, nil, "elementType", types.Typ[types.UnsafePointer]), + types.NewVar(token.NoPos, nil, "length", types.Typ[types.Uintptr]), + ) + case *types.Map: + typeFieldTypes = append(typeFieldTypes, + types.NewVar(token.NoPos, nil, "numMethods", types.Typ[types.Uint16]), + types.NewVar(token.NoPos, nil, "ptrTo", types.Typ[types.UnsafePointer]), + types.NewVar(token.NoPos, nil, "elementType", types.Typ[types.UnsafePointer]), + types.NewVar(token.NoPos, nil, "keyType", types.Typ[types.UnsafePointer]), + ) case *types.Struct: - // Take a pointer to the typecodeID of the first field (if it exists). - structGlobal := c.makeStructTypeFields(typ) - references = llvm.ConstBitCast(structGlobal, global.Type()) + typeFieldTypes = append(typeFieldTypes, + types.NewVar(token.NoPos, nil, "numMethods", types.Typ[types.Uint16]), + types.NewVar(token.NoPos, nil, "ptrTo", types.Typ[types.UnsafePointer]), + types.NewVar(token.NoPos, nil, "pkgpath", types.Typ[types.UnsafePointer]), + types.NewVar(token.NoPos, nil, "numFields", types.Typ[types.Uint16]), + types.NewVar(token.NoPos, nil, "fields", types.NewArray(c.getRuntimeType("structField"), int64(typ.NumFields()))), + ) case *types.Interface: - methodSetGlobal := c.getInterfaceMethodSet(typ) - references = llvm.ConstBitCast(methodSetGlobal, global.Type()) - } - if _, ok := typ.Underlying().(*types.Interface); !ok { - methodSet = c.getTypeMethodSet(typ) - } else { - typeAssert = c.getInterfaceImplementsFunc(typ) - typeAssert = llvm.ConstPtrToInt(typeAssert, c.uintptrType) - } - if _, ok := typ.Underlying().(*types.Pointer); !ok { - ptrTo = c.getTypeCode(types.NewPointer(typ)) + typeFieldTypes = append(typeFieldTypes, + types.NewVar(token.NoPos, nil, "ptrTo", types.Typ[types.UnsafePointer]), + ) + // TODO: methods + case *types.Signature: + typeFieldTypes = append(typeFieldTypes, + types.NewVar(token.NoPos, nil, "ptrTo", types.Typ[types.UnsafePointer]), + ) + // TODO: signature params and return values } - globalValue := llvm.ConstNull(global.GlobalValueType()) - if !references.IsNil() { - globalValue = c.builder.CreateInsertValue(globalValue, references, 0, "") + if hasMethodSet { + // This method set is appended at the start of the struct. It is + // removed in the interface lowering pass. + // TODO: don't remove these and instead do what upstream Go is doing + // instead. See: https://research.swtch.com/interfaces. This can + // likely be optimized in LLVM using + // https://llvm.org/docs/TypeMetadata.html. + typeFieldTypes = append([]*types.Var{ + types.NewVar(token.NoPos, nil, "methodSet", types.Typ[types.UnsafePointer]), + }, typeFieldTypes...) } - if length != 0 { - lengthValue := llvm.ConstInt(c.uintptrType, uint64(length), false) - globalValue = c.builder.CreateInsertValue(globalValue, lengthValue, 1, "") - } - if !methodSet.IsNil() { - globalValue = c.builder.CreateInsertValue(globalValue, methodSet, 2, "") - } - if !ptrTo.IsNil() { - globalValue = c.builder.CreateInsertValue(globalValue, ptrTo, 3, "") + globalType := types.NewStruct(typeFieldTypes, nil) + global = llvm.AddGlobal(c.mod, c.getLLVMType(globalType), globalName) + metabyte := getTypeKind(typ) + switch typ := typ.(type) { + case *types.Basic: + typeFields = []llvm.Value{c.getTypeCode(types.NewPointer(typ))} + case *types.Named: + name := typ.Obj().Name() + var pkgpath string + var pkgname string + if pkg := typ.Obj().Pkg(); pkg != nil { + pkgpath = pkg.Path() + pkgname = pkg.Name() + } + pkgPathPtr := c.pkgPathPtr(pkgpath) + typeFields = []llvm.Value{ + llvm.ConstInt(c.ctx.Int16Type(), uint64(ms.Len()), false), // numMethods + c.getTypeCode(types.NewPointer(typ)), // ptrTo + c.getTypeCode(typ.Underlying()), // underlying + pkgPathPtr, // pkgpath pointer + c.ctx.ConstString(pkgname+"."+name+"\x00", false), // name + } + metabyte |= 1 << 5 // "named" flag + case *types.Chan: + typeFields = []llvm.Value{ + llvm.ConstInt(c.ctx.Int16Type(), 0, false), // numMethods + c.getTypeCode(types.NewPointer(typ)), // ptrTo + c.getTypeCode(typ.Elem()), // elementType + } + case *types.Slice: + typeFields = []llvm.Value{ + llvm.ConstInt(c.ctx.Int16Type(), 0, false), // numMethods + c.getTypeCode(types.NewPointer(typ)), // ptrTo + c.getTypeCode(typ.Elem()), // elementType + } + case *types.Pointer: + typeFields = []llvm.Value{ + llvm.ConstInt(c.ctx.Int16Type(), uint64(ms.Len()), false), // numMethods + c.getTypeCode(typ.Elem()), + } + case *types.Array: + typeFields = []llvm.Value{ + llvm.ConstInt(c.ctx.Int16Type(), 0, false), // numMethods + c.getTypeCode(types.NewPointer(typ)), // ptrTo + c.getTypeCode(typ.Elem()), // elementType + llvm.ConstInt(c.uintptrType, uint64(typ.Len()), false), // length + } + case *types.Map: + typeFields = []llvm.Value{ + llvm.ConstInt(c.ctx.Int16Type(), 0, false), // numMethods + c.getTypeCode(types.NewPointer(typ)), // ptrTo + c.getTypeCode(typ.Elem()), // elem + c.getTypeCode(typ.Key()), // key + } + case *types.Struct: + var pkgpath string + if typ.NumFields() > 0 { + if pkg := typ.Field(0).Pkg(); pkg != nil { + pkgpath = pkg.Path() + } + } + pkgPathPtr := c.pkgPathPtr(pkgpath) + + typeFields = []llvm.Value{ + llvm.ConstInt(c.ctx.Int16Type(), uint64(ms.Len()), false), // numMethods + c.getTypeCode(types.NewPointer(typ)), // ptrTo + pkgPathPtr, + llvm.ConstInt(c.ctx.Int16Type(), uint64(typ.NumFields()), false), // numFields + } + structFieldType := c.getLLVMRuntimeType("structField") + var fields []llvm.Value + for i := 0; i < typ.NumFields(); i++ { + field := typ.Field(i) + var flags uint8 + if field.Anonymous() { + flags |= structFieldFlagAnonymous + } + if typ.Tag(i) != "" { + flags |= structFieldFlagHasTag + } + if token.IsExported(field.Name()) { + flags |= structFieldFlagIsExported + } + if field.Embedded() { + flags |= structFieldFlagIsEmbedded + } + data := string(flags) + field.Name() + "\x00" + if typ.Tag(i) != "" { + if len(typ.Tag(i)) > 0xff { + c.addError(field.Pos(), fmt.Sprintf("struct tag is %d bytes which is too long, max is 255", len(typ.Tag(i)))) + } + data += string([]byte{byte(len(typ.Tag(i)))}) + typ.Tag(i) + } + dataInitializer := c.ctx.ConstString(data, false) + dataGlobal := llvm.AddGlobal(c.mod, dataInitializer.Type(), globalName+"."+field.Name()) + dataGlobal.SetInitializer(dataInitializer) + dataGlobal.SetAlignment(1) + dataGlobal.SetUnnamedAddr(true) + dataGlobal.SetLinkage(llvm.InternalLinkage) + dataGlobal.SetGlobalConstant(true) + fieldType := c.getTypeCode(field.Type()) + fields = append(fields, llvm.ConstNamedStruct(structFieldType, []llvm.Value{ + fieldType, + llvm.ConstGEP(dataGlobal.GlobalValueType(), dataGlobal, []llvm.Value{ + llvm.ConstInt(c.ctx.Int32Type(), 0, false), + llvm.ConstInt(c.ctx.Int32Type(), 0, false), + }), + })) + } + typeFields = append(typeFields, llvm.ConstArray(structFieldType, fields)) + case *types.Interface: + typeFields = []llvm.Value{c.getTypeCode(types.NewPointer(typ))} + // TODO: methods + case *types.Signature: + typeFields = []llvm.Value{c.getTypeCode(types.NewPointer(typ))} + // TODO: params, return values, etc } - if !typeAssert.IsNil() { - globalValue = c.builder.CreateInsertValue(globalValue, typeAssert, 4, "") + // Prepend metadata byte. + typeFields = append([]llvm.Value{ + llvm.ConstInt(c.ctx.Int8Type(), uint64(metabyte), false), + }, typeFields...) + if hasMethodSet { + typeFields = append([]llvm.Value{ + llvm.ConstBitCast(c.getTypeMethodSet(typ), c.i8ptrType), + }, typeFields...) } + alignment := c.targetData.TypeAllocSize(c.i8ptrType) + globalValue := c.ctx.ConstStruct(typeFields, false) global.SetInitializer(globalValue) global.SetLinkage(llvm.LinkOnceODRLinkage) global.SetGlobalConstant(true) + global.SetAlignment(int(alignment)) + if c.Debug { + file := c.getDIFile("") + diglobal := c.dibuilder.CreateGlobalVariableExpression(file, llvm.DIGlobalVariableExpression{ + Name: "type " + typ.String(), + File: file, + Line: 1, + Type: c.getDIType(globalType), + LocalToUnit: false, + Expr: c.dibuilder.CreateExpression(nil), + AlignInBits: uint32(alignment * 8), + }) + global.AddMetadata(0, diglobal) + } } - return global + offset := uint64(0) + if hasMethodSet { + // The pointer to the method set is always the first element of the + // global (if there is a method set). However, the pointer we return + // should point to the 'kind' field not the method set. + offset = 1 + } + return llvm.ConstGEP(global.GlobalValueType(), global, []llvm.Value{ + llvm.ConstInt(llvm.Int32Type(), 0, false), + llvm.ConstInt(llvm.Int32Type(), offset, false), + }) } -// makeStructTypeFields creates a new global that stores all type information -// related to this struct type, and returns the resulting global. This global is -// actually an array of all the fields in the structs. -func (c *compilerContext) makeStructTypeFields(typ *types.Struct) llvm.Value { - // The global is an array of runtime.structField structs. - runtimeStructField := c.getLLVMRuntimeType("structField") - structGlobalType := llvm.ArrayType(runtimeStructField, typ.NumFields()) - structGlobal := llvm.AddGlobal(c.mod, structGlobalType, "reflect/types.structFields") - structGlobalValue := llvm.ConstNull(structGlobalType) - for i := 0; i < typ.NumFields(); i++ { - fieldGlobalValue := llvm.ConstNull(runtimeStructField) - fieldGlobalValue = c.builder.CreateInsertValue(fieldGlobalValue, c.getTypeCode(typ.Field(i).Type()), 0, "") - fieldNameType, fieldName := c.makeGlobalArray([]byte(typ.Field(i).Name()), "reflect/types.structFieldName", c.ctx.Int8Type()) - fieldName.SetLinkage(llvm.PrivateLinkage) - fieldName.SetUnnamedAddr(true) - fieldName = llvm.ConstGEP(fieldNameType, fieldName, []llvm.Value{ - llvm.ConstInt(c.ctx.Int32Type(), 0, false), - llvm.ConstInt(c.ctx.Int32Type(), 0, false), - }) - fieldGlobalValue = c.builder.CreateInsertValue(fieldGlobalValue, fieldName, 1, "") - if typ.Tag(i) != "" { - fieldTagType, fieldTag := c.makeGlobalArray([]byte(typ.Tag(i)), "reflect/types.structFieldTag", c.ctx.Int8Type()) - fieldTag.SetLinkage(llvm.PrivateLinkage) - fieldTag.SetUnnamedAddr(true) - fieldTag = llvm.ConstGEP(fieldTagType, fieldTag, []llvm.Value{ - llvm.ConstInt(c.ctx.Int32Type(), 0, false), - llvm.ConstInt(c.ctx.Int32Type(), 0, false), - }) - fieldGlobalValue = c.builder.CreateInsertValue(fieldGlobalValue, fieldTag, 2, "") - } - if typ.Field(i).Embedded() { - fieldEmbedded := llvm.ConstInt(c.ctx.Int1Type(), 1, false) - fieldGlobalValue = c.builder.CreateInsertValue(fieldGlobalValue, fieldEmbedded, 3, "") - } - structGlobalValue = c.builder.CreateInsertValue(structGlobalValue, fieldGlobalValue, i, "") +// getTypeKind returns the type kind for the given type, as defined by +// reflect.Kind. +func getTypeKind(t types.Type) uint8 { + switch t := t.Underlying().(type) { + case *types.Basic: + return basicTypes[t.Kind()] + case *types.Chan: + return typeKindChan + case *types.Interface: + return typeKindInterface + case *types.Pointer: + return typeKindPointer + case *types.Slice: + return typeKindSlice + case *types.Array: + return typeKindArray + case *types.Signature: + return typeKindSignature + case *types.Map: + return typeKindMap + case *types.Struct: + return typeKindStruct + default: + panic("unknown type") } - structGlobal.SetInitializer(structGlobalValue) - structGlobal.SetUnnamedAddr(true) - structGlobal.SetLinkage(llvm.PrivateLinkage) - return structGlobal } -var basicTypes = [...]string{ +var basicTypeNames = [...]string{ types.Bool: "bool", types.Int: "int", types.Int8: "int8", @@ -183,7 +418,7 @@ func getTypeCodeName(t types.Type) string { case *types.Array: return "array:" + strconv.FormatInt(t.Len(), 10) + ":" + getTypeCodeName(t.Elem()) case *types.Basic: - return "basic:" + basicTypes[t.Kind()] + return "basic:" + basicTypeNames[t.Kind()] case *types.Chan: return "chan:" + getTypeCodeName(t.Elem()) case *types.Interface: @@ -235,75 +470,40 @@ func getTypeCodeName(t types.Type) string { // getTypeMethodSet returns a reference (GEP) to a global method set. This // method set should be unreferenced after the interface lowering pass. func (c *compilerContext) getTypeMethodSet(typ types.Type) llvm.Value { - global := c.mod.NamedGlobal(typ.String() + "$methodset") - zero := llvm.ConstInt(c.ctx.Int32Type(), 0, false) - if !global.IsNil() { - // the method set already exists - return llvm.ConstGEP(global.GlobalValueType(), global, []llvm.Value{zero, zero}) - } - - ms := c.program.MethodSets.MethodSet(typ) - if ms.Len() == 0 { - // no methods, so can leave that one out - return llvm.ConstPointerNull(llvm.PointerType(c.getLLVMRuntimeType("interfaceMethodInfo"), 0)) - } - - methods := make([]llvm.Value, ms.Len()) - interfaceMethodInfoType := c.getLLVMRuntimeType("interfaceMethodInfo") - for i := 0; i < ms.Len(); i++ { - method := ms.At(i) - signatureGlobal := c.getMethodSignature(method.Obj().(*types.Func)) - fn := c.program.MethodValue(method) - llvmFnType, llvmFn := c.getFunction(fn) - if llvmFn.IsNil() { - // compiler error, so panic - panic("cannot find function: " + c.getFunctionInfo(fn).linkName) + globalName := typ.String() + "$methodset" + global := c.mod.NamedGlobal(globalName) + if global.IsNil() { + ms := c.program.MethodSets.MethodSet(typ) + + // Create method set. + var signatures, wrappers []llvm.Value + for i := 0; i < ms.Len(); i++ { + method := ms.At(i) + signatureGlobal := c.getMethodSignature(method.Obj().(*types.Func)) + signatures = append(signatures, signatureGlobal) + fn := c.program.MethodValue(method) + llvmFnType, llvmFn := c.getFunction(fn) + if llvmFn.IsNil() { + // compiler error, so panic + panic("cannot find function: " + c.getFunctionInfo(fn).linkName) + } + wrapper := c.getInterfaceInvokeWrapper(fn, llvmFnType, llvmFn) + wrappers = append(wrappers, wrapper) } - wrapper := c.getInterfaceInvokeWrapper(fn, llvmFnType, llvmFn) - methodInfo := llvm.ConstNamedStruct(interfaceMethodInfoType, []llvm.Value{ - signatureGlobal, - llvm.ConstPtrToInt(wrapper, c.uintptrType), - }) - methods[i] = methodInfo - } - arrayType := llvm.ArrayType(interfaceMethodInfoType, len(methods)) - value := llvm.ConstArray(interfaceMethodInfoType, methods) - global = llvm.AddGlobal(c.mod, arrayType, typ.String()+"$methodset") - global.SetInitializer(value) - global.SetGlobalConstant(true) - global.SetLinkage(llvm.LinkOnceODRLinkage) - return llvm.ConstGEP(arrayType, global, []llvm.Value{zero, zero}) -} - -// getInterfaceMethodSet returns a global variable with the method set of the -// given named interface type. This method set is used by the interface lowering -// pass. -func (c *compilerContext) getInterfaceMethodSet(typ types.Type) llvm.Value { - name := typ.String() - if _, ok := typ.(*types.Named); !ok { - // Anonymous interface. - name = "reflect/types.interface:" + name - } - global := c.mod.NamedGlobal(name + "$interface") - zero := llvm.ConstInt(c.ctx.Int32Type(), 0, false) - if !global.IsNil() { - // method set already exist, return it - return llvm.ConstGEP(global.GlobalValueType(), global, []llvm.Value{zero, zero}) - } - // Every method is a *i8 reference indicating the signature of this method. - methods := make([]llvm.Value, typ.Underlying().(*types.Interface).NumMethods()) - for i := range methods { - method := typ.Underlying().(*types.Interface).Method(i) - methods[i] = c.getMethodSignature(method) + // Construct global value. + globalValue := c.ctx.ConstStruct([]llvm.Value{ + llvm.ConstInt(c.uintptrType, uint64(ms.Len()), false), + llvm.ConstArray(c.i8ptrType, signatures), + c.ctx.ConstStruct(wrappers, false), + }, false) + global = llvm.AddGlobal(c.mod, globalValue.Type(), globalName) + global.SetInitializer(globalValue) + global.SetGlobalConstant(true) + global.SetUnnamedAddr(true) + global.SetLinkage(llvm.LinkOnceODRLinkage) } - - value := llvm.ConstArray(c.i8ptrType, methods) - global = llvm.AddGlobal(c.mod, value.Type(), name+"$interface") - global.SetInitializer(value) - global.SetGlobalConstant(true) - global.SetLinkage(llvm.LinkOnceODRLinkage) - return llvm.ConstGEP(value.Type(), global, []llvm.Value{zero, zero}) + return global } // getMethodSignatureName returns a unique name (that can be used as the name of @@ -345,7 +545,7 @@ func (c *compilerContext) getMethodSignature(method *types.Func) llvm.Value { // Type asserts on concrete types are trivial: just compare type numbers. Type // asserts on interfaces are more difficult, see the comments in the function. func (b *builder) createTypeAssert(expr *ssa.TypeAssert) llvm.Value { - itf := b.getValue(expr.X) + itf := b.getValue(expr.X, getPos(expr)) assertedType := b.getLLVMType(expr.AssertedType) actualTypeNum := b.CreateExtractValue(itf, 0, "interface.type") @@ -443,7 +643,7 @@ func (c *compilerContext) getInterfaceImplementsFunc(assertedType types.Type) ll fnName := getTypeCodeName(assertedType.Underlying()) + ".$typeassert" llvmFn := c.mod.NamedFunction(fnName) if llvmFn.IsNil() { - llvmFnType := llvm.FunctionType(c.ctx.Int1Type(), []llvm.Type{c.uintptrType}, false) + llvmFnType := llvm.FunctionType(c.ctx.Int1Type(), []llvm.Type{c.i8ptrType}, false) llvmFn = llvm.AddFunction(c.mod, fnName, llvmFnType) c.addStandardDeclaredAttributes(llvmFn) methods := c.getMethodsString(assertedType.Underlying().(*types.Interface)) @@ -464,7 +664,7 @@ func (c *compilerContext) getInvokeFunction(instr *ssa.CallCommon) llvm.Value { for i := 0; i < sig.Params().Len(); i++ { paramTuple = append(paramTuple, sig.Params().At(i)) } - paramTuple = append(paramTuple, types.NewVar(token.NoPos, nil, "$typecode", types.Typ[types.Uintptr])) + paramTuple = append(paramTuple, types.NewVar(token.NoPos, nil, "$typecode", types.Typ[types.UnsafePointer])) llvmFnType := c.getRawFuncType(types.NewSignature(sig.Recv(), types.NewTuple(paramTuple...), sig.Results(), false)) llvmFn = llvm.AddFunction(c.mod, fnName, llvmFnType) c.addStandardDeclaredAttributes(llvmFn) @@ -601,7 +801,7 @@ func typestring(t types.Type) string { case *types.Array: return "[" + strconv.FormatInt(t.Len(), 10) + "]" + typestring(t.Elem()) case *types.Basic: - return basicTypes[t.Kind()] + return basicTypeNames[t.Kind()] case *types.Chan: switch t.Dir() { case types.SendRecv: diff --git a/compiler/interrupt.go b/compiler/interrupt.go index c1f7d69f27..1fb4c22b4c 100644 --- a/compiler/interrupt.go +++ b/compiler/interrupt.go @@ -24,7 +24,7 @@ func (b *builder) createInterruptGlobal(instr *ssa.CallCommon) (llvm.Value, erro // Note that bound functions are allowed if the function has a pointer // receiver and is a global. This is rather strict but still allows for // idiomatic Go code. - funcValue := b.getValue(instr.Args[1]) + funcValue := b.getValue(instr.Args[1], getPos(instr)) if funcValue.IsAConstant().IsNil() { // Try to determine the cause of the non-constantness for a nice error // message. diff --git a/compiler/intrinsics.go b/compiler/intrinsics.go index a511e518b7..c196b60d8d 100644 --- a/compiler/intrinsics.go +++ b/compiler/intrinsics.go @@ -24,6 +24,8 @@ func (b *builder) defineIntrinsicFunction() { b.createMemoryCopyImpl() case name == "runtime.memzero": b.createMemoryZeroImpl() + case name == "runtime.KeepAlive": + b.createKeepAliveImpl() case strings.HasPrefix(name, "runtime/volatile.Load"): b.createVolatileLoad() case strings.HasPrefix(name, "runtime/volatile.Store"): @@ -56,7 +58,7 @@ func (b *builder) createMemoryCopyImpl() { } var params []llvm.Value for _, param := range b.fn.Params { - params = append(params, b.getValue(param)) + params = append(params, b.getValue(param, getPos(b.fn))) } params = append(params, llvm.ConstInt(b.ctx.Int1Type(), 0, false)) b.CreateCall(llvmFn.GlobalValueType(), llvmFn, params, "") @@ -78,15 +80,38 @@ func (b *builder) createMemoryZeroImpl() { llvmFn = llvm.AddFunction(b.mod, fnName, fnType) } params := []llvm.Value{ - b.getValue(b.fn.Params[0]), + b.getValue(b.fn.Params[0], getPos(b.fn)), llvm.ConstInt(b.ctx.Int8Type(), 0, false), - b.getValue(b.fn.Params[1]), + b.getValue(b.fn.Params[1], getPos(b.fn)), llvm.ConstInt(b.ctx.Int1Type(), 0, false), } b.CreateCall(llvmFn.GlobalValueType(), llvmFn, params, "") b.CreateRetVoid() } +// createKeepAlive creates the runtime.KeepAlive function. It is implemented +// using inline assembly. +func (b *builder) createKeepAliveImpl() { + b.createFunctionStart(true) + + // Get the underlying value of the interface value. + interfaceValue := b.getValue(b.fn.Params[0], getPos(b.fn)) + pointerValue := b.CreateExtractValue(interfaceValue, 1, "") + + // Create an equivalent of the following C code, which is basically just a + // nop but ensures the pointerValue is kept alive: + // + // __asm__ __volatile__("" : : "r"(pointerValue)) + // + // It should be portable to basically everything as the "r" register type + // exists basically everywhere. + asmType := llvm.FunctionType(b.ctx.VoidType(), []llvm.Type{b.i8ptrType}, false) + asmFn := llvm.InlineAsm(asmType, "", "r", true, false, 0, false) + b.createCall(asmType, asmFn, []llvm.Value{pointerValue}, "") + + b.CreateRetVoid() +} + var mathToLLVMMapping = map[string]string{ "math.Ceil": "llvm.ceil.f64", "math.Exp": "llvm.exp.f64", @@ -124,7 +149,7 @@ func (b *builder) defineMathOp() { // Create a call to the intrinsic. args := make([]llvm.Value, len(b.fn.Params)) for i, param := range b.fn.Params { - args[i] = b.getValue(param) + args[i] = b.getValue(param, getPos(b.fn)) } result := b.CreateCall(llvmFn.GlobalValueType(), llvmFn, args, "") b.CreateRet(result) diff --git a/compiler/map.go b/compiler/map.go index 9d162bfc00..9c9a3b5c66 100644 --- a/compiler/map.go +++ b/compiler/map.go @@ -46,7 +46,7 @@ func (b *builder) createMakeMap(expr *ssa.MakeMap) (llvm.Value, error) { sizeHint := llvm.ConstInt(b.uintptrType, 8, false) algEnum := llvm.ConstInt(b.ctx.Int8Type(), alg, false) if expr.Reserve != nil { - sizeHint = b.getValue(expr.Reserve) + sizeHint = b.getValue(expr.Reserve, getPos(expr)) var err error sizeHint, err = b.createConvert(expr.Reserve.Type(), types.Typ[types.Uintptr], sizeHint, expr.Pos()) if err != nil { @@ -89,6 +89,7 @@ func (b *builder) createMapLookup(keyType, valueType types.Type, m, key llvm.Val // growth. mapKeyAlloca, mapKeyPtr, mapKeySize := b.createTemporaryAlloca(key.Type(), "hashmap.key") b.CreateStore(key, mapKeyAlloca) + b.zeroUndefBytes(b.getLLVMType(keyType), mapKeyAlloca) // Fetch the value from the hashmap. params := []llvm.Value{m, mapKeyPtr, mapValuePtr, mapValueSize} commaOkValue = b.createRuntimeCall("hashmapBinaryGet", params, "") @@ -133,6 +134,7 @@ func (b *builder) createMapUpdate(keyType types.Type, m, key, value llvm.Value, // key can be compared with runtime.memequal keyAlloca, keyPtr, keySize := b.createTemporaryAlloca(key.Type(), "hashmap.key") b.CreateStore(key, keyAlloca) + b.zeroUndefBytes(b.getLLVMType(keyType), keyAlloca) params := []llvm.Value{m, keyPtr, valuePtr} b.createRuntimeCall("hashmapBinarySet", params, "") b.emitLifetimeEnd(keyPtr, keySize) @@ -161,6 +163,7 @@ func (b *builder) createMapDelete(keyType types.Type, m, key llvm.Value, pos tok } else if hashmapIsBinaryKey(keyType) { keyAlloca, keyPtr, keySize := b.createTemporaryAlloca(key.Type(), "hashmap.key") b.CreateStore(key, keyAlloca) + b.zeroUndefBytes(b.getLLVMType(keyType), keyAlloca) params := []llvm.Value{m, keyPtr} b.createRuntimeCall("hashmapBinaryDelete", params, "") b.emitLifetimeEnd(keyPtr, keySize) @@ -240,7 +243,8 @@ func (b *builder) createMapIteratorNext(rangeVal ssa.Value, llvmRangeVal, it llv } // Returns true if this key type does not contain strings, interfaces etc., so -// can be compared with runtime.memequal. +// can be compared with runtime.memequal. Note that padding bytes are undef +// and can alter two "equal" structs being equal when compared with memequal. func hashmapIsBinaryKey(keyType types.Type) bool { switch keyType := keyType.(type) { case *types.Basic: @@ -263,3 +267,76 @@ func hashmapIsBinaryKey(keyType types.Type) bool { return false } } + +func (b *builder) zeroUndefBytes(llvmType llvm.Type, ptr llvm.Value) error { + // We know that hashmapIsBinaryKey is true, so we only have to handle those types that can show up there. + // To zero all undefined bytes, we iterate over all the fields in the type. For each element, compute the + // offset of that element. If it's Basic type, there are no internal padding bytes. For compound types, we recurse to ensure + // we handle nested types. Next, we determine if there are any padding bytes before the next + // element and zero those as well. + + zero := llvm.ConstInt(b.ctx.Int32Type(), 0, false) + + switch llvmType.TypeKind() { + case llvm.IntegerTypeKind: + // no padding bytes + return nil + case llvm.PointerTypeKind: + // mo padding bytes + return nil + case llvm.ArrayTypeKind: + llvmArrayType := llvmType + llvmElemType := llvmType.ElementType() + + for i := 0; i < llvmArrayType.ArrayLength(); i++ { + idx := llvm.ConstInt(b.uintptrType, uint64(i), false) + elemPtr := b.CreateInBoundsGEP(llvmArrayType, ptr, []llvm.Value{zero, idx}, "") + + // zero any padding bytes in this element + b.zeroUndefBytes(llvmElemType, elemPtr) + } + + case llvm.StructTypeKind: + llvmStructType := llvmType + numFields := llvmStructType.StructElementTypesCount() + llvmElementTypes := llvmStructType.StructElementTypes() + + for i := 0; i < numFields; i++ { + idx := llvm.ConstInt(b.ctx.Int32Type(), uint64(i), false) + elemPtr := b.CreateInBoundsGEP(llvmStructType, ptr, []llvm.Value{zero, idx}, "") + + // zero any padding bytes in this field + llvmElemType := llvmElementTypes[i] + b.zeroUndefBytes(llvmElemType, elemPtr) + + // zero any padding bytes before the next field, if any + offset := b.targetData.ElementOffset(llvmStructType, i) + storeSize := b.targetData.TypeStoreSize(llvmElemType) + fieldEndOffset := offset + storeSize + + var nextOffset uint64 + if i < numFields-1 { + nextOffset = b.targetData.ElementOffset(llvmStructType, i+1) + } else { + // Last field? Next offset is the total size of the allcoate struct. + nextOffset = b.targetData.TypeAllocSize(llvmStructType) + } + + if fieldEndOffset != nextOffset { + n := llvm.ConstInt(b.uintptrType, nextOffset-fieldEndOffset, false) + llvmStoreSize := llvm.ConstInt(b.uintptrType, storeSize, false) + gepPtr := elemPtr + if gepPtr.Type() != b.i8ptrType { + gepPtr = b.CreateBitCast(gepPtr, b.i8ptrType, "") // LLVM 14 + } + paddingStart := b.CreateInBoundsGEP(b.ctx.Int8Type(), gepPtr, []llvm.Value{llvmStoreSize}, "") + if paddingStart.Type() != b.i8ptrType { + paddingStart = b.CreateBitCast(paddingStart, b.i8ptrType, "") // LLVM 14 + } + b.createRuntimeCall("memzero", []llvm.Value{paddingStart, n}, "") + } + } + } + + return nil +} diff --git a/compiler/symbol.go b/compiler/symbol.go index 87e7b1ed9f..c431e76f39 100644 --- a/compiler/symbol.go +++ b/compiler/symbol.go @@ -279,8 +279,12 @@ func (info *functionInfo) parsePragmas(f *ssa.Function) { info.linkName = parts[2] } case "//go:section": + // Only enable go:section when the package imports "unsafe". + // go:section also implies go:noinline since inlining could + // move the code to a different section than that requested. if len(parts) == 2 && hasUnsafeImport(f.Pkg.Pkg) { info.section = parts[1] + info.inline = inlineNone } case "//go:nobounds": // Skip bounds checking in this function. Useful for some @@ -433,7 +437,6 @@ func (c *compilerContext) getGlobal(g *ssa.Global) llvm.Value { llvmGlobal = llvm.AddGlobal(c.mod, llvmType, info.linkName) // Set alignment from the //go:align comment. - var alignInBits uint32 alignment := c.targetData.ABITypeAlignment(llvmType) if info.align > alignment { alignment = info.align @@ -444,7 +447,6 @@ func (c *compilerContext) getGlobal(g *ssa.Global) llvm.Value { c.addError(g.Pos(), "global variable alignment must be a positive power of two") } else { // Set the alignment only when it is a power of two. - alignInBits = uint32(alignment) ^ uint32(alignment-1) llvmGlobal.SetAlignment(alignment) } @@ -459,7 +461,7 @@ func (c *compilerContext) getGlobal(g *ssa.Global) llvm.Value { Type: c.getDIType(typ), LocalToUnit: false, Expr: c.dibuilder.CreateExpression(nil), - AlignInBits: alignInBits, + AlignInBits: uint32(alignment) * 8, }) llvmGlobal.AddMetadata(0, diglobal) } diff --git a/compiler/syscall.go b/compiler/syscall.go index 66baf82609..db1ffd7007 100644 --- a/compiler/syscall.go +++ b/compiler/syscall.go @@ -14,7 +14,7 @@ import ( // and returns the result as a single integer (the system call result). The // result is not further interpreted. func (b *builder) createRawSyscall(call *ssa.CallCommon) (llvm.Value, error) { - num := b.getValue(call.Args[0]) + num := b.getValue(call.Args[0], getPos(call)) switch { case b.GOARCH == "amd64" && b.GOOS == "linux": // Sources: @@ -37,7 +37,7 @@ func (b *builder) createRawSyscall(call *ssa.CallCommon) (llvm.Value, error) { "{r12}", "{r13}", }[i] - llvmValue := b.getValue(arg) + llvmValue := b.getValue(arg, getPos(call)) args = append(args, llvmValue) argTypes = append(argTypes, llvmValue.Type()) } @@ -64,7 +64,7 @@ func (b *builder) createRawSyscall(call *ssa.CallCommon) (llvm.Value, error) { "{edi}", "{ebp}", }[i] - llvmValue := b.getValue(arg) + llvmValue := b.getValue(arg, getPos(call)) args = append(args, llvmValue) argTypes = append(argTypes, llvmValue.Type()) } @@ -89,7 +89,7 @@ func (b *builder) createRawSyscall(call *ssa.CallCommon) (llvm.Value, error) { "{r5}", "{r6}", }[i] - llvmValue := b.getValue(arg) + llvmValue := b.getValue(arg, getPos(call)) args = append(args, llvmValue) argTypes = append(argTypes, llvmValue.Type()) } @@ -119,7 +119,7 @@ func (b *builder) createRawSyscall(call *ssa.CallCommon) (llvm.Value, error) { "{x4}", "{x5}", }[i] - llvmValue := b.getValue(arg) + llvmValue := b.getValue(arg, getPos(call)) args = append(args, llvmValue) argTypes = append(argTypes, llvmValue.Type()) } @@ -177,12 +177,12 @@ func (b *builder) createSyscall(call *ssa.CallCommon) (llvm.Value, error) { var paramTypes []llvm.Type var params []llvm.Value for _, val := range call.Args[2:] { - param := b.getValue(val) + param := b.getValue(val, getPos(call)) params = append(params, param) paramTypes = append(paramTypes, param.Type()) } llvmType := llvm.FunctionType(b.uintptrType, paramTypes, false) - fn := b.getValue(call.Args[0]) + fn := b.getValue(call.Args[0], getPos(call)) fnPtr := b.CreateIntToPtr(fn, llvm.PointerType(llvmType, 0), "") // Prepare some functions that will be called later. diff --git a/compiler/testdata/defer-cortex-m-qemu.ll b/compiler/testdata/defer-cortex-m-qemu.ll index 0841e25509..a99a64edcd 100644 --- a/compiler/testdata/defer-cortex-m-qemu.ll +++ b/compiler/testdata/defer-cortex-m-qemu.ll @@ -4,7 +4,7 @@ target datalayout = "e-m:e-p:32:32-Fi8-i64:64-v128:64:128-a:0:32-n32-S64" target triple = "thumbv7m-unknown-unknown-eabi" %runtime.deferFrame = type { ptr, ptr, [0 x ptr], ptr, i1, %runtime._interface } -%runtime._interface = type { i32, ptr } +%runtime._interface = type { ptr, ptr } %runtime._defer = type { i32, ptr } declare noalias nonnull ptr @runtime.alloc(i32, ptr, ptr) #0 diff --git a/compiler/testdata/gc.ll b/compiler/testdata/gc.ll index a59a546fb1..160f869072 100644 --- a/compiler/testdata/gc.ll +++ b/compiler/testdata/gc.ll @@ -3,8 +3,7 @@ source_filename = "gc.go" target datalayout = "e-m:e-p:32:32-p10:8:8-p20:8:8-i64:64-n32:64-S128-ni:1:10:20" target triple = "wasm32-unknown-wasi" -%runtime.typecodeID = type { ptr, i32, ptr, ptr, i32 } -%runtime._interface = type { i32, ptr } +%runtime._interface = type { ptr, ptr } @main.scalar1 = hidden global ptr null, align 4 @main.scalar2 = hidden global ptr null, align 4 @@ -22,8 +21,8 @@ target triple = "wasm32-unknown-wasi" @main.slice3 = hidden global { ptr, i32, i32 } zeroinitializer, align 8 @"runtime/gc.layout:62-2000000000000001" = linkonce_odr unnamed_addr constant { i32, [8 x i8] } { i32 62, [8 x i8] c"\01\00\00\00\00\00\00 " } @"runtime/gc.layout:62-0001" = linkonce_odr unnamed_addr constant { i32, [8 x i8] } { i32 62, [8 x i8] c"\01\00\00\00\00\00\00\00" } -@"reflect/types.type:basic:complex128" = linkonce_odr constant %runtime.typecodeID { ptr null, i32 0, ptr null, ptr @"reflect/types.type:pointer:basic:complex128", i32 0 } -@"reflect/types.type:pointer:basic:complex128" = linkonce_odr constant %runtime.typecodeID { ptr @"reflect/types.type:basic:complex128", i32 0, ptr null, ptr null, i32 0 } +@"reflect/types.type:basic:complex128" = linkonce_odr constant { i8, ptr } { i8 16, ptr @"reflect/types.type:pointer:basic:complex128" }, align 4 +@"reflect/types.type:pointer:basic:complex128" = linkonce_odr constant { i8, i16, ptr } { i8 21, i16 0, ptr @"reflect/types.type:basic:complex128" }, align 4 declare noalias nonnull ptr @runtime.alloc(i32, ptr, ptr) #0 @@ -129,7 +128,8 @@ entry: store double %v.r, ptr %0, align 8 %.repack1 = getelementptr inbounds { double, double }, ptr %0, i32 0, i32 1 store double %v.i, ptr %.repack1, align 8 - %1 = insertvalue %runtime._interface { i32 ptrtoint (ptr @"reflect/types.type:basic:complex128" to i32), ptr undef }, ptr %0, 1 + %1 = insertvalue %runtime._interface { ptr @"reflect/types.type:basic:complex128", ptr undef }, ptr %0, 1 + call void @runtime.trackPointer(ptr nonnull @"reflect/types.type:basic:complex128", ptr nonnull %stackalloc, ptr undef) #2 call void @runtime.trackPointer(ptr nonnull %0, ptr nonnull %stackalloc, ptr undef) #2 ret %runtime._interface %1 } diff --git a/compiler/testdata/goroutine-cortex-m-qemu-tasks.ll b/compiler/testdata/goroutine-cortex-m-qemu-tasks.ll index ac1adaffe7..2fe1b06afb 100644 --- a/compiler/testdata/goroutine-cortex-m-qemu-tasks.ll +++ b/compiler/testdata/goroutine-cortex-m-qemu-tasks.ll @@ -145,34 +145,34 @@ entry: declare void @runtime.chanClose(ptr dereferenceable_or_null(32), ptr) #0 ; Function Attrs: nounwind -define hidden void @main.startInterfaceMethod(i32 %itf.typecode, ptr %itf.value, ptr %context) unnamed_addr #1 { +define hidden void @main.startInterfaceMethod(ptr %itf.typecode, ptr %itf.value, ptr %context) unnamed_addr #1 { entry: %0 = call ptr @runtime.alloc(i32 16, ptr null, ptr undef) #8 store ptr %itf.value, ptr %0, align 4 - %1 = getelementptr inbounds { ptr, %runtime._string, i32 }, ptr %0, i32 0, i32 1 + %1 = getelementptr inbounds { ptr, %runtime._string, ptr }, ptr %0, i32 0, i32 1 store ptr @"main$string", ptr %1, align 4 - %.repack1 = getelementptr inbounds { ptr, %runtime._string, i32 }, ptr %0, i32 0, i32 1, i32 1 + %.repack1 = getelementptr inbounds { ptr, %runtime._string, ptr }, ptr %0, i32 0, i32 1, i32 1 store i32 4, ptr %.repack1, align 4 - %2 = getelementptr inbounds { ptr, %runtime._string, i32 }, ptr %0, i32 0, i32 2 - store i32 %itf.typecode, ptr %2, align 4 + %2 = getelementptr inbounds { ptr, %runtime._string, ptr }, ptr %0, i32 0, i32 2 + store ptr %itf.typecode, ptr %2, align 4 %stacksize = call i32 @"internal/task.getGoroutineStackSize"(i32 ptrtoint (ptr @"interface:{Print:func:{basic:string}{}}.Print$invoke$gowrapper" to i32), ptr undef) #8 call void @"internal/task.start"(i32 ptrtoint (ptr @"interface:{Print:func:{basic:string}{}}.Print$invoke$gowrapper" to i32), ptr nonnull %0, i32 %stacksize, ptr undef) #8 ret void } -declare void @"interface:{Print:func:{basic:string}{}}.Print$invoke"(ptr, ptr, i32, i32, ptr) #6 +declare void @"interface:{Print:func:{basic:string}{}}.Print$invoke"(ptr, ptr, i32, ptr, ptr) #6 ; Function Attrs: nounwind define linkonce_odr void @"interface:{Print:func:{basic:string}{}}.Print$invoke$gowrapper"(ptr %0) unnamed_addr #7 { entry: %1 = load ptr, ptr %0, align 4 - %2 = getelementptr inbounds { ptr, ptr, i32, i32 }, ptr %0, i32 0, i32 1 + %2 = getelementptr inbounds { ptr, ptr, i32, ptr }, ptr %0, i32 0, i32 1 %3 = load ptr, ptr %2, align 4 - %4 = getelementptr inbounds { ptr, ptr, i32, i32 }, ptr %0, i32 0, i32 2 + %4 = getelementptr inbounds { ptr, ptr, i32, ptr }, ptr %0, i32 0, i32 2 %5 = load i32, ptr %4, align 4 - %6 = getelementptr inbounds { ptr, ptr, i32, i32 }, ptr %0, i32 0, i32 3 - %7 = load i32, ptr %6, align 4 - call void @"interface:{Print:func:{basic:string}{}}.Print$invoke"(ptr %1, ptr %3, i32 %5, i32 %7, ptr undef) #8 + %6 = getelementptr inbounds { ptr, ptr, i32, ptr }, ptr %0, i32 0, i32 3 + %7 = load ptr, ptr %6, align 4 + call void @"interface:{Print:func:{basic:string}{}}.Print$invoke"(ptr %1, ptr %3, i32 %5, ptr %7, ptr undef) #8 ret void } diff --git a/compiler/testdata/goroutine-wasm-asyncify.ll b/compiler/testdata/goroutine-wasm-asyncify.ll index 0f38e181ee..d3ec398a42 100644 --- a/compiler/testdata/goroutine-wasm-asyncify.ll +++ b/compiler/testdata/goroutine-wasm-asyncify.ll @@ -154,35 +154,35 @@ entry: declare void @runtime.chanClose(ptr dereferenceable_or_null(32), ptr) #0 ; Function Attrs: nounwind -define hidden void @main.startInterfaceMethod(i32 %itf.typecode, ptr %itf.value, ptr %context) unnamed_addr #1 { +define hidden void @main.startInterfaceMethod(ptr %itf.typecode, ptr %itf.value, ptr %context) unnamed_addr #1 { entry: %stackalloc = alloca i8, align 1 %0 = call ptr @runtime.alloc(i32 16, ptr null, ptr undef) #8 call void @runtime.trackPointer(ptr nonnull %0, ptr nonnull %stackalloc, ptr undef) #8 store ptr %itf.value, ptr %0, align 4 - %1 = getelementptr inbounds { ptr, %runtime._string, i32 }, ptr %0, i32 0, i32 1 + %1 = getelementptr inbounds { ptr, %runtime._string, ptr }, ptr %0, i32 0, i32 1 store ptr @"main$string", ptr %1, align 4 - %.repack1 = getelementptr inbounds { ptr, %runtime._string, i32 }, ptr %0, i32 0, i32 1, i32 1 + %.repack1 = getelementptr inbounds { ptr, %runtime._string, ptr }, ptr %0, i32 0, i32 1, i32 1 store i32 4, ptr %.repack1, align 4 - %2 = getelementptr inbounds { ptr, %runtime._string, i32 }, ptr %0, i32 0, i32 2 - store i32 %itf.typecode, ptr %2, align 4 + %2 = getelementptr inbounds { ptr, %runtime._string, ptr }, ptr %0, i32 0, i32 2 + store ptr %itf.typecode, ptr %2, align 4 call void @"internal/task.start"(i32 ptrtoint (ptr @"interface:{Print:func:{basic:string}{}}.Print$invoke$gowrapper" to i32), ptr nonnull %0, i32 16384, ptr undef) #8 ret void } -declare void @"interface:{Print:func:{basic:string}{}}.Print$invoke"(ptr, ptr, i32, i32, ptr) #6 +declare void @"interface:{Print:func:{basic:string}{}}.Print$invoke"(ptr, ptr, i32, ptr, ptr) #6 ; Function Attrs: nounwind define linkonce_odr void @"interface:{Print:func:{basic:string}{}}.Print$invoke$gowrapper"(ptr %0) unnamed_addr #7 { entry: %1 = load ptr, ptr %0, align 4 - %2 = getelementptr inbounds { ptr, ptr, i32, i32 }, ptr %0, i32 0, i32 1 + %2 = getelementptr inbounds { ptr, ptr, i32, ptr }, ptr %0, i32 0, i32 1 %3 = load ptr, ptr %2, align 4 - %4 = getelementptr inbounds { ptr, ptr, i32, i32 }, ptr %0, i32 0, i32 2 + %4 = getelementptr inbounds { ptr, ptr, i32, ptr }, ptr %0, i32 0, i32 2 %5 = load i32, ptr %4, align 4 - %6 = getelementptr inbounds { ptr, ptr, i32, i32 }, ptr %0, i32 0, i32 3 - %7 = load i32, ptr %6, align 4 - call void @"interface:{Print:func:{basic:string}{}}.Print$invoke"(ptr %1, ptr %3, i32 %5, i32 %7, ptr undef) #8 + %6 = getelementptr inbounds { ptr, ptr, i32, ptr }, ptr %0, i32 0, i32 3 + %7 = load ptr, ptr %6, align 4 + call void @"interface:{Print:func:{basic:string}{}}.Print$invoke"(ptr %1, ptr %3, i32 %5, ptr %7, ptr undef) #8 call void @runtime.deadlock(ptr undef) #8 unreachable } diff --git a/compiler/testdata/interface.ll b/compiler/testdata/interface.ll index 2ddaf4ec37..de37341a37 100644 --- a/compiler/testdata/interface.ll +++ b/compiler/testdata/interface.ll @@ -3,22 +3,18 @@ source_filename = "interface.go" target datalayout = "e-m:e-p:32:32-p10:8:8-p20:8:8-i64:64-n32:64-S128-ni:1:10:20" target triple = "wasm32-unknown-wasi" -%runtime.typecodeID = type { ptr, i32, ptr, ptr, i32 } -%runtime._interface = type { i32, ptr } +%runtime._interface = type { ptr, ptr } %runtime._string = type { ptr, i32 } -@"reflect/types.type:basic:int" = linkonce_odr constant %runtime.typecodeID { ptr null, i32 0, ptr null, ptr @"reflect/types.type:pointer:basic:int", i32 0 } -@"reflect/types.type:pointer:basic:int" = linkonce_odr constant %runtime.typecodeID { ptr @"reflect/types.type:basic:int", i32 0, ptr null, ptr null, i32 0 } -@"reflect/types.type:pointer:named:error" = linkonce_odr constant %runtime.typecodeID { ptr @"reflect/types.type:named:error", i32 0, ptr null, ptr null, i32 0 } -@"reflect/types.type:named:error" = linkonce_odr constant %runtime.typecodeID { ptr @"reflect/types.type:interface:{Error:func:{}{basic:string}}", i32 0, ptr null, ptr @"reflect/types.type:pointer:named:error", i32 ptrtoint (ptr @"interface:{Error:func:{}{basic:string}}.$typeassert" to i32) } -@"reflect/types.type:interface:{Error:func:{}{basic:string}}" = linkonce_odr constant %runtime.typecodeID { ptr @"reflect/types.interface:interface{Error() string}$interface", i32 0, ptr null, ptr @"reflect/types.type:pointer:interface:{Error:func:{}{basic:string}}", i32 ptrtoint (ptr @"interface:{Error:func:{}{basic:string}}.$typeassert" to i32) } -@"reflect/methods.Error() string" = linkonce_odr constant i8 0, align 1 -@"reflect/types.interface:interface{Error() string}$interface" = linkonce_odr constant [1 x ptr] [ptr @"reflect/methods.Error() string"] -@"reflect/types.type:pointer:interface:{Error:func:{}{basic:string}}" = linkonce_odr constant %runtime.typecodeID { ptr @"reflect/types.type:interface:{Error:func:{}{basic:string}}", i32 0, ptr null, ptr null, i32 0 } -@"reflect/types.type:pointer:interface:{String:func:{}{basic:string}}" = linkonce_odr constant %runtime.typecodeID { ptr @"reflect/types.type:interface:{String:func:{}{basic:string}}", i32 0, ptr null, ptr null, i32 0 } -@"reflect/types.type:interface:{String:func:{}{basic:string}}" = linkonce_odr constant %runtime.typecodeID { ptr @"reflect/types.interface:interface{String() string}$interface", i32 0, ptr null, ptr @"reflect/types.type:pointer:interface:{String:func:{}{basic:string}}", i32 ptrtoint (ptr @"interface:{String:func:{}{basic:string}}.$typeassert" to i32) } -@"reflect/methods.String() string" = linkonce_odr constant i8 0, align 1 -@"reflect/types.interface:interface{String() string}$interface" = linkonce_odr constant [1 x ptr] [ptr @"reflect/methods.String() string"] +@"reflect/types.type:basic:int" = linkonce_odr constant { i8, ptr } { i8 2, ptr @"reflect/types.type:pointer:basic:int" }, align 4 +@"reflect/types.type:pointer:basic:int" = linkonce_odr constant { i8, i16, ptr } { i8 21, i16 0, ptr @"reflect/types.type:basic:int" }, align 4 +@"reflect/types.type:pointer:named:error" = linkonce_odr constant { i8, i16, ptr } { i8 21, i16 0, ptr @"reflect/types.type:named:error" }, align 4 +@"reflect/types.type:named:error" = linkonce_odr constant { i8, i16, ptr, ptr, ptr, [7 x i8] } { i8 52, i16 1, ptr @"reflect/types.type:pointer:named:error", ptr @"reflect/types.type:interface:{Error:func:{}{basic:string}}", ptr @"reflect/types.type.pkgpath.empty", [7 x i8] c".error\00" }, align 4 +@"reflect/types.type.pkgpath.empty" = linkonce_odr unnamed_addr constant [1 x i8] zeroinitializer, align 1 +@"reflect/types.type:interface:{Error:func:{}{basic:string}}" = linkonce_odr constant { i8, ptr } { i8 20, ptr @"reflect/types.type:pointer:interface:{Error:func:{}{basic:string}}" }, align 4 +@"reflect/types.type:pointer:interface:{Error:func:{}{basic:string}}" = linkonce_odr constant { i8, i16, ptr } { i8 21, i16 0, ptr @"reflect/types.type:interface:{Error:func:{}{basic:string}}" }, align 4 +@"reflect/types.type:pointer:interface:{String:func:{}{basic:string}}" = linkonce_odr constant { i8, i16, ptr } { i8 21, i16 0, ptr @"reflect/types.type:interface:{String:func:{}{basic:string}}" }, align 4 +@"reflect/types.type:interface:{String:func:{}{basic:string}}" = linkonce_odr constant { i8, ptr } { i8 20, ptr @"reflect/types.type:pointer:interface:{String:func:{}{basic:string}}" }, align 4 @"reflect/types.typeid:basic:int" = external constant i8 declare noalias nonnull ptr @runtime.alloc(i32, ptr, ptr) #0 @@ -35,42 +31,42 @@ entry: define hidden %runtime._interface @main.simpleType(ptr %context) unnamed_addr #1 { entry: %stackalloc = alloca i8, align 1 + call void @runtime.trackPointer(ptr nonnull @"reflect/types.type:basic:int", ptr nonnull %stackalloc, ptr undef) #6 call void @runtime.trackPointer(ptr null, ptr nonnull %stackalloc, ptr undef) #6 - ret %runtime._interface { i32 ptrtoint (ptr @"reflect/types.type:basic:int" to i32), ptr null } + ret %runtime._interface { ptr @"reflect/types.type:basic:int", ptr null } } ; Function Attrs: nounwind define hidden %runtime._interface @main.pointerType(ptr %context) unnamed_addr #1 { entry: %stackalloc = alloca i8, align 1 + call void @runtime.trackPointer(ptr nonnull @"reflect/types.type:pointer:basic:int", ptr nonnull %stackalloc, ptr undef) #6 call void @runtime.trackPointer(ptr null, ptr nonnull %stackalloc, ptr undef) #6 - ret %runtime._interface { i32 ptrtoint (ptr @"reflect/types.type:pointer:basic:int" to i32), ptr null } + ret %runtime._interface { ptr @"reflect/types.type:pointer:basic:int", ptr null } } ; Function Attrs: nounwind define hidden %runtime._interface @main.interfaceType(ptr %context) unnamed_addr #1 { entry: %stackalloc = alloca i8, align 1 + call void @runtime.trackPointer(ptr nonnull @"reflect/types.type:pointer:named:error", ptr nonnull %stackalloc, ptr undef) #6 call void @runtime.trackPointer(ptr null, ptr nonnull %stackalloc, ptr undef) #6 - ret %runtime._interface { i32 ptrtoint (ptr @"reflect/types.type:pointer:named:error" to i32), ptr null } + ret %runtime._interface { ptr @"reflect/types.type:pointer:named:error", ptr null } } -declare i1 @"interface:{Error:func:{}{basic:string}}.$typeassert"(i32) #2 - ; Function Attrs: nounwind define hidden %runtime._interface @main.anonymousInterfaceType(ptr %context) unnamed_addr #1 { entry: %stackalloc = alloca i8, align 1 + call void @runtime.trackPointer(ptr nonnull @"reflect/types.type:pointer:interface:{String:func:{}{basic:string}}", ptr nonnull %stackalloc, ptr undef) #6 call void @runtime.trackPointer(ptr null, ptr nonnull %stackalloc, ptr undef) #6 - ret %runtime._interface { i32 ptrtoint (ptr @"reflect/types.type:pointer:interface:{String:func:{}{basic:string}}" to i32), ptr null } + ret %runtime._interface { ptr @"reflect/types.type:pointer:interface:{String:func:{}{basic:string}}", ptr null } } -declare i1 @"interface:{String:func:{}{basic:string}}.$typeassert"(i32) #3 - ; Function Attrs: nounwind -define hidden i1 @main.isInt(i32 %itf.typecode, ptr %itf.value, ptr %context) unnamed_addr #1 { +define hidden i1 @main.isInt(ptr %itf.typecode, ptr %itf.value, ptr %context) unnamed_addr #1 { entry: - %typecode = call i1 @runtime.typeAssert(i32 %itf.typecode, ptr nonnull @"reflect/types.typeid:basic:int", ptr undef) #6 + %typecode = call i1 @runtime.typeAssert(ptr %itf.typecode, ptr nonnull @"reflect/types.typeid:basic:int", ptr undef) #6 br i1 %typecode, label %typeassert.ok, label %typeassert.next typeassert.next: ; preds = %typeassert.ok, %entry @@ -80,12 +76,12 @@ typeassert.ok: ; preds = %entry br label %typeassert.next } -declare i1 @runtime.typeAssert(i32, ptr dereferenceable_or_null(1), ptr) #0 +declare i1 @runtime.typeAssert(ptr, ptr dereferenceable_or_null(1), ptr) #0 ; Function Attrs: nounwind -define hidden i1 @main.isError(i32 %itf.typecode, ptr %itf.value, ptr %context) unnamed_addr #1 { +define hidden i1 @main.isError(ptr %itf.typecode, ptr %itf.value, ptr %context) unnamed_addr #1 { entry: - %0 = call i1 @"interface:{Error:func:{}{basic:string}}.$typeassert"(i32 %itf.typecode) #6 + %0 = call i1 @"interface:{Error:func:{}{basic:string}}.$typeassert"(ptr %itf.typecode) #6 br i1 %0, label %typeassert.ok, label %typeassert.next typeassert.next: ; preds = %typeassert.ok, %entry @@ -95,10 +91,12 @@ typeassert.ok: ; preds = %entry br label %typeassert.next } +declare i1 @"interface:{Error:func:{}{basic:string}}.$typeassert"(ptr) #2 + ; Function Attrs: nounwind -define hidden i1 @main.isStringer(i32 %itf.typecode, ptr %itf.value, ptr %context) unnamed_addr #1 { +define hidden i1 @main.isStringer(ptr %itf.typecode, ptr %itf.value, ptr %context) unnamed_addr #1 { entry: - %0 = call i1 @"interface:{String:func:{}{basic:string}}.$typeassert"(i32 %itf.typecode) #6 + %0 = call i1 @"interface:{String:func:{}{basic:string}}.$typeassert"(ptr %itf.typecode) #6 br i1 %0, label %typeassert.ok, label %typeassert.next typeassert.next: ; preds = %typeassert.ok, %entry @@ -108,26 +106,28 @@ typeassert.ok: ; preds = %entry br label %typeassert.next } +declare i1 @"interface:{String:func:{}{basic:string}}.$typeassert"(ptr) #3 + ; Function Attrs: nounwind -define hidden i8 @main.callFooMethod(i32 %itf.typecode, ptr %itf.value, ptr %context) unnamed_addr #1 { +define hidden i8 @main.callFooMethod(ptr %itf.typecode, ptr %itf.value, ptr %context) unnamed_addr #1 { entry: - %0 = call i8 @"interface:{String:func:{}{basic:string},main.foo:func:{basic:int}{basic:uint8}}.foo$invoke"(ptr %itf.value, i32 3, i32 %itf.typecode, ptr undef) #6 + %0 = call i8 @"interface:{String:func:{}{basic:string},main.foo:func:{basic:int}{basic:uint8}}.foo$invoke"(ptr %itf.value, i32 3, ptr %itf.typecode, ptr undef) #6 ret i8 %0 } -declare i8 @"interface:{String:func:{}{basic:string},main.foo:func:{basic:int}{basic:uint8}}.foo$invoke"(ptr, i32, i32, ptr) #4 +declare i8 @"interface:{String:func:{}{basic:string},main.foo:func:{basic:int}{basic:uint8}}.foo$invoke"(ptr, i32, ptr, ptr) #4 ; Function Attrs: nounwind -define hidden %runtime._string @main.callErrorMethod(i32 %itf.typecode, ptr %itf.value, ptr %context) unnamed_addr #1 { +define hidden %runtime._string @main.callErrorMethod(ptr %itf.typecode, ptr %itf.value, ptr %context) unnamed_addr #1 { entry: %stackalloc = alloca i8, align 1 - %0 = call %runtime._string @"interface:{Error:func:{}{basic:string}}.Error$invoke"(ptr %itf.value, i32 %itf.typecode, ptr undef) #6 + %0 = call %runtime._string @"interface:{Error:func:{}{basic:string}}.Error$invoke"(ptr %itf.value, ptr %itf.typecode, ptr undef) #6 %1 = extractvalue %runtime._string %0, 0 call void @runtime.trackPointer(ptr %1, ptr nonnull %stackalloc, ptr undef) #6 ret %runtime._string %0 } -declare %runtime._string @"interface:{Error:func:{}{basic:string}}.Error$invoke"(ptr, i32, ptr) #5 +declare %runtime._string @"interface:{Error:func:{}{basic:string}}.Error$invoke"(ptr, ptr, ptr) #5 attributes #0 = { "target-features"="+bulk-memory,+nontrapping-fptoint,+sign-ext" } attributes #1 = { nounwind "target-features"="+bulk-memory,+nontrapping-fptoint,+sign-ext" } diff --git a/compiler/testdata/pointer.go b/compiler/testdata/pointer.go index 84a983c3f0..6575dd83a7 100644 --- a/compiler/testdata/pointer.go +++ b/compiler/testdata/pointer.go @@ -24,18 +24,3 @@ func pointerCastToUnsafe(x *int) unsafe.Pointer { func pointerCastToUnsafeNoop(x *byte) unsafe.Pointer { return unsafe.Pointer(x) } - -// The compiler has support for a few special cast+add patterns that are -// transformed into a single GEP. - -func pointerUnsafeGEPFixedOffset(ptr *byte) *byte { - return (*byte)(unsafe.Pointer(uintptr(unsafe.Pointer(ptr)) + 10)) -} - -func pointerUnsafeGEPByteOffset(ptr *byte, offset uintptr) *byte { - return (*byte)(unsafe.Pointer(uintptr(unsafe.Pointer(ptr)) + offset)) -} - -func pointerUnsafeGEPIntOffset(ptr *int32, offset uintptr) *int32 { - return (*int32)(unsafe.Pointer(uintptr(unsafe.Pointer(ptr)) + offset*4)) -} diff --git a/compiler/testdata/pointer.ll b/compiler/testdata/pointer.ll index ac1a4bc71f..3cbf70a376 100644 --- a/compiler/testdata/pointer.ll +++ b/compiler/testdata/pointer.ll @@ -43,40 +43,6 @@ entry: ret ptr %x } -; Function Attrs: nounwind -define hidden ptr @main.pointerUnsafeGEPFixedOffset(ptr dereferenceable_or_null(1) %ptr, ptr %context) unnamed_addr #1 { -entry: - %stackalloc = alloca i8, align 1 - call void @runtime.trackPointer(ptr %ptr, ptr nonnull %stackalloc, ptr undef) #2 - %0 = getelementptr inbounds i8, ptr %ptr, i32 10 - call void @runtime.trackPointer(ptr nonnull %0, ptr nonnull %stackalloc, ptr undef) #2 - call void @runtime.trackPointer(ptr nonnull %0, ptr nonnull %stackalloc, ptr undef) #2 - ret ptr %0 -} - -; Function Attrs: nounwind -define hidden ptr @main.pointerUnsafeGEPByteOffset(ptr dereferenceable_or_null(1) %ptr, i32 %offset, ptr %context) unnamed_addr #1 { -entry: - %stackalloc = alloca i8, align 1 - call void @runtime.trackPointer(ptr %ptr, ptr nonnull %stackalloc, ptr undef) #2 - %0 = getelementptr inbounds i8, ptr %ptr, i32 %offset - call void @runtime.trackPointer(ptr %0, ptr nonnull %stackalloc, ptr undef) #2 - call void @runtime.trackPointer(ptr %0, ptr nonnull %stackalloc, ptr undef) #2 - ret ptr %0 -} - -; Function Attrs: nounwind -define hidden ptr @main.pointerUnsafeGEPIntOffset(ptr dereferenceable_or_null(4) %ptr, i32 %offset, ptr %context) unnamed_addr #1 { -entry: - %stackalloc = alloca i8, align 1 - call void @runtime.trackPointer(ptr %ptr, ptr nonnull %stackalloc, ptr undef) #2 - %0 = shl i32 %offset, 2 - %1 = getelementptr inbounds i8, ptr %ptr, i32 %0 - call void @runtime.trackPointer(ptr %1, ptr nonnull %stackalloc, ptr undef) #2 - call void @runtime.trackPointer(ptr %1, ptr nonnull %stackalloc, ptr undef) #2 - ret ptr %1 -} - attributes #0 = { "target-features"="+bulk-memory,+nontrapping-fptoint,+sign-ext" } attributes #1 = { nounwind "target-features"="+bulk-memory,+nontrapping-fptoint,+sign-ext" } attributes #2 = { nounwind } diff --git a/compiler/testdata/pragma.ll b/compiler/testdata/pragma.ll index 6828880433..35afcf7fcd 100644 --- a/compiler/testdata/pragma.ll +++ b/compiler/testdata/pragma.ll @@ -47,13 +47,13 @@ entry: ret void } -; Function Attrs: nounwind -define hidden void @main.functionInSection(ptr %context) unnamed_addr #1 section ".special_function_section" { +; Function Attrs: noinline nounwind +define hidden void @main.functionInSection(ptr %context) unnamed_addr #4 section ".special_function_section" { entry: ret void } -; Function Attrs: nounwind +; Function Attrs: noinline nounwind define void @exportedFunctionInSection() #5 section ".special_function_section" { entry: ret void @@ -66,4 +66,4 @@ attributes #1 = { nounwind "target-features"="+bulk-memory,+nontrapping-fptoint, attributes #2 = { nounwind "target-features"="+bulk-memory,+nontrapping-fptoint,+sign-ext" "wasm-export-name"="extern_func" "wasm-import-module"="env" "wasm-import-name"="extern_func" } attributes #3 = { inlinehint nounwind "target-features"="+bulk-memory,+nontrapping-fptoint,+sign-ext" } attributes #4 = { noinline nounwind "target-features"="+bulk-memory,+nontrapping-fptoint,+sign-ext" } -attributes #5 = { nounwind "target-features"="+bulk-memory,+nontrapping-fptoint,+sign-ext" "wasm-export-name"="exportedFunctionInSection" "wasm-import-module"="env" "wasm-import-name"="exportedFunctionInSection" } +attributes #5 = { noinline nounwind "target-features"="+bulk-memory,+nontrapping-fptoint,+sign-ext" "wasm-export-name"="exportedFunctionInSection" "wasm-import-module"="env" "wasm-import-name"="exportedFunctionInSection" } diff --git a/compiler/testdata/zeromap.go b/compiler/testdata/zeromap.go new file mode 100644 index 0000000000..6cf9f611b2 --- /dev/null +++ b/compiler/testdata/zeromap.go @@ -0,0 +1,37 @@ +package main + +type hasPadding struct { + b1 bool + i int + b2 bool +} + +type nestedPadding struct { + b bool + hasPadding + i int +} + +//go:noinline +func testZeroGet(m map[hasPadding]int, s hasPadding) int { + return m[s] +} + +//go:noinline +func testZeroSet(m map[hasPadding]int, s hasPadding) { + m[s] = 5 +} + +//go:noinline +func testZeroArrayGet(m map[[2]hasPadding]int, s [2]hasPadding) int { + return m[s] +} + +//go:noinline +func testZeroArraySet(m map[[2]hasPadding]int, s [2]hasPadding) { + m[s] = 5 +} + +func main() { + +} diff --git a/compiler/testdata/zeromap.ll b/compiler/testdata/zeromap.ll new file mode 100644 index 0000000000..a04ad242f3 --- /dev/null +++ b/compiler/testdata/zeromap.ll @@ -0,0 +1,170 @@ +; ModuleID = 'zeromap.go' +source_filename = "zeromap.go" +target datalayout = "e-m:e-p:32:32-p10:8:8-p20:8:8-i64:64-n32:64-S128-ni:1:10:20" +target triple = "wasm32-unknown-wasi" + +%main.hasPadding = type { i1, i32, i1 } + +declare noalias nonnull ptr @runtime.alloc(i32, ptr, ptr) #0 + +declare void @runtime.trackPointer(ptr nocapture readonly, ptr, ptr) #0 + +; Function Attrs: nounwind +define hidden void @main.init(ptr %context) unnamed_addr #1 { +entry: + ret void +} + +; Function Attrs: noinline nounwind +define hidden i32 @main.testZeroGet(ptr dereferenceable_or_null(40) %m, i1 %s.b1, i32 %s.i, i1 %s.b2, ptr %context) unnamed_addr #2 { +entry: + %hashmap.key = alloca %main.hasPadding, align 8 + %hashmap.value = alloca i32, align 4 + %s = alloca %main.hasPadding, align 8 + %0 = insertvalue %main.hasPadding zeroinitializer, i1 %s.b1, 0 + %1 = insertvalue %main.hasPadding %0, i32 %s.i, 1 + %2 = insertvalue %main.hasPadding %1, i1 %s.b2, 2 + %stackalloc = alloca i8, align 1 + store %main.hasPadding zeroinitializer, ptr %s, align 8 + call void @runtime.trackPointer(ptr nonnull %s, ptr nonnull %stackalloc, ptr undef) #4 + store %main.hasPadding %2, ptr %s, align 8 + call void @llvm.lifetime.start.p0(i64 4, ptr nonnull %hashmap.value) + call void @llvm.lifetime.start.p0(i64 12, ptr nonnull %hashmap.key) + store %main.hasPadding %2, ptr %hashmap.key, align 8 + %3 = getelementptr inbounds i8, ptr %hashmap.key, i32 1 + call void @runtime.memzero(ptr nonnull %3, i32 3, ptr undef) #4 + %4 = getelementptr inbounds i8, ptr %hashmap.key, i32 9 + call void @runtime.memzero(ptr nonnull %4, i32 3, ptr undef) #4 + %5 = call i1 @runtime.hashmapBinaryGet(ptr %m, ptr nonnull %hashmap.key, ptr nonnull %hashmap.value, i32 4, ptr undef) #4 + call void @llvm.lifetime.end.p0(i64 12, ptr nonnull %hashmap.key) + %6 = load i32, ptr %hashmap.value, align 4 + call void @llvm.lifetime.end.p0(i64 4, ptr nonnull %hashmap.value) + ret i32 %6 +} + +; Function Attrs: argmemonly nocallback nofree nosync nounwind willreturn +declare void @llvm.lifetime.start.p0(i64 immarg, ptr nocapture) #3 + +declare void @runtime.memzero(ptr, i32, ptr) #0 + +declare i1 @runtime.hashmapBinaryGet(ptr dereferenceable_or_null(40), ptr, ptr, i32, ptr) #0 + +; Function Attrs: argmemonly nocallback nofree nosync nounwind willreturn +declare void @llvm.lifetime.end.p0(i64 immarg, ptr nocapture) #3 + +; Function Attrs: noinline nounwind +define hidden void @main.testZeroSet(ptr dereferenceable_or_null(40) %m, i1 %s.b1, i32 %s.i, i1 %s.b2, ptr %context) unnamed_addr #2 { +entry: + %hashmap.key = alloca %main.hasPadding, align 8 + %hashmap.value = alloca i32, align 4 + %s = alloca %main.hasPadding, align 8 + %0 = insertvalue %main.hasPadding zeroinitializer, i1 %s.b1, 0 + %1 = insertvalue %main.hasPadding %0, i32 %s.i, 1 + %2 = insertvalue %main.hasPadding %1, i1 %s.b2, 2 + %stackalloc = alloca i8, align 1 + store %main.hasPadding zeroinitializer, ptr %s, align 8 + call void @runtime.trackPointer(ptr nonnull %s, ptr nonnull %stackalloc, ptr undef) #4 + store %main.hasPadding %2, ptr %s, align 8 + call void @llvm.lifetime.start.p0(i64 4, ptr nonnull %hashmap.value) + store i32 5, ptr %hashmap.value, align 4 + call void @llvm.lifetime.start.p0(i64 12, ptr nonnull %hashmap.key) + store %main.hasPadding %2, ptr %hashmap.key, align 8 + %3 = getelementptr inbounds i8, ptr %hashmap.key, i32 1 + call void @runtime.memzero(ptr nonnull %3, i32 3, ptr undef) #4 + %4 = getelementptr inbounds i8, ptr %hashmap.key, i32 9 + call void @runtime.memzero(ptr nonnull %4, i32 3, ptr undef) #4 + call void @runtime.hashmapBinarySet(ptr %m, ptr nonnull %hashmap.key, ptr nonnull %hashmap.value, ptr undef) #4 + call void @llvm.lifetime.end.p0(i64 12, ptr nonnull %hashmap.key) + call void @llvm.lifetime.end.p0(i64 4, ptr nonnull %hashmap.value) + ret void +} + +declare void @runtime.hashmapBinarySet(ptr dereferenceable_or_null(40), ptr, ptr, ptr) #0 + +; Function Attrs: noinline nounwind +define hidden i32 @main.testZeroArrayGet(ptr dereferenceable_or_null(40) %m, [2 x %main.hasPadding] %s, ptr %context) unnamed_addr #2 { +entry: + %hashmap.key = alloca [2 x %main.hasPadding], align 8 + %hashmap.value = alloca i32, align 4 + %s1 = alloca [2 x %main.hasPadding], align 8 + %stackalloc = alloca i8, align 1 + store %main.hasPadding zeroinitializer, ptr %s1, align 8 + %s1.repack2 = getelementptr inbounds [2 x %main.hasPadding], ptr %s1, i32 0, i32 1 + store %main.hasPadding zeroinitializer, ptr %s1.repack2, align 4 + call void @runtime.trackPointer(ptr nonnull %s1, ptr nonnull %stackalloc, ptr undef) #4 + %s.elt = extractvalue [2 x %main.hasPadding] %s, 0 + store %main.hasPadding %s.elt, ptr %s1, align 8 + %s1.repack3 = getelementptr inbounds [2 x %main.hasPadding], ptr %s1, i32 0, i32 1 + %s.elt4 = extractvalue [2 x %main.hasPadding] %s, 1 + store %main.hasPadding %s.elt4, ptr %s1.repack3, align 4 + call void @llvm.lifetime.start.p0(i64 4, ptr nonnull %hashmap.value) + call void @llvm.lifetime.start.p0(i64 24, ptr nonnull %hashmap.key) + %s.elt7 = extractvalue [2 x %main.hasPadding] %s, 0 + store %main.hasPadding %s.elt7, ptr %hashmap.key, align 8 + %hashmap.key.repack8 = getelementptr inbounds [2 x %main.hasPadding], ptr %hashmap.key, i32 0, i32 1 + %s.elt9 = extractvalue [2 x %main.hasPadding] %s, 1 + store %main.hasPadding %s.elt9, ptr %hashmap.key.repack8, align 4 + %0 = getelementptr inbounds i8, ptr %hashmap.key, i32 1 + call void @runtime.memzero(ptr nonnull %0, i32 3, ptr undef) #4 + %1 = getelementptr inbounds i8, ptr %hashmap.key, i32 9 + call void @runtime.memzero(ptr nonnull %1, i32 3, ptr undef) #4 + %2 = getelementptr inbounds i8, ptr %hashmap.key, i32 13 + call void @runtime.memzero(ptr nonnull %2, i32 3, ptr undef) #4 + %3 = getelementptr inbounds i8, ptr %hashmap.key, i32 21 + call void @runtime.memzero(ptr nonnull %3, i32 3, ptr undef) #4 + %4 = call i1 @runtime.hashmapBinaryGet(ptr %m, ptr nonnull %hashmap.key, ptr nonnull %hashmap.value, i32 4, ptr undef) #4 + call void @llvm.lifetime.end.p0(i64 24, ptr nonnull %hashmap.key) + %5 = load i32, ptr %hashmap.value, align 4 + call void @llvm.lifetime.end.p0(i64 4, ptr nonnull %hashmap.value) + ret i32 %5 +} + +; Function Attrs: noinline nounwind +define hidden void @main.testZeroArraySet(ptr dereferenceable_or_null(40) %m, [2 x %main.hasPadding] %s, ptr %context) unnamed_addr #2 { +entry: + %hashmap.key = alloca [2 x %main.hasPadding], align 8 + %hashmap.value = alloca i32, align 4 + %s1 = alloca [2 x %main.hasPadding], align 8 + %stackalloc = alloca i8, align 1 + store %main.hasPadding zeroinitializer, ptr %s1, align 8 + %s1.repack2 = getelementptr inbounds [2 x %main.hasPadding], ptr %s1, i32 0, i32 1 + store %main.hasPadding zeroinitializer, ptr %s1.repack2, align 4 + call void @runtime.trackPointer(ptr nonnull %s1, ptr nonnull %stackalloc, ptr undef) #4 + %s.elt = extractvalue [2 x %main.hasPadding] %s, 0 + store %main.hasPadding %s.elt, ptr %s1, align 8 + %s1.repack3 = getelementptr inbounds [2 x %main.hasPadding], ptr %s1, i32 0, i32 1 + %s.elt4 = extractvalue [2 x %main.hasPadding] %s, 1 + store %main.hasPadding %s.elt4, ptr %s1.repack3, align 4 + call void @llvm.lifetime.start.p0(i64 4, ptr nonnull %hashmap.value) + store i32 5, ptr %hashmap.value, align 4 + call void @llvm.lifetime.start.p0(i64 24, ptr nonnull %hashmap.key) + %s.elt7 = extractvalue [2 x %main.hasPadding] %s, 0 + store %main.hasPadding %s.elt7, ptr %hashmap.key, align 8 + %hashmap.key.repack8 = getelementptr inbounds [2 x %main.hasPadding], ptr %hashmap.key, i32 0, i32 1 + %s.elt9 = extractvalue [2 x %main.hasPadding] %s, 1 + store %main.hasPadding %s.elt9, ptr %hashmap.key.repack8, align 4 + %0 = getelementptr inbounds i8, ptr %hashmap.key, i32 1 + call void @runtime.memzero(ptr nonnull %0, i32 3, ptr undef) #4 + %1 = getelementptr inbounds i8, ptr %hashmap.key, i32 9 + call void @runtime.memzero(ptr nonnull %1, i32 3, ptr undef) #4 + %2 = getelementptr inbounds i8, ptr %hashmap.key, i32 13 + call void @runtime.memzero(ptr nonnull %2, i32 3, ptr undef) #4 + %3 = getelementptr inbounds i8, ptr %hashmap.key, i32 21 + call void @runtime.memzero(ptr nonnull %3, i32 3, ptr undef) #4 + call void @runtime.hashmapBinarySet(ptr %m, ptr nonnull %hashmap.key, ptr nonnull %hashmap.value, ptr undef) #4 + call void @llvm.lifetime.end.p0(i64 24, ptr nonnull %hashmap.key) + call void @llvm.lifetime.end.p0(i64 4, ptr nonnull %hashmap.value) + ret void +} + +; Function Attrs: nounwind +define hidden void @main.main(ptr %context) unnamed_addr #1 { +entry: + ret void +} + +attributes #0 = { "target-features"="+bulk-memory,+nontrapping-fptoint,+sign-ext" } +attributes #1 = { nounwind "target-features"="+bulk-memory,+nontrapping-fptoint,+sign-ext" } +attributes #2 = { noinline nounwind "target-features"="+bulk-memory,+nontrapping-fptoint,+sign-ext" } +attributes #3 = { argmemonly nocallback nofree nosync nounwind willreturn } +attributes #4 = { nounwind } diff --git a/compiler/volatile.go b/compiler/volatile.go index 3d3a67fa13..0f7e7b2271 100644 --- a/compiler/volatile.go +++ b/compiler/volatile.go @@ -9,7 +9,7 @@ import "go/types" // runtime/volatile.LoadT(). func (b *builder) createVolatileLoad() { b.createFunctionStart(true) - addr := b.getValue(b.fn.Params[0]) + addr := b.getValue(b.fn.Params[0], getPos(b.fn)) b.createNilCheck(b.fn.Params[0], addr, "deref") valType := b.getLLVMType(b.fn.Params[0].Type().(*types.Pointer).Elem()) val := b.CreateLoad(valType, addr, "") @@ -21,8 +21,8 @@ func (b *builder) createVolatileLoad() { // runtime/volatile.StoreT(). func (b *builder) createVolatileStore() { b.createFunctionStart(true) - addr := b.getValue(b.fn.Params[0]) - val := b.getValue(b.fn.Params[1]) + addr := b.getValue(b.fn.Params[0], getPos(b.fn)) + val := b.getValue(b.fn.Params[1], getPos(b.fn)) b.createNilCheck(b.fn.Params[0], addr, "deref") store := b.CreateStore(val, addr) store.SetVolatile(true) diff --git a/interp/interpreter.go b/interp/interpreter.go index c61ce7cf3f..7c58a5d20e 100644 --- a/interp/interpreter.go +++ b/interp/interpreter.go @@ -238,7 +238,7 @@ func (r *runner) run(fn *function, params []value, parentMem *memoryView, indent // which case this call won't even get to this point but will // already be emitted in initAll. continue - case strings.HasPrefix(callFn.name, "runtime.print") || callFn.name == "runtime._panic" || callFn.name == "runtime.hashmapGet" || + case strings.HasPrefix(callFn.name, "runtime.print") || callFn.name == "runtime._panic" || callFn.name == "runtime.hashmapGet" || callFn.name == "runtime.hashmapInterfaceHash" || callFn.name == "os.runtime_args" || callFn.name == "internal/task.start" || callFn.name == "internal/task.Current": // These functions should be run at runtime. Specifically: // * Print and panic functions are best emitted directly without @@ -378,42 +378,6 @@ func (r *runner) run(fn *function, params []value, parentMem *memoryView, indent copy(dstBuf.buf[dst.offset():dst.offset()+nBytes], srcBuf.buf[src.offset():]) dstObj.buffer = dstBuf mem.put(dst.index(), dstObj) - case callFn.name == "(reflect.rawType).elem": - if r.debug { - fmt.Fprintln(os.Stderr, indent+"call (reflect.rawType).elem:", operands[1:]) - } - // Extract the type code global from the first parameter. - typecodeIDPtrToInt, err := operands[1].toLLVMValue(inst.llvmInst.Operand(0).Type(), &mem) - if err != nil { - return nil, mem, r.errorAt(inst, err) - } - typecodeID := typecodeIDPtrToInt.Operand(0) - - // Get the type class. - // See also: getClassAndValueFromTypeCode in transform/reflect.go. - typecodeName := typecodeID.Name() - const prefix = "reflect/types.type:" - if !strings.HasPrefix(typecodeName, prefix) { - panic("unexpected typecode name: " + typecodeName) - } - id := typecodeName[len(prefix):] - class := id[:strings.IndexByte(id, ':')] - value := id[len(class)+1:] - if class == "named" { - // Get the underlying type. - class = value[:strings.IndexByte(value, ':')] - value = value[len(class)+1:] - } - - // Elem() is only valid for certain type classes. - switch class { - case "chan", "pointer", "slice", "array": - elementType := r.builder.CreateExtractValue(typecodeID.Initializer(), 0, "") - uintptrType := r.mod.Context().IntType(int(mem.r.pointerSize) * 8) - locals[inst.localIndex] = r.getValue(llvm.ConstPtrToInt(elementType, uintptrType)) - default: - return nil, mem, r.errorAt(inst, fmt.Errorf("(reflect.Type).Elem() called on %s type", class)) - } case callFn.name == "runtime.typeAssert": // This function must be implemented manually as it is normally // implemented by the interface lowering pass. @@ -424,15 +388,22 @@ func (r *runner) run(fn *function, params []value, parentMem *memoryView, indent if err != nil { return nil, mem, r.errorAt(inst, err) } - actualTypePtrToInt, err := operands[1].toLLVMValue(inst.llvmInst.Operand(0).Type(), &mem) + actualType, err := operands[1].toLLVMValue(inst.llvmInst.Operand(0).Type(), &mem) if err != nil { return nil, mem, r.errorAt(inst, err) } - if !actualTypePtrToInt.IsAConstantInt().IsNil() && actualTypePtrToInt.ZExtValue() == 0 { + if !actualType.IsAConstantInt().IsNil() && actualType.ZExtValue() == 0 { locals[inst.localIndex] = literalValue{uint8(0)} break } - actualType := actualTypePtrToInt.Operand(0) + // Strip pointer casts (bitcast, getelementptr). + for !actualType.IsAConstantExpr().IsNil() { + opcode := actualType.Opcode() + if opcode != llvm.GetElementPtr && opcode != llvm.BitCast { + break + } + actualType = actualType.Operand(0) + } if strings.TrimPrefix(actualType.Name(), "reflect/types.type:") == strings.TrimPrefix(assertedType.Name(), "reflect/types.typeid:") { locals[inst.localIndex] = literalValue{uint8(1)} } else { @@ -448,11 +419,12 @@ func (r *runner) run(fn *function, params []value, parentMem *memoryView, indent if err != nil { return nil, mem, r.errorAt(inst, err) } - methodSetPtr, err := mem.load(typecodePtr.addOffset(r.pointerSize*2), r.pointerSize).asPointer(r) + methodSetPtr, err := mem.load(typecodePtr.addOffset(-int64(r.pointerSize)), r.pointerSize).asPointer(r) if err != nil { return nil, mem, r.errorAt(inst, err) } methodSet := mem.get(methodSetPtr.index()).llvmGlobal.Initializer() + numMethods := int(r.builder.CreateExtractValue(methodSet, 0, "").ZExtValue()) llvmFn := inst.llvmInst.CalledValue() methodSetAttr := llvmFn.GetStringAttributeAtIndex(-1, "tinygo-methods") methodSetString := methodSetAttr.GetStringValue() @@ -460,9 +432,9 @@ func (r *runner) run(fn *function, params []value, parentMem *memoryView, indent // Make a set of all the methods on the concrete type, for // easier checking in the next step. concreteTypeMethods := map[string]struct{}{} - for i := 0; i < methodSet.Type().ArrayLength(); i++ { - methodInfo := r.builder.CreateExtractValue(methodSet, i, "") - name := r.builder.CreateExtractValue(methodInfo, 0, "").Name() + for i := 0; i < numMethods; i++ { + methodInfo := r.builder.CreateExtractValue(methodSet, 1, "") + name := r.builder.CreateExtractValue(methodInfo, i, "").Name() concreteTypeMethods[name] = struct{}{} } @@ -488,15 +460,16 @@ func (r *runner) run(fn *function, params []value, parentMem *memoryView, indent fmt.Fprintln(os.Stderr, indent+"invoke method:", operands[1:]) } - // Load the type code of the interface value. - typecodeIDBitCast, err := operands[len(operands)-2].toLLVMValue(inst.llvmInst.Operand(len(operands)-3).Type(), &mem) + // Load the type code and method set of the interface value. + typecodePtr, err := operands[len(operands)-2].asPointer(r) if err != nil { return nil, mem, r.errorAt(inst, err) } - typecodeID := typecodeIDBitCast.Operand(0).Initializer() - - // Load the method set, which is part of the typecodeID object. - methodSet := stripPointerCasts(r.builder.CreateExtractValue(typecodeID, 2, "")).Initializer() + methodSetPtr, err := mem.load(typecodePtr.addOffset(-int64(r.pointerSize)), r.pointerSize).asPointer(r) + if err != nil { + return nil, mem, r.errorAt(inst, err) + } + methodSet := mem.get(methodSetPtr.index()).llvmGlobal.Initializer() // We don't need to load the interface method set. @@ -508,13 +481,14 @@ func (r *runner) run(fn *function, params []value, parentMem *memoryView, indent // Iterate through all methods, looking for the one method that // should be returned. - numMethods := methodSet.Type().ArrayLength() + numMethods := int(r.builder.CreateExtractValue(methodSet, 0, "").ZExtValue()) var method llvm.Value for i := 0; i < numMethods; i++ { - methodSignatureAgg := r.builder.CreateExtractValue(methodSet, i, "") - methodSignature := r.builder.CreateExtractValue(methodSignatureAgg, 0, "") + methodSignatureAgg := r.builder.CreateExtractValue(methodSet, 1, "") + methodSignature := r.builder.CreateExtractValue(methodSignatureAgg, i, "") if methodSignature == signature { - method = r.builder.CreateExtractValue(methodSignatureAgg, 1, "").Operand(0) + methodAgg := r.builder.CreateExtractValue(methodSet, 2, "") + method = r.builder.CreateExtractValue(methodAgg, i, "") } } if method.IsNil() { @@ -685,7 +659,7 @@ func (r *runner) run(fn *function, params []value, parentMem *memoryView, indent } continue } - ptr = ptr.addOffset(uint32(offset)) + ptr = ptr.addOffset(int64(offset)) locals[inst.localIndex] = ptr if r.debug { fmt.Fprintln(os.Stderr, indent+"gep:", operands, "->", ptr) @@ -784,7 +758,7 @@ func (r *runner) run(fn *function, params []value, parentMem *memoryView, indent case llvm.Add: // This likely means this is part of a // unsafe.Pointer(uintptr(ptr) + offset) pattern. - lhsPtr = lhsPtr.addOffset(uint32(rhs.Uint())) + lhsPtr = lhsPtr.addOffset(int64(rhs.Uint())) locals[inst.localIndex] = lhsPtr continue case llvm.Xor: diff --git a/interp/memory.go b/interp/memory.go index 1f9ed99f3c..9a28f1d49e 100644 --- a/interp/memory.go +++ b/interp/memory.go @@ -501,7 +501,7 @@ func (v pointerValue) offset() uint32 { // addOffset essentially does a GEP operation (pointer arithmetic): it adds the // offset to the pointer. It also checks that the offset doesn't overflow the // maximum offset size (which is 4GB). -func (v pointerValue) addOffset(offset uint32) pointerValue { +func (v pointerValue) addOffset(offset int64) pointerValue { result := pointerValue{v.pointer + uint64(offset)} if checks && v.index() != result.index() { panic("interp: offset out of range") @@ -815,7 +815,7 @@ func (v rawValue) rawLLVMValue(mem *memoryView) (llvm.Value, error) { // as a ptrtoint, so that they can be used in certain // optimizations. name := elementType.StructName() - if name == "runtime.typecodeID" || name == "runtime.funcValueWithSignature" { + if name == "runtime.funcValueWithSignature" { uintptrType := ctx.IntType(int(mem.r.pointerSize) * 8) field = llvm.ConstPtrToInt(field, uintptrType) } diff --git a/interp/testdata/interface.ll b/interp/testdata/interface.ll index 6520efc5ca..da27ad8a01 100644 --- a/interp/testdata/interface.ll +++ b/interp/testdata/interface.ll @@ -1,17 +1,16 @@ target datalayout = "e-m:e-i64:64-f80:128-n8:16:32:64-S128" target triple = "x86_64--linux" -%runtime.typecodeID = type { %runtime.typecodeID*, i64, %runtime.interfaceMethodInfo* } -%runtime.interfaceMethodInfo = type { i8*, i64 } - @main.v1 = global i1 0 @main.v2 = global i1 0 -@"reflect/types.type:named:main.foo" = private constant %runtime.typecodeID { %runtime.typecodeID* @"reflect/types.type:basic:int", i64 0, %runtime.interfaceMethodInfo* null } +@"reflect/types.type:named:main.foo" = private constant { i8, i8*, i8* } { i8 34, i8* getelementptr inbounds ({ i8, i8* }, { i8, i8* }* @"reflect/types.type:pointer:named:main.foo", i32 0, i32 0), i8* getelementptr inbounds ({ i8, i8* }, { i8, i8* }* @"reflect/types.type:basic:int", i32 0, i32 0) }, align 4 +@"reflect/types.type:pointer:named:main.foo" = external constant { i8, i8* } @"reflect/types.typeid:named:main.foo" = external constant i8 -@"reflect/types.type:basic:int" = external constant %runtime.typecodeID +@"reflect/types.type:basic:int" = private constant { i8, i8* } { i8 2, i8* getelementptr inbounds ({ i8, i8* }, { i8, i8* }* @"reflect/types.type:pointer:basic:int", i32 0, i32 0) }, align 4 +@"reflect/types.type:pointer:basic:int" = external constant { i8, i8* } -declare i1 @runtime.typeAssert(i64, i8*, i8*, i8*) +declare i1 @runtime.typeAssert(i8*, i8*, i8*, i8*) define void @runtime.initAll() unnamed_addr { entry: @@ -22,9 +21,9 @@ entry: define internal void @main.init() unnamed_addr { entry: ; Test type asserts. - %typecode = call i1 @runtime.typeAssert(i64 ptrtoint (%runtime.typecodeID* @"reflect/types.type:named:main.foo" to i64), i8* @"reflect/types.typeid:named:main.foo", i8* undef, i8* null) + %typecode = call i1 @runtime.typeAssert(i8* getelementptr inbounds ({ i8, i8*, i8* }, { i8, i8*, i8* }* @"reflect/types.type:named:main.foo", i32 0, i32 0), i8* @"reflect/types.typeid:named:main.foo", i8* undef, i8* null) store i1 %typecode, i1* @main.v1 - %typecode2 = call i1 @runtime.typeAssert(i64 0, i8* @"reflect/types.typeid:named:main.foo", i8* undef, i8* null) + %typecode2 = call i1 @runtime.typeAssert(i8* null, i8* @"reflect/types.typeid:named:main.foo", i8* undef, i8* null) store i1 %typecode2, i1* @main.v2 ret void } diff --git a/loader/goroot.go b/loader/goroot.go index ff2697fc09..d1d8e044dd 100644 --- a/loader/goroot.go +++ b/loader/goroot.go @@ -228,15 +228,17 @@ func pathsToOverride(goMinor int, needsSyscallPackage bool) map[string]bool { "": true, "crypto/": true, "crypto/rand/": false, + "crypto/tls/": false, "device/": false, "examples/": false, "internal/": true, - "internal/fuzz/": false, "internal/bytealg/": false, + "internal/fuzz/": false, "internal/reflectlite/": false, "internal/task/": false, "machine/": false, "net/": true, + "net/http/": false, "os/": true, "reflect/": false, "runtime/": false, diff --git a/main.go b/main.go index 44eacb2cfb..481e92ba7e 100644 --- a/main.go +++ b/main.go @@ -644,7 +644,7 @@ func Debug(debugger, pkgName string, ocdOutput bool, options *compileopts.Option case "qemu-user": port = ":1234" // Run in an emulator. - args := append(emulator[1:], "-g", "1234") + args := append([]string{"-g", "1234"}, emulator[1:]...) daemon = executeCommand(config.Options, emulator[0], args...) daemon.Stdout = os.Stdout daemon.Stderr = os.Stderr @@ -1508,6 +1508,15 @@ func main() { defer pprof.StopCPUProfile() } + // Limit the number of threads to one. + // This is an attempted workaround for the crashes we're seeing in CI on + // Windows. If this change helps, it indicates there is a concurrency issue. + // If it doesn't, then there is something else going on. Either way, this + // should be removed once the test is done. + if runtime.GOOS == "windows" { + runtime.GOMAXPROCS(1) + } + switch command { case "build": pkgName := "." diff --git a/main_test.go b/main_test.go index 72f3c4bbc8..4de8fc09c5 100644 --- a/main_test.go +++ b/main_test.go @@ -180,7 +180,8 @@ func runPlatTests(options compileopts.Options, tests []string, t *testing.T) { // Skip the ones that aren't. switch name { case "reflect.go": - // Reflect tests do not work due to type code issues. + // Reflect tests do not run correctly, probably because of the + // limited amount of memory. continue case "gc.go": @@ -188,20 +189,16 @@ func runPlatTests(options compileopts.Options, tests []string, t *testing.T) { continue case "json.go", "stdlib.go", "testing.go": - // Breaks interp. + // Too big for AVR. Doesn't fit in flash/RAM. continue case "math.go": - // Stuck somewhere, not sure what's happening. + // Needs newer picolibc version (for sqrt). continue case "cgo/": - // CGo does not work on AVR. - continue - - case "timers.go": - // Doesn't compile: - // panic: compiler: could not store type code number inside interface type code + // CGo function pointers don't work on AVR (needs LLVM 16 and + // some compiler changes). continue default: diff --git a/src/crypto/tls/common.go b/src/crypto/tls/common.go new file mode 100644 index 0000000000..f97c47e19c --- /dev/null +++ b/src/crypto/tls/common.go @@ -0,0 +1,12 @@ +// TINYGO: The following is copied and modified from Go 1.19.3 official implementation. + +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package tls + +// ConnectionState records basic TLS details about the connection. +type ConnectionState struct { + // TINYGO: empty; TLS connection offloaded to device +} diff --git a/src/crypto/tls/tls.go b/src/crypto/tls/tls.go new file mode 100644 index 0000000000..1d1eee105c --- /dev/null +++ b/src/crypto/tls/tls.go @@ -0,0 +1,63 @@ +// TINYGO: The following is copied and modified from Go 1.19.3 official implementation. + +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package tls partially implements TLS 1.2, as specified in RFC 5246, +// and TLS 1.3, as specified in RFC 8446. +package tls + +// BUG(agl): The crypto/tls package only implements some countermeasures +// against Lucky13 attacks on CBC-mode encryption, and only on SHA1 +// variants. See http://www.isg.rhul.ac.uk/tls/TLStiming.pdf and +// https://www.imperialviolet.org/2013/02/04/luckythirteen.html. + +import ( + "fmt" + "net" +) + +// Client returns a new TLS client side connection +// using conn as the underlying transport. +// The config cannot be nil: users must set either ServerName or +// InsecureSkipVerify in the config. +func Client(conn net.Conn, config *Config) *net.TLSConn { + panic("tls.Client() not implemented") + return nil +} + +// DialWithDialer connects to the given network address using dialer.Dial and +// then initiates a TLS handshake, returning the resulting TLS connection. Any +// timeout or deadline given in the dialer apply to connection and TLS +// handshake as a whole. +// +// DialWithDialer interprets a nil configuration as equivalent to the zero +// configuration; see the documentation of Config for the defaults. +// +// DialWithDialer uses context.Background internally; to specify the context, +// use Dialer.DialContext with NetDialer set to the desired dialer. +func DialWithDialer(dialer *net.Dialer, network, addr string, config *Config) (*net.TLSConn, error) { + switch network { + case "tcp", "tcp4": + default: + return nil, fmt.Errorf("Network %s not supported", network) + } + + return net.DialTLS(addr) +} + +// Dial connects to the given network address using net.Dial +// and then initiates a TLS handshake, returning the resulting +// TLS connection. +// Dial interprets a nil configuration as equivalent to +// the zero configuration; see the documentation of Config +// for the defaults. +func Dial(network, addr string, config *Config) (*net.TLSConn, error) { + return DialWithDialer(new(net.Dialer), network, addr, config) +} + +// Config is a placeholder for future compatibility with +// tls.Config. +type Config struct { +} diff --git a/src/device/arm/arm.go b/src/device/arm/arm.go index 04637f3175..4b54da8f77 100644 --- a/src/device/arm/arm.go +++ b/src/device/arm/arm.go @@ -29,6 +29,7 @@ // POSSIBILITY OF SUCH DAMAGE. package arm +import "C" import ( "errors" "runtime/volatile" @@ -174,20 +175,15 @@ func SetPriority(irq uint32, priority uint32) { // DisableInterrupts disables all interrupts, and returns the old interrupt // state. -func DisableInterrupts() uintptr { - return AsmFull(` - mrs {}, PRIMASK - cpsid i - `, nil) -} +// +//export DisableInterrupts +func DisableInterrupts() uintptr // EnableInterrupts enables all interrupts again. The value passed in must be // the mask returned by DisableInterrupts. -func EnableInterrupts(mask uintptr) { - AsmFull("msr PRIMASK, {mask}", map[string]interface{}{ - "mask": mask, - }) -} +// +//export EnableInterrupts +func EnableInterrupts(mask uintptr) // Set up the system timer to generate periodic tick events. // This will cause SysTick_Handler to fire once per tick. diff --git a/src/device/arm/interrupts.c b/src/device/arm/interrupts.c new file mode 100644 index 0000000000..d94a313459 --- /dev/null +++ b/src/device/arm/interrupts.c @@ -0,0 +1,22 @@ +#include + +void EnableInterrupts(uintptr_t mask) { + asm volatile( + "msr PRIMASK, %0" + : + : "r"(mask) + : "memory" + ); +} + +uintptr_t DisableInterrupts() { + uintptr_t mask; + asm volatile( + "mrs %0, PRIMASK\n\t" + "cpsid i" + : "=r"(mask) + : + : "memory" + ); + return mask; +} \ No newline at end of file diff --git a/src/device/gba/gba.go b/src/device/gba/gba.go new file mode 100644 index 0000000000..c4236b180e --- /dev/null +++ b/src/device/gba/gba.go @@ -0,0 +1,474 @@ +// Hand written file mostly derived from https://problemkaputt.de/gbatek.htm + +//go:build gameboyadvance + +package gba + +import ( + "runtime/volatile" + "unsafe" +) + +// Interrupt numbers. +const ( + IRQ_VBLANK = 0 + IRQ_HBLANK = 1 + IRQ_VCOUNT = 2 + IRQ_TIMER0 = 3 + IRQ_TIMER1 = 4 + IRQ_TIMER2 = 5 + IRQ_TIMER3 = 6 + IRQ_COM = 7 + IRQ_DMA0 = 8 + IRQ_DMA1 = 9 + IRQ_DMA2 = 10 + IRQ_DMA3 = 11 + IRQ_KEYPAD = 12 + IRQ_GAMEPAK = 13 +) + +// Peripherals +var ( + // Display registers + DISP = (*DISP_Type)(unsafe.Add(unsafe.Pointer(REG_BASE), uintptr(0x0000))) + + // Background control registers + BGCNT0 = (*BGCNT_Type)(unsafe.Add(unsafe.Pointer(REG_BASE), uintptr(0x0008))) + BGCNT1 = (*BGCNT_Type)(unsafe.Add(unsafe.Pointer(REG_BASE), uintptr(0x000A))) + BGCNT2 = (*BGCNT_Type)(unsafe.Add(unsafe.Pointer(REG_BASE), uintptr(0x000C))) + BGCNT3 = (*BGCNT_Type)(unsafe.Add(unsafe.Pointer(REG_BASE), uintptr(0x000E))) + + BG0 = (*BG_Type)(unsafe.Add(unsafe.Pointer(REG_BASE), uintptr(0x0010))) + BG1 = (*BG_Type)(unsafe.Add(unsafe.Pointer(REG_BASE), uintptr(0x0014))) + BG2 = (*BG_Type)(unsafe.Add(unsafe.Pointer(REG_BASE), uintptr(0x0018))) + BG3 = (*BG_Type)(unsafe.Add(unsafe.Pointer(REG_BASE), uintptr(0x001C))) + + BGA2 = (*BG_AFFINE_Type)(unsafe.Add(unsafe.Pointer(REG_BASE), uintptr(0x0020))) + BGA3 = (*BG_AFFINE_Type)(unsafe.Add(unsafe.Pointer(REG_BASE), uintptr(0x0030))) + + WIN = (*WIN_Type)(unsafe.Add(unsafe.Pointer(REG_BASE), uintptr(0x0040))) + + GRAPHICS = (*GRAPHICS_Type)(unsafe.Add(unsafe.Pointer(REG_BASE), uintptr(0x004C))) + + // GBA Sound Channel 1 - Tone & Sweep + SOUND1 = (*SOUND_Type)(unsafe.Add(unsafe.Pointer(REG_BASE), uintptr(0x0060))) + + // GBA Sound Channel 2 - Tone + SOUND2 = (*SOUND_Type)(unsafe.Add(unsafe.Pointer(REG_BASE), uintptr(0x0068))) + + // TODO: Sound channel 3 and 4 + + TM0 = (*TIMER_Type)(unsafe.Add(unsafe.Pointer(REG_BASE), uintptr(0x0100))) + TM1 = (*TIMER_Type)(unsafe.Add(unsafe.Pointer(REG_BASE), uintptr(0x0104))) + TM2 = (*TIMER_Type)(unsafe.Add(unsafe.Pointer(REG_BASE), uintptr(0x0108))) + TM3 = (*TIMER_Type)(unsafe.Add(unsafe.Pointer(REG_BASE), uintptr(0x010C))) + + KEY = (*KEY_Type)(unsafe.Add(unsafe.Pointer(REG_BASE), uintptr(0x0130))) + + INTERRUPT = (*INTERRUPT_Type)(unsafe.Add(unsafe.Pointer(REG_BASE), uintptr(0x0200))) +) + +// Main memory sections +const ( + // External work RAM + MEM_EWRAM uintptr = 0x02000000 + + // Internal work RAM + MEM_IWRAM uintptr = 0x03000000 + + // I/O registers + MEM_IO uintptr = 0x04000000 + + // Palette. Note: no 8bit write !! + MEM_PAL uintptr = 0x05000000 + + // Video RAM. Note: no 8bit write !! + MEM_VRAM uintptr = 0x06000000 + + // Object Attribute Memory (OAM) Note: no 8bit write !! + MEM_OAM uintptr = 0x07000000 + + // ROM. No write at all (duh) + MEM_ROM uintptr = 0x08000000 + + // Static RAM. 8bit write only + MEM_SRAM uintptr = 0x0E000000 +) + +// Main section sizes +const ( + EWRAM_SIZE uintptr = 0x40000 + IWRAM_SIZE uintptr = 0x08000 + PAL_SIZE uintptr = 0x00400 + VRAM_SIZE uintptr = 0x18000 + OAM_SIZE uintptr = 0x00400 + SRAM_SIZE uintptr = 0x10000 +) + +// Sub section sizes +const ( + // BG palette size + PAL_BG_SIZE = 0x00200 + + // Object palette size + PAL_OBJ_SIZE = 0x00200 + + // Charblock size + CBB_SIZE = 0x04000 + + // Screenblock size + SBB_SIZE = 0x00800 + + // BG VRAM size + VRAM_BG_SIZE = 0x10000 + + // Object VRAM size + VRAM_OBJ_SIZE = 0x08000 + + // Mode 3 buffer size + M3_SIZE = 0x12C00 + + // Mode 4 buffer size + M4_SIZE = 0x09600 + + // Mode 5 buffer size + M5_SIZE = 0x0A000 + + // Bitmap page size + VRAM_PAGE_SIZE = 0x0A000 +) + +// Sub sections +var ( + REG_BASE uintptr = MEM_IO + + // Background palette address + MEM_PAL_BG = MEM_PAL + + // Object palette address + MEM_PAL_OBJ = MEM_PAL + PAL_BG_SIZE + + // Front page address + MEM_VRAM_FRONT = MEM_VRAM + + // Back page address + MEM_VRAM_BACK = MEM_VRAM + VRAM_PAGE_SIZE + + // Object VRAM address + MEM_VRAM_OBJ = MEM_VRAM + VRAM_BG_SIZE +) + +// Display registers +type DISP_Type struct { + DISPCNT volatile.Register16 + _ [2]byte + DISPSTAT volatile.Register16 + VCOUNT volatile.Register16 +} + +// Background control registers +type BGCNT_Type struct { + CNT volatile.Register16 +} + +// Regular background scroll registers. (write only!) +type BG_Type struct { + HOFS volatile.Register16 + VOFS volatile.Register16 +} + +// Affine background parameters. (write only!) +type BG_AFFINE_Type struct { + PA volatile.Register16 + PB volatile.Register16 + PC volatile.Register16 + PD volatile.Register16 + X volatile.Register32 + Y volatile.Register32 +} + +type WIN_Type struct { + // win0 right, left (0xLLRR) + WIN0H volatile.Register16 + + // win1 right, left (0xLLRR) + WIN1H volatile.Register16 + + // win0 bottom, top (0xTTBB) + WIN0V volatile.Register16 + + // win1 bottom, top (0xTTBB) + WIN1V volatile.Register16 + + // win0, win1 control + IN volatile.Register16 + + // winOut, winObj control + OUT volatile.Register16 +} + +type GRAPHICS_Type struct { + // Mosaic control + MOSAIC volatile.Register32 + + // Alpha control + BLDCNT volatile.Register16 + + // Fade level + BLDALPHA volatile.Register16 + + // Blend levels + BLDY volatile.Register16 +} + +type SOUND_Type struct { + // Sweep register + CNT_L volatile.Register16 + + // Duty/Len/Envelope + CNT_H volatile.Register16 + + // Frequency/Control + CNT_X volatile.Register16 +} + +// TODO: DMA + +// TIMER +type TIMER_Type struct { + DATA volatile.Register16 + CNT volatile.Register16 +} + +// TODO: serial + +// Keypad registers +type KEY_Type struct { + INPUT volatile.Register16 + CNT volatile.Register16 +} + +// TODO: Joybus communication + +// Interrupt / System registers +type INTERRUPT_Type struct { + IE volatile.Register16 + IF volatile.Register16 + WAITCNT volatile.Register16 + IME volatile.Register16 + PAUSE volatile.Register16 +} + +// Constants for DISP: display +const ( + // BGMODE: background mode. + // Position of BGMODE field. + DISPCNT_BGMODE_Pos = 0x0 + // Bit mask of BGMODE field. + DISPCNT_BGMODE_Msk = 0x4 + // BG Mode 0. + DISPCNT_BGMODE_0 = 0x0 + // BG Mode 1. + DISPCNT_BGMODE_1 = 0x1 + // BG Mode 2. + DISPCNT_BGMODE_2 = 0x2 + // BG Mode 3. + DISPCNT_BGMODE_3 = 0x3 + // BG Mode 4. + DISPCNT_BGMODE_4 = 0x4 + + // FRAMESELECT: frame select (mode 4 and 5 only). + DISPCNT_FRAMESELECT_Pos = 0x4 + DISPCNT_FRAMESELECT_FRAME0 = 0x0 + DISPCNT_FRAMESELECT_FRAME1 = 0x1 + + // HBLANKINTERVAL: 1=Allow access to OAM during H-Blank + DISPCNT_HBLANKINTERVAL_Pos = 0x5 + DISPCNT_HBLANKINTERVAL_NOALLOW = 0x0 + DISPCNT_HBLANKINTERVAL_ALLOW = 0x1 + + // OBJCHARVRAM: (0=Two dimensional, 1=One dimensional) + DISPCNT_OBJCHARVRAM_Pos = 0x6 + DISPCNT_OBJCHARVRAM_2D = 0x0 + DISPCNT_OBJCHARVRAM_1D = 0x1 + + // FORCEDBLANK: (1=Allow FAST access to VRAM,Palette,OAM) + DISPCNT_FORCEDBLANK_Pos = 0x7 + DISPCNT_FORCEDBLANK_NOALLOW = 0x0 + DISPCNT_FORCEDBLANK_ALLOW = 0x1 + + // Screen Display BG0 + DISPCNT_SCREENDISPLAY_BG0_Pos = 0x8 + DISPCNT_SCREENDISPLAY_BG0_ENABLE = 0x1 + DISPCNT_SCREENDISPLAY_BG0_DISABLE = 0x0 + + // Screen Display BG1 + DISPCNT_SCREENDISPLAY_BG1_Pos = 0x9 + DISPCNT_SCREENDISPLAY_BG1_ENABLE = 0x1 + DISPCNT_SCREENDISPLAY_BG1_DISABLE = 0x0 + + // Screen Display BG2 + DISPCNT_SCREENDISPLAY_BG2_Pos = 0xA + DISPCNT_SCREENDISPLAY_BG2_ENABLE = 0x1 + DISPCNT_SCREENDISPLAY_BG2_DISABLE = 0x0 + + // Screen Display BG3 + DISPCNT_SCREENDISPLAY_BG3_Pos = 0xB + DISPCNT_SCREENDISPLAY_BG3_ENABLE = 0x1 + DISPCNT_SCREENDISPLAY_BG3_DISABLE = 0x0 + + // Screen Display OBJ + DISPCNT_SCREENDISPLAY_OBJ_Pos = 0xC + DISPCNT_SCREENDISPLAY_OBJ_ENABLE = 0x1 + DISPCNT_SCREENDISPLAY_OBJ_DISABLE = 0x0 + + // Window 0 Display Flag (0=Off, 1=On) + DISPCNT_WINDOW0_DISPLAY_Pos = 0xD + DISPCNT_WINDOW0_DISPLAY_ENABLE = 0x1 + DISPCNT_WINDOW0_DISPLAY_DISABLE = 0x0 + + // Window 1 Display Flag (0=Off, 1=On) + DISPCNT_WINDOW1_DISPLAY_Pos = 0xE + DISPCNT_WINDOW1_DISPLAY_ENABLE = 0x1 + DISPCNT_WINDOW1_DISPLAY_DISABLE = 0x0 + + // OBJ Window Display Flag + DISPCNT_WINDOWOBJ_DISPLAY_Pos = 0xF + DISPCNT_WINDOWOBJ_DISPLAY_ENABLE = 0x1 + DISPCNT_WINDOWOBJ_DISPLAY_DISABLE = 0x0 + + // DISPSTAT: display status. + // V-blank + DISPSTAT_VBLANK_Pos = 0x0 + DISPSTAT_VBLANK_ENABLE = 0x1 + DISPSTAT_VBLANK_DISABLE = 0x0 + + // H-blank + DISPSTAT_HBLANK_Pos = 0x1 + DISPSTAT_HBLANK_ENABLE = 0x1 + DISPSTAT_HBLANK_DISABLE = 0x0 + + // V-counter match + DISPSTAT_VCOUNTER_Pos = 0x2 + DISPSTAT_VCOUNTER_MATCH = 0x1 + DISPSTAT_VCOUNTER_NOMATCH = 0x0 + + // V-blank IRQ + DISPSTAT_VBLANK_IRQ_Pos = 0x3 + DISPSTAT_VBLANK_IRQ_ENABLE = 0x1 + DISPSTAT_VBLANK_IRQ_DISABLE = 0x0 + + // H-blank IRQ + DISPSTAT_HBLANK_IRQ_Pos = 0x4 + DISPSTAT_HBLANK_IRQ_ENABLE = 0x1 + DISPSTAT_HBLANK_IRQ_DISABLE = 0x0 + + // V-counter IRQ + DISPSTAT_VCOUNTER_IRQ_Pos = 0x5 + DISPSTAT_VCOUNTER_IRQ_ENABLE = 0x1 + DISPSTAT_VCOUNTER_IRQ_DISABLE = 0x0 + + // V-count setting + DISPSTAT_VCOUNT_SETTING_Pos = 0x8 +) + +// Constants for TIMER +const ( + // PRESCALER: Prescaler Selection (0=F/1, 1=F/64, 2=F/256, 3=F/1024) + // Position of PRESCALER field. + TIMERCNT_PRESCALER_Pos = 0x0 + // Bit mask of PRESCALER field. + TIMERCNT_PRESCALER_Msk = 0x2 + // 0=F/1 + TIMERCNT_PRESCALER_1 = 0x0 + // 1=F/64 + TIMERCNT_PRESCALER_64 = 0x1 + // 2=F/256 + TIMERCNT_PRESCALER_256 = 0x2 + // F/1024 + TIMERCNT_PRESCALER_1024 = 0x3 + + // COUNTUP: Count-up Timing (0=Normal, 1=See below) ;Not used in TM0CNT_H + // Position of COUNTUP_TIMING field. + TIMERCNT_COUNTUP_TIMING_Pos = 0x2 + TIMERCNT_COUNTUP_TIMING_NORMAL = 0x0 + TIMERCNT_COUNTUP_TIMING_ENABLED = 0x1 + + TIMERCNT_TIMER_IRQ_ENABLED_Pos = 0x06 + TIMERCNT_TIMER_IRQ_ENABLED = 0x01 + TIMERCNT_TIMER_IRQ_DISABLED = 0x00 + + TIMERCNT_TIMER_STARTSTOP_Pos = 0x07 + TIMERCNT_TIMER_START = 0x1 + TIMERCNT_TIMER_STOP = 0x0 +) + +// Constants for KEY +const ( + // KEYINPUT + KEYINPUT_PRESSED = 0x0 + KEYINPUT_RELEASED = 0x1 + KEYINPUT_BUTTON_A_Pos = 0x0 + KEYINPUT_BUTTON_B_Pos = 0x1 + KEYINPUT_BUTTON_SELECT_Pos = 0x2 + KEYINPUT_BUTTON_START_Pos = 0x3 + KEYINPUT_BUTTON_RIGHT_Pos = 0x4 + KEYINPUT_BUTTON_LEFT_Pos = 0x5 + KEYINPUT_BUTTON_UP_Pos = 0x6 + KEYINPUT_BUTTON_DOWN_Pos = 0x7 + KEYINPUT_BUTTON_R_Pos = 0x8 + KEYINPUT_BUTTON_L_Pos = 0x9 + + // KEYCNT + KEYCNT_IGNORE = 0x0 + KEYCNT_SELECT = 0x1 + KEYCNT_BUTTON_A_Pos = 0x0 + KEYCNT_BUTTON_B_Pos = 0x1 + KEYCNT_BUTTON_SELECT_Pos = 0x2 + KEYCNT_BUTTON_START_Pos = 0x3 + KEYCNT_BUTTON_RIGHT_Pos = 0x4 + KEYCNT_BUTTON_LEFT_Pos = 0x5 + KEYCNT_BUTTON_UP_Pos = 0x6 + KEYCNT_BUTTON_DOWN_Pos = 0x7 + KEYCNT_BUTTON_R_Pos = 0x8 + KEYCNT_BUTTON_L_Pos = 0x9 +) + +// Constants for INTERRUPT +const ( + // IE + INTERRUPT_IE_ENABLED = 0x1 + INTERRUPT_IE_DISABLED = 0x0 + INTERRUPT_IE_VBLANK_Pos = 0x0 + INTERRUPT_IE_HBLANK_Pos = 0x1 + INTERRUPT_IE_VCOUNTER_MATCH_Pos = 0x2 + INTERRUPT_IE_TIMER0_OVERFLOW_Pos = 0x3 + INTERRUPT_IE_TIMER1_OVERFLOW_Pos = 0x4 + INTERRUPT_IE_TIMER2_OVERFLOW_Pos = 0x5 + INTERRUPT_IE_TIMER3_OVERFLOW_Pos = 0x6 + INTERRUPT_IE_SERIAL_Pos = 0x7 + INTERRUPT_IE_DMA0_Pos = 0x8 + INTERRUPT_IE_DMA1_Pos = 0x9 + INTERRUPT_IE_DMA2_Pos = 0xA + INTERRUPT_IE_DMA3_Pos = 0xB + INTERRUPT_IE_KEYPAD_Pos = 0xC + INTERRUPT_IE_GAMPAK_Pos = 0xD + + // IF + INTERRUPT_IF_ENABLED = 0x1 + INTERRUPT_IF_DISABLED = 0x0 + INTERRUPT_IF_VBLANK_Pos = 0x0 + INTERRUPT_IF_HBLANK_Pos = 0x1 + INTERRUPT_IF_VCOUNTER_MATCH_Pos = 0x2 + INTERRUPT_IF_TIMER0_OVERFLOW_Pos = 0x3 + INTERRUPT_IF_TIMER1_OVERFLOW_Pos = 0x4 + INTERRUPT_IF_TIMER2_OVERFLOW_Pos = 0x5 + INTERRUPT_IF_TIMER3_OVERFLOW_Pos = 0x6 + INTERRUPT_IF_SERIAL_Pos = 0x7 + INTERRUPT_IF_DMA0_Pos = 0x8 + INTERRUPT_IF_DMA1_Pos = 0x9 + INTERRUPT_IF_DMA2_Pos = 0xA + INTERRUPT_IF_DMA3_Pos = 0xB + INTERRUPT_IF_KEYPAD_Pos = 0xC + INTERRUPT_IF_GAMPAK_Pos = 0xD +) diff --git a/src/examples/flash/main.go b/src/examples/flash/main.go new file mode 100644 index 0000000000..923a6f877a --- /dev/null +++ b/src/examples/flash/main.go @@ -0,0 +1,57 @@ +package main + +import ( + "machine" + "time" +) + +var ( + err error + message = "1234567887654321123456788765432112345678876543211234567887654321" +) + +func main() { + time.Sleep(3 * time.Second) + + // Print out general information + println("Flash data start: ", machine.FlashDataStart()) + println("Flash data end: ", machine.FlashDataEnd()) + println("Flash data size, bytes:", machine.Flash.Size()) + println("Flash write block size:", machine.Flash.WriteBlockSize()) + println("Flash erase block size:", machine.Flash.EraseBlockSize()) + println() + + flash := machine.OpenFlashBuffer(machine.Flash, machine.FlashDataStart()) + original := make([]byte, len(message)) + saved := make([]byte, len(message)) + + // Read flash contents on start (data shall survive power off) + print("Reading data from flash: ") + _, err = flash.Read(original) + checkError(err) + println(string(original)) + + // Write the message to flash + print("Writing data to flash: ") + flash.Seek(0, 0) // rewind back to beginning + _, err = flash.Write([]byte(message)) + checkError(err) + println(string(message)) + + // Read back flash contents after write (verify data is the same as written) + print("Reading data back from flash: ") + flash.Seek(0, 0) // rewind back to beginning + _, err = flash.Read(saved) + checkError(err) + println(string(saved)) + println() +} + +func checkError(err error) { + if err != nil { + for { + println(err.Error()) + time.Sleep(time.Second) + } + } +} diff --git a/src/examples/ram-func/main.go b/src/examples/ram-func/main.go new file mode 100644 index 0000000000..b4556c1942 --- /dev/null +++ b/src/examples/ram-func/main.go @@ -0,0 +1,93 @@ +package main + +// This example demonstrates how to use go:section to place code into RAM for +// execution. The code is present in flash in the `.data` region and copied +// into the correct place in RAM early in startup sequence (at the same time +// as non-zero global variables are initialized). +// +// This example should work on any ARM Cortex MCU. +// +// For Go code use the pragma "//go:section", for cgo use the "section" and +// "noinline" attributes. The `.ramfuncs` section is explicitly placed into +// the `.data` region by the linker script. +// +// Running the example should print out the program counter from the functions +// below. The program counters should be in different memory regions. +// +// On RP2040, for example, the output is something like this: +// +// Go in RAM: 0x20000DB4 +// Go in flash: 0x10007610 +// cgo in RAM: 0x20000DB8 +// cgo in flash: 0x10002C26 +// +// This can be confirmed using `objdump -t xxx.elf | grep main | sort`: +// +// 00000000 l df *ABS* 00000000 main +// 1000760d l F .text 00000004 main.in_flash +// 10007611 l F .text 0000000c __Thumbv6MABSLongThunk_main.in_ram +// 1000761d l F .text 0000000c __Thumbv6MABSLongThunk__Cgo_static_eea7585d7291176ad3bb_main_c_in_ram +// 1000bdb5 l O .text 00000013 main$string +// 1000bdc8 l O .text 00000013 main$string.1 +// 1000bddb l O .text 00000013 main$string.2 +// 1000bdee l O .text 00000013 main$string.3 +// 20000db1 l F .data 00000004 main.in_ram +// 20000db5 l F .data 00000004 _Cgo_static_eea7585d7291176ad3bb_main_c_in_ram +// + +import ( + "device" + "fmt" + "time" + _ "unsafe" // unsafe is required for "//go:section" +) + +/* + #define ram_func __attribute__((section(".ramfuncs"),noinline)) + + static ram_func void* main_c_in_ram() { + void* p = 0; + + asm( + "MOV %0, PC" + : "=r"(p) + ); + + return p; + } + + static void* main_c_in_flash() { + void* p = 0; + + asm( + "MOV %0, PC" + : "=r"(p) + ); + + return p; + } +*/ +import "C" + +func main() { + time.Sleep(2 * time.Second) + + fmt.Printf("Go in RAM: 0x%X\n", in_ram()) + fmt.Printf("Go in flash: 0x%X\n", in_flash()) + fmt.Printf("cgo in RAM: 0x%X\n", C.main_c_in_ram()) + fmt.Printf("cgo in flash: 0x%X\n", C.main_c_in_flash()) +} + +//go:section .ramfuncs +func in_ram() uintptr { + return device.AsmFull("MOV {}, PC", nil) +} + +// go:noinline used here to prevent function being 'inlined' into main() +// so it appears in objdump output. In normal use, go:inline is not +// required for functions running from flash (flash is the default). +// +//go:noinline +func in_flash() uintptr { + return device.AsmFull("MOV {}, PC", nil) +} diff --git a/src/examples/rtcinterrupt/rtcinterrupt.go b/src/examples/rtcinterrupt/rtcinterrupt.go new file mode 100644 index 0000000000..7211e4c30b --- /dev/null +++ b/src/examples/rtcinterrupt/rtcinterrupt.go @@ -0,0 +1,35 @@ +//go:build rp2040 + +package main + +// This example demonstrates scheduling a delayed interrupt by real time clock. +// +// An interrupt may execute user callback function or used for its side effects +// like waking up from sleep or dormant states. +// +// The interrupt can be configured to repeat. +// +// There is no separate method to disable interrupt, use 0 delay for that. +// +// Unfortunately, it is not possible to use time.Duration to work with RTC directly, +// that would introduce a circular dependency between "machine" and "time" packages. + +import ( + "fmt" + "machine" + "time" +) + +func main() { + + // Schedule and enable recurring interrupt. + // The callback function is executed in the context of an interrupt handler, + // so regular restructions for this sort of code apply: no blocking, no memory allocation, etc. + delay := time.Minute + 12*time.Second + machine.RTC.SetInterrupt(uint32(delay.Seconds()), true, func() { println("Peekaboo!") }) + + for { + fmt.Printf("%v\r\n", time.Now().Format(time.RFC3339)) + time.Sleep(1 * time.Second) + } +} diff --git a/src/internal/task/task_stack.go b/src/internal/task/task_stack.go index ed938a63a6..81e0f9ad76 100644 --- a/src/internal/task/task_stack.go +++ b/src/internal/task/task_stack.go @@ -2,7 +2,10 @@ package task -import "unsafe" +import ( + "runtime/interrupt" + "unsafe" +) //go:linkname runtimePanic runtime.runtimePanic func runtimePanic(str string) @@ -45,6 +48,9 @@ func Pause() { if *currentTask.state.canaryPtr != stackCanary { runtimePanic("goroutine stack overflow") } + if interrupt.In() { + runtimePanic("blocked inside interrupt") + } currentTask.state.pause() } diff --git a/src/internal/task/task_stack_arm64.S b/src/internal/task/task_stack_arm64.S index 93f0027a90..1baacb49f1 100644 --- a/src/internal/task/task_stack_arm64.S +++ b/src/internal/task/task_stack_arm64.S @@ -4,7 +4,6 @@ _tinygo_startTask: #else .section .text.tinygo_startTask .global tinygo_startTask -.type tinygo_startTask, %function tinygo_startTask: #endif .cfi_startproc @@ -35,7 +34,6 @@ tinygo_startTask: #endif .cfi_endproc #ifndef __MACH__ -.size tinygo_startTask, .-tinygo_startTask #endif @@ -44,7 +42,6 @@ tinygo_startTask: _tinygo_swapTask: #else .global tinygo_swapTask -.type tinygo_swapTask, %function tinygo_swapTask: #endif // This function gets the following parameters: @@ -52,12 +49,16 @@ tinygo_swapTask: // x1 = oldStack *uintptr // Save all callee-saved registers: - stp x19, x20, [sp, #-96]! + stp x19, x20, [sp, #-160]! stp x21, x22, [sp, #16] stp x23, x24, [sp, #32] stp x25, x26, [sp, #48] stp x27, x28, [sp, #64] stp x29, x30, [sp, #80] + stp d8, d9, [sp, #96] + stp d10, d11, [sp, #112] + stp d12, d13, [sp, #128] + stp d14, d15, [sp, #144] // Save the current stack pointer in oldStack. mov x8, sp @@ -67,10 +68,14 @@ tinygo_swapTask: mov sp, x0 // Restore stack state and return. + ldp d14, d15, [sp, #144] + ldp d12, d13, [sp, #128] + ldp d10, d11, [sp, #112] + ldp d8, d9, [sp, #96] ldp x29, x30, [sp, #80] ldp x27, x28, [sp, #64] ldp x25, x26, [sp, #48] ldp x23, x24, [sp, #32] ldp x21, x22, [sp, #16] - ldp x19, x20, [sp], #96 + ldp x19, x20, [sp], #160 ret diff --git a/src/internal/task/task_stack_arm64.go b/src/internal/task/task_stack_arm64.go index 4cf500bda7..164d62f186 100644 --- a/src/internal/task/task_stack_arm64.go +++ b/src/internal/task/task_stack_arm64.go @@ -1,4 +1,4 @@ -//go:build scheduler.tasks && arm64 && !windows +//go:build scheduler.tasks && arm64 package task @@ -21,8 +21,16 @@ type calleeSavedRegs struct { x27 uintptr x28 uintptr x29 uintptr + pc uintptr // aka x30 aka LR - pc uintptr // aka x30 aka LR + d8 uintptr + d9 uintptr + d10 uintptr + d11 uintptr + d12 uintptr + d13 uintptr + d14 uintptr + d15 uintptr } // archInit runs architecture-specific setup for the goroutine startup. diff --git a/src/internal/task/task_stack_arm64_windows.S b/src/internal/task/task_stack_arm64_windows.S deleted file mode 100644 index 45c1ec4ddc..0000000000 --- a/src/internal/task/task_stack_arm64_windows.S +++ /dev/null @@ -1,65 +0,0 @@ -.section .text.tinygo_startTask,"ax" -.global tinygo_startTask -tinygo_startTask: - .cfi_startproc - // Small assembly stub for starting a goroutine. This is already run on the - // new stack, with the callee-saved registers already loaded. - // Most importantly, x19 contains the pc of the to-be-started function and - // x20 contains the only argument it is given. Multiple arguments are packed - // into one by storing them in a new allocation. - - // Indicate to the unwinder that there is nothing to unwind, this is the - // root frame. It avoids the following (bogus) error message in GDB: - // Backtrace stopped: previous frame identical to this frame (corrupt stack?) - .cfi_undefined lr - - // Set the first argument of the goroutine start wrapper, which contains all - // the arguments. - mov x0, x20 - - // Branch to the "goroutine start" function. By using blx instead of bx, - // we'll return here instead of tail calling. - blr x19 - - // After return, exit this goroutine. This is a tail call. - b tinygo_pause - .cfi_endproc - - -.global tinygo_swapTask -tinygo_swapTask: - // This function gets the following parameters: - // x0 = newStack uintptr - // x1 = oldStack *uintptr - - // Save all callee-saved registers: - stp x19, x20, [sp, #-160]! - stp x21, x22, [sp, #16] - stp x23, x24, [sp, #32] - stp x25, x26, [sp, #48] - stp x27, x28, [sp, #64] - stp x29, x30, [sp, #80] - stp d8, d9, [sp, #96] - stp d10, d11, [sp, #112] - stp d12, d13, [sp, #128] - stp d14, d15, [sp, #144] - - // Save the current stack pointer in oldStack. - mov x8, sp - str x8, [x1] - - // Switch to the new stack pointer. - mov sp, x0 - - // Restore stack state and return. - ldp d14, d15, [sp, #144] - ldp d12, d13, [sp, #128] - ldp d10, d11, [sp, #112] - ldp d8, d9, [sp, #96] - ldp x29, x30, [sp, #80] - ldp x27, x28, [sp, #64] - ldp x25, x26, [sp, #48] - ldp x23, x24, [sp, #32] - ldp x21, x22, [sp, #16] - ldp x19, x20, [sp], #160 - ret diff --git a/src/internal/task/task_stack_arm64_windows.go b/src/internal/task/task_stack_arm64_windows.go deleted file mode 100644 index c3c6e8f024..0000000000 --- a/src/internal/task/task_stack_arm64_windows.go +++ /dev/null @@ -1,72 +0,0 @@ -//go:build scheduler.tasks && arm64 && windows - -package task - -import "unsafe" - -var systemStack uintptr - -// calleeSavedRegs is the list of registers that must be saved and restored -// when switching between tasks. Also see task_stack_arm64_windows.S that -// relies on the exact layout of this struct. -type calleeSavedRegs struct { - x19 uintptr - x20 uintptr - x21 uintptr - x22 uintptr - x23 uintptr - x24 uintptr - x25 uintptr - x26 uintptr - x27 uintptr - x28 uintptr - x29 uintptr - pc uintptr // aka x30 aka LR - - d8 uintptr - d9 uintptr - d10 uintptr - d11 uintptr - d12 uintptr - d13 uintptr - d14 uintptr - d15 uintptr -} - -// archInit runs architecture-specific setup for the goroutine startup. -func (s *state) archInit(r *calleeSavedRegs, fn uintptr, args unsafe.Pointer) { - // Store the initial sp for the startTask function (implemented in assembly). - s.sp = uintptr(unsafe.Pointer(r)) - - // Initialize the registers. - // These will be popped off of the stack on the first resume of the goroutine. - - // Start the function at tinygo_startTask (defined in src/internal/task/task_stack_arm64_windows.S). - // This assembly code calls a function (passed in x19) with a single argument - // (passed in x20). After the function returns, it calls Pause(). - r.pc = uintptr(unsafe.Pointer(&startTask)) - - // Pass the function to call in x19. - // This function is a compiler-generated wrapper which loads arguments out of a struct pointer. - // See createGoroutineStartWrapper (defined in compiler/goroutine.go) for more information. - r.x19 = fn - - // Pass the pointer to the arguments struct in x20. - r.x20 = uintptr(args) -} - -func (s *state) resume() { - swapTask(s.sp, &systemStack) -} - -func (s *state) pause() { - newStack := systemStack - systemStack = 0 - swapTask(newStack, &s.sp) -} - -// SystemStack returns the system stack pointer when called from a task stack. -// When called from the system stack, it returns 0. -func SystemStack() uintptr { - return systemStack -} diff --git a/src/machine/board_lorae5.go b/src/machine/board_lorae5.go index 4da7972b89..e42551b5f5 100644 --- a/src/machine/board_lorae5.go +++ b/src/machine/board_lorae5.go @@ -56,6 +56,18 @@ var ( } DefaultUART = UART0 + // Since we treat UART1 as zero, let's also call it by the real name + UART1 = UART0 + + // UART2 + UART2 = &_UART2 + _UART2 = UART{ + Buffer: NewRingBuffer(), + Bus: stm32.USART2, + TxAltFuncSelector: AF7_USART1_2, + RxAltFuncSelector: AF7_USART1_2, + } + // I2C Busses I2C1 = &I2C{ Bus: stm32.I2C1, @@ -72,4 +84,5 @@ var ( func init() { // Enable UARTs Interrupts UART0.Interrupt = interrupt.New(stm32.IRQ_USART1, _UART0.handleInterrupt) + UART2.Interrupt = interrupt.New(stm32.IRQ_USART2, _UART2.handleInterrupt) } diff --git a/src/machine/board_qtpy.go b/src/machine/board_qtpy.go index d0f89e461a..49bb9c97b5 100644 --- a/src/machine/board_qtpy.go +++ b/src/machine/board_qtpy.go @@ -29,7 +29,7 @@ const ( // Analog pins const ( - A0 = D1 + A0 = D0 A1 = D1 A2 = D2 A3 = D3 diff --git a/src/machine/board_wioterminal.go b/src/machine/board_wioterminal.go index 9923c37126..15eefbed1d 100644 --- a/src/machine/board_wioterminal.go +++ b/src/machine/board_wioterminal.go @@ -90,7 +90,7 @@ const ( BCM10 = PB02 // SPI SDO BCM11 = PB03 // SPI SCK BCM12 = PB06 - BCM13 = PA07 + BCM13 = PA04 BCM14 = PB27 // UART Serial1 BCM15 = PB26 // UART Serial1 BCM16 = PB07 diff --git a/src/machine/board_xiao-rp2040.go b/src/machine/board_xiao-rp2040.go index 197ad31fee..272fcc599d 100644 --- a/src/machine/board_xiao-rp2040.go +++ b/src/machine/board_xiao-rp2040.go @@ -50,11 +50,11 @@ const ( // I2C pins const ( - I2C0_SDA_PIN Pin = D4 - I2C0_SCL_PIN Pin = D5 + I2C0_SDA_PIN Pin = D2 + I2C0_SCL_PIN Pin = D3 - I2C1_SDA_PIN Pin = NoPin - I2C1_SCL_PIN Pin = NoPin + I2C1_SDA_PIN Pin = D4 + I2C1_SCL_PIN Pin = D5 ) // SPI pins diff --git a/src/machine/flash.go b/src/machine/flash.go new file mode 100644 index 0000000000..885e5c8872 --- /dev/null +++ b/src/machine/flash.go @@ -0,0 +1,145 @@ +//go:build nrf || nrf51 || nrf52 || nrf528xx || stm32f4 || stm32l4 || stm32wlx || atsamd21 || atsamd51 || atsame5x || rp2040 + +package machine + +import ( + "errors" + "io" + "unsafe" +) + +//go:extern __flash_data_start +var flashDataStart [0]byte + +//go:extern __flash_data_end +var flashDataEnd [0]byte + +// Return the start of the writable flash area, aligned on a page boundary. This +// is usually just after the program and static data. +func FlashDataStart() uintptr { + pagesize := uintptr(eraseBlockSize()) + return (uintptr(unsafe.Pointer(&flashDataStart)) + pagesize - 1) &^ (pagesize - 1) +} + +// Return the end of the writable flash area. Usually this is the address one +// past the end of the on-chip flash. +func FlashDataEnd() uintptr { + return uintptr(unsafe.Pointer(&flashDataEnd)) +} + +var ( + errFlashCannotErasePage = errors.New("cannot erase flash page") + errFlashInvalidWriteLength = errors.New("write flash data must align to correct number of bits") + errFlashNotAllowedWriteData = errors.New("not allowed to write flash data") + errFlashCannotWriteData = errors.New("cannot write flash data") + errFlashCannotReadPastEOF = errors.New("cannot read beyond end of flash data") + errFlashCannotWritePastEOF = errors.New("cannot write beyond end of flash data") + errFlashCannotErasePastEOF = errors.New("cannot erase beyond end of flash data") +) + +// BlockDevice is the raw device that is meant to store flash data. +type BlockDevice interface { + // ReadAt reads the given number of bytes from the block device. + io.ReaderAt + + // WriteAt writes the given number of bytes to the block device. + io.WriterAt + + // Size returns the number of bytes in this block device. + Size() int64 + + // WriteBlockSize returns the block size in which data can be written to + // memory. It can be used by a client to optimize writes, non-aligned writes + // should always work correctly. + WriteBlockSize() int64 + + // EraseBlockSize returns the smallest erasable area on this particular chip + // in bytes. This is used for the block size in EraseBlocks. + // It must be a power of two, and may be as small as 1. A typical size is 4096. + EraseBlockSize() int64 + + // EraseBlocks erases the given number of blocks. An implementation may + // transparently coalesce ranges of blocks into larger bundles if the chip + // supports this. The start and len parameters are in block numbers, use + // EraseBlockSize to map addresses to blocks. + EraseBlocks(start, len int64) error +} + +// FlashBuffer implements the ReadWriteCloser interface using the BlockDevice interface. +type FlashBuffer struct { + b BlockDevice + + // start is actual address + start uintptr + + // offset is relative to start + offset uintptr +} + +// OpenFlashBuffer opens a FlashBuffer. +func OpenFlashBuffer(b BlockDevice, address uintptr) *FlashBuffer { + return &FlashBuffer{b: b, start: address} +} + +// Read data from a FlashBuffer. +func (fl *FlashBuffer) Read(p []byte) (n int, err error) { + n, err = fl.b.ReadAt(p, int64(fl.offset)) + fl.offset += uintptr(n) + + return +} + +// Write data to a FlashBuffer. +func (fl *FlashBuffer) Write(p []byte) (n int, err error) { + // any new pages needed? + // NOTE probably will not work as expected if you try to write over page boundary + // of pages with different sizes. + pagesize := uintptr(fl.b.EraseBlockSize()) + + // calculate currentPageBlock relative to fl.start, meaning that + // block 0 -> fl.start + // block 1 -> fl.start + pagesize + // block 2 -> fl.start + pagesize*2 + // ... + currentPageBlock := (fl.start + fl.offset - FlashDataStart()) + (pagesize-1)/pagesize + lastPageBlockNeeded := (fl.start + fl.offset + uintptr(len(p)) - FlashDataStart()) + (pagesize-1)/pagesize + + // erase enough blocks to hold the data + if err := fl.b.EraseBlocks(int64(currentPageBlock), int64(lastPageBlockNeeded-currentPageBlock)); err != nil { + return 0, err + } + + // write the data + for i := 0; i < len(p); i += int(pagesize) { + var last int = i + int(pagesize) + if i+int(pagesize) > len(p) { + last = len(p) + } + + _, err := fl.b.WriteAt(p[i:last], int64(fl.offset)) + if err != nil { + return 0, err + } + fl.offset += uintptr(len(p[i:last])) + } + + return len(p), nil +} + +// Close the FlashBuffer. +func (fl *FlashBuffer) Close() error { + return nil +} + +// Seek implements io.Seeker interface, but with limitations. +// You can only seek relative to the start. +// Also, you cannot use seek before write operations, only read. +func (fl *FlashBuffer) Seek(offset int64, whence int) (int64, error) { + if whence != io.SeekStart { + panic("you can only Seek relative to Start") + } + + fl.offset = uintptr(offset) + + return offset, nil +} diff --git a/src/machine/machine_atsamd21.go b/src/machine/machine_atsamd21.go index 7a5a20c7ea..2f4f7338e9 100644 --- a/src/machine/machine_atsamd21.go +++ b/src/machine/machine_atsamd21.go @@ -7,8 +7,11 @@ package machine import ( + "bytes" "device/arm" "device/sam" + "encoding/binary" + "errors" "runtime/interrupt" "unsafe" ) @@ -1788,3 +1791,167 @@ func syncDAC() { for sam.DAC.STATUS.HasBits(sam.DAC_STATUS_SYNCBUSY) { } } + +// Flash related code +const memoryStart = 0x0 + +// compile-time check for ensuring we fulfill BlockDevice interface +var _ BlockDevice = flashBlockDevice{} + +var Flash flashBlockDevice + +type flashBlockDevice struct { + initComplete bool +} + +// ReadAt reads the given number of bytes from the block device. +func (f flashBlockDevice) ReadAt(p []byte, off int64) (n int, err error) { + if FlashDataStart()+uintptr(off)+uintptr(len(p)) > FlashDataEnd() { + return 0, errFlashCannotReadPastEOF + } + + f.ensureInitComplete() + + waitWhileFlashBusy() + + data := unsafe.Slice((*byte)(unsafe.Add(unsafe.Pointer(FlashDataStart()), uintptr(off))), len(p)) + copy(p, data) + + return len(p), nil +} + +// WriteAt writes the given number of bytes to the block device. +// Only word (32 bits) length data can be programmed. +// See Atmel-42181G–SAM-D21_Datasheet–09/2015 page 359. +// If the length of p is not long enough it will be padded with 0xFF bytes. +// This method assumes that the destination is already erased. +func (f flashBlockDevice) WriteAt(p []byte, off int64) (n int, err error) { + if FlashDataStart()+uintptr(off)+uintptr(len(p)) > FlashDataEnd() { + return 0, errFlashCannotWritePastEOF + } + + f.ensureInitComplete() + + address := FlashDataStart() + uintptr(off) + padded := f.pad(p) + + waitWhileFlashBusy() + + for j := 0; j < len(padded); j += int(f.WriteBlockSize()) { + // write word + *(*uint32)(unsafe.Pointer(address)) = binary.LittleEndian.Uint32(padded[j : j+int(f.WriteBlockSize())]) + + sam.NVMCTRL.SetADDR(uint32(address >> 1)) + sam.NVMCTRL.CTRLA.Set(sam.NVMCTRL_CTRLA_CMD_WP | (sam.NVMCTRL_CTRLA_CMDEX_KEY << sam.NVMCTRL_CTRLA_CMDEX_Pos)) + + waitWhileFlashBusy() + + if err := checkFlashError(); err != nil { + return j, err + } + + address += uintptr(f.WriteBlockSize()) + } + + return len(padded), nil +} + +// Size returns the number of bytes in this block device. +func (f flashBlockDevice) Size() int64 { + return int64(FlashDataEnd() - FlashDataStart()) +} + +const writeBlockSize = 4 + +// WriteBlockSize returns the block size in which data can be written to +// memory. It can be used by a client to optimize writes, non-aligned writes +// should always work correctly. +func (f flashBlockDevice) WriteBlockSize() int64 { + return writeBlockSize +} + +const eraseBlockSizeValue = 256 + +func eraseBlockSize() int64 { + return eraseBlockSizeValue +} + +// EraseBlockSize returns the smallest erasable area on this particular chip +// in bytes. This is used for the block size in EraseBlocks. +func (f flashBlockDevice) EraseBlockSize() int64 { + return eraseBlockSize() +} + +// EraseBlocks erases the given number of blocks. An implementation may +// transparently coalesce ranges of blocks into larger bundles if the chip +// supports this. The start and len parameters are in block numbers, use +// EraseBlockSize to map addresses to blocks. +func (f flashBlockDevice) EraseBlocks(start, len int64) error { + f.ensureInitComplete() + + address := FlashDataStart() + uintptr(start*f.EraseBlockSize()) + waitWhileFlashBusy() + + for i := start; i < start+len; i++ { + sam.NVMCTRL.SetADDR(uint32(address >> 1)) + sam.NVMCTRL.CTRLA.Set(sam.NVMCTRL_CTRLA_CMD_ER | (sam.NVMCTRL_CTRLA_CMDEX_KEY << sam.NVMCTRL_CTRLA_CMDEX_Pos)) + + waitWhileFlashBusy() + + if err := checkFlashError(); err != nil { + return err + } + + address += uintptr(f.EraseBlockSize()) + } + + return nil +} + +// pad data if needed so it is long enough for correct byte alignment on writes. +func (f flashBlockDevice) pad(p []byte) []byte { + paddingNeeded := f.WriteBlockSize() - (int64(len(p)) % f.WriteBlockSize()) + if paddingNeeded == 0 { + return p + } + + padding := bytes.Repeat([]byte{0xff}, int(paddingNeeded)) + return append(p, padding...) +} + +func (f flashBlockDevice) ensureInitComplete() { + if f.initComplete { + return + } + + sam.NVMCTRL.SetCTRLB_READMODE(sam.NVMCTRL_CTRLB_READMODE_NO_MISS_PENALTY) + sam.NVMCTRL.SetCTRLB_SLEEPPRM(sam.NVMCTRL_CTRLB_SLEEPPRM_WAKEONACCESS) + + waitWhileFlashBusy() + + f.initComplete = true +} + +func waitWhileFlashBusy() { + for sam.NVMCTRL.GetINTFLAG_READY() != sam.NVMCTRL_INTFLAG_READY { + } +} + +var ( + errFlashPROGE = errors.New("errFlashPROGE") + errFlashLOCKE = errors.New("errFlashLOCKE") + errFlashNVME = errors.New("errFlashNVME") +) + +func checkFlashError() error { + switch { + case sam.NVMCTRL.GetSTATUS_PROGE() != 0: + return errFlashPROGE + case sam.NVMCTRL.GetSTATUS_LOCKE() != 0: + return errFlashLOCKE + case sam.NVMCTRL.GetSTATUS_NVME() != 0: + return errFlashNVME + } + + return nil +} diff --git a/src/machine/machine_atsamd21g18.go b/src/machine/machine_atsamd21g18.go index 6b60cfc4aa..383f51d910 100644 --- a/src/machine/machine_atsamd21g18.go +++ b/src/machine/machine_atsamd21g18.go @@ -406,61 +406,61 @@ func (p Pin) getPinCfg() uint8 { return uint8(sam.PORT.PINCFG1_0.Get()>>16) & 0xff case 35: // PB03 return uint8(sam.PORT.PINCFG1_0.Get()>>24) & 0xff - case 37: // PB04 + case 36: // PB04 return uint8(sam.PORT.PINCFG1_4.Get()>>0) & 0xff - case 38: // PB05 + case 37: // PB05 return uint8(sam.PORT.PINCFG1_4.Get()>>8) & 0xff - case 39: // PB06 + case 38: // PB06 return uint8(sam.PORT.PINCFG1_4.Get()>>16) & 0xff - case 40: // PB07 + case 39: // PB07 return uint8(sam.PORT.PINCFG1_4.Get()>>24) & 0xff - case 41: // PB08 + case 40: // PB08 return uint8(sam.PORT.PINCFG1_8.Get()>>0) & 0xff - case 42: // PB09 + case 41: // PB09 return uint8(sam.PORT.PINCFG1_8.Get()>>8) & 0xff - case 43: // PB10 + case 42: // PB10 return uint8(sam.PORT.PINCFG1_8.Get()>>16) & 0xff - case 44: // PB11 + case 43: // PB11 return uint8(sam.PORT.PINCFG1_8.Get()>>24) & 0xff - case 45: // PB12 + case 44: // PB12 return uint8(sam.PORT.PINCFG1_12.Get()>>0) & 0xff - case 46: // PB13 + case 45: // PB13 return uint8(sam.PORT.PINCFG1_12.Get()>>8) & 0xff - case 47: // PB14 + case 46: // PB14 return uint8(sam.PORT.PINCFG1_12.Get()>>16) & 0xff - case 48: // PB15 + case 47: // PB15 return uint8(sam.PORT.PINCFG1_12.Get()>>24) & 0xff - case 49: // PB16 + case 48: // PB16 return uint8(sam.PORT.PINCFG1_16.Get()>>0) & 0xff - case 50: // PB17 + case 49: // PB17 return uint8(sam.PORT.PINCFG1_16.Get()>>8) & 0xff - case 51: // PB18 + case 50: // PB18 return uint8(sam.PORT.PINCFG1_16.Get()>>16) & 0xff - case 52: // PB19 + case 51: // PB19 return uint8(sam.PORT.PINCFG1_16.Get()>>24) & 0xff - case 53: // PB20 + case 52: // PB20 return uint8(sam.PORT.PINCFG1_20.Get()>>0) & 0xff - case 54: // PB21 + case 53: // PB21 return uint8(sam.PORT.PINCFG1_20.Get()>>8) & 0xff - case 55: // PB22 + case 54: // PB22 return uint8(sam.PORT.PINCFG1_20.Get()>>16) & 0xff - case 56: // PB23 + case 55: // PB23 return uint8(sam.PORT.PINCFG1_20.Get()>>24) & 0xff - case 57: // PB24 + case 56: // PB24 return uint8(sam.PORT.PINCFG1_24.Get()>>0) & 0xff - case 58: // PB25 + case 57: // PB25 return uint8(sam.PORT.PINCFG1_24.Get()>>8) & 0xff - case 59: // PB26 + case 58: // PB26 return uint8(sam.PORT.PINCFG1_24.Get()>>16) & 0xff - case 60: // PB27 + case 59: // PB27 return uint8(sam.PORT.PINCFG1_24.Get()>>24) & 0xff - case 61: // PB28 + case 60: // PB28 return uint8(sam.PORT.PINCFG1_28.Get()>>0) & 0xff - case 62: // PB29 + case 61: // PB29 return uint8(sam.PORT.PINCFG1_28.Get()>>8) & 0xff - case 63: // PB30 + case 62: // PB30 return uint8(sam.PORT.PINCFG1_28.Get()>>16) & 0xff - case 64: // PB31 + case 63: // PB31 return uint8(sam.PORT.PINCFG1_28.Get()>>24) & 0xff default: return 0 diff --git a/src/machine/machine_atsamd51.go b/src/machine/machine_atsamd51.go index 21da8d59d3..97bfb28592 100644 --- a/src/machine/machine_atsamd51.go +++ b/src/machine/machine_atsamd51.go @@ -7,8 +7,11 @@ package machine import ( + "bytes" "device/arm" "device/sam" + "encoding/binary" + "errors" "runtime/interrupt" "unsafe" ) @@ -2074,3 +2077,190 @@ func GetRNG() (uint32, error) { ret := sam.TRNG.DATA.Get() return ret, nil } + +// Flash related code +const memoryStart = 0x0 + +// compile-time check for ensuring we fulfill BlockDevice interface +var _ BlockDevice = flashBlockDevice{} + +var Flash flashBlockDevice + +type flashBlockDevice struct { + initComplete bool +} + +// ReadAt reads the given number of bytes from the block device. +func (f flashBlockDevice) ReadAt(p []byte, off int64) (n int, err error) { + if FlashDataStart()+uintptr(off)+uintptr(len(p)) > FlashDataEnd() { + return 0, errFlashCannotReadPastEOF + } + + waitWhileFlashBusy() + + data := unsafe.Slice((*byte)(unsafe.Add(unsafe.Pointer(FlashDataStart()), uintptr(off))), len(p)) + copy(p, data) + + return len(p), nil +} + +// WriteAt writes the given number of bytes to the block device. +// Only word (32 bits) length data can be programmed. +// See SAM-D5x-E5x-Family-Data-Sheet-DS60001507.pdf page 591-592. +// If the length of p is not long enough it will be padded with 0xFF bytes. +// This method assumes that the destination is already erased. +func (f flashBlockDevice) WriteAt(p []byte, off int64) (n int, err error) { + if FlashDataStart()+uintptr(off)+uintptr(len(p)) > FlashDataEnd() { + return 0, errFlashCannotWritePastEOF + } + + address := FlashDataStart() + uintptr(off) + padded := f.pad(p) + + settings := disableFlashCache() + defer restoreFlashCache(settings) + + waitWhileFlashBusy() + + sam.NVMCTRL.CTRLB.Set(sam.NVMCTRL_CTRLB_CMD_PBC | (sam.NVMCTRL_CTRLB_CMDEX_KEY << sam.NVMCTRL_CTRLB_CMDEX_Pos)) + + waitWhileFlashBusy() + + for j := 0; j < len(padded); j += int(f.WriteBlockSize()) { + // write first word using double-word low order word + *(*uint32)(unsafe.Pointer(address)) = binary.LittleEndian.Uint32(padded[j : j+int(f.WriteBlockSize()/2)]) + + // write second word using double-word high order word + *(*uint32)(unsafe.Add(unsafe.Pointer(address), uintptr(f.WriteBlockSize())/2)) = binary.LittleEndian.Uint32(padded[j+int(f.WriteBlockSize()/2) : j+int(f.WriteBlockSize())]) + + waitWhileFlashBusy() + + sam.NVMCTRL.SetADDR(uint32(address)) + sam.NVMCTRL.CTRLB.Set(sam.NVMCTRL_CTRLB_CMD_WQW | (sam.NVMCTRL_CTRLB_CMDEX_KEY << sam.NVMCTRL_CTRLB_CMDEX_Pos)) + + waitWhileFlashBusy() + + if err := checkFlashError(); err != nil { + return j, err + } + + address += uintptr(f.WriteBlockSize()) + } + + return len(padded), nil +} + +// Size returns the number of bytes in this block device. +func (f flashBlockDevice) Size() int64 { + return int64(FlashDataEnd() - FlashDataStart()) +} + +const writeBlockSize = 8 + +// WriteBlockSize returns the block size in which data can be written to +// memory. It can be used by a client to optimize writes, non-aligned writes +// should always work correctly. +func (f flashBlockDevice) WriteBlockSize() int64 { + return writeBlockSize +} + +const eraseBlockSizeValue = 8192 + +func eraseBlockSize() int64 { + return eraseBlockSizeValue +} + +// EraseBlockSize returns the smallest erasable area on this particular chip +// in bytes. This is used for the block size in EraseBlocks. +func (f flashBlockDevice) EraseBlockSize() int64 { + return eraseBlockSize() +} + +// EraseBlocks erases the given number of blocks. An implementation may +// transparently coalesce ranges of blocks into larger bundles if the chip +// supports this. The start and len parameters are in block numbers, use +// EraseBlockSize to map addresses to blocks. +func (f flashBlockDevice) EraseBlocks(start, len int64) error { + address := FlashDataStart() + uintptr(start*f.EraseBlockSize()) + + settings := disableFlashCache() + defer restoreFlashCache(settings) + + waitWhileFlashBusy() + + for i := start; i < start+len; i++ { + sam.NVMCTRL.SetADDR(uint32(address)) + sam.NVMCTRL.CTRLB.Set(sam.NVMCTRL_CTRLB_CMD_EB | (sam.NVMCTRL_CTRLB_CMDEX_KEY << sam.NVMCTRL_CTRLB_CMDEX_Pos)) + + waitWhileFlashBusy() + + if err := checkFlashError(); err != nil { + return err + } + + address += uintptr(f.EraseBlockSize()) + } + + return nil +} + +// pad data if needed so it is long enough for correct byte alignment on writes. +func (f flashBlockDevice) pad(p []byte) []byte { + paddingNeeded := f.WriteBlockSize() - (int64(len(p)) % f.WriteBlockSize()) + if paddingNeeded == 0 { + return p + } + + padding := bytes.Repeat([]byte{0xff}, int(paddingNeeded)) + return append(p, padding...) +} + +func disableFlashCache() uint16 { + settings := sam.NVMCTRL.CTRLA.Get() + + // disable caches + sam.NVMCTRL.SetCTRLA_CACHEDIS0(1) + sam.NVMCTRL.SetCTRLA_CACHEDIS1(1) + + waitWhileFlashBusy() + + return settings +} + +func restoreFlashCache(settings uint16) { + sam.NVMCTRL.CTRLA.Set(settings) + waitWhileFlashBusy() +} + +func waitWhileFlashBusy() { + for sam.NVMCTRL.GetSTATUS_READY() != sam.NVMCTRL_STATUS_READY { + } +} + +var ( + errFlashADDRE = errors.New("errFlashADDRE") + errFlashPROGE = errors.New("errFlashPROGE") + errFlashLOCKE = errors.New("errFlashLOCKE") + errFlashECCSE = errors.New("errFlashECCSE") + errFlashNVME = errors.New("errFlashNVME") + errFlashSEESOVF = errors.New("errFlashSEESOVF") +) + +func checkFlashError() error { + switch { + case sam.NVMCTRL.GetINTENSET_ADDRE() != 0: + return errFlashADDRE + case sam.NVMCTRL.GetINTENSET_PROGE() != 0: + return errFlashPROGE + case sam.NVMCTRL.GetINTENSET_LOCKE() != 0: + return errFlashLOCKE + case sam.NVMCTRL.GetINTENSET_ECCSE() != 0: + return errFlashECCSE + case sam.NVMCTRL.GetINTENSET_NVME() != 0: + return errFlashNVME + case sam.NVMCTRL.GetINTENSET_SEESOVF() != 0: + return errFlashSEESOVF + } + + return nil +} diff --git a/src/machine/machine_atsamd51_usb.go b/src/machine/machine_atsamd51_usb.go index b3f570ac78..13f522e6a1 100644 --- a/src/machine/machine_atsamd51_usb.go +++ b/src/machine/machine_atsamd51_usb.go @@ -142,8 +142,9 @@ func handleUSBIRQ(intr interrupt.Interrupt) { setup := usb.NewSetup(udd_ep_out_cache_buffer[0][:]) // Clear the Bank 0 ready flag on Control OUT - setEPSTATUSCLR(0, sam.USB_DEVICE_ENDPOINT_EPSTATUSCLR_BK0RDY) + usbEndpointDescriptors[0].DeviceDescBank[0].ADDR.Set(uint32(uintptr(unsafe.Pointer(&udd_ep_out_cache_buffer[0])))) usbEndpointDescriptors[0].DeviceDescBank[0].PCKSIZE.ClearBits(usb_DEVICE_PCKSIZE_BYTE_COUNT_Mask << usb_DEVICE_PCKSIZE_BYTE_COUNT_Pos) + setEPSTATUSCLR(0, sam.USB_DEVICE_ENDPOINT_EPSTATUSCLR_BK0RDY) ok := false if (setup.BmRequestType & usb.REQUEST_TYPE) == usb.REQUEST_STANDARD { @@ -347,15 +348,6 @@ func sendUSBPacket(ep uint32, data []byte, maxsize uint16) { func ReceiveUSBControlPacket() ([cdcLineInfoSize]byte, error) { var b [cdcLineInfoSize]byte - // address - usbEndpointDescriptors[0].DeviceDescBank[0].ADDR.Set(uint32(uintptr(unsafe.Pointer(&udd_ep_out_cache_buffer[0])))) - - // set byte count to zero - usbEndpointDescriptors[0].DeviceDescBank[0].PCKSIZE.ClearBits(usb_DEVICE_PCKSIZE_BYTE_COUNT_Mask << usb_DEVICE_PCKSIZE_BYTE_COUNT_Pos) - - // set ready for next data - setEPSTATUSCLR(0, sam.USB_DEVICE_ENDPOINT_EPSTATUSCLR_BK0RDY) - // Wait until OUT transfer is ready. timeout := 300000 for (getEPSTATUS(0) & sam.USB_DEVICE_ENDPOINT_EPSTATUS_BK0RDY) == 0 { diff --git a/src/machine/machine_esp32.go b/src/machine/machine_esp32.go index adca2bb689..b58cef66ae 100644 --- a/src/machine/machine_esp32.go +++ b/src/machine/machine_esp32.go @@ -143,13 +143,13 @@ func (p Pin) configure(config PinConfig, signal uint32) { // outFunc returns the FUNCx_OUT_SEL_CFG register used for configuring the // output function selection. func (p Pin) outFunc() *volatile.Register32 { - return (*volatile.Register32)(unsafe.Pointer((uintptr(unsafe.Pointer(&esp.GPIO.FUNC0_OUT_SEL_CFG)) + uintptr(p)*4))) + return (*volatile.Register32)(unsafe.Add(unsafe.Pointer(&esp.GPIO.FUNC0_OUT_SEL_CFG), uintptr(p)*4)) } // inFunc returns the FUNCy_IN_SEL_CFG register used for configuring the input // function selection. func inFunc(signal uint32) *volatile.Register32 { - return (*volatile.Register32)(unsafe.Pointer((uintptr(unsafe.Pointer(&esp.GPIO.FUNC0_IN_SEL_CFG)) + uintptr(signal)*4))) + return (*volatile.Register32)(unsafe.Add(unsafe.Pointer(&esp.GPIO.FUNC0_IN_SEL_CFG), uintptr(signal)*4)) } // Set the pin to high or low. diff --git a/src/machine/machine_esp32c3.go b/src/machine/machine_esp32c3.go index 1c60c19362..f1f646fd5e 100644 --- a/src/machine/machine_esp32c3.go +++ b/src/machine/machine_esp32c3.go @@ -108,24 +108,24 @@ func (p Pin) Configure(config PinConfig) { // outFunc returns the FUNCx_OUT_SEL_CFG register used for configuring the // output function selection. func (p Pin) outFunc() *volatile.Register32 { - return (*volatile.Register32)(unsafe.Pointer((uintptr(unsafe.Pointer(&esp.GPIO.FUNC0_OUT_SEL_CFG)) + uintptr(p)*4))) + return (*volatile.Register32)(unsafe.Add(unsafe.Pointer(&esp.GPIO.FUNC0_OUT_SEL_CFG), uintptr(p)*4)) } // inFunc returns the FUNCy_IN_SEL_CFG register used for configuring the input // function selection. func inFunc(signal uint32) *volatile.Register32 { - return (*volatile.Register32)(unsafe.Pointer((uintptr(unsafe.Pointer(&esp.GPIO.FUNC0_IN_SEL_CFG)) + uintptr(signal)*4))) + return (*volatile.Register32)(unsafe.Add(unsafe.Pointer(&esp.GPIO.FUNC0_IN_SEL_CFG), uintptr(signal)*4)) } // mux returns the I/O mux configuration register corresponding to the given // GPIO pin. func (p Pin) mux() *volatile.Register32 { - return (*volatile.Register32)(unsafe.Pointer((uintptr(unsafe.Pointer(&esp.IO_MUX.GPIO0)) + uintptr(p)*4))) + return (*volatile.Register32)(unsafe.Add(unsafe.Pointer(&esp.IO_MUX.GPIO0), uintptr(p)*4)) } // pin returns the PIN register corresponding to the given GPIO pin. func (p Pin) pin() *volatile.Register32 { - return (*volatile.Register32)(unsafe.Pointer((uintptr(unsafe.Pointer(&esp.GPIO.PIN0)) + uintptr(p)*4))) + return (*volatile.Register32)(unsafe.Add(unsafe.Pointer(&esp.GPIO.PIN0), uintptr(p)*4)) } // Set the pin to high or low. diff --git a/src/machine/machine_gameboyadvance.go b/src/machine/machine_gameboyadvance.go index 0c4cd7cbd8..0b666a4bd2 100644 --- a/src/machine/machine_gameboyadvance.go +++ b/src/machine/machine_gameboyadvance.go @@ -3,8 +3,9 @@ package machine import ( + "device/gba" + "image/color" - "runtime/interrupt" "runtime/volatile" "unsafe" ) @@ -16,40 +17,37 @@ const deviceName = "GBA" // Interrupt numbers as used on the GameBoy Advance. Register them with // runtime/interrupt.New. const ( - IRQ_VBLANK = interrupt.IRQ_VBLANK - IRQ_HBLANK = interrupt.IRQ_HBLANK - IRQ_VCOUNT = interrupt.IRQ_VCOUNT - IRQ_TIMER0 = interrupt.IRQ_TIMER0 - IRQ_TIMER1 = interrupt.IRQ_TIMER1 - IRQ_TIMER2 = interrupt.IRQ_TIMER2 - IRQ_TIMER3 = interrupt.IRQ_TIMER3 - IRQ_COM = interrupt.IRQ_COM - IRQ_DMA0 = interrupt.IRQ_DMA0 - IRQ_DMA1 = interrupt.IRQ_DMA1 - IRQ_DMA2 = interrupt.IRQ_DMA2 - IRQ_DMA3 = interrupt.IRQ_DMA3 - IRQ_KEYPAD = interrupt.IRQ_KEYPAD - IRQ_GAMEPAK = interrupt.IRQ_GAMEPAK + IRQ_VBLANK = gba.IRQ_VBLANK + IRQ_HBLANK = gba.IRQ_HBLANK + IRQ_VCOUNT = gba.IRQ_VCOUNT + IRQ_TIMER0 = gba.IRQ_TIMER0 + IRQ_TIMER1 = gba.IRQ_TIMER1 + IRQ_TIMER2 = gba.IRQ_TIMER2 + IRQ_TIMER3 = gba.IRQ_TIMER3 + IRQ_COM = gba.IRQ_COM + IRQ_DMA0 = gba.IRQ_DMA0 + IRQ_DMA1 = gba.IRQ_DMA1 + IRQ_DMA2 = gba.IRQ_DMA2 + IRQ_DMA3 = gba.IRQ_DMA3 + IRQ_KEYPAD = gba.IRQ_KEYPAD + IRQ_GAMEPAK = gba.IRQ_GAMEPAK ) -// Make it easier to directly write to I/O RAM. -var ioram = (*[0x400]volatile.Register8)(unsafe.Pointer(uintptr(0x04000000))) - // Set has not been implemented. func (p Pin) Set(value bool) { // do nothing } -var Display = FramebufDisplay{(*[160][240]volatile.Register16)(unsafe.Pointer(uintptr(0x06000000)))} +var Display = FramebufDisplay{(*[160][240]volatile.Register16)(unsafe.Pointer(uintptr(gba.MEM_VRAM)))} type FramebufDisplay struct { port *[160][240]volatile.Register16 } func (d FramebufDisplay) Configure() { - // Write into the I/O registers, setting video display parameters. - ioram[0].Set(0x03) // Use video mode 3 (in BG2, a 16bpp bitmap in VRAM) - ioram[1].Set(0x04) // Enable BG2 (BG0 = 1, BG1 = 2, BG2 = 4, ...) + // Use video mode 3 (in BG2, a 16bpp bitmap in VRAM) and Enable BG2 + gba.DISP.DISPCNT.Set(gba.DISPCNT_BGMODE_3< FlashDataEnd() { + return 0, errFlashCannotReadPastEOF + } + + data := unsafe.Slice((*byte)(unsafe.Pointer(FlashDataStart()+uintptr(off))), len(p)) + copy(p, data) + + return len(p), nil +} + +// WriteAt writes the given number of bytes to the block device. +// Only double-word (64 bits) length data can be programmed. See rm0461 page 78. +// If the length of p is not long enough it will be padded with 0xFF bytes. +// This method assumes that the destination is already erased. +func (f flashBlockDevice) WriteAt(p []byte, off int64) (n int, err error) { + if FlashDataStart()+uintptr(off)+uintptr(len(p)) > FlashDataEnd() { + return 0, errFlashCannotWritePastEOF + } + + address := FlashDataStart() + uintptr(off) + padded := f.pad(p) + + waitWhileFlashBusy() + + nrf.NVMC.SetCONFIG_WEN(nrf.NVMC_CONFIG_WEN_Wen) + defer nrf.NVMC.SetCONFIG_WEN(nrf.NVMC_CONFIG_WEN_Ren) + + for j := 0; j < len(padded); j += int(f.WriteBlockSize()) { + // write word + *(*uint32)(unsafe.Pointer(address)) = binary.LittleEndian.Uint32(padded[j : j+int(f.WriteBlockSize())]) + address += uintptr(f.WriteBlockSize()) + waitWhileFlashBusy() + } + + return len(padded), nil +} + +// Size returns the number of bytes in this block device. +func (f flashBlockDevice) Size() int64 { + return int64(FlashDataEnd() - FlashDataStart()) +} + +const writeBlockSize = 4 + +// WriteBlockSize returns the block size in which data can be written to +// memory. It can be used by a client to optimize writes, non-aligned writes +// should always work correctly. +func (f flashBlockDevice) WriteBlockSize() int64 { + return writeBlockSize +} + +// EraseBlockSize returns the smallest erasable area on this particular chip +// in bytes. This is used for the block size in EraseBlocks. +// It must be a power of two, and may be as small as 1. A typical size is 4096. +func (f flashBlockDevice) EraseBlockSize() int64 { + return eraseBlockSize() +} + +// EraseBlocks erases the given number of blocks. An implementation may +// transparently coalesce ranges of blocks into larger bundles if the chip +// supports this. The start and len parameters are in block numbers, use +// EraseBlockSize to map addresses to blocks. +func (f flashBlockDevice) EraseBlocks(start, len int64) error { + address := FlashDataStart() + uintptr(start*f.EraseBlockSize()) + waitWhileFlashBusy() + + nrf.NVMC.SetCONFIG_WEN(nrf.NVMC_CONFIG_WEN_Een) + defer nrf.NVMC.SetCONFIG_WEN(nrf.NVMC_CONFIG_WEN_Ren) + + for i := start; i < start+len; i++ { + nrf.NVMC.ERASEPAGE.Set(uint32(address)) + waitWhileFlashBusy() + address += uintptr(f.EraseBlockSize()) + } + + return nil +} + +// pad data if needed so it is long enough for correct byte alignment on writes. +func (f flashBlockDevice) pad(p []byte) []byte { + paddingNeeded := f.WriteBlockSize() - (int64(len(p)) % f.WriteBlockSize()) + if paddingNeeded == 0 { + return p + } + + padding := bytes.Repeat([]byte{0xff}, int(paddingNeeded)) + return append(p, padding...) +} + +func waitWhileFlashBusy() { + for nrf.NVMC.GetREADY() != nrf.NVMC_READY_READY_Ready { + } +} diff --git a/src/machine/machine_nrf51.go b/src/machine/machine_nrf51.go index 0fd04f776f..95723cfd06 100644 --- a/src/machine/machine_nrf51.go +++ b/src/machine/machine_nrf51.go @@ -6,6 +6,12 @@ import ( "device/nrf" ) +const eraseBlockSizeValue = 1024 + +func eraseBlockSize() int64 { + return eraseBlockSizeValue +} + // Get peripheral and pin number for this GPIO pin. func (p Pin) getPortPin() (*nrf.GPIO_Type, uint32) { return nrf.GPIO, uint32(p) diff --git a/src/machine/machine_nrf52.go b/src/machine/machine_nrf52.go index 06ca6e01c7..71c534325d 100644 --- a/src/machine/machine_nrf52.go +++ b/src/machine/machine_nrf52.go @@ -63,3 +63,9 @@ var ( PWM1 = &PWM{PWM: nrf.PWM1} PWM2 = &PWM{PWM: nrf.PWM2} ) + +const eraseBlockSizeValue = 4096 + +func eraseBlockSize() int64 { + return eraseBlockSizeValue +} diff --git a/src/machine/machine_nrf52833.go b/src/machine/machine_nrf52833.go index ded3c90cff..60558eb0e4 100644 --- a/src/machine/machine_nrf52833.go +++ b/src/machine/machine_nrf52833.go @@ -84,3 +84,9 @@ var ( PWM2 = &PWM{PWM: nrf.PWM2} PWM3 = &PWM{PWM: nrf.PWM3} ) + +const eraseBlockSizeValue = 4096 + +func eraseBlockSize() int64 { + return eraseBlockSizeValue +} diff --git a/src/machine/machine_nrf52840.go b/src/machine/machine_nrf52840.go index d38ebbe56d..21a4367803 100644 --- a/src/machine/machine_nrf52840.go +++ b/src/machine/machine_nrf52840.go @@ -102,3 +102,9 @@ func (pdm *PDM) Read(buf []int16) (uint32, error) { return uint32(len(buf)), nil } + +const eraseBlockSizeValue = 4096 + +func eraseBlockSize() int64 { + return eraseBlockSizeValue +} diff --git a/src/machine/machine_rp2040_enter_bootloader.go b/src/machine/machine_rp2040_enter_bootloader.go deleted file mode 100644 index 189060756b..0000000000 --- a/src/machine/machine_rp2040_enter_bootloader.go +++ /dev/null @@ -1,52 +0,0 @@ -//go:build rp2040 - -package machine - -/* -// https://github.com/raspberrypi/pico-sdk -// src/rp2_common/pico_bootrom/include/pico/bootrom.h - -#define ROM_FUNC_POPCOUNT32 ROM_TABLE_CODE('P', '3') -#define ROM_FUNC_REVERSE32 ROM_TABLE_CODE('R', '3') -#define ROM_FUNC_CLZ32 ROM_TABLE_CODE('L', '3') -#define ROM_FUNC_CTZ32 ROM_TABLE_CODE('T', '3') -#define ROM_FUNC_MEMSET ROM_TABLE_CODE('M', 'S') -#define ROM_FUNC_MEMSET4 ROM_TABLE_CODE('S', '4') -#define ROM_FUNC_MEMCPY ROM_TABLE_CODE('M', 'C') -#define ROM_FUNC_MEMCPY44 ROM_TABLE_CODE('C', '4') -#define ROM_FUNC_RESET_USB_BOOT ROM_TABLE_CODE('U', 'B') -#define ROM_FUNC_CONNECT_INTERNAL_FLASH ROM_TABLE_CODE('I', 'F') -#define ROM_FUNC_FLASH_EXIT_XIP ROM_TABLE_CODE('E', 'X') -#define ROM_FUNC_FLASH_RANGE_ERASE ROM_TABLE_CODE('R', 'E') -#define ROM_FUNC_FLASH_RANGE_PROGRAM ROM_TABLE_CODE('R', 'P') -#define ROM_FUNC_FLASH_FLUSH_CACHE ROM_TABLE_CODE('F', 'C') -#define ROM_FUNC_FLASH_ENTER_CMD_XIP ROM_TABLE_CODE('C', 'X') - -#define ROM_TABLE_CODE(c1, c2) ((c1) | ((c2) << 8)) - -typedef unsigned short uint16_t; -typedef unsigned long uint32_t; -typedef unsigned long uintptr_t; - -typedef void *(*rom_table_lookup_fn)(uint16_t *table, uint32_t code); -typedef void __attribute__((noreturn)) (*rom_reset_usb_boot_fn)(uint32_t, uint32_t); -#define rom_hword_as_ptr(rom_address) (void *)(uintptr_t)(*(uint16_t *)(uintptr_t)(rom_address)) - -void *rom_func_lookup(uint32_t code) { - rom_table_lookup_fn rom_table_lookup = (rom_table_lookup_fn) rom_hword_as_ptr(0x18); - uint16_t *func_table = (uint16_t *) rom_hword_as_ptr(0x14); - return rom_table_lookup(func_table, code); -} - -void reset_usb_boot(uint32_t usb_activity_gpio_pin_mask, uint32_t disable_interface_mask) { - rom_reset_usb_boot_fn func = (rom_reset_usb_boot_fn) rom_func_lookup(ROM_FUNC_RESET_USB_BOOT); - func(usb_activity_gpio_pin_mask, disable_interface_mask); -} -*/ -import "C" - -// EnterBootloader should perform a system reset in preparation -// to switch to the bootloader to flash new firmware. -func EnterBootloader() { - C.reset_usb_boot(0, 0) -} diff --git a/src/machine/machine_rp2040_gpio.go b/src/machine/machine_rp2040_gpio.go index a62a904845..fa2051af86 100644 --- a/src/machine/machine_rp2040_gpio.go +++ b/src/machine/machine_rp2040_gpio.go @@ -9,7 +9,7 @@ import ( "unsafe" ) -type io struct { +type ioType struct { status volatile.Register32 ctrl volatile.Register32 } @@ -21,7 +21,7 @@ type irqCtrl struct { } type ioBank0Type struct { - io [30]io + io [30]ioType intR [4]volatile.Register32 proc0IRQctrl irqCtrl proc1IRQctrl irqCtrl diff --git a/src/machine/machine_rp2040_i2c.go b/src/machine/machine_rp2040_i2c.go index d49e7f5760..0c5b688ef4 100644 --- a/src/machine/machine_rp2040_i2c.go +++ b/src/machine/machine_rp2040_i2c.go @@ -57,6 +57,8 @@ var ( ErrInvalidTgtAddr = errors.New("invalid target i2c address not in 0..0x80 or is reserved") ErrI2CGeneric = errors.New("i2c error") ErrRP2040I2CDisable = errors.New("i2c rp2040 peripheral timeout in disable") + errInvalidI2CSDA = errors.New("invalid I2C SDA pin") + errInvalidI2CSCL = errors.New("invalid I2C SCL pin") ) // Tx performs a write and then a read transfer placing the result in @@ -90,7 +92,7 @@ func (i2c *I2C) Tx(addr uint16, w, r []byte) error { // SCL: 3, 7, 11, 15, 19, 27 func (i2c *I2C) Configure(config I2CConfig) error { const defaultBaud uint32 = 100_000 // 100kHz standard mode - if config.SCL == 0 { + if config.SCL == 0 && config.SDA == 0 { // If config pins are zero valued or clock pin is invalid then we set default values. switch i2c.Bus { case rp.I2C0: @@ -101,6 +103,23 @@ func (i2c *I2C) Configure(config I2CConfig) error { config.SDA = I2C1_SDA_PIN } } + var okSCL, okSDA bool + switch i2c.Bus { + case rp.I2C0: + okSCL = (config.SCL+3)%4 == 0 + okSDA = (config.SDA+4)%4 == 0 + case rp.I2C1: + okSCL = (config.SCL+1)%4 == 0 + okSDA = (config.SDA+2)%4 == 0 + } + + switch { + case !okSCL: + return errInvalidI2CSCL + case !okSDA: + return errInvalidI2CSDA + } + if config.Frequency == 0 { config.Frequency = defaultBaud } diff --git a/src/machine/machine_rp2040_pwm.go b/src/machine/machine_rp2040_pwm.go index c114005e1d..cc2f2f5d93 100644 --- a/src/machine/machine_rp2040_pwm.go +++ b/src/machine/machine_rp2040_pwm.go @@ -50,7 +50,7 @@ type pwmGroup struct { // // 0x14 is the size of a pwmGroup. func getPWMGroup(index uintptr) *pwmGroup { - return (*pwmGroup)(unsafe.Pointer(uintptr(unsafe.Pointer(rp.PWM)) + 0x14*index)) + return (*pwmGroup)(unsafe.Add(unsafe.Pointer(rp.PWM), 0x14*index)) } // Hardware Pulse Width Modulation (PWM) API diff --git a/src/machine/machine_rp2040_rom.go b/src/machine/machine_rp2040_rom.go new file mode 100644 index 0000000000..bf93069cc9 --- /dev/null +++ b/src/machine/machine_rp2040_rom.go @@ -0,0 +1,256 @@ +//go:build rp2040 + +package machine + +import ( + "bytes" + "runtime/interrupt" + "unsafe" +) + +/* +// https://github.com/raspberrypi/pico-sdk +// src/rp2_common/pico_bootrom/include/pico/bootrom.h + +#define ROM_FUNC_POPCOUNT32 ROM_TABLE_CODE('P', '3') +#define ROM_FUNC_REVERSE32 ROM_TABLE_CODE('R', '3') +#define ROM_FUNC_CLZ32 ROM_TABLE_CODE('L', '3') +#define ROM_FUNC_CTZ32 ROM_TABLE_CODE('T', '3') +#define ROM_FUNC_MEMSET ROM_TABLE_CODE('M', 'S') +#define ROM_FUNC_MEMSET4 ROM_TABLE_CODE('S', '4') +#define ROM_FUNC_MEMCPY ROM_TABLE_CODE('M', 'C') +#define ROM_FUNC_MEMCPY44 ROM_TABLE_CODE('C', '4') +#define ROM_FUNC_RESET_USB_BOOT ROM_TABLE_CODE('U', 'B') +#define ROM_FUNC_CONNECT_INTERNAL_FLASH ROM_TABLE_CODE('I', 'F') +#define ROM_FUNC_FLASH_EXIT_XIP ROM_TABLE_CODE('E', 'X') +#define ROM_FUNC_FLASH_RANGE_ERASE ROM_TABLE_CODE('R', 'E') +#define ROM_FUNC_FLASH_RANGE_PROGRAM ROM_TABLE_CODE('R', 'P') +#define ROM_FUNC_FLASH_FLUSH_CACHE ROM_TABLE_CODE('F', 'C') +#define ROM_FUNC_FLASH_ENTER_CMD_XIP ROM_TABLE_CODE('C', 'X') + +#define ROM_TABLE_CODE(c1, c2) ((c1) | ((c2) << 8)) + +typedef unsigned char uint8_t; +typedef unsigned short uint16_t; +typedef unsigned long uint32_t; +typedef unsigned long size_t; +typedef unsigned long uintptr_t; + +#define false 0 +#define true 1 +typedef int bool; + +#define ram_func __attribute__((section(".ramfuncs"),noinline)) + +typedef void *(*rom_table_lookup_fn)(uint16_t *table, uint32_t code); +typedef void __attribute__((noreturn)) (*rom_reset_usb_boot_fn)(uint32_t, uint32_t); +typedef void (*flash_init_boot2_copyout_fn)(void); +typedef void (*flash_enable_xip_via_boot2_fn)(void); +typedef void (*flash_exit_xip_fn)(void); +typedef void (*flash_flush_cache_fn)(void); +typedef void (*flash_connect_internal_fn)(void); +typedef void (*flash_range_erase_fn)(uint32_t, size_t, uint32_t, uint16_t); +typedef void (*flash_range_program_fn)(uint32_t, const uint8_t*, size_t); + +static inline __attribute__((always_inline)) void __compiler_memory_barrier(void) { + __asm__ volatile ("" : : : "memory"); +} + +#define rom_hword_as_ptr(rom_address) (void *)(uintptr_t)(*(uint16_t *)(uintptr_t)(rom_address)) + +void *rom_func_lookup(uint32_t code) { + rom_table_lookup_fn rom_table_lookup = (rom_table_lookup_fn) rom_hword_as_ptr(0x18); + uint16_t *func_table = (uint16_t *) rom_hword_as_ptr(0x14); + return rom_table_lookup(func_table, code); +} + +void reset_usb_boot(uint32_t usb_activity_gpio_pin_mask, uint32_t disable_interface_mask) { + rom_reset_usb_boot_fn func = (rom_reset_usb_boot_fn) rom_func_lookup(ROM_FUNC_RESET_USB_BOOT); + func(usb_activity_gpio_pin_mask, disable_interface_mask); +} + +#define FLASH_BLOCK_ERASE_CMD 0xd8 + +#define FLASH_PAGE_SIZE (1u << 8) +#define FLASH_SECTOR_SIZE (1u << 12) +#define FLASH_BLOCK_SIZE (1u << 16) + +#define BOOT2_SIZE_WORDS 64 +#define XIP_BASE 0x10000000 + +static uint32_t boot2_copyout[BOOT2_SIZE_WORDS]; +static bool boot2_copyout_valid = false; + +static ram_func void flash_init_boot2_copyout() { + if (boot2_copyout_valid) + return; + for (int i = 0; i < BOOT2_SIZE_WORDS; ++i) + boot2_copyout[i] = ((uint32_t *)XIP_BASE)[i]; + __compiler_memory_barrier(); + boot2_copyout_valid = true; +} + +static ram_func void flash_enable_xip_via_boot2() { + ((void (*)(void))boot2_copyout+1)(); +} + +// See https://github.com/raspberrypi/pico-sdk/blob/master/src/rp2_common/hardware_flash/flash.c#L86 +void ram_func flash_range_write(uint32_t offset, const uint8_t *data, size_t count) +{ + flash_range_program_fn flash_range_program_func = (flash_range_program_fn) rom_func_lookup(ROM_FUNC_FLASH_RANGE_PROGRAM); + flash_connect_internal_fn flash_connect_internal_func = (flash_connect_internal_fn) rom_func_lookup(ROM_FUNC_CONNECT_INTERNAL_FLASH); + flash_exit_xip_fn flash_exit_xip_func = (flash_exit_xip_fn) rom_func_lookup(ROM_FUNC_FLASH_EXIT_XIP); + flash_flush_cache_fn flash_flush_cache_func = (flash_flush_cache_fn) rom_func_lookup(ROM_FUNC_FLASH_FLUSH_CACHE); + + flash_init_boot2_copyout(); + + __compiler_memory_barrier(); + + flash_connect_internal_func(); + flash_exit_xip_func(); + + flash_range_program_func(offset, data, count); + flash_flush_cache_func(); + flash_enable_xip_via_boot2(); +} + +void ram_func flash_erase_blocks(uint32_t offset, size_t count) +{ + flash_range_erase_fn flash_range_erase_func = (flash_range_erase_fn) rom_func_lookup(ROM_FUNC_FLASH_RANGE_ERASE); + flash_connect_internal_fn flash_connect_internal_func = (flash_connect_internal_fn) rom_func_lookup(ROM_FUNC_CONNECT_INTERNAL_FLASH); + flash_exit_xip_fn flash_exit_xip_func = (flash_exit_xip_fn) rom_func_lookup(ROM_FUNC_FLASH_EXIT_XIP); + flash_flush_cache_fn flash_flush_cache_func = (flash_flush_cache_fn) rom_func_lookup(ROM_FUNC_FLASH_FLUSH_CACHE); + + flash_init_boot2_copyout(); + + __compiler_memory_barrier(); + + flash_connect_internal_func(); + flash_exit_xip_func(); + + flash_range_erase_func(offset, count, FLASH_BLOCK_SIZE, FLASH_BLOCK_ERASE_CMD); + flash_flush_cache_func(); + flash_enable_xip_via_boot2(); +} + +*/ +import "C" + +// EnterBootloader should perform a system reset in preparation +// to switch to the bootloader to flash new firmware. +func EnterBootloader() { + C.reset_usb_boot(0, 0) +} + +// Flash related code +const memoryStart = C.XIP_BASE // memory start for purpose of erase + +// compile-time check for ensuring we fulfill BlockDevice interface +var _ BlockDevice = flashBlockDevice{} + +var Flash flashBlockDevice + +type flashBlockDevice struct { +} + +// ReadAt reads the given number of bytes from the block device. +func (f flashBlockDevice) ReadAt(p []byte, off int64) (n int, err error) { + if readAddress(off) > FlashDataEnd() { + return 0, errFlashCannotReadPastEOF + } + + data := unsafe.Slice((*byte)(unsafe.Pointer(readAddress(off))), len(p)) + copy(p, data) + + return len(p), nil +} + +// WriteAt writes the given number of bytes to the block device. +// Only word (32 bits) length data can be programmed. +// If the length of p is not long enough it will be padded with 0xFF bytes. +// This method assumes that the destination is already erased. +func (f flashBlockDevice) WriteAt(p []byte, off int64) (n int, err error) { + if writeAddress(off)+uintptr(C.XIP_BASE) > FlashDataEnd() { + return 0, errFlashCannotWritePastEOF + } + + state := interrupt.Disable() + defer interrupt.Restore(state) + + // rp2040 writes to offset, not actual address + // e.g. real address 0x10003000 is written to at + // 0x00003000 + address := writeAddress(off) + padded := f.pad(p) + + C.flash_range_write(C.uint32_t(address), + (*C.uint8_t)(unsafe.Pointer(&padded[0])), + C.ulong(len(padded))) + + return len(padded), nil +} + +// Size returns the number of bytes in this block device. +func (f flashBlockDevice) Size() int64 { + return int64(FlashDataEnd() - FlashDataStart()) +} + +const writeBlockSize = 4 + +// WriteBlockSize returns the block size in which data can be written to +// memory. It can be used by a client to optimize writes, non-aligned writes +// should always work correctly. +func (f flashBlockDevice) WriteBlockSize() int64 { + return writeBlockSize +} + +const eraseBlockSizeValue = 1 << 12 + +func eraseBlockSize() int64 { + return eraseBlockSizeValue +} + +// EraseBlockSize returns the smallest erasable area on this particular chip +// in bytes. This is used for the block size in EraseBlocks. +func (f flashBlockDevice) EraseBlockSize() int64 { + return eraseBlockSize() +} + +// EraseBlocks erases the given number of blocks. An implementation may +// transparently coalesce ranges of blocks into larger bundles if the chip +// supports this. The start and len parameters are in block numbers, use +// EraseBlockSize to map addresses to blocks. +func (f flashBlockDevice) EraseBlocks(start, length int64) error { + address := writeAddress(start * f.EraseBlockSize()) + if address+uintptr(C.XIP_BASE) > FlashDataEnd() { + return errFlashCannotErasePastEOF + } + + state := interrupt.Disable() + defer interrupt.Restore(state) + + C.flash_erase_blocks(C.uint32_t(address), C.ulong(length)) + + return nil +} + +// pad data if needed so it is long enough for correct byte alignment on writes. +func (f flashBlockDevice) pad(p []byte) []byte { + paddingNeeded := f.WriteBlockSize() - (int64(len(p)) % f.WriteBlockSize()) + if paddingNeeded == 0 { + return p + } + + padding := bytes.Repeat([]byte{0xff}, int(paddingNeeded)) + return append(p, padding...) +} + +// return the correct address to be used for write +func writeAddress(off int64) uintptr { + return readAddress(off) - uintptr(C.XIP_BASE) +} + +// return the correct address to be used for reads +func readAddress(off int64) uintptr { + return FlashDataStart() + uintptr(off) +} diff --git a/src/machine/machine_rp2040_rtc.go b/src/machine/machine_rp2040_rtc.go new file mode 100644 index 0000000000..192e187c0a --- /dev/null +++ b/src/machine/machine_rp2040_rtc.go @@ -0,0 +1,240 @@ +//go:build rp2040 + +// Implementation based on code located here: +// https://github.com/raspberrypi/pico-sdk/blob/master/src/rp2_common/hardware_rtc/rtc.c + +package machine + +import ( + "device/rp" + "errors" + "runtime/interrupt" + "unsafe" +) + +type rtcType rp.RTC_Type + +type rtcTime struct { + Year int16 + Month int8 + Day int8 + Dotw int8 + Hour int8 + Min int8 + Sec int8 +} + +var RTC = (*rtcType)(unsafe.Pointer(rp.RTC)) + +const ( + second = 1 + minute = 60 * second + hour = 60 * minute + day = 24 * hour +) + +var ( + rtcAlarmRepeats bool + rtcCallback func() + rtcEpoch = rtcTime{ + Year: 1970, Month: 1, Day: 1, Dotw: 4, Hour: 0, Min: 0, Sec: 0, + } +) + +var ( + ErrRtcDelayTooSmall = errors.New("RTC interrupt deplay is too small, shall be at least 1 second") + ErrRtcDelayTooLarge = errors.New("RTC interrupt deplay is too large, shall be no more than 1 day") +) + +// SetInterrupt configures delayed and optionally recurring interrupt by real time clock. +// +// Delay is specified in whole seconds, allowed range depends on platform. +// Zero delay disables previously configured interrupt, if any. +// +// RP2040 implementation allows delay to be up to 1 day, otherwise a respective error is emitted. +func (rtc *rtcType) SetInterrupt(delay uint32, repeat bool, callback func()) error { + + // Verify delay range + if delay > day { + return ErrRtcDelayTooLarge + } + + // De-configure delayed interrupt if delay is zero + if delay == 0 { + rtc.disableInterruptMatch() + return nil + } + + // Configure delayed interrupt + rtc.setDivider() + + rtcAlarmRepeats = repeat + rtcCallback = callback + + err := rtc.setTime(rtcEpoch) + if err != nil { + return err + } + rtc.setAlarm(toAlarmTime(delay), callback) + + return nil +} + +func toAlarmTime(delay uint32) rtcTime { + result := rtcEpoch + remainder := delay + 1 // needed "+1", otherwise alarm fires one second too early + if remainder >= hour { + result.Hour = int8(remainder / hour) + remainder %= hour + } + if remainder >= minute { + result.Min = int8(remainder / minute) + remainder %= minute + } + result.Sec = int8(remainder) + return result +} + +func (rtc *rtcType) setDivider() { + // Get clk_rtc freq and make sure it is running + rtcFreq := configuredFreq[clkRTC] + if rtcFreq == 0 { + panic("can not set RTC divider, clock is not running") + } + + // Take rtc out of reset now that we know clk_rtc is running + resetBlock(rp.RESETS_RESET_RTC) + unresetBlockWait(rp.RESETS_RESET_RTC) + + // Set up the 1 second divider. + // If rtc_freq is 400 then clkdiv_m1 should be 399 + rtcFreq -= 1 + + // Check the freq is not too big to divide + if rtcFreq > rp.RTC_CLKDIV_M1_CLKDIV_M1_Msk { + panic("can not set RTC divider, clock frequency is too big to divide") + } + + // Write divide value + rtc.CLKDIV_M1.Set(rtcFreq) +} + +// setTime configures RTC with supplied time, initialises and activates it. +func (rtc *rtcType) setTime(t rtcTime) error { + + // Disable RTC and wait while it is still running + rtc.CTRL.Set(0) + for rtc.isActive() { + } + + rtc.SETUP_0.Set((uint32(t.Year) << rp.RTC_SETUP_0_YEAR_Pos) | + (uint32(t.Month) << rp.RTC_SETUP_0_MONTH_Pos) | + (uint32(t.Day) << rp.RTC_SETUP_0_DAY_Pos)) + + rtc.SETUP_1.Set((uint32(t.Dotw) << rp.RTC_SETUP_1_DOTW_Pos) | + (uint32(t.Hour) << rp.RTC_SETUP_1_HOUR_Pos) | + (uint32(t.Min) << rp.RTC_SETUP_1_MIN_Pos) | + (uint32(t.Sec) << rp.RTC_SETUP_1_SEC_Pos)) + + // Load setup values into RTC clock domain + rtc.CTRL.SetBits(rp.RTC_CTRL_LOAD) + + // Enable RTC and wait for it to be running + rtc.CTRL.SetBits(rp.RTC_CTRL_RTC_ENABLE) + for !rtc.isActive() { + } + + return nil +} + +func (rtc *rtcType) isActive() bool { + return rtc.CTRL.HasBits(rp.RTC_CTRL_RTC_ACTIVE) +} + +// setAlarm configures alarm in RTC and arms it. +// The callback is executed in the context of an interrupt handler, +// so regular restructions for this sort of code apply: no blocking, no memory allocation, etc. +func (rtc *rtcType) setAlarm(t rtcTime, callback func()) { + + rtc.disableInterruptMatch() + + // Clear all match enable bits + rtc.IRQ_SETUP_0.ClearBits(rp.RTC_IRQ_SETUP_0_YEAR_ENA | rp.RTC_IRQ_SETUP_0_MONTH_ENA | rp.RTC_IRQ_SETUP_0_DAY_ENA) + rtc.IRQ_SETUP_1.ClearBits(rp.RTC_IRQ_SETUP_1_DOTW_ENA | rp.RTC_IRQ_SETUP_1_HOUR_ENA | rp.RTC_IRQ_SETUP_1_MIN_ENA | rp.RTC_IRQ_SETUP_1_SEC_ENA) + + // Only add to setup if it isn't -1 and set the match enable bits for things we care about + if t.Year >= 0 { + rtc.IRQ_SETUP_0.SetBits(uint32(t.Year) << rp.RTC_SETUP_0_YEAR_Pos) + rtc.IRQ_SETUP_0.SetBits(rp.RTC_IRQ_SETUP_0_YEAR_ENA) + } + + if t.Month >= 0 { + rtc.IRQ_SETUP_0.SetBits(uint32(t.Month) << rp.RTC_SETUP_0_MONTH_Pos) + rtc.IRQ_SETUP_0.SetBits(rp.RTC_IRQ_SETUP_0_MONTH_ENA) + } + + if t.Day >= 0 { + rtc.IRQ_SETUP_0.SetBits(uint32(t.Day) << rp.RTC_SETUP_0_DAY_Pos) + rtc.IRQ_SETUP_0.SetBits(rp.RTC_IRQ_SETUP_0_DAY_ENA) + } + + if t.Dotw >= 0 { + rtc.IRQ_SETUP_1.SetBits(uint32(t.Dotw) << rp.RTC_SETUP_1_DOTW_Pos) + rtc.IRQ_SETUP_1.SetBits(rp.RTC_IRQ_SETUP_1_DOTW_ENA) + } + + if t.Hour >= 0 { + rtc.IRQ_SETUP_1.SetBits(uint32(t.Hour) << rp.RTC_SETUP_1_HOUR_Pos) + rtc.IRQ_SETUP_1.SetBits(rp.RTC_IRQ_SETUP_1_HOUR_ENA) + } + + if t.Min >= 0 { + rtc.IRQ_SETUP_1.SetBits(uint32(t.Min) << rp.RTC_SETUP_1_MIN_Pos) + rtc.IRQ_SETUP_1.SetBits(rp.RTC_IRQ_SETUP_1_MIN_ENA) + } + + if t.Sec >= 0 { + rtc.IRQ_SETUP_1.SetBits(uint32(t.Sec) << rp.RTC_SETUP_1_SEC_Pos) + rtc.IRQ_SETUP_1.SetBits(rp.RTC_IRQ_SETUP_1_SEC_ENA) + } + + // Enable the IRQ at the proc + interrupt.New(rp.IRQ_RTC_IRQ, rtcHandleInterrupt).Enable() + + // Enable the IRQ at the peri + rtc.INTE.Set(rp.RTC_INTE_RTC) + + rtc.enableInterruptMatch() +} + +func (rtc *rtcType) enableInterruptMatch() { + // Set matching and wait for it to be enabled + rtc.IRQ_SETUP_0.SetBits(rp.RTC_IRQ_SETUP_0_MATCH_ENA) + for !rtc.IRQ_SETUP_0.HasBits(rp.RTC_IRQ_SETUP_0_MATCH_ACTIVE) { + } +} + +func (rtc *rtcType) disableInterruptMatch() { + // Disable matching and wait for it to stop being active + rtc.IRQ_SETUP_0.ClearBits(rp.RTC_IRQ_SETUP_0_MATCH_ENA) + for rtc.IRQ_SETUP_0.HasBits(rp.RTC_IRQ_SETUP_0_MATCH_ACTIVE) { + } +} + +func rtcHandleInterrupt(itr interrupt.Interrupt) { + // Always disable the alarm to clear the current IRQ. + // Even if it is a repeatable alarm, we don't want it to keep firing. + // If it matches on a second it can keep firing for that second. + RTC.disableInterruptMatch() + + // Call user callback function + if rtcCallback != nil { + rtcCallback() + } + + if rtcAlarmRepeats { + // If it is a repeatable alarm, reset time and re-enable the alarm. + RTC.setTime(rtcEpoch) + RTC.enableInterruptMatch() + } +} diff --git a/src/machine/machine_rp2040_spi.go b/src/machine/machine_rp2040_spi.go index fd1b6ab0f0..62c5c0c075 100644 --- a/src/machine/machine_rp2040_spi.go +++ b/src/machine/machine_rp2040_spi.go @@ -40,6 +40,9 @@ var ( ErrLSBNotSupported = errors.New("SPI LSB unsupported on PL022") ErrSPITimeout = errors.New("SPI timeout") ErrSPIBaud = errors.New("SPI baud too low or above 66.5Mhz") + errSPIInvalidSDI = errors.New("invalid SPI SDI pin") + errSPIInvalidSDO = errors.New("invalid SPI SDO pin") + errSPIInvalidSCK = errors.New("invalid SPI SCK pin") ) type SPI struct { @@ -162,7 +165,7 @@ func (spi SPI) GetBaudRate() uint32 { // No pin configuration is needed of SCK, SDO and SDI needed after calling Configure. func (spi SPI) Configure(config SPIConfig) error { const defaultBaud uint32 = 115200 - if config.SCK == 0 { + if config.SCK == 0 && config.SDO == 0 && config.SDI == 0 { // set default pins if config zero valued or invalid clock pin supplied. switch spi.Bus { case rp.SPI0: @@ -175,6 +178,27 @@ func (spi SPI) Configure(config SPIConfig) error { config.SDI = SPI1_SDI_PIN } } + var okSDI, okSDO, okSCK bool + switch spi.Bus { + case rp.SPI0: + okSDI = config.SDI == 0 || config.SDI == 4 || config.SDI == 16 || config.SDI == 20 + okSDO = config.SDO == 3 || config.SDO == 7 || config.SDO == 19 || config.SDO == 23 + okSCK = config.SCK == 2 || config.SCK == 6 || config.SCK == 18 || config.SCK == 22 + case rp.SPI1: + okSDI = config.SDI == 8 || config.SDI == 12 || config.SDI == 24 || config.SDI == 28 + okSDO = config.SDO == 11 || config.SDO == 15 || config.SDO == 27 + okSCK = config.SCK == 10 || config.SCK == 14 || config.SCK == 26 + } + + switch { + case !okSDI: + return errSPIInvalidSDI + case !okSDO: + return errSPIInvalidSDO + case !okSCK: + return errSPIInvalidSCK + } + if config.DataBits < 4 || config.DataBits > 16 { config.DataBits = 8 } diff --git a/src/machine/machine_rp2040_uart.go b/src/machine/machine_rp2040_uart.go index b4053aeb92..e5e4f77de3 100644 --- a/src/machine/machine_rp2040_uart.go +++ b/src/machine/machine_rp2040_uart.go @@ -41,8 +41,12 @@ func (uart *UART) Configure(config UARTConfig) error { rp.UART0_UARTCR_TXE) // set GPIO mux to UART for the pins - config.TX.Configure(PinConfig{Mode: PinUART}) - config.RX.Configure(PinConfig{Mode: PinUART}) + if config.TX != NoPin { + config.TX.Configure(PinConfig{Mode: PinUART}) + } + if config.RX != NoPin { + config.RX.Configure(PinConfig{Mode: PinUART}) + } // Enable RX IRQ. uart.Interrupt.SetPriority(0x80) diff --git a/src/machine/machine_rp2040_usb.go b/src/machine/machine_rp2040_usb.go index cb3bb789d6..6e6fd49623 100644 --- a/src/machine/machine_rp2040_usb.go +++ b/src/machine/machine_rp2040_usb.go @@ -127,10 +127,10 @@ func handleUSBIRQ(intr interrupt.Interrupt) { // Bus is reset if (status & rp.USBCTRL_REGS_INTS_BUS_RESET) > 0 { rp.USBCTRL_REGS.SIE_STATUS.Set(rp.USBCTRL_REGS_SIE_STATUS_BUS_RESET) - rp.USBCTRL_REGS.ADDR_ENDP.Set(0) + fixRP2040UsbDeviceEnumeration() + rp.USBCTRL_REGS.ADDR_ENDP.Set(0) initEndpoint(0, usb.ENDPOINT_TYPE_CONTROL) - fixRP2040UsbDeviceEnumeration() } } diff --git a/src/machine/machine_stm32_flash.go b/src/machine/machine_stm32_flash.go new file mode 100644 index 0000000000..898c50dc00 --- /dev/null +++ b/src/machine/machine_stm32_flash.go @@ -0,0 +1,126 @@ +//go:build stm32f4 || stm32l4 || stm32wlx + +package machine + +import ( + "device/stm32" + + "bytes" + "unsafe" +) + +// compile-time check for ensuring we fulfill BlockDevice interface +var _ BlockDevice = flashBlockDevice{} + +var Flash flashBlockDevice + +type flashBlockDevice struct { +} + +// ReadAt reads the given number of bytes from the block device. +func (f flashBlockDevice) ReadAt(p []byte, off int64) (n int, err error) { + if FlashDataStart()+uintptr(off)+uintptr(len(p)) > FlashDataEnd() { + return 0, errFlashCannotReadPastEOF + } + + data := unsafe.Slice((*byte)(unsafe.Pointer(FlashDataStart()+uintptr(off))), len(p)) + copy(p, data) + + return len(p), nil +} + +// WriteAt writes the given number of bytes to the block device. +// Only double-word (64 bits) length data can be programmed. See rm0461 page 78. +// If the length of p is not long enough it will be padded with 0xFF bytes. +// This method assumes that the destination is already erased. +func (f flashBlockDevice) WriteAt(p []byte, off int64) (n int, err error) { + if FlashDataStart()+uintptr(off)+uintptr(len(p)) > FlashDataEnd() { + return 0, errFlashCannotWritePastEOF + } + + unlockFlash() + defer lockFlash() + + return writeFlashData(FlashDataStart()+uintptr(off), f.pad(p)) +} + +// Size returns the number of bytes in this block device. +func (f flashBlockDevice) Size() int64 { + return int64(FlashDataEnd() - FlashDataStart()) +} + +// WriteBlockSize returns the block size in which data can be written to +// memory. It can be used by a client to optimize writes, non-aligned writes +// should always work correctly. +func (f flashBlockDevice) WriteBlockSize() int64 { + return writeBlockSize +} + +func eraseBlockSize() int64 { + return eraseBlockSizeValue +} + +// EraseBlockSize returns the smallest erasable area on this particular chip +// in bytes. This is used for the block size in EraseBlocks. +// It must be a power of two, and may be as small as 1. A typical size is 4096. +// TODO: correctly handle processors that have differently sized blocks +// in different areas of memory like the STM32F40x and STM32F1x. +func (f flashBlockDevice) EraseBlockSize() int64 { + return eraseBlockSize() +} + +// EraseBlocks erases the given number of blocks. An implementation may +// transparently coalesce ranges of blocks into larger bundles if the chip +// supports this. The start and len parameters are in block numbers, use +// EraseBlockSize to map addresses to blocks. +// Note that block 0 should map to the address of FlashDataStart(). +func (f flashBlockDevice) EraseBlocks(start, len int64) error { + var address uintptr = uintptr(start*f.EraseBlockSize()) + FlashDataStart() + blk := int64(address-uintptr(memoryStart)) / f.EraseBlockSize() + + unlockFlash() + defer lockFlash() + + for i := blk; i < blk+len; i++ { + if err := eraseBlock(uint32(i)); err != nil { + return err + } + } + + return nil +} + +// pad data if needed so it is long enough for correct byte alignment on writes. +func (f flashBlockDevice) pad(p []byte) []byte { + paddingNeeded := f.WriteBlockSize() - (int64(len(p)) % f.WriteBlockSize()) + if paddingNeeded == 0 { + return p + } + + padded := bytes.Repeat([]byte{0xff}, int(paddingNeeded)) + return append(p, padded...) +} + +const memoryStart = 0x08000000 + +func unlockFlash() { + // keys as described rm0461 page 76 + var fkey1 uint32 = 0x45670123 + var fkey2 uint32 = 0xCDEF89AB + + // Wait for the flash memory not to be busy + for stm32.FLASH.GetSR_BSY() != 0 { + } + + // Check if the controller is unlocked already + if stm32.FLASH.GetCR_LOCK() != 0 { + // Write the first key + stm32.FLASH.SetKEYR(fkey1) + // Write the second key + stm32.FLASH.SetKEYR(fkey2) + } +} + +func lockFlash() { + stm32.FLASH.SetCR_LOCK(1) +} diff --git a/src/machine/machine_stm32f4.go b/src/machine/machine_stm32f4.go index 829dc7f85b..3b8923cb76 100644 --- a/src/machine/machine_stm32f4.go +++ b/src/machine/machine_stm32f4.go @@ -6,6 +6,8 @@ package machine import ( "device/stm32" + "encoding/binary" + "errors" "math/bits" "runtime/interrupt" "runtime/volatile" @@ -791,3 +793,142 @@ func (i2c *I2C) getSpeed(config I2CConfig) uint32 { } } } + +//---------- Flash related code + +// the block size actually depends on the sector. +// TODO: handle this correctly for sectors > 3 +const eraseBlockSizeValue = 16384 + +// see RM0090 page 75 +func sectorNumber(address uintptr) uint32 { + switch { + // 0x0800 0000 - 0x0800 3FFF + case address >= 0x08000000 && address <= 0x08003FFF: + return 0 + // 0x0800 4000 - 0x0800 7FFF + case address >= 0x08004000 && address <= 0x08007FFF: + return 1 + // 0x0800 8000 - 0x0800 BFFF + case address >= 0x08008000 && address <= 0x0800BFFF: + return 2 + // 0x0800 C000 - 0x0800 FFFF + case address >= 0x0800C000 && address <= 0x0800FFFF: + return 3 + // 0x0801 0000 - 0x0801 FFFF + case address >= 0x08010000 && address <= 0x0801FFFF: + return 4 + // 0x0802 0000 - 0x0803 FFFF + case address >= 0x08020000 && address <= 0x0803FFFF: + return 5 + // 0x0804 0000 - 0x0805 FFFF + case address >= 0x08040000 && address <= 0x0805FFFF: + return 6 + case address >= 0x08060000 && address <= 0x0807FFFF: + return 7 + case address >= 0x08080000 && address <= 0x0809FFFF: + return 8 + case address >= 0x080A0000 && address <= 0x080BFFFF: + return 9 + case address >= 0x080C0000 && address <= 0x080DFFFF: + return 10 + case address >= 0x080E0000 && address <= 0x080FFFFF: + return 11 + default: + return 0 + } +} + +// calculate sector number from address +// var sector uint32 = sectorNumber(address) + +// see RM0090 page 85 +// eraseBlock at the passed in block number +func eraseBlock(block uint32) error { + waitUntilFlashDone() + + // clear any previous errors + stm32.FLASH.SR.SetBits(0xF0) + + // set SER bit + stm32.FLASH.SetCR_SER(1) + defer stm32.FLASH.SetCR_SER(0) + + // set the block (aka sector) to be erased + stm32.FLASH.SetCR_SNB(block) + defer stm32.FLASH.SetCR_SNB(0) + + // start the page erase + stm32.FLASH.SetCR_STRT(1) + + waitUntilFlashDone() + + if err := checkError(); err != nil { + return err + } + + return nil +} + +const writeBlockSize = 2 + +// see RM0090 page 86 +// must write data in word-length +func writeFlashData(address uintptr, data []byte) (int, error) { + if len(data)%writeBlockSize != 0 { + return 0, errFlashInvalidWriteLength + } + + waitUntilFlashDone() + + // clear any previous errors + stm32.FLASH.SR.SetBits(0xF0) + + // set parallelism to x32 + stm32.FLASH.SetCR_PSIZE(2) + + for i := 0; i < len(data); i += writeBlockSize { + // start write operation + stm32.FLASH.SetCR_PG(1) + + *(*uint16)(unsafe.Pointer(address)) = binary.LittleEndian.Uint16(data[i : i+writeBlockSize]) + + waitUntilFlashDone() + + if err := checkError(); err != nil { + return i, err + } + + // end write operation + stm32.FLASH.SetCR_PG(0) + } + + return len(data), nil +} + +func waitUntilFlashDone() { + for stm32.FLASH.GetSR_BSY() != 0 { + } +} + +var ( + errFlashPGS = errors.New("errFlashPGS") + errFlashPGP = errors.New("errFlashPGP") + errFlashPGA = errors.New("errFlashPGA") + errFlashWRP = errors.New("errFlashWRP") +) + +func checkError() error { + switch { + case stm32.FLASH.GetSR_PGSERR() != 0: + return errFlashPGS + case stm32.FLASH.GetSR_PGPERR() != 0: + return errFlashPGP + case stm32.FLASH.GetSR_PGAERR() != 0: + return errFlashPGA + case stm32.FLASH.GetSR_WRPERR() != 0: + return errFlashWRP + } + + return nil +} diff --git a/src/machine/machine_stm32l4.go b/src/machine/machine_stm32l4.go index b6da1b4db7..f60a77e700 100644 --- a/src/machine/machine_stm32l4.go +++ b/src/machine/machine_stm32l4.go @@ -4,6 +4,8 @@ package machine import ( "device/stm32" + "encoding/binary" + "errors" "runtime/interrupt" "runtime/volatile" "unsafe" @@ -543,3 +545,104 @@ func initRNG() { stm32.RCC.AHB2ENR.SetBits(stm32.RCC_AHB2ENR_RNGEN) stm32.RNG.CR.SetBits(stm32.RNG_CR_RNGEN) } + +//---------- Flash related code + +const eraseBlockSizeValue = 2048 + +// see RM0394 page 83 +// eraseBlock of the passed in block number +func eraseBlock(block uint32) error { + waitUntilFlashDone() + + // clear any previous errors + stm32.FLASH.SR.SetBits(0x3FA) + + // page erase operation + stm32.FLASH.SetCR_PER(1) + defer stm32.FLASH.SetCR_PER(0) + + // set the page to be erased + stm32.FLASH.SetCR_PNB(block) + + // start the page erase + stm32.FLASH.SetCR_START(1) + + waitUntilFlashDone() + + if err := checkError(); err != nil { + return err + } + + return nil +} + +const writeBlockSize = 8 + +// see RM0394 page 84 +// It is only possible to program double word (2 x 32-bit data). +func writeFlashData(address uintptr, data []byte) (int, error) { + if len(data)%writeBlockSize != 0 { + return 0, errFlashInvalidWriteLength + } + + waitUntilFlashDone() + + // clear any previous errors + stm32.FLASH.SR.SetBits(0x3FA) + + for j := 0; j < len(data); j += writeBlockSize { + // start page write operation + stm32.FLASH.SetCR_PG(1) + + // write second word using double-word high order word + *(*uint32)(unsafe.Pointer(address)) = binary.LittleEndian.Uint32(data[j : j+writeBlockSize/2]) + + address += writeBlockSize / 2 + + // write first word using double-word low order word + *(*uint32)(unsafe.Pointer(address)) = binary.LittleEndian.Uint32(data[j+writeBlockSize/2 : j+writeBlockSize]) + + waitUntilFlashDone() + + if err := checkError(); err != nil { + return j, err + } + + // end flash write + stm32.FLASH.SetCR_PG(0) + address += writeBlockSize / 2 + } + + return len(data), nil +} + +func waitUntilFlashDone() { + for stm32.FLASH.GetSR_BSY() != 0 { + } +} + +var ( + errFlashPGS = errors.New("errFlashPGS") + errFlashSIZE = errors.New("errFlashSIZE") + errFlashPGA = errors.New("errFlashPGA") + errFlashWRP = errors.New("errFlashWRP") + errFlashPROG = errors.New("errFlashPROG") +) + +func checkError() error { + switch { + case stm32.FLASH.GetSR_PGSERR() != 0: + return errFlashPGS + case stm32.FLASH.GetSR_SIZERR() != 0: + return errFlashSIZE + case stm32.FLASH.GetSR_PGAERR() != 0: + return errFlashPGA + case stm32.FLASH.GetSR_WRPERR() != 0: + return errFlashWRP + case stm32.FLASH.GetSR_PROGERR() != 0: + return errFlashPROG + } + + return nil +} diff --git a/src/machine/machine_stm32wlx.go b/src/machine/machine_stm32wlx.go index 32b3b58237..010d038e03 100644 --- a/src/machine/machine_stm32wlx.go +++ b/src/machine/machine_stm32wlx.go @@ -6,6 +6,8 @@ package machine import ( "device/stm32" + "encoding/binary" + "errors" "math/bits" "runtime/interrupt" "runtime/volatile" @@ -424,3 +426,115 @@ const ( ARR_MAX = 0x10000 PSC_MAX = 0x10000 ) + +//---------- Flash related code + +const eraseBlockSizeValue = 2048 + +// eraseBlock of the passed in block number +func eraseBlock(block uint32) error { + waitUntilFlashDone() + + // check if operation is allowed. + if stm32.FLASH.GetSR_PESD() != 0 { + return errFlashCannotErasePage + } + + // clear any previous errors + stm32.FLASH.SR.SetBits(0x3FA) + + // page erase operation + stm32.FLASH.SetCR_PER(1) + defer stm32.FLASH.SetCR_PER(0) + + // set the address to the page to be written + stm32.FLASH.SetCR_PNB(block) + defer stm32.FLASH.SetCR_PNB(0) + + // start the page erase + stm32.FLASH.SetCR_STRT(1) + + waitUntilFlashDone() + + if err := checkError(); err != nil { + return err + } + + return nil +} + +const writeBlockSize = 8 + +func writeFlashData(address uintptr, data []byte) (int, error) { + if len(data)%writeBlockSize != 0 { + return 0, errFlashInvalidWriteLength + } + + waitUntilFlashDone() + + // check if operation is allowed + if stm32.FLASH.GetSR_PESD() != 0 { + return 0, errFlashNotAllowedWriteData + } + + // clear any previous errors + stm32.FLASH.SR.SetBits(0x3FA) + + for j := 0; j < len(data); j += writeBlockSize { + // start page write operation + stm32.FLASH.SetCR_PG(1) + + // write first word using double-word high order word + *(*uint32)(unsafe.Pointer(address)) = binary.LittleEndian.Uint32(data[j : j+writeBlockSize/2]) + + address += writeBlockSize / 2 + + // write second word using double-word low order word + *(*uint32)(unsafe.Pointer(address)) = binary.LittleEndian.Uint32(data[j+writeBlockSize/2 : j+writeBlockSize]) + + waitUntilFlashDone() + + if err := checkError(); err != nil { + return j, err + } + + // end flash write + stm32.FLASH.SetCR_PG(0) + address += writeBlockSize / 2 + } + + return len(data), nil +} + +func waitUntilFlashDone() { + for stm32.FLASH.GetSR_BSY() != 0 { + } + + for stm32.FLASH.GetSR_CFGBSY() != 0 { + } +} + +var ( + errFlashPGS = errors.New("errFlashPGS") + errFlashSIZE = errors.New("errFlashSIZE") + errFlashPGA = errors.New("errFlashPGA") + errFlashWRP = errors.New("errFlashWRP") + errFlashPROG = errors.New("errFlashPROG") +) + +func checkError() error { + switch { + case stm32.FLASH.GetSR_PGSERR() != 0: + return errFlashPGS + case stm32.FLASH.GetSR_SIZERR() != 0: + return errFlashSIZE + case stm32.FLASH.GetSR_PGAERR() != 0: + return errFlashPGA + case stm32.FLASH.GetSR_WRPERR() != 0: + return errFlashWRP + case stm32.FLASH.GetSR_PROGERR() != 0: + return errFlashPROG + } + + return nil +} diff --git a/src/machine/usb/descriptor.go b/src/machine/usb/descriptor.go index 3acb2e9a79..d57d0bd8a9 100644 --- a/src/machine/usb/descriptor.go +++ b/src/machine/usb/descriptor.go @@ -67,7 +67,7 @@ var DescriptorCDCHID = Descriptor{ 0x07, 0x05, 0x02, 0x02, 0x40, 0x00, 0x00, 0x07, 0x05, 0x83, 0x02, 0x40, 0x00, 0x00, 0x09, 0x04, 0x02, 0x00, 0x01, 0x03, 0x00, 0x00, 0x00, - 0x09, 0x21, 0x01, 0x01, 0x00, 0x01, 0x22, 0x65, 0x00, + 0x09, 0x21, 0x01, 0x01, 0x00, 0x01, 0x22, 0x7E, 0x00, 0x07, 0x05, 0x84, 0x03, 0x40, 0x00, 0x01, }, HID: map[uint16][]byte{ @@ -80,6 +80,19 @@ var DescriptorCDCHID = Descriptor{ 0x03, 0x15, 0x00, 0x25, 0x01, 0x95, 0x03, 0x75, 0x01, 0x81, 0x02, 0x95, 0x01, 0x75, 0x05, 0x81, 0x03, 0x05, 0x01, 0x09, 0x30, 0x09, 0x31, 0x09, 0x38, 0x15, 0x81, 0x25, 0x7f, 0x75, 0x08, 0x95, 0x03, 0x81, 0x06, 0xc0, 0xc0, + + 0x05, 0x0C, // Usage Page (Consumer) + 0x09, 0x01, // Usage (Consumer Control) + 0xA1, 0x01, // Collection (Application) + 0x85, 0x03, // Report ID (3) + 0x15, 0x00, // Logical Minimum (0) + 0x26, 0xFF, 0x1F, // Logical Maximum (8191) + 0x19, 0x00, // Usage Minimum (Unassigned) + 0x2A, 0xFF, 0x1F, // Usage Maximum (0x1FFF) + 0x75, 0x10, // Report Size (16) + 0x95, 0x01, // Report Count (1) + 0x81, 0x00, // Input (Data,Array,Abs,No Wrap,Linear,Preferred State,No Null Position) + 0xC0, // End Collection }, }, } diff --git a/src/machine/usb/hid/keyboard/keyboard.go b/src/machine/usb/hid/keyboard/keyboard.go index ce69939989..6a5bad647b 100644 --- a/src/machine/usb/hid/keyboard/keyboard.go +++ b/src/machine/usb/hid/keyboard/keyboard.go @@ -238,16 +238,26 @@ func (kb *keyboard) sendKey(consumer bool, b []byte) bool { func (kb *keyboard) keyboardSendKeys(consumer bool) bool { var b [9]byte - b[0] = 0x02 - b[1] = kb.mod - b[2] = 0x02 - b[3] = kb.key[0] - b[4] = kb.key[1] - b[5] = kb.key[2] - b[6] = kb.key[3] - b[7] = kb.key[4] - b[8] = kb.key[5] - return kb.sendKey(consumer, b[:]) + + if !consumer { + b[0] = 0x02 // REPORT_ID + b[1] = kb.mod + b[2] = 0x02 + b[3] = kb.key[0] + b[4] = kb.key[1] + b[5] = kb.key[2] + b[6] = kb.key[3] + b[7] = kb.key[4] + b[8] = kb.key[5] + return kb.sendKey(consumer, b[:]) + + } else { + b[0] = 0x03 // REPORT_ID + b[1] = uint8(kb.con[0]) + b[2] = uint8((kb.con[0] & 0x0300) >> 8) + + return kb.sendKey(consumer, b[:3]) + } } // Down transmits a key-down event for the given Keycode. diff --git a/src/machine/usb/hid/keyboard/keycode.go b/src/machine/usb/hid/keyboard/keycode.go index 8b26eaccd8..762f65b9c0 100644 --- a/src/machine/usb/hid/keyboard/keycode.go +++ b/src/machine/usb/hid/keyboard/keycode.go @@ -77,6 +77,7 @@ const ( KeyModifierRightAlt Keycode = 0x40 | 0xE000 KeyModifierRightGUI Keycode = 0x80 | 0xE000 + // KeySystemXXX is not supported now KeySystemPowerDown Keycode = 0x81 | 0xE200 KeySystemSleep Keycode = 0x82 | 0xE200 KeySystemWakeUp Keycode = 0x83 | 0xE200 diff --git a/src/net/README.md b/src/net/README.md new file mode 100644 index 0000000000..bba8fdbba8 --- /dev/null +++ b/src/net/README.md @@ -0,0 +1,107 @@ +This is a port of Go's "net" package. The port offers a subset of Go's "net" +package. The subset maintains Go 1 compatiblity guarantee. + +The "net" package is modified to use netdev, TinyGo's network device driver interface. +Netdev replaces the OS syscall interface for I/O access to the networking +device. + +#### Table of Contents + +- ["net" Package](#net-package) +- [Netdev and Netlink](#netdev-and-netlink) +- [Using "net" and "net/http" Packages](#using-net-and-nethttp-packages) + +## "net" Package + +The "net" package is ported from Go 1.19.3. The tree listings below shows the +files copied. If the file is marked with an '\*', it is copied _and_ modified +to work with netdev. If the file is marked with an '+', the file is new. If +there is no mark, it is a straight copy. + +``` +src/net +├── dial.go * +├── http +│   ├── client.go * +│   ├── clone.go +│   ├── cookie.go +│   ├── fs.go +│   ├── header.go * +│   ├── http.go +│   ├── internal +│   │   ├── ascii +│   │   │   ├── print.go +│   │   │   └── print_test.go +│   │   ├── chunked.go +│   │   └── chunked_test.go +│   ├── jar.go +│   ├── method.go +│   ├── request.go * +│   ├── response.go * +│   ├── server.go * +│   ├── sniff.go +│   ├── status.go +│   ├── transfer.go * +│   └── transport.go * +├── ip.go +├── iprawsock.go * +├── ipsock.go * +├── mac.go +├── mac_test.go +├── netdev.go + +├── net.go * +├── parse.go +├── pipe.go +├── README.md +├── tcpsock.go * +├── tlssock.go + +└── udpsock.go * + +src/crypto/tls/ +├── common.go * +└── tls.go * +``` + +The modifications to "net" are to basically wrap TCPConn, UDPConn, and TLSConn +around netdev socket calls. In Go, these net.Conns call out to OS syscalls for +the socket operations. In TinyGo, the OS syscalls aren't available, so netdev +socket calls are substituted. + +The modifications to "net/http" are on the client and the server side. On the +client side, the TinyGo code changes remove the back-end round-tripper code and +replaces it with direct calls to TCPConns/TLSConns. All of Go's http +request/response handling code is intact and operational in TinyGo. Same holds +true for the server side. The server side supports the normal server features +like ServeMux and Hijacker (for websockets). + +### Maintaining "net" + +As Go progresses, changes to the "net" package need to be periodically +back-ported to TinyGo's "net" package. This is to pick up any upstream bug +fixes or security fixes. + +Changes "net" package files are marked with // TINYGO comments. + +The files that are marked modified * may contain only a subset of the original +file. Basically only the parts necessary to compile and run the example/net +examples are copied (and maybe modified). + +## Netdev and Netlink + +Netdev is TinyGo's network device driver model. Network drivers implement the +netdever interface, providing a common network I/O interface to TinyGo's "net" +package. The interface is modeled after the BSD socket interface. net.Conn +implementations (TCPConn, UDPConn, and TLSConn) use the netdev interface for +device I/O access. + +Network drivers also (optionally) implement the Netlinker interface. This +interface is not used by TinyGo's "net" package, but rather provides the TinyGo +application direct access to the network device for common settings and control +that fall outside of netdev's socket interface. + +See the README-net.md in drivers repo for more details on netdev and netlink. + +## Using "net" and "net/http" Packages + +See README-net.md in drivers repo to more details on using "net" and "net/http" +packages in a TinyGo application. diff --git a/src/net/conn_test.go b/src/net/conn_test.go deleted file mode 100644 index 4e1ac28c43..0000000000 --- a/src/net/conn_test.go +++ /dev/null @@ -1,478 +0,0 @@ -// The following is copied from x/net official implementation. -// Source: https://cs.opensource.google/go/x/net/+/f15817d1:nettest/conntest.go -// Changes from original the file: -// - Some variables are pulled in from nettest/nettest.go file. -// - The implementation of checkForTimeoutError() function is changed in -// accordance with error returned by the Pipe implementation. - -// Copyright 2016 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package net - -import ( - "bytes" - "encoding/binary" - "io" - "io/ioutil" - "math/rand" - "os" - "runtime" - "sync" - "testing" - "time" -) - -// The following variables are copied from nettest/nettest.go file -var ( - aLongTimeAgo = time.Unix(233431200, 0) - neverTimeout = time.Time{} -) - -// MakePipe creates a connection between two endpoints and returns the pair -// as c1 and c2, such that anything written to c1 is read by c2 and vice-versa. -// The stop function closes all resources, including c1, c2, and the underlying -// Listener (if there is one), and should not be nil. -type MakePipe func() (c1, c2 Conn, stop func(), err error) - -// testConn tests that a Conn implementation properly satisfies the interface. -// The tests should not produce any false positives, but may experience -// false negatives. Thus, some issues may only be detected when the test is -// run multiple times. For maximal effectiveness, run the tests under the -// race detector. -func testConn(t *testing.T, mp MakePipe) { - t.Run("BasicIO", func(t *testing.T) { timeoutWrapper(t, mp, testBasicIO) }) - t.Run("PingPong", func(t *testing.T) { timeoutWrapper(t, mp, testPingPong) }) - t.Run("RacyRead", func(t *testing.T) { timeoutWrapper(t, mp, testRacyRead) }) - t.Run("RacyWrite", func(t *testing.T) { timeoutWrapper(t, mp, testRacyWrite) }) - t.Run("ReadTimeout", func(t *testing.T) { timeoutWrapper(t, mp, testReadTimeout) }) - t.Run("WriteTimeout", func(t *testing.T) { timeoutWrapper(t, mp, testWriteTimeout) }) - t.Run("PastTimeout", func(t *testing.T) { timeoutWrapper(t, mp, testPastTimeout) }) - t.Run("PresentTimeout", func(t *testing.T) { timeoutWrapper(t, mp, testPresentTimeout) }) - t.Run("FutureTimeout", func(t *testing.T) { timeoutWrapper(t, mp, testFutureTimeout) }) - t.Run("CloseTimeout", func(t *testing.T) { timeoutWrapper(t, mp, testCloseTimeout) }) - t.Run("ConcurrentMethods", func(t *testing.T) { timeoutWrapper(t, mp, testConcurrentMethods) }) -} - -type connTester func(t *testing.T, c1, c2 Conn) - -func timeoutWrapper(t *testing.T, mp MakePipe, f connTester) { - t.Helper() - c1, c2, stop, err := mp() - if err != nil { - t.Fatalf("unable to make pipe: %v", err) - } - var once sync.Once - defer once.Do(func() { stop() }) - timer := time.AfterFunc(time.Minute, func() { - once.Do(func() { - t.Error("test timed out; terminating pipe") - stop() - }) - }) - defer timer.Stop() - f(t, c1, c2) -} - -// testBasicIO tests that the data sent on c1 is properly received on c2. -func testBasicIO(t *testing.T, c1, c2 Conn) { - want := make([]byte, 1<<20) - rand.New(rand.NewSource(0)).Read(want) - - dataCh := make(chan []byte) - go func() { - rd := bytes.NewReader(want) - if err := chunkedCopy(c1, rd); err != nil { - t.Errorf("unexpected c1.Write error: %v", err) - } - if err := c1.Close(); err != nil { - t.Errorf("unexpected c1.Close error: %v", err) - } - }() - - go func() { - wr := new(bytes.Buffer) - if err := chunkedCopy(wr, c2); err != nil { - t.Errorf("unexpected c2.Read error: %v", err) - } - if err := c2.Close(); err != nil { - t.Errorf("unexpected c2.Close error: %v", err) - } - dataCh <- wr.Bytes() - }() - - if got := <-dataCh; !bytes.Equal(got, want) { - t.Error("transmitted data differs") - } -} - -// testPingPong tests that the two endpoints can synchronously send data to -// each other in a typical request-response pattern. -func testPingPong(t *testing.T, c1, c2 Conn) { - var wg sync.WaitGroup - defer wg.Wait() - - pingPonger := func(c Conn) { - defer wg.Done() - buf := make([]byte, 8) - var prev uint64 - for { - if _, err := io.ReadFull(c, buf); err != nil { - if err == io.EOF { - break - } - t.Errorf("unexpected Read error: %v", err) - } - - v := binary.LittleEndian.Uint64(buf) - binary.LittleEndian.PutUint64(buf, v+1) - if prev != 0 && prev+2 != v { - t.Errorf("mismatching value: got %d, want %d", v, prev+2) - } - prev = v - if v == 1000 { - break - } - - if _, err := c.Write(buf); err != nil { - t.Errorf("unexpected Write error: %v", err) - break - } - } - if err := c.Close(); err != nil { - t.Errorf("unexpected Close error: %v", err) - } - } - - wg.Add(2) - go pingPonger(c1) - go pingPonger(c2) - - // Start off the chain reaction. - if _, err := c1.Write(make([]byte, 8)); err != nil { - t.Errorf("unexpected c1.Write error: %v", err) - } -} - -// testRacyRead tests that it is safe to mutate the input Read buffer -// immediately after cancelation has occurred. -func testRacyRead(t *testing.T, c1, c2 Conn) { - go chunkedCopy(c2, rand.New(rand.NewSource(0))) - - var wg sync.WaitGroup - defer wg.Wait() - - c1.SetReadDeadline(time.Now().Add(time.Millisecond)) - for i := 0; i < 10; i++ { - wg.Add(1) - go func() { - defer wg.Done() - - b1 := make([]byte, 1024) - b2 := make([]byte, 1024) - for j := 0; j < 100; j++ { - _, err := c1.Read(b1) - copy(b1, b2) // Mutate b1 to trigger potential race - if err != nil { - checkForTimeoutError(t, err) - c1.SetReadDeadline(time.Now().Add(time.Millisecond)) - } - } - }() - } -} - -// testRacyWrite tests that it is safe to mutate the input Write buffer -// immediately after cancelation has occurred. -func testRacyWrite(t *testing.T, c1, c2 Conn) { - go chunkedCopy(ioutil.Discard, c2) - - var wg sync.WaitGroup - defer wg.Wait() - - c1.SetWriteDeadline(time.Now().Add(time.Millisecond)) - for i := 0; i < 10; i++ { - wg.Add(1) - go func() { - defer wg.Done() - - b1 := make([]byte, 1024) - b2 := make([]byte, 1024) - for j := 0; j < 100; j++ { - _, err := c1.Write(b1) - copy(b1, b2) // Mutate b1 to trigger potential race - if err != nil { - checkForTimeoutError(t, err) - c1.SetWriteDeadline(time.Now().Add(time.Millisecond)) - } - } - }() - } -} - -// testReadTimeout tests that Read timeouts do not affect Write. -func testReadTimeout(t *testing.T, c1, c2 Conn) { - go chunkedCopy(ioutil.Discard, c2) - - c1.SetReadDeadline(aLongTimeAgo) - _, err := c1.Read(make([]byte, 1024)) - checkForTimeoutError(t, err) - if _, err := c1.Write(make([]byte, 1024)); err != nil { - t.Errorf("unexpected Write error: %v", err) - } -} - -// testWriteTimeout tests that Write timeouts do not affect Read. -func testWriteTimeout(t *testing.T, c1, c2 Conn) { - go chunkedCopy(c2, rand.New(rand.NewSource(0))) - - c1.SetWriteDeadline(aLongTimeAgo) - _, err := c1.Write(make([]byte, 1024)) - checkForTimeoutError(t, err) - if _, err := c1.Read(make([]byte, 1024)); err != nil { - t.Errorf("unexpected Read error: %v", err) - } -} - -// testPastTimeout tests that a deadline set in the past immediately times out -// Read and Write requests. -func testPastTimeout(t *testing.T, c1, c2 Conn) { - go chunkedCopy(c2, c2) - - testRoundtrip(t, c1) - - c1.SetDeadline(aLongTimeAgo) - n, err := c1.Write(make([]byte, 1024)) - if n != 0 { - t.Errorf("unexpected Write count: got %d, want 0", n) - } - checkForTimeoutError(t, err) - n, err = c1.Read(make([]byte, 1024)) - if n != 0 { - t.Errorf("unexpected Read count: got %d, want 0", n) - } - checkForTimeoutError(t, err) - - testRoundtrip(t, c1) -} - -// testPresentTimeout tests that a past deadline set while there are pending -// Read and Write operations immediately times out those operations. -func testPresentTimeout(t *testing.T, c1, c2 Conn) { - var wg sync.WaitGroup - defer wg.Wait() - wg.Add(3) - - deadlineSet := make(chan bool, 1) - go func() { - defer wg.Done() - time.Sleep(100 * time.Millisecond) - deadlineSet <- true - c1.SetReadDeadline(aLongTimeAgo) - c1.SetWriteDeadline(aLongTimeAgo) - }() - go func() { - defer wg.Done() - n, err := c1.Read(make([]byte, 1024)) - if n != 0 { - t.Errorf("unexpected Read count: got %d, want 0", n) - } - checkForTimeoutError(t, err) - if len(deadlineSet) == 0 { - t.Error("Read timed out before deadline is set") - } - }() - go func() { - defer wg.Done() - var err error - for err == nil { - _, err = c1.Write(make([]byte, 1024)) - } - checkForTimeoutError(t, err) - if len(deadlineSet) == 0 { - t.Error("Write timed out before deadline is set") - } - }() -} - -// testFutureTimeout tests that a future deadline will eventually time out -// Read and Write operations. -func testFutureTimeout(t *testing.T, c1, c2 Conn) { - var wg sync.WaitGroup - wg.Add(2) - - c1.SetDeadline(time.Now().Add(100 * time.Millisecond)) - go func() { - defer wg.Done() - _, err := c1.Read(make([]byte, 1024)) - checkForTimeoutError(t, err) - }() - go func() { - defer wg.Done() - var err error - for err == nil { - _, err = c1.Write(make([]byte, 1024)) - } - checkForTimeoutError(t, err) - }() - wg.Wait() - - go chunkedCopy(c2, c2) - resyncConn(t, c1) - testRoundtrip(t, c1) -} - -// testCloseTimeout tests that calling Close immediately times out pending -// Read and Write operations. -func testCloseTimeout(t *testing.T, c1, c2 Conn) { - go chunkedCopy(c2, c2) - - var wg sync.WaitGroup - defer wg.Wait() - wg.Add(3) - - // Test for cancelation upon connection closure. - c1.SetDeadline(neverTimeout) - go func() { - defer wg.Done() - time.Sleep(100 * time.Millisecond) - c1.Close() - }() - go func() { - defer wg.Done() - var err error - buf := make([]byte, 1024) - for err == nil { - _, err = c1.Read(buf) - } - }() - go func() { - defer wg.Done() - var err error - buf := make([]byte, 1024) - for err == nil { - _, err = c1.Write(buf) - } - }() -} - -// testConcurrentMethods tests that the methods of Conn can safely -// be called concurrently. -func testConcurrentMethods(t *testing.T, c1, c2 Conn) { - if runtime.GOOS == "plan9" { - t.Skip("skipping on plan9; see https://golang.org/issue/20489") - } - go chunkedCopy(c2, c2) - - // The results of the calls may be nonsensical, but this should - // not trigger a race detector warning. - var wg sync.WaitGroup - for i := 0; i < 100; i++ { - wg.Add(7) - go func() { - defer wg.Done() - c1.Read(make([]byte, 1024)) - }() - go func() { - defer wg.Done() - c1.Write(make([]byte, 1024)) - }() - go func() { - defer wg.Done() - c1.SetDeadline(time.Now().Add(10 * time.Millisecond)) - }() - go func() { - defer wg.Done() - c1.SetReadDeadline(aLongTimeAgo) - }() - go func() { - defer wg.Done() - c1.SetWriteDeadline(aLongTimeAgo) - }() - go func() { - defer wg.Done() - c1.LocalAddr() - }() - go func() { - defer wg.Done() - c1.RemoteAddr() - }() - } - wg.Wait() // At worst, the deadline is set 10ms into the future - - resyncConn(t, c1) - testRoundtrip(t, c1) -} - -// checkForTimeoutError checks that the error satisfies the OpError interface -// and that underlying Err is os.ErrDeadlineExceeded -func checkForTimeoutError(t *testing.T, err error) { - t.Helper() - operr, ok := err.(*OpError) - if !ok { - t.Errorf("got %T: %v, want OpError", err, err) - return - } - if operr.Err != os.ErrDeadlineExceeded { - t.Errorf("got %T: %v, want os.ErrDeadlineExceeded", err, err) - } -} - -// testRoundtrip writes something into c and reads it back. -// It assumes that everything written into c is echoed back to itself. -func testRoundtrip(t *testing.T, c Conn) { - t.Helper() - if err := c.SetDeadline(neverTimeout); err != nil { - t.Errorf("roundtrip SetDeadline error: %v", err) - } - - const s = "Hello, world!" - buf := []byte(s) - if _, err := c.Write(buf); err != nil { - t.Errorf("roundtrip Write error: %v", err) - } - if _, err := io.ReadFull(c, buf); err != nil { - t.Errorf("roundtrip Read error: %v", err) - } - if string(buf) != s { - t.Errorf("roundtrip data mismatch: got %q, want %q", buf, s) - } -} - -// resyncConn resynchronizes the connection into a sane state. -// It assumes that everything written into c is echoed back to itself. -// It assumes that 0xff is not currently on the wire or in the read buffer. -func resyncConn(t *testing.T, c Conn) { - t.Helper() - c.SetDeadline(neverTimeout) - errCh := make(chan error) - go func() { - _, err := c.Write([]byte{0xff}) - errCh <- err - }() - buf := make([]byte, 1024) - for { - n, err := c.Read(buf) - if n > 0 && bytes.IndexByte(buf[:n], 0xff) == n-1 { - break - } - if err != nil { - t.Errorf("unexpected Read error: %v", err) - break - } - } - if err := <-errCh; err != nil { - t.Errorf("unexpected Write error: %v", err) - } -} - -// chunkedCopy copies from r to w in fixed-width chunks to avoid -// causing a Write that exceeds the maximum packet size for packet-based -// connections like "unixpacket". -// We assume that the maximum packet size is at least 1024. -func chunkedCopy(w io.Writer, r io.Reader) error { - b := make([]byte, 1024) - _, err := io.CopyBuffer(struct{ io.Writer }{w}, struct{ io.Reader }{r}, b) - return err -} diff --git a/src/net/dial.go b/src/net/dial.go index a1cb75d87f..ac32e62182 100644 --- a/src/net/dial.go +++ b/src/net/dial.go @@ -1,25 +1,169 @@ +// TINYGO: The following is copied and modified from Go 1.19.3 official implementation. + +// TINYGO: Omit DualStack support +// TINYGO: Omit Fast Fallback support +// TINYGO: Don't allow alternate resolver +// TINYGO: Omit DialTimeout + +// Copyright 2010 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + package net import ( "context" + "fmt" "time" ) +// defaultTCPKeepAlive is a default constant value for TCPKeepAlive times +// See golang.org/issue/31510 +const ( + defaultTCPKeepAlive = 15 * time.Second +) + +// A Dialer contains options for connecting to an address. +// +// The zero value for each field is equivalent to dialing +// without that option. Dialing with the zero value of Dialer +// is therefore equivalent to just calling the Dial function. +// +// It is safe to call Dialer's methods concurrently. type Dialer struct { - Timeout time.Duration - Deadline time.Time - DualStack bool + // Timeout is the maximum amount of time a dial will wait for + // a connect to complete. If Deadline is also set, it may fail + // earlier. + // + // The default is no timeout. + // + // When using TCP and dialing a host name with multiple IP + // addresses, the timeout may be divided between them. + // + // With or without a timeout, the operating system may impose + // its own earlier timeout. For instance, TCP timeouts are + // often around 3 minutes. + Timeout time.Duration + + // Deadline is the absolute point in time after which dials + // will fail. If Timeout is set, it may fail earlier. + // Zero means no deadline, or dependent on the operating system + // as with the Timeout option. + Deadline time.Time + + // LocalAddr is the local address to use when dialing an + // address. The address must be of a compatible type for the + // network being dialed. + // If nil, a local address is automatically chosen. + LocalAddr Addr + + // KeepAlive specifies the interval between keep-alive + // probes for an active network connection. + // If zero, keep-alive probes are sent with a default value + // (currently 15 seconds), if supported by the protocol and operating + // system. Network protocols or operating systems that do + // not support keep-alives ignore this field. + // If negative, keep-alive probes are disabled. KeepAlive time.Duration } +// Dial connects to the address on the named network. +// +// See Go "net" package Dial() for more information. +// +// Note: Tinygo Dial supports a subset of networks supported by Go Dial, +// specifically: "tcp", "tcp4", "udp", and "udp4". IP and unix networks are +// not supported. func Dial(network, address string) (Conn, error) { - return nil, ErrNotImplemented + var d Dialer + return d.Dial(network, address) } -func Listen(network, address string) (Listener, error) { - return nil, ErrNotImplemented +// DialTimeout acts like Dial but takes a timeout. +// +// The timeout includes name resolution, if required. +// When using TCP, and the host in the address parameter resolves to +// multiple IP addresses, the timeout is spread over each consecutive +// dial, such that each is given an appropriate fraction of the time +// to connect. +// +// See func Dial for a description of the network and address +// parameters. +func DialTimeout(network, address string, timeout time.Duration) (Conn, error) { + d := Dialer{Timeout: timeout} + return d.Dial(network, address) +} + +// Dial connects to the address on the named network. +// +// See func Dial for a description of the network and address +// parameters. +// +// Dial uses context.Background internally; to specify the context, use +// DialContext. +func (d *Dialer) Dial(network, address string) (Conn, error) { + return d.DialContext(context.Background(), network, address) } +// DialContext connects to the address on the named network using +// the provided context. +// +// The provided Context must be non-nil. If the context expires before +// the connection is complete, an error is returned. Once successfully +// connected, any expiration of the context will not affect the +// connection. +// +// When using TCP, and the host in the address parameter resolves to multiple +// network addresses, any dial timeout (from d.Timeout or ctx) is spread +// over each consecutive dial, such that each is given an appropriate +// fraction of the time to connect. +// For example, if a host has 4 IP addresses and the timeout is 1 minute, +// the connect to each single address will be given 15 seconds to complete +// before trying the next one. +// +// See func Dial for a description of the network and address +// parameters. func (d *Dialer) DialContext(ctx context.Context, network, address string) (Conn, error) { - return nil, ErrNotImplemented + + // TINYGO: Ignoring context + + switch network { + case "tcp", "tcp4": + raddr, err := ResolveTCPAddr(network, address) + if err != nil { + return nil, err + } + return DialTCP(network, nil, raddr) + case "udp", "udp4": + raddr, err := ResolveUDPAddr(network, address) + if err != nil { + return nil, err + } + return DialUDP(network, nil, raddr) + } + + return nil, fmt.Errorf("Network %s not supported", network) +} + +// Listen announces on the local network address. +// +// See Go "net" package Listen() for more information. +// +// Note: Tinygo Listen supports a subset of networks supported by Go Listen, +// specifically: "tcp", "tcp4". "tcp6" and unix networks are not supported. +func Listen(network, address string) (Listener, error) { + + // println("Listen", address) + switch network { + case "tcp", "tcp4": + default: + return nil, fmt.Errorf("Network %s not supported", network) + } + + laddr, err := ResolveTCPAddr(network, address) + if err != nil { + return nil, err + } + + return listenTCP(laddr) } diff --git a/src/net/errors.go b/src/net/errors.go deleted file mode 100644 index c1dc7b31c8..0000000000 --- a/src/net/errors.go +++ /dev/null @@ -1,10 +0,0 @@ -package net - -import "errors" - -var ( - // copied from poll.ErrNetClosing - errClosed = errors.New("use of closed network connection") - - ErrNotImplemented = errors.New("operation not implemented") -) diff --git a/src/net/http/client.go b/src/net/http/client.go new file mode 100644 index 0000000000..8bfef71efb --- /dev/null +++ b/src/net/http/client.go @@ -0,0 +1,523 @@ +// TINYGO: The following is copied and modified from Go 1.19.3 official implementation. + +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// HTTP client. See RFC 7230 through 7235. +// +// This is the high-level Client interface. +// The low-level implementation is in transport.go. + +package http + +import ( + "bufio" + "crypto/tls" + "encoding/base64" + "errors" + "fmt" + "io" + "net" + "net/http/internal/ascii" + "net/url" + "strings" + "time" + + "golang.org/x/net/http/httpguts" +) + +// A Client is an HTTP client. Its zero value (DefaultClient) is a +// usable client that uses DefaultTransport. +// +// The Client's Transport typically has internal state (cached TCP +// connections), so Clients should be reused instead of created as +// needed. Clients are safe for concurrent use by multiple goroutines. +// +// A Client is higher-level than a RoundTripper (such as Transport) +// and additionally handles HTTP details such as cookies and +// redirects. +// +// When following redirects, the Client will forward all headers set on the +// initial Request except: +// +// • when forwarding sensitive headers like "Authorization", +// "WWW-Authenticate", and "Cookie" to untrusted targets. +// These headers will be ignored when following a redirect to a domain +// that is not a subdomain match or exact match of the initial domain. +// For example, a redirect from "foo.com" to either "foo.com" or "sub.foo.com" +// will forward the sensitive headers, but a redirect to "bar.com" will not. +// +// • when forwarding the "Cookie" header with a non-nil cookie Jar. +// Since each redirect may mutate the state of the cookie jar, +// a redirect may possibly alter a cookie set in the initial request. +// When forwarding the "Cookie" header, any mutated cookies will be omitted, +// with the expectation that the Jar will insert those mutated cookies +// with the updated values (assuming the origin matches). +// If Jar is nil, the initial cookies are forwarded without change. +type Client struct { + // Jar specifies the cookie jar. + // + // The Jar is used to insert relevant cookies into every + // outbound Request and is updated with the cookie values + // of every inbound Response. The Jar is consulted for every + // redirect that the Client follows. + // + // If Jar is nil, cookies are only sent if they are explicitly + // set on the Request. + Jar CookieJar + + // Timeout specifies a time limit for requests made by this + // Client. The timeout includes connection time, any + // redirects, and reading the response body. The timer remains + // running after Get, Head, Post, or Do return and will + // interrupt reading of the Response.Body. + // + // A Timeout of zero means no timeout. + // + // The Client cancels requests to the underlying Transport + // as if the Request's Context ended. + // + // For compatibility, the Client will also use the deprecated + // CancelRequest method on Transport if found. New + // RoundTripper implementations should use the Request's Context + // for cancellation instead of implementing CancelRequest. + Timeout time.Duration +} + +// DefaultClient is the default Client and is used by Get, Head, and Post. +var DefaultClient = &Client{} + +// didTimeout is non-nil only if err != nil. +func (c *Client) send(req *Request, deadline time.Time) (resp *Response, didTimeout func() bool, err error) { + if c.Jar != nil { + for _, cookie := range c.Jar.Cookies(req.URL) { + req.AddCookie(cookie) + } + } + resp, didTimeout, err = send(req, deadline) + if err != nil { + return nil, didTimeout, err + } + if c.Jar != nil { + if rc := resp.Cookies(); len(rc) > 0 { + c.Jar.SetCookies(req.URL, rc) + } + } + return resp, nil, nil +} + +func (c *Client) deadline() time.Time { + if c.Timeout > 0 { + return time.Now().Add(c.Timeout) + } + return time.Time{} +} + +// send issues an HTTP request. +// Caller should close resp.Body when done reading from it. +func send(req *Request, deadline time.Time) (resp *Response, didTimeout func() bool, err error) { + + // TINYGO: Removed round tripper + + if req.URL == nil { + req.closeBody() + return nil, alwaysFalse, errors.New("http: nil Request.URL") + } + + if req.RequestURI != "" { + req.closeBody() + return nil, alwaysFalse, errors.New("http: Request.RequestURI can't be set in client requests") + } + + // TINYGO: Removed forkReq stuff + + // Most the callers of send (Get, Post, et al) don't need + // Headers, leaving it uninitialized. We guarantee to the + // Transport that this has been initialized, though. + if req.Header == nil { + req.Header = make(Header) + } + + if u := req.URL.User; u != nil && req.Header.Get("Authorization") == "" { + username := u.Username() + password, _ := u.Password() + req.Header.Set("Authorization", "Basic "+basicAuth(username, password)) + } + + resp, err = roundTrip(req) + if err != nil { + + // TINYGO: Remove TLS error check + + return nil, didTimeout, err + } + if resp == nil { + return nil, didTimeout, fmt.Errorf("http: sendit returned a nil *Response with a nil error") + } + + // TINYGO: Skip check for resp.Body == nil since we'll set it in roundTrip + + return resp, nil, nil +} + +func roundTrip(req *Request) (*Response, error) { + + // TINYGO: This is an approximation of Transport.roudTrip() + + if req.URL == nil { + req.closeBody() + return nil, errors.New("http: nil Request.URL") + } + if req.Header == nil { + req.closeBody() + return nil, errors.New("http: nil Request.Header") + } + scheme := req.URL.Scheme + isHTTP := scheme == "http" || scheme == "https" + if isHTTP { + for k, vv := range req.Header { + if !httpguts.ValidHeaderFieldName(k) { + req.closeBody() + return nil, fmt.Errorf("net/http: invalid header field name %q", k) + } + for _, v := range vv { + if !httpguts.ValidHeaderFieldValue(v) { + req.closeBody() + // Don't include the value in the error, because it may be sensitive. + return nil, fmt.Errorf("net/http: invalid header field value for %q", k) + } + } + } + } + + // TINYGO: Skipping alternate round tripper + + if !isHTTP { + req.closeBody() + return nil, badStringError("unsupported protocol scheme", scheme) + } + if req.Method != "" && !validMethod(req.Method) { + req.closeBody() + return nil, fmt.Errorf("net/http: invalid method %q", req.Method) + } + if req.URL.Host == "" { + req.closeBody() + return nil, errors.New("http: no Host in request URL") + } + + // TINYGO: From here on just brute force dial a connection, + // TINYGO: send the request, read and return the response. + // TINYGO: The connection is closed when resp body is closed. + + var conn net.Conn + var err error + + host := req.Host + missingPort := !strings.Contains(host, ":") + + switch scheme { + case "http": + if missingPort { + host = host + ":80" + } + conn, err = net.Dial("tcp", host) + case "https": + if missingPort { + host = host + ":443" + } + conn, err = tls.Dial("tcp", host, nil) + } + if err != nil { + req.closeBody() + return nil, err + } + + // TINYGO: TODO handle timeouts + + writer := bufio.NewWriter(conn) + if err = req.Write(writer); err != nil { + req.closeBody() + return nil, err + } + req.closeBody() + if err = writer.Flush(); err != nil { + return nil, err + } + + req.onEOF = func() { conn.Close() } + + reader := bufio.NewReader(conn) + return ReadResponse(reader, req) +} + +// See 2 (end of page 4) https://www.ietf.org/rfc/rfc2617.txt +// "To receive authorization, the client sends the userid and password, +// separated by a single colon (":") character, within a base64 +// encoded string in the credentials." +// It is not meant to be urlencoded. +func basicAuth(username, password string) string { + auth := username + ":" + password + return base64.StdEncoding.EncodeToString([]byte(auth)) +} + +// Get issues a GET to the specified URL. If the response is one of +// the following redirect codes, Get follows the redirect, up to a +// maximum of 10 redirects: +// +// 301 (Moved Permanently) +// 302 (Found) +// 303 (See Other) +// 307 (Temporary Redirect) +// 308 (Permanent Redirect) +// +// An error is returned if there were too many redirects or if there +// was an HTTP protocol error. A non-2xx response doesn't cause an +// error. Any returned error will be of type *url.Error. The url.Error +// value's Timeout method will report true if the request timed out. +// +// When err is nil, resp always contains a non-nil resp.Body. +// Caller should close resp.Body when done reading from it. +// +// Get is a wrapper around DefaultClient.Get. +// +// To make a request with custom headers, use NewRequest and +// DefaultClient.Do. +// +// To make a request with a specified context.Context, use NewRequestWithContext +// and DefaultClient.Do. +func Get(url string) (resp *Response, err error) { + return DefaultClient.Get(url) +} + +// Get issues a GET to the specified URL. If the response is one of the +// following redirect codes, Get follows the redirect after calling the +// Client's CheckRedirect function: +// +// 301 (Moved Permanently) +// 302 (Found) +// 303 (See Other) +// 307 (Temporary Redirect) +// 308 (Permanent Redirect) +// +// An error is returned if the Client's CheckRedirect function fails +// or if there was an HTTP protocol error. A non-2xx response doesn't +// cause an error. Any returned error will be of type *url.Error. The +// url.Error value's Timeout method will report true if the request +// timed out. +// +// When err is nil, resp always contains a non-nil resp.Body. +// Caller should close resp.Body when done reading from it. +// +// To make a request with custom headers, use NewRequest and Client.Do. +// +// To make a request with a specified context.Context, use NewRequestWithContext +// and Client.Do. +func (c *Client) Get(url string) (resp *Response, err error) { + req, err := NewRequest("GET", url, nil) + if err != nil { + return nil, err + } + return c.Do(req) +} + +func alwaysFalse() bool { return false } + +// urlErrorOp returns the (*url.Error).Op value to use for the +// provided (*Request).Method value. +func urlErrorOp(method string) string { + if method == "" { + return "Get" + } + if lowerMethod, ok := ascii.ToLower(method); ok { + return method[:1] + lowerMethod[1:] + } + return method +} + +// Do sends an HTTP request and returns an HTTP response, following +// policy (such as redirects, cookies, auth) as configured on the +// client. +// +// An error is returned if caused by client policy (such as +// CheckRedirect), or failure to speak HTTP (such as a network +// connectivity problem). A non-2xx status code doesn't cause an +// error. +// +// If the returned error is nil, the Response will contain a non-nil +// Body which the user is expected to close. If the Body is not both +// read to EOF and closed, the Client's underlying RoundTripper +// (typically Transport) may not be able to re-use a persistent TCP +// connection to the server for a subsequent "keep-alive" request. +// +// The request Body, if non-nil, will be closed by the underlying +// Transport, even on errors. +// +// On error, any Response can be ignored. A non-nil Response with a +// non-nil error only occurs when CheckRedirect fails, and even then +// the returned Response.Body is already closed. +// +// Generally Get, Post, or PostForm will be used instead of Do. +// +// If the server replies with a redirect, the Client first uses the +// CheckRedirect function to determine whether the redirect should be +// followed. If permitted, a 301, 302, or 303 redirect causes +// subsequent requests to use HTTP method GET +// (or HEAD if the original request was HEAD), with no body. +// A 307 or 308 redirect preserves the original HTTP method and body, +// provided that the Request.GetBody function is defined. +// The NewRequest function automatically sets GetBody for common +// standard library body types. +// +// Any returned error will be of type *url.Error. The url.Error +// value's Timeout method will report true if the request timed out. +func (c *Client) Do(req *Request) (*Response, error) { + return c.do(req) +} + +func (c *Client) do(req *Request) (retres *Response, reterr error) { + if req.URL == nil { + req.closeBody() + return nil, &url.Error{ + Op: urlErrorOp(req.Method), + Err: errors.New("http: nil Request.URL"), + } + } + + var err error + var didTimeout func() bool + var resp *Response + var deadline = c.deadline() + + // TINYGO: lots removed here, mostly handling multiple requests. + // TINYGO: we just want simple GET, POST, etc. In and out. + + if resp, didTimeout, err = c.send(req, deadline); err != nil { + // c.send() always closes req.Body + if !deadline.IsZero() && didTimeout() { + return nil, fmt.Errorf("%s (Client.Timeout exceeded while awaiting headers)", err.Error()) + } + return nil, err + } + + return resp, nil +} + +// Post issues a POST to the specified URL. +// +// Caller should close resp.Body when done reading from it. +// +// If the provided body is an io.Closer, it is closed after the +// request. +// +// Post is a wrapper around DefaultClient.Post. +// +// To set custom headers, use NewRequest and DefaultClient.Do. +// +// See the Client.Do method documentation for details on how redirects +// are handled. +// +// To make a request with a specified context.Context, use NewRequestWithContext +// and DefaultClient.Do. +func Post(url, contentType string, body io.Reader) (resp *Response, err error) { + return DefaultClient.Post(url, contentType, body) +} + +// Post issues a POST to the specified URL. +// +// Caller should close resp.Body when done reading from it. +// +// If the provided body is an io.Closer, it is closed after the +// request. +// +// To set custom headers, use NewRequest and Client.Do. +// +// To make a request with a specified context.Context, use NewRequestWithContext +// and Client.Do. +// +// See the Client.Do method documentation for details on how redirects +// are handled. +func (c *Client) Post(url, contentType string, body io.Reader) (resp *Response, err error) { + req, err := NewRequest("POST", url, body) + if err != nil { + return nil, err + } + req.Header.Set("Content-Type", contentType) + return c.Do(req) +} + +// PostForm issues a POST to the specified URL, with data's keys and +// values URL-encoded as the request body. +// +// The Content-Type header is set to application/x-www-form-urlencoded. +// To set other headers, use NewRequest and DefaultClient.Do. +// +// When err is nil, resp always contains a non-nil resp.Body. +// Caller should close resp.Body when done reading from it. +// +// PostForm is a wrapper around DefaultClient.PostForm. +// +// See the Client.Do method documentation for details on how redirects +// are handled. +// +// To make a request with a specified context.Context, use NewRequestWithContext +// and DefaultClient.Do. +func PostForm(url string, data url.Values) (resp *Response, err error) { + return DefaultClient.PostForm(url, data) +} + +// PostForm issues a POST to the specified URL, +// with data's keys and values URL-encoded as the request body. +// +// The Content-Type header is set to application/x-www-form-urlencoded. +// To set other headers, use NewRequest and Client.Do. +// +// When err is nil, resp always contains a non-nil resp.Body. +// Caller should close resp.Body when done reading from it. +// +// See the Client.Do method documentation for details on how redirects +// are handled. +// +// To make a request with a specified context.Context, use NewRequestWithContext +// and Client.Do. +func (c *Client) PostForm(url string, data url.Values) (resp *Response, err error) { + return c.Post(url, "application/x-www-form-urlencoded", strings.NewReader(data.Encode())) +} + +// Head issues a HEAD to the specified URL. If the response is one of +// the following redirect codes, Head follows the redirect, up to a +// maximum of 10 redirects: +// +// 301 (Moved Permanently) +// 302 (Found) +// 303 (See Other) +// 307 (Temporary Redirect) +// 308 (Permanent Redirect) +// +// Head is a wrapper around DefaultClient.Head. +// +// To make a request with a specified context.Context, use NewRequestWithContext +// and DefaultClient.Do. +func Head(url string) (resp *Response, err error) { + return DefaultClient.Head(url) +} + +// Head issues a HEAD to the specified URL. If the response is one of the +// following redirect codes, Head follows the redirect after calling the +// Client's CheckRedirect function: +// +// 301 (Moved Permanently) +// 302 (Found) +// 303 (See Other) +// 307 (Temporary Redirect) +// 308 (Permanent Redirect) +// +// To make a request with a specified context.Context, use NewRequestWithContext +// and Client.Do. +func (c *Client) Head(url string) (resp *Response, err error) { + req, err := NewRequest("HEAD", url, nil) + if err != nil { + return nil, err + } + return c.Do(req) +} diff --git a/src/net/http/clone.go b/src/net/http/clone.go new file mode 100644 index 0000000000..aa42a7e9c7 --- /dev/null +++ b/src/net/http/clone.go @@ -0,0 +1,76 @@ +// TINYGO: The following is copied from Go 1.19.3 official implementation. + +// Copyright 2019 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package http + +import ( + "mime/multipart" + "net/textproto" + "net/url" +) + +func cloneURLValues(v url.Values) url.Values { + if v == nil { + return nil + } + // http.Header and url.Values have the same representation, so temporarily + // treat it like http.Header, which does have a clone: + return url.Values(Header(v).Clone()) +} + +func cloneURL(u *url.URL) *url.URL { + if u == nil { + return nil + } + u2 := new(url.URL) + *u2 = *u + if u.User != nil { + u2.User = new(url.Userinfo) + *u2.User = *u.User + } + return u2 +} + +func cloneMultipartForm(f *multipart.Form) *multipart.Form { + if f == nil { + return nil + } + f2 := &multipart.Form{ + Value: (map[string][]string)(Header(f.Value).Clone()), + } + if f.File != nil { + m := make(map[string][]*multipart.FileHeader) + for k, vv := range f.File { + vv2 := make([]*multipart.FileHeader, len(vv)) + for i, v := range vv { + vv2[i] = cloneMultipartFileHeader(v) + } + m[k] = vv2 + } + f2.File = m + } + return f2 +} + +func cloneMultipartFileHeader(fh *multipart.FileHeader) *multipart.FileHeader { + if fh == nil { + return nil + } + fh2 := new(multipart.FileHeader) + *fh2 = *fh + fh2.Header = textproto.MIMEHeader(Header(fh.Header).Clone()) + return fh2 +} + +// cloneOrMakeHeader invokes Header.Clone but if the +// result is nil, it'll instead make and return a non-nil Header. +func cloneOrMakeHeader(hdr Header) Header { + clone := hdr.Clone() + if clone == nil { + clone = make(Header) + } + return clone +} diff --git a/src/net/http/cookie.go b/src/net/http/cookie.go new file mode 100644 index 0000000000..24c938c3d4 --- /dev/null +++ b/src/net/http/cookie.go @@ -0,0 +1,470 @@ +// TINYGO: The following is copied and modified from Go 1.19.3 official implementation. + +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package http + +import ( + "errors" + "fmt" + "log" + "net" + "net/http/internal/ascii" + "net/textproto" + "strconv" + "strings" + "time" +) + +// A Cookie represents an HTTP cookie as sent in the Set-Cookie header of an +// HTTP response or the Cookie header of an HTTP request. +// +// See https://tools.ietf.org/html/rfc6265 for details. +type Cookie struct { + Name string + Value string + + Path string // optional + Domain string // optional + Expires time.Time // optional + RawExpires string // for reading cookies only + + // MaxAge=0 means no 'Max-Age' attribute specified. + // MaxAge<0 means delete cookie now, equivalently 'Max-Age: 0' + // MaxAge>0 means Max-Age attribute present and given in seconds + MaxAge int + Secure bool + HttpOnly bool + SameSite SameSite + Raw string + Unparsed []string // Raw text of unparsed attribute-value pairs +} + +// SameSite allows a server to define a cookie attribute making it impossible for +// the browser to send this cookie along with cross-site requests. The main +// goal is to mitigate the risk of cross-origin information leakage, and provide +// some protection against cross-site request forgery attacks. +// +// See https://tools.ietf.org/html/draft-ietf-httpbis-cookie-same-site-00 for details. +type SameSite int + +const ( + SameSiteDefaultMode SameSite = iota + 1 + SameSiteLaxMode + SameSiteStrictMode + SameSiteNoneMode +) + +// readSetCookies parses all "Set-Cookie" values from +// the header h and returns the successfully parsed Cookies. +func readSetCookies(h Header) []*Cookie { + cookieCount := len(h["Set-Cookie"]) + if cookieCount == 0 { + return []*Cookie{} + } + cookies := make([]*Cookie, 0, cookieCount) + for _, line := range h["Set-Cookie"] { + parts := strings.Split(textproto.TrimString(line), ";") + if len(parts) == 1 && parts[0] == "" { + continue + } + parts[0] = textproto.TrimString(parts[0]) + name, value, ok := strings.Cut(parts[0], "=") + if !ok { + continue + } + name = textproto.TrimString(name) + if !isCookieNameValid(name) { + continue + } + value, ok = parseCookieValue(value, true) + if !ok { + continue + } + c := &Cookie{ + Name: name, + Value: value, + Raw: line, + } + for i := 1; i < len(parts); i++ { + parts[i] = textproto.TrimString(parts[i]) + if len(parts[i]) == 0 { + continue + } + + attr, val, _ := strings.Cut(parts[i], "=") + lowerAttr, isASCII := ascii.ToLower(attr) + if !isASCII { + continue + } + val, ok = parseCookieValue(val, false) + if !ok { + c.Unparsed = append(c.Unparsed, parts[i]) + continue + } + + switch lowerAttr { + case "samesite": + lowerVal, ascii := ascii.ToLower(val) + if !ascii { + c.SameSite = SameSiteDefaultMode + continue + } + switch lowerVal { + case "lax": + c.SameSite = SameSiteLaxMode + case "strict": + c.SameSite = SameSiteStrictMode + case "none": + c.SameSite = SameSiteNoneMode + default: + c.SameSite = SameSiteDefaultMode + } + continue + case "secure": + c.Secure = true + continue + case "httponly": + c.HttpOnly = true + continue + case "domain": + c.Domain = val + continue + case "max-age": + secs, err := strconv.Atoi(val) + if err != nil || secs != 0 && val[0] == '0' { + break + } + if secs <= 0 { + secs = -1 + } + c.MaxAge = secs + continue + case "expires": + c.RawExpires = val + exptime, err := time.Parse(time.RFC1123, val) + if err != nil { + exptime, err = time.Parse("Mon, 02-Jan-2006 15:04:05 MST", val) + if err != nil { + c.Expires = time.Time{} + break + } + } + c.Expires = exptime.UTC() + continue + case "path": + c.Path = val + continue + } + c.Unparsed = append(c.Unparsed, parts[i]) + } + cookies = append(cookies, c) + } + return cookies +} + +// SetCookie adds a Set-Cookie header to the provided ResponseWriter's headers. +// The provided cookie must have a valid Name. Invalid cookies may be +// silently dropped. +func SetCookie(w ResponseWriter, cookie *Cookie) { + if v := cookie.String(); v != "" { + w.Header().Add("Set-Cookie", v) + } +} + +// String returns the serialization of the cookie for use in a Cookie +// header (if only Name and Value are set) or a Set-Cookie response +// header (if other fields are set). +// If c is nil or c.Name is invalid, the empty string is returned. +func (c *Cookie) String() string { + if c == nil || !isCookieNameValid(c.Name) { + return "" + } + // extraCookieLength derived from typical length of cookie attributes + // see RFC 6265 Sec 4.1. + const extraCookieLength = 110 + var b strings.Builder + b.Grow(len(c.Name) + len(c.Value) + len(c.Domain) + len(c.Path) + extraCookieLength) + b.WriteString(c.Name) + b.WriteRune('=') + b.WriteString(sanitizeCookieValue(c.Value)) + + if len(c.Path) > 0 { + b.WriteString("; Path=") + b.WriteString(sanitizeCookiePath(c.Path)) + } + if len(c.Domain) > 0 { + if validCookieDomain(c.Domain) { + // A c.Domain containing illegal characters is not + // sanitized but simply dropped which turns the cookie + // into a host-only cookie. A leading dot is okay + // but won't be sent. + d := c.Domain + if d[0] == '.' { + d = d[1:] + } + b.WriteString("; Domain=") + b.WriteString(d) + } else { + log.Printf("net/http: invalid Cookie.Domain %q; dropping domain attribute", c.Domain) + } + } + var buf [len(TimeFormat)]byte + if validCookieExpires(c.Expires) { + b.WriteString("; Expires=") + b.Write(c.Expires.UTC().AppendFormat(buf[:0], TimeFormat)) + } + if c.MaxAge > 0 { + b.WriteString("; Max-Age=") + b.Write(strconv.AppendInt(buf[:0], int64(c.MaxAge), 10)) + } else if c.MaxAge < 0 { + b.WriteString("; Max-Age=0") + } + if c.HttpOnly { + b.WriteString("; HttpOnly") + } + if c.Secure { + b.WriteString("; Secure") + } + switch c.SameSite { + case SameSiteDefaultMode: + // Skip, default mode is obtained by not emitting the attribute. + case SameSiteNoneMode: + b.WriteString("; SameSite=None") + case SameSiteLaxMode: + b.WriteString("; SameSite=Lax") + case SameSiteStrictMode: + b.WriteString("; SameSite=Strict") + } + return b.String() +} + +// Valid reports whether the cookie is valid. +func (c *Cookie) Valid() error { + if c == nil { + return errors.New("http: nil Cookie") + } + if !isCookieNameValid(c.Name) { + return errors.New("http: invalid Cookie.Name") + } + if !c.Expires.IsZero() && !validCookieExpires(c.Expires) { + return errors.New("http: invalid Cookie.Expires") + } + for i := 0; i < len(c.Value); i++ { + if !validCookieValueByte(c.Value[i]) { + return fmt.Errorf("http: invalid byte %q in Cookie.Value", c.Value[i]) + } + } + if len(c.Path) > 0 { + for i := 0; i < len(c.Path); i++ { + if !validCookiePathByte(c.Path[i]) { + return fmt.Errorf("http: invalid byte %q in Cookie.Path", c.Path[i]) + } + } + } + if len(c.Domain) > 0 { + if !validCookieDomain(c.Domain) { + return errors.New("http: invalid Cookie.Domain") + } + } + return nil +} + +// readCookies parses all "Cookie" values from the header h and +// returns the successfully parsed Cookies. +// +// if filter isn't empty, only cookies of that name are returned. +func readCookies(h Header, filter string) []*Cookie { + lines := h["Cookie"] + if len(lines) == 0 { + return []*Cookie{} + } + + cookies := make([]*Cookie, 0, len(lines)+strings.Count(lines[0], ";")) + for _, line := range lines { + line = textproto.TrimString(line) + + var part string + for len(line) > 0 { // continue since we have rest + part, line, _ = strings.Cut(line, ";") + part = textproto.TrimString(part) + if part == "" { + continue + } + name, val, _ := strings.Cut(part, "=") + name = textproto.TrimString(name) + if !isCookieNameValid(name) { + continue + } + if filter != "" && filter != name { + continue + } + val, ok := parseCookieValue(val, true) + if !ok { + continue + } + cookies = append(cookies, &Cookie{Name: name, Value: val}) + } + } + return cookies +} + +// validCookieDomain reports whether v is a valid cookie domain-value. +func validCookieDomain(v string) bool { + if isCookieDomainName(v) { + return true + } + if net.ParseIP(v) != nil && !strings.Contains(v, ":") { + return true + } + return false +} + +// validCookieExpires reports whether v is a valid cookie expires-value. +func validCookieExpires(t time.Time) bool { + // IETF RFC 6265 Section 5.1.1.5, the year must not be less than 1601 + return t.Year() >= 1601 +} + +// isCookieDomainName reports whether s is a valid domain name or a valid +// domain name with a leading dot '.'. It is almost a direct copy of +// package net's isDomainName. +func isCookieDomainName(s string) bool { + if len(s) == 0 { + return false + } + if len(s) > 255 { + return false + } + + if s[0] == '.' { + // A cookie a domain attribute may start with a leading dot. + s = s[1:] + } + last := byte('.') + ok := false // Ok once we've seen a letter. + partlen := 0 + for i := 0; i < len(s); i++ { + c := s[i] + switch { + default: + return false + case 'a' <= c && c <= 'z' || 'A' <= c && c <= 'Z': + // No '_' allowed here (in contrast to package net). + ok = true + partlen++ + case '0' <= c && c <= '9': + // fine + partlen++ + case c == '-': + // Byte before dash cannot be dot. + if last == '.' { + return false + } + partlen++ + case c == '.': + // Byte before dot cannot be dot, dash. + if last == '.' || last == '-' { + return false + } + if partlen > 63 || partlen == 0 { + return false + } + partlen = 0 + } + last = c + } + if last == '-' || partlen > 63 { + return false + } + + return ok +} + +var cookieNameSanitizer = strings.NewReplacer("\n", "-", "\r", "-") + +func sanitizeCookieName(n string) string { + return cookieNameSanitizer.Replace(n) +} + +// sanitizeCookieValue produces a suitable cookie-value from v. +// https://tools.ietf.org/html/rfc6265#section-4.1.1 +// +// cookie-value = *cookie-octet / ( DQUOTE *cookie-octet DQUOTE ) +// cookie-octet = %x21 / %x23-2B / %x2D-3A / %x3C-5B / %x5D-7E +// ; US-ASCII characters excluding CTLs, +// ; whitespace DQUOTE, comma, semicolon, +// ; and backslash +// +// We loosen this as spaces and commas are common in cookie values +// but we produce a quoted cookie-value if and only if v contains +// commas or spaces. +// See https://golang.org/issue/7243 for the discussion. +func sanitizeCookieValue(v string) string { + v = sanitizeOrWarn("Cookie.Value", validCookieValueByte, v) + if len(v) == 0 { + return v + } + if strings.ContainsAny(v, " ,") { + return `"` + v + `"` + } + return v +} + +func validCookieValueByte(b byte) bool { + return 0x20 <= b && b < 0x7f && b != '"' && b != ';' && b != '\\' +} + +// path-av = "Path=" path-value +// path-value = +func sanitizeCookiePath(v string) string { + return sanitizeOrWarn("Cookie.Path", validCookiePathByte, v) +} + +func validCookiePathByte(b byte) bool { + return 0x20 <= b && b < 0x7f && b != ';' +} + +func sanitizeOrWarn(fieldName string, valid func(byte) bool, v string) string { + ok := true + for i := 0; i < len(v); i++ { + if valid(v[i]) { + continue + } + log.Printf("net/http: invalid byte %q in %s; dropping invalid bytes", v[i], fieldName) + ok = false + break + } + if ok { + return v + } + buf := make([]byte, 0, len(v)) + for i := 0; i < len(v); i++ { + if b := v[i]; valid(b) { + buf = append(buf, b) + } + } + return string(buf) +} + +func parseCookieValue(raw string, allowDoubleQuote bool) (string, bool) { + // Strip the quotes, if present. + if allowDoubleQuote && len(raw) > 1 && raw[0] == '"' && raw[len(raw)-1] == '"' { + raw = raw[1 : len(raw)-1] + } + for i := 0; i < len(raw); i++ { + if !validCookieValueByte(raw[i]) { + return "", false + } + } + return raw, true +} + +func isCookieNameValid(raw string) bool { + if raw == "" { + return false + } + return strings.IndexFunc(raw, isNotToken) < 0 +} diff --git a/src/net/http/fs.go b/src/net/http/fs.go new file mode 100644 index 0000000000..3967045c2f --- /dev/null +++ b/src/net/http/fs.go @@ -0,0 +1,974 @@ +// TINYGO: The following is copied from Go 1.19.3 official implementation. + +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// HTTP file system request handler + +package http + +import ( + "errors" + "fmt" + "io" + "io/fs" + "mime" + "mime/multipart" + "net/textproto" + "net/url" + "os" + "path" + "path/filepath" + "sort" + "strconv" + "strings" + "time" +) + +// A Dir implements FileSystem using the native file system restricted to a +// specific directory tree. +// +// While the FileSystem.Open method takes '/'-separated paths, a Dir's string +// value is a filename on the native file system, not a URL, so it is separated +// by filepath.Separator, which isn't necessarily '/'. +// +// Note that Dir could expose sensitive files and directories. Dir will follow +// symlinks pointing out of the directory tree, which can be especially dangerous +// if serving from a directory in which users are able to create arbitrary symlinks. +// Dir will also allow access to files and directories starting with a period, +// which could expose sensitive directories like .git or sensitive files like +// .htpasswd. To exclude files with a leading period, remove the files/directories +// from the server or create a custom FileSystem implementation. +// +// An empty Dir is treated as ".". +type Dir string + +// mapOpenError maps the provided non-nil error from opening name +// to a possibly better non-nil error. In particular, it turns OS-specific errors +// about opening files in non-directories into fs.ErrNotExist. See Issues 18984 and 49552. +func mapOpenError(originalErr error, name string, sep rune, stat func(string) (fs.FileInfo, error)) error { + if errors.Is(originalErr, fs.ErrNotExist) || errors.Is(originalErr, fs.ErrPermission) { + return originalErr + } + + parts := strings.Split(name, string(sep)) + for i := range parts { + if parts[i] == "" { + continue + } + fi, err := stat(strings.Join(parts[:i+1], string(sep))) + if err != nil { + return originalErr + } + if !fi.IsDir() { + return fs.ErrNotExist + } + } + return originalErr +} + +// Open implements FileSystem using os.Open, opening files for reading rooted +// and relative to the directory d. +func (d Dir) Open(name string) (File, error) { + if filepath.Separator != '/' && strings.ContainsRune(name, filepath.Separator) { + return nil, errors.New("http: invalid character in file path") + } + dir := string(d) + if dir == "" { + dir = "." + } + fullName := filepath.Join(dir, filepath.FromSlash(path.Clean("/"+name))) + f, err := os.Open(fullName) + if err != nil { + return nil, mapOpenError(err, fullName, filepath.Separator, os.Stat) + } + return f, nil +} + +// A FileSystem implements access to a collection of named files. +// The elements in a file path are separated by slash ('/', U+002F) +// characters, regardless of host operating system convention. +// See the FileServer function to convert a FileSystem to a Handler. +// +// This interface predates the fs.FS interface, which can be used instead: +// the FS adapter function converts an fs.FS to a FileSystem. +type FileSystem interface { + Open(name string) (File, error) +} + +// A File is returned by a FileSystem's Open method and can be +// served by the FileServer implementation. +// +// The methods should behave the same as those on an *os.File. +type File interface { + io.Closer + io.Reader + io.Seeker + Readdir(count int) ([]fs.FileInfo, error) + Stat() (fs.FileInfo, error) +} + +type anyDirs interface { + len() int + name(i int) string + isDir(i int) bool +} + +type fileInfoDirs []fs.FileInfo + +func (d fileInfoDirs) len() int { return len(d) } +func (d fileInfoDirs) isDir(i int) bool { return d[i].IsDir() } +func (d fileInfoDirs) name(i int) string { return d[i].Name() } + +type dirEntryDirs []fs.DirEntry + +func (d dirEntryDirs) len() int { return len(d) } +func (d dirEntryDirs) isDir(i int) bool { return d[i].IsDir() } +func (d dirEntryDirs) name(i int) string { return d[i].Name() } + +func dirList(w ResponseWriter, r *Request, f File) { + // Prefer to use ReadDir instead of Readdir, + // because the former doesn't require calling + // Stat on every entry of a directory on Unix. + var dirs anyDirs + var err error + if d, ok := f.(fs.ReadDirFile); ok { + var list dirEntryDirs + list, err = d.ReadDir(-1) + dirs = list + } else { + var list fileInfoDirs + list, err = f.Readdir(-1) + dirs = list + } + + if err != nil { + logf(r, "http: error reading directory: %v", err) + Error(w, "Error reading directory", StatusInternalServerError) + return + } + sort.Slice(dirs, func(i, j int) bool { return dirs.name(i) < dirs.name(j) }) + + w.Header().Set("Content-Type", "text/html; charset=utf-8") + fmt.Fprintf(w, "
\n")
+	for i, n := 0, dirs.len(); i < n; i++ {
+		name := dirs.name(i)
+		if dirs.isDir(i) {
+			name += "/"
+		}
+		// name may contain '?' or '#', which must be escaped to remain
+		// part of the URL path, and not indicate the start of a query
+		// string or fragment.
+		url := url.URL{Path: name}
+		fmt.Fprintf(w, "%s\n", url.String(), htmlReplacer.Replace(name))
+	}
+	fmt.Fprintf(w, "
\n") +} + +// ServeContent replies to the request using the content in the +// provided ReadSeeker. The main benefit of ServeContent over io.Copy +// is that it handles Range requests properly, sets the MIME type, and +// handles If-Match, If-Unmodified-Since, If-None-Match, If-Modified-Since, +// and If-Range requests. +// +// If the response's Content-Type header is not set, ServeContent +// first tries to deduce the type from name's file extension and, +// if that fails, falls back to reading the first block of the content +// and passing it to DetectContentType. +// The name is otherwise unused; in particular it can be empty and is +// never sent in the response. +// +// If modtime is not the zero time or Unix epoch, ServeContent +// includes it in a Last-Modified header in the response. If the +// request includes an If-Modified-Since header, ServeContent uses +// modtime to decide whether the content needs to be sent at all. +// +// The content's Seek method must work: ServeContent uses +// a seek to the end of the content to determine its size. +// +// If the caller has set w's ETag header formatted per RFC 7232, section 2.3, +// ServeContent uses it to handle requests using If-Match, If-None-Match, or If-Range. +// +// Note that *os.File implements the io.ReadSeeker interface. +func ServeContent(w ResponseWriter, req *Request, name string, modtime time.Time, content io.ReadSeeker) { + sizeFunc := func() (int64, error) { + size, err := content.Seek(0, io.SeekEnd) + if err != nil { + return 0, errSeeker + } + _, err = content.Seek(0, io.SeekStart) + if err != nil { + return 0, errSeeker + } + return size, nil + } + serveContent(w, req, name, modtime, sizeFunc, content) +} + +// errSeeker is returned by ServeContent's sizeFunc when the content +// doesn't seek properly. The underlying Seeker's error text isn't +// included in the sizeFunc reply so it's not sent over HTTP to end +// users. +var errSeeker = errors.New("seeker can't seek") + +// errNoOverlap is returned by serveContent's parseRange if first-byte-pos of +// all of the byte-range-spec values is greater than the content size. +var errNoOverlap = errors.New("invalid range: failed to overlap") + +// if name is empty, filename is unknown. (used for mime type, before sniffing) +// if modtime.IsZero(), modtime is unknown. +// content must be seeked to the beginning of the file. +// The sizeFunc is called at most once. Its error, if any, is sent in the HTTP response. +func serveContent(w ResponseWriter, r *Request, name string, modtime time.Time, sizeFunc func() (int64, error), content io.ReadSeeker) { + setLastModified(w, modtime) + done, rangeReq := checkPreconditions(w, r, modtime) + if done { + return + } + + code := StatusOK + + // If Content-Type isn't set, use the file's extension to find it, but + // if the Content-Type is unset explicitly, do not sniff the type. + ctypes, haveType := w.Header()["Content-Type"] + var ctype string + if !haveType { + ctype = mime.TypeByExtension(filepath.Ext(name)) + if ctype == "" { + // read a chunk to decide between utf-8 text and binary + var buf [sniffLen]byte + n, _ := io.ReadFull(content, buf[:]) + ctype = DetectContentType(buf[:n]) + _, err := content.Seek(0, io.SeekStart) // rewind to output whole file + if err != nil { + Error(w, "seeker can't seek", StatusInternalServerError) + return + } + } + w.Header().Set("Content-Type", ctype) + } else if len(ctypes) > 0 { + ctype = ctypes[0] + } + + size, err := sizeFunc() + if err != nil { + Error(w, err.Error(), StatusInternalServerError) + return + } + + // handle Content-Range header. + sendSize := size + var sendContent io.Reader = content + if size >= 0 { + ranges, err := parseRange(rangeReq, size) + if err != nil { + if err == errNoOverlap { + w.Header().Set("Content-Range", fmt.Sprintf("bytes */%d", size)) + } + Error(w, err.Error(), StatusRequestedRangeNotSatisfiable) + return + } + if sumRangesSize(ranges) > size { + // The total number of bytes in all the ranges + // is larger than the size of the file by + // itself, so this is probably an attack, or a + // dumb client. Ignore the range request. + ranges = nil + } + switch { + case len(ranges) == 1: + // RFC 7233, Section 4.1: + // "If a single part is being transferred, the server + // generating the 206 response MUST generate a + // Content-Range header field, describing what range + // of the selected representation is enclosed, and a + // payload consisting of the range. + // ... + // A server MUST NOT generate a multipart response to + // a request for a single range, since a client that + // does not request multiple parts might not support + // multipart responses." + ra := ranges[0] + if _, err := content.Seek(ra.start, io.SeekStart); err != nil { + Error(w, err.Error(), StatusRequestedRangeNotSatisfiable) + return + } + sendSize = ra.length + code = StatusPartialContent + w.Header().Set("Content-Range", ra.contentRange(size)) + case len(ranges) > 1: + sendSize = rangesMIMESize(ranges, ctype, size) + code = StatusPartialContent + + pr, pw := io.Pipe() + mw := multipart.NewWriter(pw) + w.Header().Set("Content-Type", "multipart/byteranges; boundary="+mw.Boundary()) + sendContent = pr + defer pr.Close() // cause writing goroutine to fail and exit if CopyN doesn't finish. + go func() { + for _, ra := range ranges { + part, err := mw.CreatePart(ra.mimeHeader(ctype, size)) + if err != nil { + pw.CloseWithError(err) + return + } + if _, err := content.Seek(ra.start, io.SeekStart); err != nil { + pw.CloseWithError(err) + return + } + if _, err := io.CopyN(part, content, ra.length); err != nil { + pw.CloseWithError(err) + return + } + } + mw.Close() + pw.Close() + }() + } + + w.Header().Set("Accept-Ranges", "bytes") + if w.Header().Get("Content-Encoding") == "" { + w.Header().Set("Content-Length", strconv.FormatInt(sendSize, 10)) + } + } + + w.WriteHeader(code) + + if r.Method != "HEAD" { + io.CopyN(w, sendContent, sendSize) + } +} + +// scanETag determines if a syntactically valid ETag is present at s. If so, +// the ETag and remaining text after consuming ETag is returned. Otherwise, +// it returns "", "". +func scanETag(s string) (etag string, remain string) { + s = textproto.TrimString(s) + start := 0 + if strings.HasPrefix(s, "W/") { + start = 2 + } + if len(s[start:]) < 2 || s[start] != '"' { + return "", "" + } + // ETag is either W/"text" or "text". + // See RFC 7232 2.3. + for i := start + 1; i < len(s); i++ { + c := s[i] + switch { + // Character values allowed in ETags. + case c == 0x21 || c >= 0x23 && c <= 0x7E || c >= 0x80: + case c == '"': + return s[:i+1], s[i+1:] + default: + return "", "" + } + } + return "", "" +} + +// etagStrongMatch reports whether a and b match using strong ETag comparison. +// Assumes a and b are valid ETags. +func etagStrongMatch(a, b string) bool { + return a == b && a != "" && a[0] == '"' +} + +// etagWeakMatch reports whether a and b match using weak ETag comparison. +// Assumes a and b are valid ETags. +func etagWeakMatch(a, b string) bool { + return strings.TrimPrefix(a, "W/") == strings.TrimPrefix(b, "W/") +} + +// condResult is the result of an HTTP request precondition check. +// See https://tools.ietf.org/html/rfc7232 section 3. +type condResult int + +const ( + condNone condResult = iota + condTrue + condFalse +) + +func checkIfMatch(w ResponseWriter, r *Request) condResult { + im := r.Header.Get("If-Match") + if im == "" { + return condNone + } + for { + im = textproto.TrimString(im) + if len(im) == 0 { + break + } + if im[0] == ',' { + im = im[1:] + continue + } + if im[0] == '*' { + return condTrue + } + etag, remain := scanETag(im) + if etag == "" { + break + } + if etagStrongMatch(etag, w.Header().get("Etag")) { + return condTrue + } + im = remain + } + + return condFalse +} + +func checkIfUnmodifiedSince(r *Request, modtime time.Time) condResult { + ius := r.Header.Get("If-Unmodified-Since") + if ius == "" || isZeroTime(modtime) { + return condNone + } + t, err := ParseTime(ius) + if err != nil { + return condNone + } + + // The Last-Modified header truncates sub-second precision so + // the modtime needs to be truncated too. + modtime = modtime.Truncate(time.Second) + if modtime.Before(t) || modtime.Equal(t) { + return condTrue + } + return condFalse +} + +func checkIfNoneMatch(w ResponseWriter, r *Request) condResult { + inm := r.Header.get("If-None-Match") + if inm == "" { + return condNone + } + buf := inm + for { + buf = textproto.TrimString(buf) + if len(buf) == 0 { + break + } + if buf[0] == ',' { + buf = buf[1:] + continue + } + if buf[0] == '*' { + return condFalse + } + etag, remain := scanETag(buf) + if etag == "" { + break + } + if etagWeakMatch(etag, w.Header().get("Etag")) { + return condFalse + } + buf = remain + } + return condTrue +} + +func checkIfModifiedSince(r *Request, modtime time.Time) condResult { + if r.Method != "GET" && r.Method != "HEAD" { + return condNone + } + ims := r.Header.Get("If-Modified-Since") + if ims == "" || isZeroTime(modtime) { + return condNone + } + t, err := ParseTime(ims) + if err != nil { + return condNone + } + // The Last-Modified header truncates sub-second precision so + // the modtime needs to be truncated too. + modtime = modtime.Truncate(time.Second) + if modtime.Before(t) || modtime.Equal(t) { + return condFalse + } + return condTrue +} + +func checkIfRange(w ResponseWriter, r *Request, modtime time.Time) condResult { + if r.Method != "GET" && r.Method != "HEAD" { + return condNone + } + ir := r.Header.get("If-Range") + if ir == "" { + return condNone + } + etag, _ := scanETag(ir) + if etag != "" { + if etagStrongMatch(etag, w.Header().Get("Etag")) { + return condTrue + } else { + return condFalse + } + } + // The If-Range value is typically the ETag value, but it may also be + // the modtime date. See golang.org/issue/8367. + if modtime.IsZero() { + return condFalse + } + t, err := ParseTime(ir) + if err != nil { + return condFalse + } + if t.Unix() == modtime.Unix() { + return condTrue + } + return condFalse +} + +var unixEpochTime = time.Unix(0, 0) + +// isZeroTime reports whether t is obviously unspecified (either zero or Unix()=0). +func isZeroTime(t time.Time) bool { + return t.IsZero() || t.Equal(unixEpochTime) +} + +func setLastModified(w ResponseWriter, modtime time.Time) { + if !isZeroTime(modtime) { + w.Header().Set("Last-Modified", modtime.UTC().Format(TimeFormat)) + } +} + +func writeNotModified(w ResponseWriter) { + // RFC 7232 section 4.1: + // a sender SHOULD NOT generate representation metadata other than the + // above listed fields unless said metadata exists for the purpose of + // guiding cache updates (e.g., Last-Modified might be useful if the + // response does not have an ETag field). + h := w.Header() + delete(h, "Content-Type") + delete(h, "Content-Length") + delete(h, "Content-Encoding") + if h.Get("Etag") != "" { + delete(h, "Last-Modified") + } + w.WriteHeader(StatusNotModified) +} + +// checkPreconditions evaluates request preconditions and reports whether a precondition +// resulted in sending StatusNotModified or StatusPreconditionFailed. +func checkPreconditions(w ResponseWriter, r *Request, modtime time.Time) (done bool, rangeHeader string) { + // This function carefully follows RFC 7232 section 6. + ch := checkIfMatch(w, r) + if ch == condNone { + ch = checkIfUnmodifiedSince(r, modtime) + } + if ch == condFalse { + w.WriteHeader(StatusPreconditionFailed) + return true, "" + } + switch checkIfNoneMatch(w, r) { + case condFalse: + if r.Method == "GET" || r.Method == "HEAD" { + writeNotModified(w) + return true, "" + } else { + w.WriteHeader(StatusPreconditionFailed) + return true, "" + } + case condNone: + if checkIfModifiedSince(r, modtime) == condFalse { + writeNotModified(w) + return true, "" + } + } + + rangeHeader = r.Header.get("Range") + if rangeHeader != "" && checkIfRange(w, r, modtime) == condFalse { + rangeHeader = "" + } + return false, rangeHeader +} + +// name is '/'-separated, not filepath.Separator. +func serveFile(w ResponseWriter, r *Request, fs FileSystem, name string, redirect bool) { + const indexPage = "/index.html" + + // redirect .../index.html to .../ + // can't use Redirect() because that would make the path absolute, + // which would be a problem running under StripPrefix + if strings.HasSuffix(r.URL.Path, indexPage) { + localRedirect(w, r, "./") + return + } + + f, err := fs.Open(name) + if err != nil { + msg, code := toHTTPError(err) + Error(w, msg, code) + return + } + defer f.Close() + + d, err := f.Stat() + if err != nil { + msg, code := toHTTPError(err) + Error(w, msg, code) + return + } + + if redirect { + // redirect to canonical path: / at end of directory url + // r.URL.Path always begins with / + url := r.URL.Path + if d.IsDir() { + if url[len(url)-1] != '/' { + localRedirect(w, r, path.Base(url)+"/") + return + } + } else { + if url[len(url)-1] == '/' { + localRedirect(w, r, "../"+path.Base(url)) + return + } + } + } + + if d.IsDir() { + url := r.URL.Path + // redirect if the directory name doesn't end in a slash + if url == "" || url[len(url)-1] != '/' { + localRedirect(w, r, path.Base(url)+"/") + return + } + + // use contents of index.html for directory, if present + index := strings.TrimSuffix(name, "/") + indexPage + ff, err := fs.Open(index) + if err == nil { + defer ff.Close() + dd, err := ff.Stat() + if err == nil { + name = index + d = dd + f = ff + } + } + } + + // Still a directory? (we didn't find an index.html file) + if d.IsDir() { + if checkIfModifiedSince(r, d.ModTime()) == condFalse { + writeNotModified(w) + return + } + setLastModified(w, d.ModTime()) + dirList(w, r, f) + return + } + + // serveContent will check modification time + sizeFunc := func() (int64, error) { return d.Size(), nil } + serveContent(w, r, d.Name(), d.ModTime(), sizeFunc, f) +} + +// toHTTPError returns a non-specific HTTP error message and status code +// for a given non-nil error value. It's important that toHTTPError does not +// actually return err.Error(), since msg and httpStatus are returned to users, +// and historically Go's ServeContent always returned just "404 Not Found" for +// all errors. We don't want to start leaking information in error messages. +func toHTTPError(err error) (msg string, httpStatus int) { + if errors.Is(err, fs.ErrNotExist) { + return "404 page not found", StatusNotFound + } + if errors.Is(err, fs.ErrPermission) { + return "403 Forbidden", StatusForbidden + } + // Default: + return "500 Internal Server Error", StatusInternalServerError +} + +// localRedirect gives a Moved Permanently response. +// It does not convert relative paths to absolute paths like Redirect does. +func localRedirect(w ResponseWriter, r *Request, newPath string) { + if q := r.URL.RawQuery; q != "" { + newPath += "?" + q + } + w.Header().Set("Location", newPath) + w.WriteHeader(StatusMovedPermanently) +} + +// ServeFile replies to the request with the contents of the named +// file or directory. +// +// If the provided file or directory name is a relative path, it is +// interpreted relative to the current directory and may ascend to +// parent directories. If the provided name is constructed from user +// input, it should be sanitized before calling ServeFile. +// +// As a precaution, ServeFile will reject requests where r.URL.Path +// contains a ".." path element; this protects against callers who +// might unsafely use filepath.Join on r.URL.Path without sanitizing +// it and then use that filepath.Join result as the name argument. +// +// As another special case, ServeFile redirects any request where r.URL.Path +// ends in "/index.html" to the same path, without the final +// "index.html". To avoid such redirects either modify the path or +// use ServeContent. +// +// Outside of those two special cases, ServeFile does not use +// r.URL.Path for selecting the file or directory to serve; only the +// file or directory provided in the name argument is used. +func ServeFile(w ResponseWriter, r *Request, name string) { + if containsDotDot(r.URL.Path) { + // Too many programs use r.URL.Path to construct the argument to + // serveFile. Reject the request under the assumption that happened + // here and ".." may not be wanted. + // Note that name might not contain "..", for example if code (still + // incorrectly) used filepath.Join(myDir, r.URL.Path). + Error(w, "invalid URL path", StatusBadRequest) + return + } + dir, file := filepath.Split(name) + serveFile(w, r, Dir(dir), file, false) +} + +func containsDotDot(v string) bool { + if !strings.Contains(v, "..") { + return false + } + for _, ent := range strings.FieldsFunc(v, isSlashRune) { + if ent == ".." { + return true + } + } + return false +} + +func isSlashRune(r rune) bool { return r == '/' || r == '\\' } + +type fileHandler struct { + root FileSystem +} + +type ioFS struct { + fsys fs.FS +} + +type ioFile struct { + file fs.File +} + +func (f ioFS) Open(name string) (File, error) { + if name == "/" { + name = "." + } else { + name = strings.TrimPrefix(name, "/") + } + file, err := f.fsys.Open(name) + if err != nil { + return nil, mapOpenError(err, name, '/', func(path string) (fs.FileInfo, error) { + return fs.Stat(f.fsys, path) + }) + } + return ioFile{file}, nil +} + +func (f ioFile) Close() error { return f.file.Close() } +func (f ioFile) Read(b []byte) (int, error) { return f.file.Read(b) } +func (f ioFile) Stat() (fs.FileInfo, error) { return f.file.Stat() } + +var errMissingSeek = errors.New("io.File missing Seek method") +var errMissingReadDir = errors.New("io.File directory missing ReadDir method") + +func (f ioFile) Seek(offset int64, whence int) (int64, error) { + s, ok := f.file.(io.Seeker) + if !ok { + return 0, errMissingSeek + } + return s.Seek(offset, whence) +} + +func (f ioFile) ReadDir(count int) ([]fs.DirEntry, error) { + d, ok := f.file.(fs.ReadDirFile) + if !ok { + return nil, errMissingReadDir + } + return d.ReadDir(count) +} + +func (f ioFile) Readdir(count int) ([]fs.FileInfo, error) { + d, ok := f.file.(fs.ReadDirFile) + if !ok { + return nil, errMissingReadDir + } + var list []fs.FileInfo + for { + dirs, err := d.ReadDir(count - len(list)) + for _, dir := range dirs { + info, err := dir.Info() + if err != nil { + // Pretend it doesn't exist, like (*os.File).Readdir does. + continue + } + list = append(list, info) + } + if err != nil { + return list, err + } + if count < 0 || len(list) >= count { + break + } + } + return list, nil +} + +// FS converts fsys to a FileSystem implementation, +// for use with FileServer and NewFileTransport. +func FS(fsys fs.FS) FileSystem { + return ioFS{fsys} +} + +// FileServer returns a handler that serves HTTP requests +// with the contents of the file system rooted at root. +// +// As a special case, the returned file server redirects any request +// ending in "/index.html" to the same path, without the final +// "index.html". +// +// To use the operating system's file system implementation, +// use http.Dir: +// +// http.Handle("/", http.FileServer(http.Dir("/tmp"))) +// +// To use an fs.FS implementation, use http.FS to convert it: +// +// http.Handle("/", http.FileServer(http.FS(fsys))) +func FileServer(root FileSystem) Handler { + return &fileHandler{root} +} + +func (f *fileHandler) ServeHTTP(w ResponseWriter, r *Request) { + upath := r.URL.Path + if !strings.HasPrefix(upath, "/") { + upath = "/" + upath + r.URL.Path = upath + } + serveFile(w, r, f.root, path.Clean(upath), true) +} + +// httpRange specifies the byte range to be sent to the client. +type httpRange struct { + start, length int64 +} + +func (r httpRange) contentRange(size int64) string { + return fmt.Sprintf("bytes %d-%d/%d", r.start, r.start+r.length-1, size) +} + +func (r httpRange) mimeHeader(contentType string, size int64) textproto.MIMEHeader { + return textproto.MIMEHeader{ + "Content-Range": {r.contentRange(size)}, + "Content-Type": {contentType}, + } +} + +// parseRange parses a Range header string as per RFC 7233. +// errNoOverlap is returned if none of the ranges overlap. +func parseRange(s string, size int64) ([]httpRange, error) { + if s == "" { + return nil, nil // header not present + } + const b = "bytes=" + if !strings.HasPrefix(s, b) { + return nil, errors.New("invalid range") + } + var ranges []httpRange + noOverlap := false + for _, ra := range strings.Split(s[len(b):], ",") { + ra = textproto.TrimString(ra) + if ra == "" { + continue + } + start, end, ok := strings.Cut(ra, "-") + if !ok { + return nil, errors.New("invalid range") + } + start, end = textproto.TrimString(start), textproto.TrimString(end) + var r httpRange + if start == "" { + // If no start is specified, end specifies the + // range start relative to the end of the file, + // and we are dealing with + // which has to be a non-negative integer as per + // RFC 7233 Section 2.1 "Byte-Ranges". + if end == "" || end[0] == '-' { + return nil, errors.New("invalid range") + } + i, err := strconv.ParseInt(end, 10, 64) + if i < 0 || err != nil { + return nil, errors.New("invalid range") + } + if i > size { + i = size + } + r.start = size - i + r.length = size - r.start + } else { + i, err := strconv.ParseInt(start, 10, 64) + if err != nil || i < 0 { + return nil, errors.New("invalid range") + } + if i >= size { + // If the range begins after the size of the content, + // then it does not overlap. + noOverlap = true + continue + } + r.start = i + if end == "" { + // If no end is specified, range extends to end of the file. + r.length = size - r.start + } else { + i, err := strconv.ParseInt(end, 10, 64) + if err != nil || r.start > i { + return nil, errors.New("invalid range") + } + if i >= size { + i = size - 1 + } + r.length = i - r.start + 1 + } + } + ranges = append(ranges, r) + } + if noOverlap && len(ranges) == 0 { + // The specified ranges did not overlap with the content. + return nil, errNoOverlap + } + return ranges, nil +} + +// countingWriter counts how many bytes have been written to it. +type countingWriter int64 + +func (w *countingWriter) Write(p []byte) (n int, err error) { + *w += countingWriter(len(p)) + return len(p), nil +} + +// rangesMIMESize returns the number of bytes it takes to encode the +// provided ranges as a multipart response. +func rangesMIMESize(ranges []httpRange, contentType string, contentSize int64) (encSize int64) { + var w countingWriter + mw := multipart.NewWriter(&w) + for _, ra := range ranges { + mw.CreatePart(ra.mimeHeader(contentType, contentSize)) + encSize += ra.length + } + mw.Close() + encSize += int64(w) + return +} + +func sumRangesSize(ranges []httpRange) (size int64) { + for _, ra := range ranges { + size += ra.length + } + return +} diff --git a/src/net/http/header.go b/src/net/http/header.go new file mode 100644 index 0000000000..a5779f6132 --- /dev/null +++ b/src/net/http/header.go @@ -0,0 +1,275 @@ +// TINYGO: The following is copied and modified from Go 1.19.3 official implementation. + +// TINYGO: Removed trace stuff + +// Copyright 2010 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package http + +import ( + "io" + "net/http/internal/ascii" + "net/textproto" + "sort" + "strings" + "sync" + "time" + + "golang.org/x/net/http/httpguts" +) + +// A Header represents the key-value pairs in an HTTP header. +// +// The keys should be in canonical form, as returned by +// CanonicalHeaderKey. +type Header map[string][]string + +// Add adds the key, value pair to the header. +// It appends to any existing values associated with key. +// The key is case insensitive; it is canonicalized by +// CanonicalHeaderKey. +func (h Header) Add(key, value string) { + textproto.MIMEHeader(h).Add(key, value) +} + +// Set sets the header entries associated with key to the +// single element value. It replaces any existing values +// associated with key. The key is case insensitive; it is +// canonicalized by textproto.CanonicalMIMEHeaderKey. +// To use non-canonical keys, assign to the map directly. +func (h Header) Set(key, value string) { + textproto.MIMEHeader(h).Set(key, value) +} + +// Get gets the first value associated with the given key. If +// there are no values associated with the key, Get returns "". +// It is case insensitive; textproto.CanonicalMIMEHeaderKey is +// used to canonicalize the provided key. Get assumes that all +// keys are stored in canonical form. To use non-canonical keys, +// access the map directly. +func (h Header) Get(key string) string { + return textproto.MIMEHeader(h).Get(key) +} + +// Values returns all values associated with the given key. +// It is case insensitive; textproto.CanonicalMIMEHeaderKey is +// used to canonicalize the provided key. To use non-canonical +// keys, access the map directly. +// The returned slice is not a copy. +func (h Header) Values(key string) []string { + return textproto.MIMEHeader(h).Values(key) +} + +// get is like Get, but key must already be in CanonicalHeaderKey form. +func (h Header) get(key string) string { + if v := h[key]; len(v) > 0 { + return v[0] + } + return "" +} + +// has reports whether h has the provided key defined, even if it's +// set to 0-length slice. +func (h Header) has(key string) bool { + _, ok := h[key] + return ok +} + +// Del deletes the values associated with key. +// The key is case insensitive; it is canonicalized by +// CanonicalHeaderKey. +func (h Header) Del(key string) { + textproto.MIMEHeader(h).Del(key) +} + +// Write writes a header in wire format. +func (h Header) Write(w io.Writer) error { + return h.write(w) +} + +func (h Header) write(w io.Writer) error { + return h.writeSubset(w, nil) +} + +// Clone returns a copy of h or nil if h is nil. +func (h Header) Clone() Header { + if h == nil { + return nil + } + + // Find total number of values. + nv := 0 + for _, vv := range h { + nv += len(vv) + } + sv := make([]string, nv) // shared backing array for headers' values + h2 := make(Header, len(h)) + for k, vv := range h { + if vv == nil { + // Preserve nil values. ReverseProxy distinguishes + // between nil and zero-length header values. + h2[k] = nil + continue + } + n := copy(sv, vv) + h2[k] = sv[:n:n] + sv = sv[n:] + } + return h2 +} + +var timeFormats = []string{ + TimeFormat, + time.RFC850, + time.ANSIC, +} + +// ParseTime parses a time header (such as the Date: header), +// trying each of the three formats allowed by HTTP/1.1: +// TimeFormat, time.RFC850, and time.ANSIC. +func ParseTime(text string) (t time.Time, err error) { + for _, layout := range timeFormats { + t, err = time.Parse(layout, text) + if err == nil { + return + } + } + return +} + +var headerNewlineToSpace = strings.NewReplacer("\n", " ", "\r", " ") + +// stringWriter implements WriteString on a Writer. +type stringWriter struct { + w io.Writer +} + +func (w stringWriter) WriteString(s string) (n int, err error) { + return w.w.Write([]byte(s)) +} + +type keyValues struct { + key string + values []string +} + +// A headerSorter implements sort.Interface by sorting a []keyValues +// by key. It's used as a pointer, so it can fit in a sort.Interface +// interface value without allocation. +type headerSorter struct { + kvs []keyValues +} + +func (s *headerSorter) Len() int { return len(s.kvs) } +func (s *headerSorter) Swap(i, j int) { s.kvs[i], s.kvs[j] = s.kvs[j], s.kvs[i] } +func (s *headerSorter) Less(i, j int) bool { return s.kvs[i].key < s.kvs[j].key } + +var headerSorterPool = sync.Pool{ + New: func() any { return new(headerSorter) }, +} + +// sortedKeyValues returns h's keys sorted in the returned kvs +// slice. The headerSorter used to sort is also returned, for possible +// return to headerSorterCache. +func (h Header) sortedKeyValues(exclude map[string]bool) (kvs []keyValues, hs *headerSorter) { + hs = headerSorterPool.Get().(*headerSorter) + if cap(hs.kvs) < len(h) { + hs.kvs = make([]keyValues, 0, len(h)) + } + kvs = hs.kvs[:0] + for k, vv := range h { + if !exclude[k] { + kvs = append(kvs, keyValues{k, vv}) + } + } + hs.kvs = kvs + sort.Sort(hs) + return kvs, hs +} + +// WriteSubset writes a header in wire format. +// If exclude is not nil, keys where exclude[key] == true are not written. +// Keys are not canonicalized before checking the exclude map. +func (h Header) WriteSubset(w io.Writer, exclude map[string]bool) error { + return h.writeSubset(w, exclude) +} + +func (h Header) writeSubset(w io.Writer, exclude map[string]bool) error { + ws, ok := w.(io.StringWriter) + if !ok { + ws = stringWriter{w} + } + kvs, sorter := h.sortedKeyValues(exclude) + for _, kv := range kvs { + if !httpguts.ValidHeaderFieldName(kv.key) { + // This could be an error. In the common case of + // writing response headers, however, we have no good + // way to provide the error back to the server + // handler, so just drop invalid headers instead. + continue + } + for _, v := range kv.values { + v = headerNewlineToSpace.Replace(v) + v = textproto.TrimString(v) + for _, s := range []string{kv.key, ": ", v, "\r\n"} { + if _, err := ws.WriteString(s); err != nil { + headerSorterPool.Put(sorter) + return err + } + } + } + } + headerSorterPool.Put(sorter) + return nil +} + +// CanonicalHeaderKey returns the canonical format of the +// header key s. The canonicalization converts the first +// letter and any letter following a hyphen to upper case; +// the rest are converted to lowercase. For example, the +// canonical key for "accept-encoding" is "Accept-Encoding". +// If s contains a space or invalid header field bytes, it is +// returned without modifications. +func CanonicalHeaderKey(s string) string { return textproto.CanonicalMIMEHeaderKey(s) } + +// hasToken reports whether token appears with v, ASCII +// case-insensitive, with space or comma boundaries. +// token must be all lowercase. +// v may contain mixed cased. +func hasToken(v, token string) bool { + if len(token) > len(v) || token == "" { + return false + } + if v == token { + return true + } + for sp := 0; sp <= len(v)-len(token); sp++ { + // Check that first character is good. + // The token is ASCII, so checking only a single byte + // is sufficient. We skip this potential starting + // position if both the first byte and its potential + // ASCII uppercase equivalent (b|0x20) don't match. + // False positives ('^' => '~') are caught by EqualFold. + if b := v[sp]; b != token[0] && b|0x20 != token[0] { + continue + } + // Check that start pos is on a valid token boundary. + if sp > 0 && !isTokenBoundary(v[sp-1]) { + continue + } + // Check that end pos is on a valid token boundary. + if endPos := sp + len(token); endPos != len(v) && !isTokenBoundary(v[endPos]) { + continue + } + if ascii.EqualFold(v[sp:sp+len(token)], token) { + return true + } + } + return false +} + +func isTokenBoundary(b byte) bool { + return b == ' ' || b == ',' || b == '\t' +} diff --git a/src/net/http/http.go b/src/net/http/http.go new file mode 100644 index 0000000000..fc1db57f38 --- /dev/null +++ b/src/net/http/http.go @@ -0,0 +1,161 @@ +// TINYGO: The following is copied from Go 1.19.3 official implementation. + +// Copyright 2016 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:generate bundle -o=h2_bundle.go -prefix=http2 -tags=!nethttpomithttp2 golang.org/x/net/http2 + +package http + +import ( + "io" + "strconv" + "strings" + "time" + "unicode/utf8" + + "golang.org/x/net/http/httpguts" +) + +// incomparable is a zero-width, non-comparable type. Adding it to a struct +// makes that struct also non-comparable, and generally doesn't add +// any size (as long as it's first). +type incomparable [0]func() + +// maxInt64 is the effective "infinite" value for the Server and +// Transport's byte-limiting readers. +const maxInt64 = 1<<63 - 1 + +// aLongTimeAgo is a non-zero time, far in the past, used for +// immediate cancellation of network operations. +var aLongTimeAgo = time.Unix(1, 0) + +// omitBundledHTTP2 is set by omithttp2.go when the nethttpomithttp2 +// build tag is set. That means h2_bundle.go isn't compiled in and we +// shouldn't try to use it. +var omitBundledHTTP2 bool + +// TODO(bradfitz): move common stuff here. The other files have accumulated +// generic http stuff in random places. + +// contextKey is a value for use with context.WithValue. It's used as +// a pointer so it fits in an interface{} without allocation. +type contextKey struct { + name string +} + +func (k *contextKey) String() string { return "net/http context value " + k.name } + +// Given a string of the form "host", "host:port", or "[ipv6::address]:port", +// return true if the string includes a port. +func hasPort(s string) bool { return strings.LastIndex(s, ":") > strings.LastIndex(s, "]") } + +// removeEmptyPort strips the empty port in ":port" to "" +// as mandated by RFC 3986 Section 6.2.3. +func removeEmptyPort(host string) string { + if hasPort(host) { + return strings.TrimSuffix(host, ":") + } + return host +} + +func isNotToken(r rune) bool { + return !httpguts.IsTokenRune(r) +} + +// stringContainsCTLByte reports whether s contains any ASCII control character. +func stringContainsCTLByte(s string) bool { + for i := 0; i < len(s); i++ { + b := s[i] + if b < ' ' || b == 0x7f { + return true + } + } + return false +} + +func hexEscapeNonASCII(s string) string { + newLen := 0 + for i := 0; i < len(s); i++ { + if s[i] >= utf8.RuneSelf { + newLen += 3 + } else { + newLen++ + } + } + if newLen == len(s) { + return s + } + b := make([]byte, 0, newLen) + for i := 0; i < len(s); i++ { + if s[i] >= utf8.RuneSelf { + b = append(b, '%') + b = strconv.AppendInt(b, int64(s[i]), 16) + } else { + b = append(b, s[i]) + } + } + return string(b) +} + +// NoBody is an io.ReadCloser with no bytes. Read always returns EOF +// and Close always returns nil. It can be used in an outgoing client +// request to explicitly signal that a request has zero bytes. +// An alternative, however, is to simply set Request.Body to nil. +var NoBody = noBody{} + +type noBody struct{} + +func (noBody) Read([]byte) (int, error) { return 0, io.EOF } +func (noBody) Close() error { return nil } +func (noBody) WriteTo(io.Writer) (int64, error) { return 0, nil } + +var ( + // verify that an io.Copy from NoBody won't require a buffer: + _ io.WriterTo = NoBody + _ io.ReadCloser = NoBody +) + +// PushOptions describes options for Pusher.Push. +type PushOptions struct { + // Method specifies the HTTP method for the promised request. + // If set, it must be "GET" or "HEAD". Empty means "GET". + Method string + + // Header specifies additional promised request headers. This cannot + // include HTTP/2 pseudo header fields like ":path" and ":scheme", + // which will be added automatically. + Header Header +} + +// Pusher is the interface implemented by ResponseWriters that support +// HTTP/2 server push. For more background, see +// https://tools.ietf.org/html/rfc7540#section-8.2. +type Pusher interface { + // Push initiates an HTTP/2 server push. This constructs a synthetic + // request using the given target and options, serializes that request + // into a PUSH_PROMISE frame, then dispatches that request using the + // server's request handler. If opts is nil, default options are used. + // + // The target must either be an absolute path (like "/path") or an absolute + // URL that contains a valid host and the same scheme as the parent request. + // If the target is a path, it will inherit the scheme and host of the + // parent request. + // + // The HTTP/2 spec disallows recursive pushes and cross-authority pushes. + // Push may or may not detect these invalid pushes; however, invalid + // pushes will be detected and canceled by conforming clients. + // + // Handlers that wish to push URL X should call Push before sending any + // data that may trigger a request for URL X. This avoids a race where the + // client issues requests for X before receiving the PUSH_PROMISE for X. + // + // Push will run in a separate goroutine making the order of arrival + // non-deterministic. Any required synchronization needs to be implemented + // by the caller. + // + // Push returns ErrNotSupported if the client has disabled push or if push + // is not supported on the underlying connection. + Push(target string, opts *PushOptions) error +} diff --git a/src/net/http/internal/ascii/print.go b/src/net/http/internal/ascii/print.go new file mode 100644 index 0000000000..c2b3a9bda9 --- /dev/null +++ b/src/net/http/internal/ascii/print.go @@ -0,0 +1,63 @@ +// TINYGO: The following is copied from Go 1.19.3 official implementation. + +// Copyright 2021 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ascii + +import ( + "strings" + "unicode" +) + +// EqualFold is strings.EqualFold, ASCII only. It reports whether s and t +// are equal, ASCII-case-insensitively. +func EqualFold(s, t string) bool { + if len(s) != len(t) { + return false + } + for i := 0; i < len(s); i++ { + if lower(s[i]) != lower(t[i]) { + return false + } + } + return true +} + +// lower returns the ASCII lowercase version of b. +func lower(b byte) byte { + if 'A' <= b && b <= 'Z' { + return b + ('a' - 'A') + } + return b +} + +// IsPrint returns whether s is ASCII and printable according to +// https://tools.ietf.org/html/rfc20#section-4.2. +func IsPrint(s string) bool { + for i := 0; i < len(s); i++ { + if s[i] < ' ' || s[i] > '~' { + return false + } + } + return true +} + +// Is returns whether s is ASCII. +func Is(s string) bool { + for i := 0; i < len(s); i++ { + if s[i] > unicode.MaxASCII { + return false + } + } + return true +} + +// ToLower returns the lowercase version of s if s is ASCII and printable. +func ToLower(s string) (lower string, ok bool) { + if !IsPrint(s) { + return "", false + } + return strings.ToLower(s), true +} diff --git a/src/net/http/internal/ascii/print_test.go b/src/net/http/internal/ascii/print_test.go new file mode 100644 index 0000000000..0b7767ca33 --- /dev/null +++ b/src/net/http/internal/ascii/print_test.go @@ -0,0 +1,95 @@ +// Copyright 2021 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ascii + +import "testing" + +func TestEqualFold(t *testing.T) { + var tests = []struct { + name string + a, b string + want bool + }{ + { + name: "empty", + want: true, + }, + { + name: "simple match", + a: "CHUNKED", + b: "chunked", + want: true, + }, + { + name: "same string", + a: "chunked", + b: "chunked", + want: true, + }, + { + name: "Unicode Kelvin symbol", + a: "chunKed", // This "K" is 'KELVIN SIGN' (\u212A) + b: "chunked", + want: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := EqualFold(tt.a, tt.b); got != tt.want { + t.Errorf("AsciiEqualFold(%q,%q): got %v want %v", tt.a, tt.b, got, tt.want) + } + }) + } +} + +func TestIsPrint(t *testing.T) { + var tests = []struct { + name string + in string + want bool + }{ + { + name: "empty", + want: true, + }, + { + name: "ASCII low", + in: "This is a space: ' '", + want: true, + }, + { + name: "ASCII high", + in: "This is a tilde: '~'", + want: true, + }, + { + name: "ASCII low non-print", + in: "This is a unit separator: \x1F", + want: false, + }, + { + name: "Ascii high non-print", + in: "This is a Delete: \x7F", + want: false, + }, + { + name: "Unicode letter", + in: "Today it's 280K outside: it's freezing!", // This "K" is 'KELVIN SIGN' (\u212A) + want: false, + }, + { + name: "Unicode emoji", + in: "Gophers like 🧀", + want: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := IsPrint(tt.in); got != tt.want { + t.Errorf("IsASCIIPrint(%q): got %v want %v", tt.in, got, tt.want) + } + }) + } +} diff --git a/src/net/http/internal/chunked.go b/src/net/http/internal/chunked.go new file mode 100644 index 0000000000..34b533158d --- /dev/null +++ b/src/net/http/internal/chunked.go @@ -0,0 +1,264 @@ +// TINYGO: The following is copied from Go 1.19.3 official implementation. + +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// The wire protocol for HTTP's "chunked" Transfer-Encoding. + +// Package internal contains HTTP internals shared by net/http and +// net/http/httputil. +package internal + +import ( + "bufio" + "bytes" + "errors" + "fmt" + "io" +) + +const maxLineLength = 4096 // assumed <= bufio.defaultBufSize + +var ErrLineTooLong = errors.New("header line too long") + +// NewChunkedReader returns a new chunkedReader that translates the data read from r +// out of HTTP "chunked" format before returning it. +// The chunkedReader returns io.EOF when the final 0-length chunk is read. +// +// NewChunkedReader is not needed by normal applications. The http package +// automatically decodes chunking when reading response bodies. +func NewChunkedReader(r io.Reader) io.Reader { + br, ok := r.(*bufio.Reader) + if !ok { + br = bufio.NewReader(r) + } + return &chunkedReader{r: br} +} + +type chunkedReader struct { + r *bufio.Reader + n uint64 // unread bytes in chunk + err error + buf [2]byte + checkEnd bool // whether need to check for \r\n chunk footer +} + +func (cr *chunkedReader) beginChunk() { + // chunk-size CRLF + var line []byte + line, cr.err = readChunkLine(cr.r) + if cr.err != nil { + return + } + cr.n, cr.err = parseHexUint(line) + if cr.err != nil { + return + } + if cr.n == 0 { + cr.err = io.EOF + } +} + +func (cr *chunkedReader) chunkHeaderAvailable() bool { + n := cr.r.Buffered() + if n > 0 { + peek, _ := cr.r.Peek(n) + return bytes.IndexByte(peek, '\n') >= 0 + } + return false +} + +func (cr *chunkedReader) Read(b []uint8) (n int, err error) { + for cr.err == nil { + if cr.checkEnd { + if n > 0 && cr.r.Buffered() < 2 { + // We have some data. Return early (per the io.Reader + // contract) instead of potentially blocking while + // reading more. + break + } + if _, cr.err = io.ReadFull(cr.r, cr.buf[:2]); cr.err == nil { + if string(cr.buf[:]) != "\r\n" { + cr.err = errors.New("malformed chunked encoding") + break + } + } else { + if cr.err == io.EOF { + cr.err = io.ErrUnexpectedEOF + } + break + } + cr.checkEnd = false + } + if cr.n == 0 { + if n > 0 && !cr.chunkHeaderAvailable() { + // We've read enough. Don't potentially block + // reading a new chunk header. + break + } + cr.beginChunk() + continue + } + if len(b) == 0 { + break + } + rbuf := b + if uint64(len(rbuf)) > cr.n { + rbuf = rbuf[:cr.n] + } + var n0 int + n0, cr.err = cr.r.Read(rbuf) + n += n0 + b = b[n0:] + cr.n -= uint64(n0) + // If we're at the end of a chunk, read the next two + // bytes to verify they are "\r\n". + if cr.n == 0 && cr.err == nil { + cr.checkEnd = true + } else if cr.err == io.EOF { + cr.err = io.ErrUnexpectedEOF + } + } + return n, cr.err +} + +// Read a line of bytes (up to \n) from b. +// Give up if the line exceeds maxLineLength. +// The returned bytes are owned by the bufio.Reader +// so they are only valid until the next bufio read. +func readChunkLine(b *bufio.Reader) ([]byte, error) { + p, err := b.ReadSlice('\n') + if err != nil { + // We always know when EOF is coming. + // If the caller asked for a line, there should be a line. + if err == io.EOF { + err = io.ErrUnexpectedEOF + } else if err == bufio.ErrBufferFull { + err = ErrLineTooLong + } + return nil, err + } + if len(p) >= maxLineLength { + return nil, ErrLineTooLong + } + p = trimTrailingWhitespace(p) + p, err = removeChunkExtension(p) + if err != nil { + return nil, err + } + return p, nil +} + +func trimTrailingWhitespace(b []byte) []byte { + for len(b) > 0 && isASCIISpace(b[len(b)-1]) { + b = b[:len(b)-1] + } + return b +} + +func isASCIISpace(b byte) bool { + return b == ' ' || b == '\t' || b == '\n' || b == '\r' +} + +var semi = []byte(";") + +// removeChunkExtension removes any chunk-extension from p. +// For example, +// +// "0" => "0" +// "0;token" => "0" +// "0;token=val" => "0" +// `0;token="quoted string"` => "0" +func removeChunkExtension(p []byte) ([]byte, error) { + p, _, _ = bytes.Cut(p, semi) + // TODO: care about exact syntax of chunk extensions? We're + // ignoring and stripping them anyway. For now just never + // return an error. + return p, nil +} + +// NewChunkedWriter returns a new chunkedWriter that translates writes into HTTP +// "chunked" format before writing them to w. Closing the returned chunkedWriter +// sends the final 0-length chunk that marks the end of the stream but does +// not send the final CRLF that appears after trailers; trailers and the last +// CRLF must be written separately. +// +// NewChunkedWriter is not needed by normal applications. The http +// package adds chunking automatically if handlers don't set a +// Content-Length header. Using newChunkedWriter inside a handler +// would result in double chunking or chunking with a Content-Length +// length, both of which are wrong. +func NewChunkedWriter(w io.Writer) io.WriteCloser { + return &chunkedWriter{w} +} + +// Writing to chunkedWriter translates to writing in HTTP chunked Transfer +// Encoding wire format to the underlying Wire chunkedWriter. +type chunkedWriter struct { + Wire io.Writer +} + +// Write the contents of data as one chunk to Wire. +// NOTE: Note that the corresponding chunk-writing procedure in Conn.Write has +// a bug since it does not check for success of io.WriteString +func (cw *chunkedWriter) Write(data []byte) (n int, err error) { + + // Don't send 0-length data. It looks like EOF for chunked encoding. + if len(data) == 0 { + return 0, nil + } + + if _, err = fmt.Fprintf(cw.Wire, "%x\r\n", len(data)); err != nil { + return 0, err + } + if n, err = cw.Wire.Write(data); err != nil { + return + } + if n != len(data) { + err = io.ErrShortWrite + return + } + if _, err = io.WriteString(cw.Wire, "\r\n"); err != nil { + return + } + if bw, ok := cw.Wire.(*FlushAfterChunkWriter); ok { + err = bw.Flush() + } + return +} + +func (cw *chunkedWriter) Close() error { + _, err := io.WriteString(cw.Wire, "0\r\n") + return err +} + +// FlushAfterChunkWriter signals from the caller of NewChunkedWriter +// that each chunk should be followed by a flush. It is used by the +// http.Transport code to keep the buffering behavior for headers and +// trailers, but flush out chunks aggressively in the middle for +// request bodies which may be generated slowly. See Issue 6574. +type FlushAfterChunkWriter struct { + *bufio.Writer +} + +func parseHexUint(v []byte) (n uint64, err error) { + for i, b := range v { + switch { + case '0' <= b && b <= '9': + b = b - '0' + case 'a' <= b && b <= 'f': + b = b - 'a' + 10 + case 'A' <= b && b <= 'F': + b = b - 'A' + 10 + default: + return 0, errors.New("invalid byte in chunk length") + } + if i == 16 { + return 0, errors.New("http chunk length too large") + } + n <<= 4 + n |= uint64(b) + } + return +} diff --git a/src/net/http/internal/chunked_test.go b/src/net/http/internal/chunked_test.go new file mode 100644 index 0000000000..5e29a786dd --- /dev/null +++ b/src/net/http/internal/chunked_test.go @@ -0,0 +1,241 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package internal + +import ( + "bufio" + "bytes" + "fmt" + "io" + "strings" + "testing" + "testing/iotest" +) + +func TestChunk(t *testing.T) { + var b bytes.Buffer + + w := NewChunkedWriter(&b) + const chunk1 = "hello, " + const chunk2 = "world! 0123456789abcdef" + w.Write([]byte(chunk1)) + w.Write([]byte(chunk2)) + w.Close() + + if g, e := b.String(), "7\r\nhello, \r\n17\r\nworld! 0123456789abcdef\r\n0\r\n"; g != e { + t.Fatalf("chunk writer wrote %q; want %q", g, e) + } + + r := NewChunkedReader(&b) + data, err := io.ReadAll(r) + if err != nil { + t.Logf(`data: "%s"`, data) + t.Fatalf("ReadAll from reader: %v", err) + } + if g, e := string(data), chunk1+chunk2; g != e { + t.Errorf("chunk reader read %q; want %q", g, e) + } +} + +func TestChunkReadMultiple(t *testing.T) { + // Bunch of small chunks, all read together. + { + var b bytes.Buffer + w := NewChunkedWriter(&b) + w.Write([]byte("foo")) + w.Write([]byte("bar")) + w.Close() + + r := NewChunkedReader(&b) + buf := make([]byte, 10) + n, err := r.Read(buf) + if n != 6 || err != io.EOF { + t.Errorf("Read = %d, %v; want 6, EOF", n, err) + } + buf = buf[:n] + if string(buf) != "foobar" { + t.Errorf("Read = %q; want %q", buf, "foobar") + } + } + + // One big chunk followed by a little chunk, but the small bufio.Reader size + // should prevent the second chunk header from being read. + { + var b bytes.Buffer + w := NewChunkedWriter(&b) + // fillBufChunk is 11 bytes + 3 bytes header + 2 bytes footer = 16 bytes, + // the same as the bufio ReaderSize below (the minimum), so even + // though we're going to try to Read with a buffer larger enough to also + // receive "foo", the second chunk header won't be read yet. + const fillBufChunk = "0123456789a" + const shortChunk = "foo" + w.Write([]byte(fillBufChunk)) + w.Write([]byte(shortChunk)) + w.Close() + + r := NewChunkedReader(bufio.NewReaderSize(&b, 16)) + buf := make([]byte, len(fillBufChunk)+len(shortChunk)) + n, err := r.Read(buf) + if n != len(fillBufChunk) || err != nil { + t.Errorf("Read = %d, %v; want %d, nil", n, err, len(fillBufChunk)) + } + buf = buf[:n] + if string(buf) != fillBufChunk { + t.Errorf("Read = %q; want %q", buf, fillBufChunk) + } + + n, err = r.Read(buf) + if n != len(shortChunk) || err != io.EOF { + t.Errorf("Read = %d, %v; want %d, EOF", n, err, len(shortChunk)) + } + } + + // And test that we see an EOF chunk, even though our buffer is already full: + { + r := NewChunkedReader(bufio.NewReader(strings.NewReader("3\r\nfoo\r\n0\r\n"))) + buf := make([]byte, 3) + n, err := r.Read(buf) + if n != 3 || err != io.EOF { + t.Errorf("Read = %d, %v; want 3, EOF", n, err) + } + if string(buf) != "foo" { + t.Errorf("buf = %q; want foo", buf) + } + } +} + +func TestChunkReaderAllocs(t *testing.T) { + if testing.Short() { + t.Skip("skipping in short mode") + } + var buf bytes.Buffer + w := NewChunkedWriter(&buf) + a, b, c := []byte("aaaaaa"), []byte("bbbbbbbbbbbb"), []byte("cccccccccccccccccccccccc") + w.Write(a) + w.Write(b) + w.Write(c) + w.Close() + + readBuf := make([]byte, len(a)+len(b)+len(c)+1) + byter := bytes.NewReader(buf.Bytes()) + bufr := bufio.NewReader(byter) + mallocs := testing.AllocsPerRun(100, func() { + byter.Seek(0, io.SeekStart) + bufr.Reset(byter) + r := NewChunkedReader(bufr) + n, err := io.ReadFull(r, readBuf) + if n != len(readBuf)-1 { + t.Fatalf("read %d bytes; want %d", n, len(readBuf)-1) + } + if err != io.ErrUnexpectedEOF { + t.Fatalf("read error = %v; want ErrUnexpectedEOF", err) + } + }) + if mallocs > 1.5 { + t.Errorf("mallocs = %v; want 1", mallocs) + } +} + +func TestParseHexUint(t *testing.T) { + type testCase struct { + in string + want uint64 + wantErr string + } + tests := []testCase{ + {"x", 0, "invalid byte in chunk length"}, + {"0000000000000000", 0, ""}, + {"0000000000000001", 1, ""}, + {"ffffffffffffffff", 1<<64 - 1, ""}, + {"000000000000bogus", 0, "invalid byte in chunk length"}, + {"00000000000000000", 0, "http chunk length too large"}, // could accept if we wanted + {"10000000000000000", 0, "http chunk length too large"}, + {"00000000000000001", 0, "http chunk length too large"}, // could accept if we wanted + } + for i := uint64(0); i <= 1234; i++ { + tests = append(tests, testCase{in: fmt.Sprintf("%x", i), want: i}) + } + for _, tt := range tests { + got, err := parseHexUint([]byte(tt.in)) + if tt.wantErr != "" { + if !strings.Contains(fmt.Sprint(err), tt.wantErr) { + t.Errorf("parseHexUint(%q) = %v, %v; want error %q", tt.in, got, err, tt.wantErr) + } + } else { + if err != nil || got != tt.want { + t.Errorf("parseHexUint(%q) = %v, %v; want %v", tt.in, got, err, tt.want) + } + } + } +} + +func TestChunkReadingIgnoresExtensions(t *testing.T) { + in := "7;ext=\"some quoted string\"\r\n" + // token=quoted string + "hello, \r\n" + + "17;someext\r\n" + // token without value + "world! 0123456789abcdef\r\n" + + "0;someextension=sometoken\r\n" // token=token + data, err := io.ReadAll(NewChunkedReader(strings.NewReader(in))) + if err != nil { + t.Fatalf("ReadAll = %q, %v", data, err) + } + if g, e := string(data), "hello, world! 0123456789abcdef"; g != e { + t.Errorf("read %q; want %q", g, e) + } +} + +// Issue 17355: ChunkedReader shouldn't block waiting for more data +// if it can return something. +func TestChunkReadPartial(t *testing.T) { + pr, pw := io.Pipe() + go func() { + pw.Write([]byte("7\r\n1234567")) + }() + cr := NewChunkedReader(pr) + readBuf := make([]byte, 7) + n, err := cr.Read(readBuf) + if err != nil { + t.Fatal(err) + } + want := "1234567" + if n != 7 || string(readBuf) != want { + t.Fatalf("Read: %v %q; want %d, %q", n, readBuf[:n], len(want), want) + } + go func() { + pw.Write([]byte("xx")) + }() + _, err = cr.Read(readBuf) + if got := fmt.Sprint(err); !strings.Contains(got, "malformed") { + t.Fatalf("second read = %v; want malformed error", err) + } + +} + +// Issue 48861: ChunkedReader should report incomplete chunks +func TestIncompleteChunk(t *testing.T) { + const valid = "4\r\nabcd\r\n" + "5\r\nabc\r\n\r\n" + "0\r\n" + + for i := 0; i < len(valid); i++ { + incomplete := valid[:i] + r := NewChunkedReader(strings.NewReader(incomplete)) + if _, err := io.ReadAll(r); err != io.ErrUnexpectedEOF { + t.Errorf("expected io.ErrUnexpectedEOF for %q, got %v", incomplete, err) + } + } + + r := NewChunkedReader(strings.NewReader(valid)) + if _, err := io.ReadAll(r); err != nil { + t.Errorf("unexpected error for %q: %v", valid, err) + } +} + +func TestChunkEndReadError(t *testing.T) { + readErr := fmt.Errorf("chunk end read error") + + r := NewChunkedReader(io.MultiReader(strings.NewReader("4\r\nabcd"), iotest.ErrReader(readErr))) + if _, err := io.ReadAll(r); err != readErr { + t.Errorf("expected %v, got %v", readErr, err) + } +} diff --git a/src/net/http/jar.go b/src/net/http/jar.go new file mode 100644 index 0000000000..3091c58aaa --- /dev/null +++ b/src/net/http/jar.go @@ -0,0 +1,29 @@ +// TINYGO: The following is copied and modified from Go 1.19.3 official implementation. + +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package http + +import ( + "net/url" +) + +// A CookieJar manages storage and use of cookies in HTTP requests. +// +// Implementations of CookieJar must be safe for concurrent use by multiple +// goroutines. +// +// The net/http/cookiejar package provides a CookieJar implementation. +type CookieJar interface { + // SetCookies handles the receipt of the cookies in a reply for the + // given URL. It may or may not choose to save the cookies, depending + // on the jar's policy and implementation. + SetCookies(u *url.URL, cookies []*Cookie) + + // Cookies returns the cookies to send in a request for the given URL. + // It is up to the implementation to honor the standard cookie use + // restrictions such as in RFC 6265. + Cookies(u *url.URL) []*Cookie +} diff --git a/src/net/http/method.go b/src/net/http/method.go new file mode 100644 index 0000000000..b8a4c33beb --- /dev/null +++ b/src/net/http/method.go @@ -0,0 +1,22 @@ +// TINYGO: The following is copied from Go 1.19.3 official implementation. + +// Copyright 2015 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package http + +// Common HTTP methods. +// +// Unless otherwise noted, these are defined in RFC 7231 section 4.3. +const ( + MethodGet = "GET" + MethodHead = "HEAD" + MethodPost = "POST" + MethodPut = "PUT" + MethodPatch = "PATCH" // RFC 5789 + MethodDelete = "DELETE" + MethodConnect = "CONNECT" + MethodOptions = "OPTIONS" + MethodTrace = "TRACE" +) diff --git a/src/net/http/request.go b/src/net/http/request.go new file mode 100644 index 0000000000..1971ec4253 --- /dev/null +++ b/src/net/http/request.go @@ -0,0 +1,1447 @@ +// TINYGO: The following is copied and modified from Go 1.19.3 official implementation. + +// TINYGO: Removed multipart stuff +// TINYGO: Removed trace stuff + +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// HTTP Request reading and parsing. + +package http + +import ( + "bufio" + "bytes" + "context" + "crypto/tls" + "encoding/base64" + "errors" + "fmt" + "io" + "mime" + "mime/multipart" + "net" + "net/http/internal/ascii" + "net/textproto" + "net/url" + urlpkg "net/url" + "strconv" + "strings" + "sync" +) + +const ( + defaultMaxMemory = 32 << 20 // 32 MB +) + +// ErrMissingFile is returned by FormFile when the provided file field name +// is either not present in the request or not a file field. +var ErrMissingFile = errors.New("http: no such file") + +// ProtocolError represents an HTTP protocol error. +// +// Deprecated: Not all errors in the http package related to protocol errors +// are of type ProtocolError. +type ProtocolError struct { + ErrorString string +} + +func (pe *ProtocolError) Error() string { return pe.ErrorString } + +var ( + // ErrNotSupported is returned by the Push method of Pusher + // implementations to indicate that HTTP/2 Push support is not + // available. + ErrNotSupported = &ProtocolError{"feature not supported"} + + // Deprecated: ErrUnexpectedTrailer is no longer returned by + // anything in the net/http package. Callers should not + // compare errors against this variable. + ErrUnexpectedTrailer = &ProtocolError{"trailer header without chunked transfer encoding"} + + // ErrMissingBoundary is returned by Request.MultipartReader when the + // request's Content-Type does not include a "boundary" parameter. + ErrMissingBoundary = &ProtocolError{"no multipart boundary param in Content-Type"} + + // ErrNotMultipart is returned by Request.MultipartReader when the + // request's Content-Type is not multipart/form-data. + ErrNotMultipart = &ProtocolError{"request Content-Type isn't multipart/form-data"} + + // Deprecated: ErrHeaderTooLong is no longer returned by + // anything in the net/http package. Callers should not + // compare errors against this variable. + ErrHeaderTooLong = &ProtocolError{"header too long"} + + // Deprecated: ErrShortBody is no longer returned by + // anything in the net/http package. Callers should not + // compare errors against this variable. + ErrShortBody = &ProtocolError{"entity body too short"} + + // Deprecated: ErrMissingContentLength is no longer returned by + // anything in the net/http package. Callers should not + // compare errors against this variable. + ErrMissingContentLength = &ProtocolError{"missing ContentLength in HEAD response"} +) + +func badStringError(what, val string) error { return fmt.Errorf("%s %q", what, val) } + +// Headers that Request.Write handles itself and should be skipped. +var reqWriteExcludeHeader = map[string]bool{ + "Host": true, // not in Header map anyway + "User-Agent": true, + "Content-Length": true, + "Transfer-Encoding": true, + "Trailer": true, +} + +// A Request represents an HTTP request received by a server +// or to be sent by a client. +// +// The field semantics differ slightly between client and server +// usage. In addition to the notes on the fields below, see the +// documentation for Request.Write and RoundTripper. +type Request struct { + // Method specifies the HTTP method (GET, POST, PUT, etc.). + // For client requests, an empty string means GET. + // + // Go's HTTP client does not support sending a request with + // the CONNECT method. See the documentation on Transport for + // details. + Method string + + // URL specifies either the URI being requested (for server + // requests) or the URL to access (for client requests). + // + // For server requests, the URL is parsed from the URI + // supplied on the Request-Line as stored in RequestURI. For + // most requests, fields other than Path and RawQuery will be + // empty. (See RFC 7230, Section 5.3) + // + // For client requests, the URL's Host specifies the server to + // connect to, while the Request's Host field optionally + // specifies the Host header value to send in the HTTP + // request. + URL *url.URL + + // The protocol version for incoming server requests. + // + // For client requests, these fields are ignored. The HTTP + // client code always uses either HTTP/1.1 or HTTP/2. + // See the docs on Transport for details. + Proto string // "HTTP/1.0" + ProtoMajor int // 1 + ProtoMinor int // 0 + + // Header contains the request header fields either received + // by the server or to be sent by the client. + // + // If a server received a request with header lines, + // + // Host: example.com + // accept-encoding: gzip, deflate + // Accept-Language: en-us + // fOO: Bar + // foo: two + // + // then + // + // Header = map[string][]string{ + // "Accept-Encoding": {"gzip, deflate"}, + // "Accept-Language": {"en-us"}, + // "Foo": {"Bar", "two"}, + // } + // + // For incoming requests, the Host header is promoted to the + // Request.Host field and removed from the Header map. + // + // HTTP defines that header names are case-insensitive. The + // request parser implements this by using CanonicalHeaderKey, + // making the first character and any characters following a + // hyphen uppercase and the rest lowercase. + // + // For client requests, certain headers such as Content-Length + // and Connection are automatically written when needed and + // values in Header may be ignored. See the documentation + // for the Request.Write method. + Header Header + + // Body is the request's body. + // + // For client requests, a nil body means the request has no + // body, such as a GET request. The HTTP Client's Transport + // is responsible for calling the Close method. + // + // For server requests, the Request Body is always non-nil + // but will return EOF immediately when no body is present. + // The Server will close the request body. The ServeHTTP + // Handler does not need to. + // + // Body must allow Read to be called concurrently with Close. + // In particular, calling Close should unblock a Read waiting + // for input. + Body io.ReadCloser + + // GetBody defines an optional func to return a new copy of + // Body. It is used for client requests when a redirect requires + // reading the body more than once. Use of GetBody still + // requires setting Body. + // + // For server requests, it is unused. + GetBody func() (io.ReadCloser, error) + + // ContentLength records the length of the associated content. + // The value -1 indicates that the length is unknown. + // Values >= 0 indicate that the given number of bytes may + // be read from Body. + // + // For client requests, a value of 0 with a non-nil Body is + // also treated as unknown. + ContentLength int64 + + // TransferEncoding lists the transfer encodings from outermost to + // innermost. An empty list denotes the "identity" encoding. + // TransferEncoding can usually be ignored; chunked encoding is + // automatically added and removed as necessary when sending and + // receiving requests. + TransferEncoding []string + + // Close indicates whether to close the connection after + // replying to this request (for servers) or after sending this + // request and reading its response (for clients). + // + // For server requests, the HTTP server handles this automatically + // and this field is not needed by Handlers. + // + // For client requests, setting this field prevents re-use of + // TCP connections between requests to the same hosts, as if + // Transport.DisableKeepAlives were set. + Close bool + + // For server requests, Host specifies the host on which the + // URL is sought. For HTTP/1 (per RFC 7230, section 5.4), this + // is either the value of the "Host" header or the host name + // given in the URL itself. For HTTP/2, it is the value of the + // ":authority" pseudo-header field. + // It may be of the form "host:port". For international domain + // names, Host may be in Punycode or Unicode form. Use + // golang.org/x/net/idna to convert it to either format if + // needed. + // To prevent DNS rebinding attacks, server Handlers should + // validate that the Host header has a value for which the + // Handler considers itself authoritative. The included + // ServeMux supports patterns registered to particular host + // names and thus protects its registered Handlers. + // + // For client requests, Host optionally overrides the Host + // header to send. If empty, the Request.Write method uses + // the value of URL.Host. Host may contain an international + // domain name. + Host string + + // Form contains the parsed form data, including both the URL + // field's query parameters and the PATCH, POST, or PUT form data. + // This field is only available after ParseForm is called. + // The HTTP client ignores Form and uses Body instead. + Form url.Values + + // PostForm contains the parsed form data from PATCH, POST + // or PUT body parameters. + // + // This field is only available after ParseForm is called. + // The HTTP client ignores PostForm and uses Body instead. + PostForm url.Values + + // MultipartForm is the parsed multipart form, including file uploads. + // This field is only available after ParseMultipartForm is called. + // The HTTP client ignores MultipartForm and uses Body instead. + MultipartForm *multipart.Form + + // Trailer specifies additional headers that are sent after the request + // body. + // + // For server requests, the Trailer map initially contains only the + // trailer keys, with nil values. (The client declares which trailers it + // will later send.) While the handler is reading from Body, it must + // not reference Trailer. After reading from Body returns EOF, Trailer + // can be read again and will contain non-nil values, if they were sent + // by the client. + // + // For client requests, Trailer must be initialized to a map containing + // the trailer keys to later send. The values may be nil or their final + // values. The ContentLength must be 0 or -1, to send a chunked request. + // After the HTTP request is sent the map values can be updated while + // the request body is read. Once the body returns EOF, the caller must + // not mutate Trailer. + // + // Few HTTP clients, servers, or proxies support HTTP trailers. + Trailer Header + + // RemoteAddr allows HTTP servers and other software to record + // the network address that sent the request, usually for + // logging. This field is not filled in by ReadRequest and + // has no defined format. The HTTP server in this package + // sets RemoteAddr to an "IP:port" address before invoking a + // handler. + // This field is ignored by the HTTP client. + RemoteAddr string + + // RequestURI is the unmodified request-target of the + // Request-Line (RFC 7230, Section 3.1.1) as sent by the client + // to a server. Usually the URL field should be used instead. + // It is an error to set this field in an HTTP client request. + RequestURI string + + // TLS allows HTTP servers and other software to record + // information about the TLS connection on which the request + // was received. This field is not filled in by ReadRequest. + // The HTTP server in this package sets the field for + // TLS-enabled connections before invoking a handler; + // otherwise it leaves the field nil. + // This field is ignored by the HTTP client. + TLS *tls.ConnectionState + + // Cancel is an optional channel whose closure indicates that the client + // request should be regarded as canceled. Not all implementations of + // RoundTripper may support Cancel. + // + // For server requests, this field is not applicable. + // + // Deprecated: Set the Request's context with NewRequestWithContext + // instead. If a Request's Cancel field and context are both + // set, it is undefined whether Cancel is respected. + Cancel <-chan struct{} + + // Response is the redirect response which caused this request + // to be created. This field is only populated during client + // redirects. + Response *Response + + // ctx is either the client or server context. It should only + // be modified via copying the whole Request using WithContext. + // It is unexported to prevent people from using Context wrong + // and mutating the contexts held by callers of the same request. + ctx context.Context + + // TINYGO: Add onEOF func for callback when response is fully read + // TINYGO: so we can close the connection. + onEOF func() +} + +// Context returns the request's context. To change the context, use +// WithContext. +// +// The returned context is always non-nil; it defaults to the +// background context. +// +// For outgoing client requests, the context controls cancellation. +// +// For incoming server requests, the context is canceled when the +// client's connection closes, the request is canceled (with HTTP/2), +// or when the ServeHTTP method returns. +func (r *Request) Context() context.Context { + if r.ctx != nil { + return r.ctx + } + return context.Background() +} + +// WithContext returns a shallow copy of r with its context changed +// to ctx. The provided ctx must be non-nil. +// +// For outgoing client request, the context controls the entire +// lifetime of a request and its response: obtaining a connection, +// sending the request, and reading the response headers and body. +// +// To create a new request with a context, use NewRequestWithContext. +// To change the context of a request, such as an incoming request you +// want to modify before sending back out, use Request.Clone. Between +// those two uses, it's rare to need WithContext. +func (r *Request) WithContext(ctx context.Context) *Request { + if ctx == nil { + panic("nil context") + } + r2 := new(Request) + *r2 = *r + r2.ctx = ctx + return r2 +} + +// Clone returns a deep copy of r with its context changed to ctx. +// The provided ctx must be non-nil. +// +// For an outgoing client request, the context controls the entire +// lifetime of a request and its response: obtaining a connection, +// sending the request, and reading the response headers and body. +func (r *Request) Clone(ctx context.Context) *Request { + if ctx == nil { + panic("nil context") + } + r2 := new(Request) + *r2 = *r + r2.ctx = ctx + r2.URL = cloneURL(r.URL) + if r.Header != nil { + r2.Header = r.Header.Clone() + } + if r.Trailer != nil { + r2.Trailer = r.Trailer.Clone() + } + if s := r.TransferEncoding; s != nil { + s2 := make([]string, len(s)) + copy(s2, s) + r2.TransferEncoding = s2 + } + r2.Form = cloneURLValues(r.Form) + r2.PostForm = cloneURLValues(r.PostForm) + r2.MultipartForm = cloneMultipartForm(r.MultipartForm) + return r2 +} + +// ProtoAtLeast reports whether the HTTP protocol used +// in the request is at least major.minor. +func (r *Request) ProtoAtLeast(major, minor int) bool { + return r.ProtoMajor > major || + r.ProtoMajor == major && r.ProtoMinor >= minor +} + +// UserAgent returns the client's User-Agent, if sent in the request. +func (r *Request) UserAgent() string { + return r.Header.Get("User-Agent") +} + +// Cookies parses and returns the HTTP cookies sent with the request. +func (r *Request) Cookies() []*Cookie { + return readCookies(r.Header, "") +} + +// ErrNoCookie is returned by Request's Cookie method when a cookie is not found. +var ErrNoCookie = errors.New("http: named cookie not present") + +// Cookie returns the named cookie provided in the request or +// ErrNoCookie if not found. +// If multiple cookies match the given name, only one cookie will +// be returned. +func (r *Request) Cookie(name string) (*Cookie, error) { + for _, c := range readCookies(r.Header, name) { + return c, nil + } + return nil, ErrNoCookie +} + +// AddCookie adds a cookie to the request. Per RFC 6265 section 5.4, +// AddCookie does not attach more than one Cookie header field. That +// means all cookies, if any, are written into the same line, +// separated by semicolon. +// AddCookie only sanitizes c's name and value, and does not sanitize +// a Cookie header already present in the request. +func (r *Request) AddCookie(c *Cookie) { + s := fmt.Sprintf("%s=%s", sanitizeCookieName(c.Name), sanitizeCookieValue(c.Value)) + if c := r.Header.Get("Cookie"); c != "" { + r.Header.Set("Cookie", c+"; "+s) + } else { + r.Header.Set("Cookie", s) + } +} + +// Referer returns the referring URL, if sent in the request. +// +// Referer is misspelled as in the request itself, a mistake from the +// earliest days of HTTP. This value can also be fetched from the +// Header map as Header["Referer"]; the benefit of making it available +// as a method is that the compiler can diagnose programs that use the +// alternate (correct English) spelling req.Referrer() but cannot +// diagnose programs that use Header["Referrer"]. +func (r *Request) Referer() string { + return r.Header.Get("Referer") +} + +// multipartByReader is a sentinel value. +// Its presence in Request.MultipartForm indicates that parsing of the request +// body has been handed off to a MultipartReader instead of ParseMultipartForm. +var multipartByReader = &multipart.Form{ + Value: make(map[string][]string), + File: make(map[string][]*multipart.FileHeader), +} + +// MultipartReader returns a MIME multipart reader if this is a +// multipart/form-data or a multipart/mixed POST request, else returns nil and an error. +// Use this function instead of ParseMultipartForm to +// process the request body as a stream. +func (r *Request) MultipartReader() (*multipart.Reader, error) { + if r.MultipartForm == multipartByReader { + return nil, errors.New("http: MultipartReader called twice") + } + if r.MultipartForm != nil { + return nil, errors.New("http: multipart handled by ParseMultipartForm") + } + r.MultipartForm = multipartByReader + return r.multipartReader(true) +} + +func (r *Request) multipartReader(allowMixed bool) (*multipart.Reader, error) { + v := r.Header.Get("Content-Type") + if v == "" { + return nil, ErrNotMultipart + } + if r.Body == nil { + return nil, errors.New("missing form body") + } + d, params, err := mime.ParseMediaType(v) + if err != nil || !(d == "multipart/form-data" || allowMixed && d == "multipart/mixed") { + return nil, ErrNotMultipart + } + boundary, ok := params["boundary"] + if !ok { + return nil, ErrMissingBoundary + } + return multipart.NewReader(r.Body, boundary), nil +} + +// isH2Upgrade reports whether r represents the http2 "client preface" +// magic string. +func (r *Request) isH2Upgrade() bool { + return r.Method == "PRI" && len(r.Header) == 0 && r.URL.Path == "*" && r.Proto == "HTTP/2.0" +} + +// Return value if nonempty, def otherwise. +func valueOrDefault(value, def string) string { + if value != "" { + return value + } + return def +} + +// NOTE: This is not intended to reflect the actual Go version being used. +// It was changed at the time of Go 1.1 release because the former User-Agent +// had ended up blocked by some intrusion detection systems. +// See https://codereview.appspot.com/7532043. +const defaultUserAgent = "Go-http-client/1.1" + +// Write writes an HTTP/1.1 request, which is the header and body, in wire format. +// This method consults the following fields of the request: +// +// Host +// URL +// Method (defaults to "GET") +// Header +// ContentLength +// TransferEncoding +// Body +// +// If Body is present, Content-Length is <= 0 and TransferEncoding +// hasn't been set to "identity", Write adds "Transfer-Encoding: +// chunked" to the header. Body is closed after it is sent. +func (r *Request) Write(w io.Writer) error { + return r.write(w, false, nil, nil) +} + +// WriteProxy is like Write but writes the request in the form +// expected by an HTTP proxy. In particular, WriteProxy writes the +// initial Request-URI line of the request with an absolute URI, per +// section 5.3 of RFC 7230, including the scheme and host. +// In either case, WriteProxy also writes a Host header, using +// either r.Host or r.URL.Host. +func (r *Request) WriteProxy(w io.Writer) error { + return r.write(w, true, nil, nil) +} + +// errMissingHost is returned by Write when there is no Host or URL present in +// the Request. +var errMissingHost = errors.New("http: Request.Write on Request with no Host or URL set") + +// extraHeaders may be nil +// waitForContinue may be nil +// always closes body +func (r *Request) write(w io.Writer, usingProxy bool, extraHeaders Header, waitForContinue func() bool) (err error) { + closed := false + defer func() { + if closed { + return + } + if closeErr := r.closeBody(); closeErr != nil && err == nil { + err = closeErr + } + }() + + // Find the target host. Prefer the Host: header, but if that + // is not given, use the host from the request URL. + // + // Clean the host, in case it arrives with unexpected stuff in it. + host := cleanHost(r.Host) + if host == "" { + if r.URL == nil { + return errMissingHost + } + host = cleanHost(r.URL.Host) + } + + // According to RFC 6874, an HTTP client, proxy, or other + // intermediary must remove any IPv6 zone identifier attached + // to an outgoing URI. + host = removeZone(host) + + ruri := r.URL.RequestURI() + if usingProxy && r.URL.Scheme != "" && r.URL.Opaque == "" { + ruri = r.URL.Scheme + "://" + host + ruri + } else if r.Method == "CONNECT" && r.URL.Path == "" { + // CONNECT requests normally give just the host and port, not a full URL. + ruri = host + if r.URL.Opaque != "" { + ruri = r.URL.Opaque + } + } + if stringContainsCTLByte(ruri) { + return errors.New("net/http: can't write control character in Request.URL") + } + // TODO: validate r.Method too? At least it's less likely to + // come from an attacker (more likely to be a constant in + // code). + + // Wrap the writer in a bufio Writer if it's not already buffered. + // Don't always call NewWriter, as that forces a bytes.Buffer + // and other small bufio Writers to have a minimum 4k buffer + // size. + var bw *bufio.Writer + if _, ok := w.(io.ByteWriter); !ok { + bw = bufio.NewWriter(w) + w = bw + } + + _, err = fmt.Fprintf(w, "%s %s HTTP/1.1\r\n", valueOrDefault(r.Method, "GET"), ruri) + if err != nil { + return err + } + + // Header lines + _, err = fmt.Fprintf(w, "Host: %s\r\n", host) + if err != nil { + return err + } + + // Use the defaultUserAgent unless the Header contains one, which + // may be blank to not send the header. + userAgent := defaultUserAgent + if r.Header.has("User-Agent") { + userAgent = r.Header.Get("User-Agent") + } + if userAgent != "" { + _, err = fmt.Fprintf(w, "User-Agent: %s\r\n", userAgent) + if err != nil { + return err + } + } + + // Process Body,ContentLength,Close,Trailer + tw, err := newTransferWriter(r) + if err != nil { + return err + } + err = tw.writeHeader(w) + if err != nil { + return err + } + + err = r.Header.writeSubset(w, reqWriteExcludeHeader) + if err != nil { + return err + } + + if extraHeaders != nil { + err = extraHeaders.write(w) + if err != nil { + return err + } + } + + _, err = io.WriteString(w, "\r\n") + if err != nil { + return err + } + + // Flush and wait for 100-continue if expected. + if waitForContinue != nil { + if bw, ok := w.(*bufio.Writer); ok { + err = bw.Flush() + if err != nil { + return err + } + } + if !waitForContinue() { + closed = true + r.closeBody() + return nil + } + } + + if bw, ok := w.(*bufio.Writer); ok && tw.FlushHeaders { + if err := bw.Flush(); err != nil { + return err + } + } + + // Write body and trailer + closed = true + err = tw.writeBody(w) + if err != nil { + if tw.bodyReadError == err { + err = requestBodyReadError{err} + } + return err + } + + if bw != nil { + return bw.Flush() + } + return nil +} + +// requestBodyReadError wraps an error from (*Request).write to indicate +// that the error came from a Read call on the Request.Body. +// This error type should not escape the net/http package to users. +type requestBodyReadError struct{ error } + +// cleanHost cleans up the host sent in request's Host header. +// +// It both strips anything after '/' or ' ', and puts the value +// into Punycode form, if necessary. +// +// Ideally we'd clean the Host header according to the spec: +// +// https://tools.ietf.org/html/rfc7230#section-5.4 (Host = uri-host [ ":" port ]") +// https://tools.ietf.org/html/rfc7230#section-2.7 (uri-host -> rfc3986's host) +// https://tools.ietf.org/html/rfc3986#section-3.2.2 (definition of host) +// +// But practically, what we are trying to avoid is the situation in +// issue 11206, where a malformed Host header used in the proxy context +// would create a bad request. So it is enough to just truncate at the +// first offending character. + +// TINYGO: Removed IDNA checks...it doubled the binary size + +func cleanHost(in string) string { + if i := strings.IndexAny(in, " /"); i != -1 { + in = in[:i] + } + host, port, err := net.SplitHostPort(in) + if err != nil { // input was just a host + return in + } + return net.JoinHostPort(host, port) +} + +// removeZone removes IPv6 zone identifier from host. +// E.g., "[fe80::1%en0]:8080" to "[fe80::1]:8080" +func removeZone(host string) string { + if !strings.HasPrefix(host, "[") { + return host + } + i := strings.LastIndex(host, "]") + if i < 0 { + return host + } + j := strings.LastIndex(host[:i], "%") + if j < 0 { + return host + } + return host[:j] + host[i:] +} + +// ParseHTTPVersion parses an HTTP version string according to RFC 7230, section 2.6. +// "HTTP/1.0" returns (1, 0, true). Note that strings without +// a minor version, such as "HTTP/2", are not valid. +func ParseHTTPVersion(vers string) (major, minor int, ok bool) { + switch vers { + case "HTTP/1.1": + return 1, 1, true + case "HTTP/1.0": + return 1, 0, true + } + if !strings.HasPrefix(vers, "HTTP/") { + return 0, 0, false + } + if len(vers) != len("HTTP/X.Y") { + return 0, 0, false + } + if vers[6] != '.' { + return 0, 0, false + } + maj, err := strconv.ParseUint(vers[5:6], 10, 0) + if err != nil { + return 0, 0, false + } + min, err := strconv.ParseUint(vers[7:8], 10, 0) + if err != nil { + return 0, 0, false + } + return int(maj), int(min), true +} + +func validMethod(method string) bool { + /* + Method = "OPTIONS" ; Section 9.2 + | "GET" ; Section 9.3 + | "HEAD" ; Section 9.4 + | "POST" ; Section 9.5 + | "PUT" ; Section 9.6 + | "DELETE" ; Section 9.7 + | "TRACE" ; Section 9.8 + | "CONNECT" ; Section 9.9 + | extension-method + extension-method = token + token = 1* + */ + return len(method) > 0 && strings.IndexFunc(method, isNotToken) == -1 +} + +// NewRequest wraps NewRequestWithContext using context.Background. +func NewRequest(method, url string, body io.Reader) (*Request, error) { + return NewRequestWithContext(context.Background(), method, url, body) +} + +// NewRequestWithContext returns a new Request given a method, URL, and +// optional body. +// +// If the provided body is also an io.Closer, the returned +// Request.Body is set to body and will be closed by the Client +// methods Do, Post, and PostForm, and Transport.RoundTrip. +// +// NewRequestWithContext returns a Request suitable for use with +// Client.Do or Transport.RoundTrip. To create a request for use with +// testing a Server Handler, either use the NewRequest function in the +// net/http/httptest package, use ReadRequest, or manually update the +// Request fields. For an outgoing client request, the context +// controls the entire lifetime of a request and its response: +// obtaining a connection, sending the request, and reading the +// response headers and body. See the Request type's documentation for +// the difference between inbound and outbound request fields. +// +// If body is of type *bytes.Buffer, *bytes.Reader, or +// *strings.Reader, the returned request's ContentLength is set to its +// exact value (instead of -1), GetBody is populated (so 307 and 308 +// redirects can replay the body), and Body is set to NoBody if the +// ContentLength is 0. +func NewRequestWithContext(ctx context.Context, method, url string, body io.Reader) (*Request, error) { + if method == "" { + // We document that "" means "GET" for Request.Method, and people have + // relied on that from NewRequest, so keep that working. + // We still enforce validMethod for non-empty methods. + method = "GET" + } + if !validMethod(method) { + return nil, fmt.Errorf("net/http: invalid method %q", method) + } + if ctx == nil { + return nil, errors.New("net/http: nil Context") + } + u, err := urlpkg.Parse(url) + if err != nil { + return nil, err + } + rc, ok := body.(io.ReadCloser) + if !ok && body != nil { + rc = io.NopCloser(body) + } + // The host's colon:port should be normalized. See Issue 14836. + u.Host = removeEmptyPort(u.Host) + req := &Request{ + ctx: ctx, + Method: method, + URL: u, + Proto: "HTTP/1.1", + ProtoMajor: 1, + ProtoMinor: 1, + Header: make(Header), + Body: rc, + Host: u.Host, + } + if body != nil { + switch v := body.(type) { + case *bytes.Buffer: + req.ContentLength = int64(v.Len()) + buf := v.Bytes() + req.GetBody = func() (io.ReadCloser, error) { + r := bytes.NewReader(buf) + return io.NopCloser(r), nil + } + case *bytes.Reader: + req.ContentLength = int64(v.Len()) + snapshot := *v + req.GetBody = func() (io.ReadCloser, error) { + r := snapshot + return io.NopCloser(&r), nil + } + case *strings.Reader: + req.ContentLength = int64(v.Len()) + snapshot := *v + req.GetBody = func() (io.ReadCloser, error) { + r := snapshot + return io.NopCloser(&r), nil + } + default: + // This is where we'd set it to -1 (at least + // if body != NoBody) to mean unknown, but + // that broke people during the Go 1.8 testing + // period. People depend on it being 0 I + // guess. Maybe retry later. See Issue 18117. + } + // For client requests, Request.ContentLength of 0 + // means either actually 0, or unknown. The only way + // to explicitly say that the ContentLength is zero is + // to set the Body to nil. But turns out too much code + // depends on NewRequest returning a non-nil Body, + // so we use a well-known ReadCloser variable instead + // and have the http package also treat that sentinel + // variable to mean explicitly zero. + if req.GetBody != nil && req.ContentLength == 0 { + req.Body = NoBody + req.GetBody = func() (io.ReadCloser, error) { return NoBody, nil } + } + } + + return req, nil +} + +// BasicAuth returns the username and password provided in the request's +// Authorization header, if the request uses HTTP Basic Authentication. +// See RFC 2617, Section 2. +func (r *Request) BasicAuth() (username, password string, ok bool) { + auth := r.Header.Get("Authorization") + if auth == "" { + return "", "", false + } + return parseBasicAuth(auth) +} + +// parseBasicAuth parses an HTTP Basic Authentication string. +// "Basic QWxhZGRpbjpvcGVuIHNlc2FtZQ==" returns ("Aladdin", "open sesame", true). +func parseBasicAuth(auth string) (username, password string, ok bool) { + const prefix = "Basic " + // Case insensitive prefix match. See Issue 22736. + if len(auth) < len(prefix) || !ascii.EqualFold(auth[:len(prefix)], prefix) { + return "", "", false + } + c, err := base64.StdEncoding.DecodeString(auth[len(prefix):]) + if err != nil { + return "", "", false + } + cs := string(c) + username, password, ok = strings.Cut(cs, ":") + if !ok { + return "", "", false + } + return username, password, true +} + +// SetBasicAuth sets the request's Authorization header to use HTTP +// Basic Authentication with the provided username and password. +// +// With HTTP Basic Authentication the provided username and password +// are not encrypted. It should generally only be used in an HTTPS +// request. +// +// The username may not contain a colon. Some protocols may impose +// additional requirements on pre-escaping the username and +// password. For instance, when used with OAuth2, both arguments must +// be URL encoded first with url.QueryEscape. +func (r *Request) SetBasicAuth(username, password string) { + r.Header.Set("Authorization", "Basic "+basicAuth(username, password)) +} + +// parseRequestLine parses "GET /foo HTTP/1.1" into its three parts. +func parseRequestLine(line string) (method, requestURI, proto string, ok bool) { + method, rest, ok1 := strings.Cut(line, " ") + requestURI, proto, ok2 := strings.Cut(rest, " ") + if !ok1 || !ok2 { + return "", "", "", false + } + return method, requestURI, proto, true +} + +var textprotoReaderPool sync.Pool + +func newTextprotoReader(br *bufio.Reader) *textproto.Reader { + if v := textprotoReaderPool.Get(); v != nil { + tr := v.(*textproto.Reader) + tr.R = br + return tr + } + return textproto.NewReader(br) +} + +func putTextprotoReader(r *textproto.Reader) { + r.R = nil + textprotoReaderPool.Put(r) +} + +// ReadRequest reads and parses an incoming request from b. +// +// ReadRequest is a low-level function and should only be used for +// specialized applications; most code should use the Server to read +// requests and handle them via the Handler interface. ReadRequest +// only supports HTTP/1.x requests. For HTTP/2, use golang.org/x/net/http2. +func ReadRequest(b *bufio.Reader) (*Request, error) { + req, err := readRequest(b) + if err != nil { + return nil, err + } + + delete(req.Header, "Host") + return req, err +} + +func readRequest(b *bufio.Reader) (req *Request, err error) { + tp := newTextprotoReader(b) + req = new(Request) + + // First line: GET /index.html HTTP/1.0 + var s string + if s, err = tp.ReadLine(); err != nil { + return nil, err + } + defer func() { + putTextprotoReader(tp) + if err == io.EOF { + err = io.ErrUnexpectedEOF + } + }() + + var ok bool + req.Method, req.RequestURI, req.Proto, ok = parseRequestLine(s) + if !ok { + return nil, badStringError("malformed HTTP request", s) + } + if !validMethod(req.Method) { + return nil, badStringError("invalid method", req.Method) + } + rawurl := req.RequestURI + if req.ProtoMajor, req.ProtoMinor, ok = ParseHTTPVersion(req.Proto); !ok { + return nil, badStringError("malformed HTTP version", req.Proto) + } + + // CONNECT requests are used two different ways, and neither uses a full URL: + // The standard use is to tunnel HTTPS through an HTTP proxy. + // It looks like "CONNECT www.google.com:443 HTTP/1.1", and the parameter is + // just the authority section of a URL. This information should go in req.URL.Host. + // + // The net/rpc package also uses CONNECT, but there the parameter is a path + // that starts with a slash. It can be parsed with the regular URL parser, + // and the path will end up in req.URL.Path, where it needs to be in order for + // RPC to work. + justAuthority := req.Method == "CONNECT" && !strings.HasPrefix(rawurl, "/") + if justAuthority { + rawurl = "http://" + rawurl + } + + if req.URL, err = url.ParseRequestURI(rawurl); err != nil { + return nil, err + } + + if justAuthority { + // Strip the bogus "http://" back off. + req.URL.Scheme = "" + } + + // Subsequent lines: Key: value. + mimeHeader, err := tp.ReadMIMEHeader() + if err != nil { + return nil, err + } + req.Header = Header(mimeHeader) + if len(req.Header["Host"]) > 1 { + return nil, fmt.Errorf("too many Host headers") + } + + // RFC 7230, section 5.3: Must treat + // GET /index.html HTTP/1.1 + // Host: www.google.com + // and + // GET http://www.google.com/index.html HTTP/1.1 + // Host: doesntmatter + // the same. In the second case, any Host line is ignored. + req.Host = req.URL.Host + if req.Host == "" { + req.Host = req.Header.get("Host") + } + + fixPragmaCacheControl(req.Header) + + req.Close = shouldClose(req.ProtoMajor, req.ProtoMinor, req.Header, false) + + err = readTransfer(req, b, nil) + if err != nil { + return nil, err + } + + if req.isH2Upgrade() { + // Because it's neither chunked, nor declared: + req.ContentLength = -1 + + // We want to give handlers a chance to hijack the + // connection, but we need to prevent the Server from + // dealing with the connection further if it's not + // hijacked. Set Close to ensure that: + req.Close = true + } + return req, nil +} + +// MaxBytesReader is similar to io.LimitReader but is intended for +// limiting the size of incoming request bodies. In contrast to +// io.LimitReader, MaxBytesReader's result is a ReadCloser, returns a +// non-nil error of type *MaxBytesError for a Read beyond the limit, +// and closes the underlying reader when its Close method is called. +// +// MaxBytesReader prevents clients from accidentally or maliciously +// sending a large request and wasting server resources. If possible, +// it tells the ResponseWriter to close the connection after the limit +// has been reached. +func MaxBytesReader(w ResponseWriter, r io.ReadCloser, n int64) io.ReadCloser { + if n < 0 { // Treat negative limits as equivalent to 0. + n = 0 + } + return &maxBytesReader{w: w, r: r, i: n, n: n} +} + +// MaxBytesError is returned by MaxBytesReader when its read limit is exceeded. +type MaxBytesError struct { + Limit int64 +} + +func (e *MaxBytesError) Error() string { + // Due to Hyrum's law, this text cannot be changed. + return "http: request body too large" +} + +type maxBytesReader struct { + w ResponseWriter + r io.ReadCloser // underlying reader + i int64 // max bytes initially, for MaxBytesError + n int64 // max bytes remaining + err error // sticky error +} + +func (l *maxBytesReader) Read(p []byte) (n int, err error) { + if l.err != nil { + return 0, l.err + } + if len(p) == 0 { + return 0, nil + } + // If they asked for a 32KB byte read but only 5 bytes are + // remaining, no need to read 32KB. 6 bytes will answer the + // question of the whether we hit the limit or go past it. + if int64(len(p)) > l.n+1 { + p = p[:l.n+1] + } + n, err = l.r.Read(p) + + if int64(n) <= l.n { + l.n -= int64(n) + l.err = err + return n, err + } + + n = int(l.n) + l.n = 0 + + // The server code and client code both use + // maxBytesReader. This "requestTooLarge" check is + // only used by the server code. To prevent binaries + // which only using the HTTP Client code (such as + // cmd/go) from also linking in the HTTP server, don't + // use a static type assertion to the server + // "*response" type. Check this interface instead: + type requestTooLarger interface { + requestTooLarge() + } + if res, ok := l.w.(requestTooLarger); ok { + res.requestTooLarge() + } + l.err = &MaxBytesError{l.i} + return n, l.err +} + +func (l *maxBytesReader) Close() error { + return l.r.Close() +} + +func copyValues(dst, src url.Values) { + for k, vs := range src { + dst[k] = append(dst[k], vs...) + } +} + +func parsePostForm(r *Request) (vs url.Values, err error) { + if r.Body == nil { + err = errors.New("missing form body") + return + } + ct := r.Header.Get("Content-Type") + // RFC 7231, section 3.1.1.5 - empty type + // MAY be treated as application/octet-stream + if ct == "" { + ct = "application/octet-stream" + } + ct, _, err = mime.ParseMediaType(ct) + switch { + case ct == "application/x-www-form-urlencoded": + var reader io.Reader = r.Body + maxFormSize := int64(1<<63 - 1) + if _, ok := r.Body.(*maxBytesReader); !ok { + maxFormSize = int64(10 << 20) // 10 MB is a lot of text. + reader = io.LimitReader(r.Body, maxFormSize+1) + } + b, e := io.ReadAll(reader) + if e != nil { + if err == nil { + err = e + } + break + } + if int64(len(b)) > maxFormSize { + err = errors.New("http: POST too large") + return + } + vs, e = url.ParseQuery(string(b)) + if err == nil { + err = e + } + case ct == "multipart/form-data": + // handled by ParseMultipartForm (which is calling us, or should be) + // TODO(bradfitz): there are too many possible + // orders to call too many functions here. + // Clean this up and write more tests. + // request_test.go contains the start of this, + // in TestParseMultipartFormOrder and others. + } + return +} + +// ParseForm populates r.Form and r.PostForm. +// +// For all requests, ParseForm parses the raw query from the URL and updates +// r.Form. +// +// For POST, PUT, and PATCH requests, it also reads the request body, parses it +// as a form and puts the results into both r.PostForm and r.Form. Request body +// parameters take precedence over URL query string values in r.Form. +// +// If the request Body's size has not already been limited by MaxBytesReader, +// the size is capped at 10MB. +// +// For other HTTP methods, or when the Content-Type is not +// application/x-www-form-urlencoded, the request Body is not read, and +// r.PostForm is initialized to a non-nil, empty value. +// +// ParseMultipartForm calls ParseForm automatically. +// ParseForm is idempotent. +func (r *Request) ParseForm() error { + var err error + if r.PostForm == nil { + if r.Method == "POST" || r.Method == "PUT" || r.Method == "PATCH" { + r.PostForm, err = parsePostForm(r) + } + if r.PostForm == nil { + r.PostForm = make(url.Values) + } + } + if r.Form == nil { + if len(r.PostForm) > 0 { + r.Form = make(url.Values) + copyValues(r.Form, r.PostForm) + } + var newValues url.Values + if r.URL != nil { + var e error + newValues, e = url.ParseQuery(r.URL.RawQuery) + if err == nil { + err = e + } + } + if newValues == nil { + newValues = make(url.Values) + } + if r.Form == nil { + r.Form = newValues + } else { + copyValues(r.Form, newValues) + } + } + return err +} + +// ParseMultipartForm parses a request body as multipart/form-data. +// The whole request body is parsed and up to a total of maxMemory bytes of +// its file parts are stored in memory, with the remainder stored on +// disk in temporary files. +// ParseMultipartForm calls ParseForm if necessary. +// If ParseForm returns an error, ParseMultipartForm returns it but also +// continues parsing the request body. +// After one call to ParseMultipartForm, subsequent calls have no effect. +func (r *Request) ParseMultipartForm(maxMemory int64) error { + if r.MultipartForm == multipartByReader { + return errors.New("http: multipart handled by MultipartReader") + } + var parseFormErr error + if r.Form == nil { + // Let errors in ParseForm fall through, and just + // return it at the end. + parseFormErr = r.ParseForm() + } + if r.MultipartForm != nil { + return nil + } + + mr, err := r.multipartReader(false) + if err != nil { + return err + } + + f, err := mr.ReadForm(maxMemory) + if err != nil { + return err + } + + if r.PostForm == nil { + r.PostForm = make(url.Values) + } + for k, v := range f.Value { + r.Form[k] = append(r.Form[k], v...) + // r.PostForm should also be populated. See Issue 9305. + r.PostForm[k] = append(r.PostForm[k], v...) + } + + r.MultipartForm = f + + return parseFormErr +} + +// FormValue returns the first value for the named component of the query. +// POST and PUT body parameters take precedence over URL query string values. +// FormValue calls ParseMultipartForm and ParseForm if necessary and ignores +// any errors returned by these functions. +// If key is not present, FormValue returns the empty string. +// To access multiple values of the same key, call ParseForm and +// then inspect Request.Form directly. +func (r *Request) FormValue(key string) string { + if r.Form == nil { + r.ParseMultipartForm(defaultMaxMemory) + } + if vs := r.Form[key]; len(vs) > 0 { + return vs[0] + } + return "" +} + +// PostFormValue returns the first value for the named component of the POST, +// PATCH, or PUT request body. URL query parameters are ignored. +// PostFormValue calls ParseMultipartForm and ParseForm if necessary and ignores +// any errors returned by these functions. +// If key is not present, PostFormValue returns the empty string. +func (r *Request) PostFormValue(key string) string { + if r.PostForm == nil { + r.ParseMultipartForm(defaultMaxMemory) + } + if vs := r.PostForm[key]; len(vs) > 0 { + return vs[0] + } + return "" +} + +// FormFile returns the first file for the provided form key. +// FormFile calls ParseMultipartForm and ParseForm if necessary. +func (r *Request) FormFile(key string) (multipart.File, *multipart.FileHeader, error) { + if r.MultipartForm == multipartByReader { + return nil, nil, errors.New("http: multipart handled by MultipartReader") + } + if r.MultipartForm == nil { + err := r.ParseMultipartForm(defaultMaxMemory) + if err != nil { + return nil, nil, err + } + } + if r.MultipartForm != nil && r.MultipartForm.File != nil { + if fhs := r.MultipartForm.File[key]; len(fhs) > 0 { + f, err := fhs[0].Open() + return f, fhs[0], err + } + } + return nil, nil, ErrMissingFile +} + +func (r *Request) expectsContinue() bool { + return hasToken(r.Header.get("Expect"), "100-continue") +} + +func (r *Request) wantsHttp10KeepAlive() bool { + if r.ProtoMajor != 1 || r.ProtoMinor != 0 { + return false + } + return hasToken(r.Header.get("Connection"), "keep-alive") +} + +func (r *Request) wantsClose() bool { + if r.Close { + return true + } + return hasToken(r.Header.get("Connection"), "close") +} + +func (r *Request) closeBody() error { + if r.Body == nil { + return nil + } + return r.Body.Close() +} + +func (r *Request) isReplayable() bool { + if r.Body == nil || r.Body == NoBody || r.GetBody != nil { + switch valueOrDefault(r.Method, "GET") { + case "GET", "HEAD", "OPTIONS", "TRACE": + return true + } + // The Idempotency-Key, while non-standard, is widely used to + // mean a POST or other request is idempotent. See + // https://golang.org/issue/19943#issuecomment-421092421 + if r.Header.has("Idempotency-Key") || r.Header.has("X-Idempotency-Key") { + return true + } + } + return false +} + +// outgoingLength reports the Content-Length of this outgoing (Client) request. +// It maps 0 into -1 (unknown) when the Body is non-nil. +func (r *Request) outgoingLength() int64 { + if r.Body == nil || r.Body == NoBody { + return 0 + } + if r.ContentLength != 0 { + return r.ContentLength + } + return -1 +} + +// requestMethodUsuallyLacksBody reports whether the given request +// method is one that typically does not involve a request body. +// This is used by the Transport (via +// transferWriter.shouldSendChunkedRequestBody) to determine whether +// we try to test-read a byte from a non-nil Request.Body when +// Request.outgoingLength() returns -1. See the comments in +// shouldSendChunkedRequestBody. +func requestMethodUsuallyLacksBody(method string) bool { + switch method { + case "GET", "HEAD", "DELETE", "OPTIONS", "PROPFIND", "SEARCH": + return true + } + return false +} + +// requiresHTTP1 reports whether this request requires being sent on +// an HTTP/1 connection. +func (r *Request) requiresHTTP1() bool { + return hasToken(r.Header.Get("Connection"), "upgrade") && + ascii.EqualFold(r.Header.Get("Upgrade"), "websocket") +} diff --git a/src/net/http/response.go b/src/net/http/response.go new file mode 100644 index 0000000000..980329f9c5 --- /dev/null +++ b/src/net/http/response.go @@ -0,0 +1,373 @@ +// TINYGO: The following is copied and modified from Go 1.19.3 official implementation. + +// TINYGO: Removed TLS connection state +// TINYGO: Added onEOF hook to get callback when response has been read + +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// HTTP Response reading and parsing. + +package http + +import ( + "bufio" + "bytes" + "errors" + "fmt" + "io" + "net/textproto" + "net/url" + "strconv" + "strings" + + "golang.org/x/net/http/httpguts" +) + +var respExcludeHeader = map[string]bool{ + "Content-Length": true, + "Transfer-Encoding": true, + "Trailer": true, +} + +// Response represents the response from an HTTP request. +// +// The Client and Transport return Responses from servers once +// the response headers have been received. The response body +// is streamed on demand as the Body field is read. +type Response struct { + Status string // e.g. "200 OK" + StatusCode int // e.g. 200 + Proto string // e.g. "HTTP/1.0" + ProtoMajor int // e.g. 1 + ProtoMinor int // e.g. 0 + + // Header maps header keys to values. If the response had multiple + // headers with the same key, they may be concatenated, with comma + // delimiters. (RFC 7230, section 3.2.2 requires that multiple headers + // be semantically equivalent to a comma-delimited sequence.) When + // Header values are duplicated by other fields in this struct (e.g., + // ContentLength, TransferEncoding, Trailer), the field values are + // authoritative. + // + // Keys in the map are canonicalized (see CanonicalHeaderKey). + Header Header + + // Body represents the response body. + // + // The response body is streamed on demand as the Body field + // is read. If the network connection fails or the server + // terminates the response, Body.Read calls return an error. + // + // The http Client and Transport guarantee that Body is always + // non-nil, even on responses without a body or responses with + // a zero-length body. It is the caller's responsibility to + // close Body. The default HTTP client's Transport may not + // reuse HTTP/1.x "keep-alive" TCP connections if the Body is + // not read to completion and closed. + // + // The Body is automatically dechunked if the server replied + // with a "chunked" Transfer-Encoding. + // + // As of Go 1.12, the Body will also implement io.Writer + // on a successful "101 Switching Protocols" response, + // as used by WebSockets and HTTP/2's "h2c" mode. + Body io.ReadCloser + + // ContentLength records the length of the associated content. The + // value -1 indicates that the length is unknown. Unless Request.Method + // is "HEAD", values >= 0 indicate that the given number of bytes may + // be read from Body. + ContentLength int64 + + // Contains transfer encodings from outer-most to inner-most. Value is + // nil, means that "identity" encoding is used. + TransferEncoding []string + + // Close records whether the header directed that the connection be + // closed after reading Body. The value is advice for clients: neither + // ReadResponse nor Response.Write ever closes a connection. + Close bool + + // Uncompressed reports whether the response was sent compressed but + // was decompressed by the http package. When true, reading from + // Body yields the uncompressed content instead of the compressed + // content actually set from the server, ContentLength is set to -1, + // and the "Content-Length" and "Content-Encoding" fields are deleted + // from the responseHeader. To get the original response from + // the server, set Transport.DisableCompression to true. + Uncompressed bool + + // Trailer maps trailer keys to values in the same + // format as Header. + // + // The Trailer initially contains only nil values, one for + // each key specified in the server's "Trailer" header + // value. Those values are not added to Header. + // + // Trailer must not be accessed concurrently with Read calls + // on the Body. + // + // After Body.Read has returned io.EOF, Trailer will contain + // any trailer values sent by the server. + Trailer Header + + // Request is the request that was sent to obtain this Response. + // Request's Body is nil (having already been consumed). + // This is only populated for Client requests. + Request *Request +} + +// Cookies parses and returns the cookies set in the Set-Cookie headers. +func (r *Response) Cookies() []*Cookie { + return readSetCookies(r.Header) +} + +// ErrNoLocation is returned by Response's Location method +// when no Location header is present. +var ErrNoLocation = errors.New("http: no Location header in response") + +// Location returns the URL of the response's "Location" header, +// if present. Relative redirects are resolved relative to +// the Response's Request. ErrNoLocation is returned if no +// Location header is present. +func (r *Response) Location() (*url.URL, error) { + lv := r.Header.Get("Location") + if lv == "" { + return nil, ErrNoLocation + } + if r.Request != nil && r.Request.URL != nil { + return r.Request.URL.Parse(lv) + } + return url.Parse(lv) +} + +// ReadResponse reads and returns an HTTP response from r. +// The req parameter optionally specifies the Request that corresponds +// to this Response. If nil, a GET request is assumed. +// Clients must call resp.Body.Close when finished reading resp.Body. +// After that call, clients can inspect resp.Trailer to find key/value +// pairs included in the response trailer. + +// TINYGO: Added onEOF func to be called when response body is closed +// TINYGO: so we can clean up the connection (r) + +func ReadResponse(r *bufio.Reader, req *Request) (*Response, error) { + tp := textproto.NewReader(r) + resp := &Response{ + Request: req, + } + + // Parse the first line of the response. + line, err := tp.ReadLine() + if err != nil { + if err == io.EOF { + err = io.ErrUnexpectedEOF + } + return nil, err + } + proto, status, ok := strings.Cut(line, " ") + if !ok { + return nil, badStringError("malformed HTTP response", line) + } + resp.Proto = proto + resp.Status = strings.TrimLeft(status, " ") + + statusCode, _, _ := strings.Cut(resp.Status, " ") + if len(statusCode) != 3 { + return nil, badStringError("malformed HTTP status code", statusCode) + } + resp.StatusCode, err = strconv.Atoi(statusCode) + if err != nil || resp.StatusCode < 0 { + return nil, badStringError("malformed HTTP status code", statusCode) + } + if resp.ProtoMajor, resp.ProtoMinor, ok = ParseHTTPVersion(resp.Proto); !ok { + return nil, badStringError("malformed HTTP version", resp.Proto) + } + + // Parse the response headers. + mimeHeader, err := tp.ReadMIMEHeader() + if err != nil { + if err == io.EOF { + err = io.ErrUnexpectedEOF + } + return nil, err + } + resp.Header = Header(mimeHeader) + + fixPragmaCacheControl(resp.Header) + + err = readTransfer(resp, r, req.onEOF) + if err != nil { + return nil, err + } + + return resp, nil +} + +// RFC 7234, section 5.4: Should treat +// +// Pragma: no-cache +// +// like +// +// Cache-Control: no-cache +func fixPragmaCacheControl(header Header) { + if hp, ok := header["Pragma"]; ok && len(hp) > 0 && hp[0] == "no-cache" { + if _, presentcc := header["Cache-Control"]; !presentcc { + header["Cache-Control"] = []string{"no-cache"} + } + } +} + +// ProtoAtLeast reports whether the HTTP protocol used +// in the response is at least major.minor. +func (r *Response) ProtoAtLeast(major, minor int) bool { + return r.ProtoMajor > major || + r.ProtoMajor == major && r.ProtoMinor >= minor +} + +// Write writes r to w in the HTTP/1.x server response format, +// including the status line, headers, body, and optional trailer. +// +// This method consults the following fields of the response r: +// +// StatusCode +// ProtoMajor +// ProtoMinor +// Request.Method +// TransferEncoding +// Trailer +// Body +// ContentLength +// Header, values for non-canonical keys will have unpredictable behavior +// +// The Response Body is closed after it is sent. +func (r *Response) Write(w io.Writer) error { + // Status line + text := r.Status + if text == "" { + text = StatusText(r.StatusCode) + if text == "" { + text = "status code " + strconv.Itoa(r.StatusCode) + } + } else { + // Just to reduce stutter, if user set r.Status to "200 OK" and StatusCode to 200. + // Not important. + text = strings.TrimPrefix(text, strconv.Itoa(r.StatusCode)+" ") + } + + if _, err := fmt.Fprintf(w, "HTTP/%d.%d %03d %s\r\n", r.ProtoMajor, r.ProtoMinor, r.StatusCode, text); err != nil { + return err + } + + // Clone it, so we can modify r1 as needed. + r1 := new(Response) + *r1 = *r + if r1.ContentLength == 0 && r1.Body != nil { + // Is it actually 0 length? Or just unknown? + var buf [1]byte + n, err := r1.Body.Read(buf[:]) + if err != nil && err != io.EOF { + return err + } + if n == 0 { + // Reset it to a known zero reader, in case underlying one + // is unhappy being read repeatedly. + r1.Body = NoBody + } else { + r1.ContentLength = -1 + r1.Body = struct { + io.Reader + io.Closer + }{ + io.MultiReader(bytes.NewReader(buf[:1]), r.Body), + r.Body, + } + } + } + // If we're sending a non-chunked HTTP/1.1 response without a + // content-length, the only way to do that is the old HTTP/1.0 + // way, by noting the EOF with a connection close, so we need + // to set Close. + if r1.ContentLength == -1 && !r1.Close && r1.ProtoAtLeast(1, 1) && !chunked(r1.TransferEncoding) && !r1.Uncompressed { + r1.Close = true + } + + // Process Body,ContentLength,Close,Trailer + tw, err := newTransferWriter(r1) + if err != nil { + return err + } + err = tw.writeHeader(w) + if err != nil { + return err + } + + // Rest of header + err = r.Header.WriteSubset(w, respExcludeHeader) + if err != nil { + return err + } + + // contentLengthAlreadySent may have been already sent for + // POST/PUT requests, even if zero length. See Issue 8180. + contentLengthAlreadySent := tw.shouldSendContentLength() + if r1.ContentLength == 0 && !chunked(r1.TransferEncoding) && !contentLengthAlreadySent && bodyAllowedForStatus(r.StatusCode) { + if _, err := io.WriteString(w, "Content-Length: 0\r\n"); err != nil { + return err + } + } + + // End-of-header + if _, err := io.WriteString(w, "\r\n"); err != nil { + return err + } + + // Write body and trailer + err = tw.writeBody(w) + if err != nil { + return err + } + + // Success + return nil +} + +func (r *Response) closeBody() { + if r.Body != nil { + r.Body.Close() + } +} + +// bodyIsWritable reports whether the Body supports writing. The +// Transport returns Writable bodies for 101 Switching Protocols +// responses. +// The Transport uses this method to determine whether a persistent +// connection is done being managed from its perspective. Once we +// return a writable response body to a user, the net/http package is +// done managing that connection. +func (r *Response) bodyIsWritable() bool { + _, ok := r.Body.(io.Writer) + return ok +} + +// isProtocolSwitch reports whether the response code and header +// indicate a successful protocol upgrade response. +func (r *Response) isProtocolSwitch() bool { + return isProtocolSwitchResponse(r.StatusCode, r.Header) +} + +// isProtocolSwitchResponse reports whether the response code and +// response header indicate a successful protocol upgrade response. +func isProtocolSwitchResponse(code int, h Header) bool { + return code == StatusSwitchingProtocols && isProtocolSwitchHeader(h) +} + +// isProtocolSwitchHeader reports whether the request or response header +// is for a protocol switch. +func isProtocolSwitchHeader(h Header) bool { + return h.Get("Upgrade") != "" && + httpguts.HeaderValuesContainsToken(h["Connection"], "Upgrade") +} diff --git a/src/net/http/server.go b/src/net/http/server.go new file mode 100644 index 0000000000..1a4264d17e --- /dev/null +++ b/src/net/http/server.go @@ -0,0 +1,3261 @@ +// TINYGO: The following is copied and modified from Go 1.19.3 official implementation. + +// TINYGO: Removed ALPN protocol support +// TINYGO: Removed some HTTP/2 support +// TINYGO: Removed TimeoutHandler +// TINYGO: Removed ServeTLS and ListenAndServeTLS + +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// HTTP server. See RFC 7230 through 7235. + +package http + +import ( + "bufio" + "context" + "crypto/tls" + "errors" + "fmt" + "io" + "log" + "math/rand" + "net" + "net/textproto" + "net/url" + urlpkg "net/url" + "path" + "runtime" + "sort" + "strconv" + "strings" + "sync" + "sync/atomic" + "time" + + "golang.org/x/net/http/httpguts" +) + +// Errors used by the HTTP server. +var ( + // ErrBodyNotAllowed is returned by ResponseWriter.Write calls + // when the HTTP method or response code does not permit a + // body. + ErrBodyNotAllowed = errors.New("http: request method or response status code does not allow body") + + // ErrHijacked is returned by ResponseWriter.Write calls when + // the underlying connection has been hijacked using the + // Hijacker interface. A zero-byte write on a hijacked + // connection will return ErrHijacked without any other side + // effects. + ErrHijacked = errors.New("http: connection has been hijacked") + + // ErrContentLength is returned by ResponseWriter.Write calls + // when a Handler set a Content-Length response header with a + // declared size and then attempted to write more bytes than + // declared. + ErrContentLength = errors.New("http: wrote more than the declared Content-Length") + + // Deprecated: ErrWriteAfterFlush is no longer returned by + // anything in the net/http package. Callers should not + // compare errors against this variable. + ErrWriteAfterFlush = errors.New("unused") +) + +// A Handler responds to an HTTP request. +// +// ServeHTTP should write reply headers and data to the ResponseWriter +// and then return. Returning signals that the request is finished; it +// is not valid to use the ResponseWriter or read from the +// Request.Body after or concurrently with the completion of the +// ServeHTTP call. +// +// Depending on the HTTP client software, HTTP protocol version, and +// any intermediaries between the client and the Go server, it may not +// be possible to read from the Request.Body after writing to the +// ResponseWriter. Cautious handlers should read the Request.Body +// first, and then reply. +// +// Except for reading the body, handlers should not modify the +// provided Request. +// +// If ServeHTTP panics, the server (the caller of ServeHTTP) assumes +// that the effect of the panic was isolated to the active request. +// It recovers the panic, logs a stack trace to the server error log, +// and either closes the network connection or sends an HTTP/2 +// RST_STREAM, depending on the HTTP protocol. To abort a handler so +// the client sees an interrupted response but the server doesn't log +// an error, panic with the value ErrAbortHandler. +type Handler interface { + ServeHTTP(ResponseWriter, *Request) +} + +// A ResponseWriter interface is used by an HTTP handler to +// construct an HTTP response. +// +// A ResponseWriter may not be used after the Handler.ServeHTTP method +// has returned. +type ResponseWriter interface { + // Header returns the header map that will be sent by + // WriteHeader. The Header map also is the mechanism with which + // Handlers can set HTTP trailers. + // + // Changing the header map after a call to WriteHeader (or + // Write) has no effect unless the HTTP status code was of the + // 1xx class or the modified headers are trailers. + // + // There are two ways to set Trailers. The preferred way is to + // predeclare in the headers which trailers you will later + // send by setting the "Trailer" header to the names of the + // trailer keys which will come later. In this case, those + // keys of the Header map are treated as if they were + // trailers. See the example. The second way, for trailer + // keys not known to the Handler until after the first Write, + // is to prefix the Header map keys with the TrailerPrefix + // constant value. See TrailerPrefix. + // + // To suppress automatic response headers (such as "Date"), set + // their value to nil. + Header() Header + + // Write writes the data to the connection as part of an HTTP reply. + // + // If WriteHeader has not yet been called, Write calls + // WriteHeader(http.StatusOK) before writing the data. If the Header + // does not contain a Content-Type line, Write adds a Content-Type set + // to the result of passing the initial 512 bytes of written data to + // DetectContentType. Additionally, if the total size of all written + // data is under a few KB and there are no Flush calls, the + // Content-Length header is added automatically. + // + // Depending on the HTTP protocol version and the client, calling + // Write or WriteHeader may prevent future reads on the + // Request.Body. For HTTP/1.x requests, handlers should read any + // needed request body data before writing the response. Once the + // headers have been flushed (due to either an explicit Flusher.Flush + // call or writing enough data to trigger a flush), the request body + // may be unavailable. For HTTP/2 requests, the Go HTTP server permits + // handlers to continue to read the request body while concurrently + // writing the response. However, such behavior may not be supported + // by all HTTP/2 clients. Handlers should read before writing if + // possible to maximize compatibility. + Write([]byte) (int, error) + + // WriteHeader sends an HTTP response header with the provided + // status code. + // + // If WriteHeader is not called explicitly, the first call to Write + // will trigger an implicit WriteHeader(http.StatusOK). + // Thus explicit calls to WriteHeader are mainly used to + // send error codes or 1xx informational responses. + // + // The provided code must be a valid HTTP 1xx-5xx status code. + // Any number of 1xx headers may be written, followed by at most + // one 2xx-5xx header. 1xx headers are sent immediately, but 2xx-5xx + // headers may be buffered. Use the Flusher interface to send + // buffered data. The header map is cleared when 2xx-5xx headers are + // sent, but not with 1xx headers. + // + // The server will automatically send a 100 (Continue) header + // on the first read from the request body if the request has + // an "Expect: 100-continue" header. + WriteHeader(statusCode int) +} + +// The Flusher interface is implemented by ResponseWriters that allow +// an HTTP handler to flush buffered data to the client. +// +// The default HTTP/1.x and HTTP/2 ResponseWriter implementations +// support Flusher, but ResponseWriter wrappers may not. Handlers +// should always test for this ability at runtime. +// +// Note that even for ResponseWriters that support Flush, +// if the client is connected through an HTTP proxy, +// the buffered data may not reach the client until the response +// completes. +type Flusher interface { + // Flush sends any buffered data to the client. + Flush() +} + +// The Hijacker interface is implemented by ResponseWriters that allow +// an HTTP handler to take over the connection. +// +// The default ResponseWriter for HTTP/1.x connections supports +// Hijacker, but HTTP/2 connections intentionally do not. +// ResponseWriter wrappers may also not support Hijacker. Handlers +// should always test for this ability at runtime. +type Hijacker interface { + // Hijack lets the caller take over the connection. + // After a call to Hijack the HTTP server library + // will not do anything else with the connection. + // + // It becomes the caller's responsibility to manage + // and close the connection. + // + // The returned net.Conn may have read or write deadlines + // already set, depending on the configuration of the + // Server. It is the caller's responsibility to set + // or clear those deadlines as needed. + // + // The returned bufio.Reader may contain unprocessed buffered + // data from the client. + // + // After a call to Hijack, the original Request.Body must not + // be used. The original Request's Context remains valid and + // is not canceled until the Request's ServeHTTP method + // returns. + Hijack() (net.Conn, *bufio.ReadWriter, error) +} + +// The CloseNotifier interface is implemented by ResponseWriters which +// allow detecting when the underlying connection has gone away. +// +// This mechanism can be used to cancel long operations on the server +// if the client has disconnected before the response is ready. +// +// Deprecated: the CloseNotifier interface predates Go's context package. +// New code should use Request.Context instead. +type CloseNotifier interface { + // CloseNotify returns a channel that receives at most a + // single value (true) when the client connection has gone + // away. + // + // CloseNotify may wait to notify until Request.Body has been + // fully read. + // + // After the Handler has returned, there is no guarantee + // that the channel receives a value. + // + // If the protocol is HTTP/1.1 and CloseNotify is called while + // processing an idempotent request (such a GET) while + // HTTP/1.1 pipelining is in use, the arrival of a subsequent + // pipelined request may cause a value to be sent on the + // returned channel. In practice HTTP/1.1 pipelining is not + // enabled in browsers and not seen often in the wild. If this + // is a problem, use HTTP/2 or only use CloseNotify on methods + // such as POST. + CloseNotify() <-chan bool +} + +var ( + // ServerContextKey is a context key. It can be used in HTTP + // handlers with Context.Value to access the server that + // started the handler. The associated value will be of + // type *Server. + ServerContextKey = &contextKey{"http-server"} + + // LocalAddrContextKey is a context key. It can be used in + // HTTP handlers with Context.Value to access the local + // address the connection arrived on. + // The associated value will be of type net.Addr. + LocalAddrContextKey = &contextKey{"local-addr"} +) + +// A conn represents the server side of an HTTP connection. +type conn struct { + // server is the server on which the connection arrived. + // Immutable; never nil. + server *Server + + // cancelCtx cancels the connection-level context. + cancelCtx context.CancelFunc + + // rwc is the underlying network connection. + // This is never wrapped by other types and is the value given out + // to CloseNotifier callers. It is usually of type *net.TCPConn or + // *tls.Conn. + rwc net.Conn + + // remoteAddr is rwc.RemoteAddr().String(). It is not populated synchronously + // inside the Listener's Accept goroutine, as some implementations block. + // It is populated immediately inside the (*conn).serve goroutine. + // This is the value of a Handler's (*Request).RemoteAddr. + remoteAddr string + + // tlsState is the TLS connection state when using TLS. + // nil means not TLS. + tlsState *tls.ConnectionState + + // werr is set to the first write error to rwc. + // It is set via checkConnErrorWriter{w}, where bufw writes. + werr error + + // r is bufr's read source. It's a wrapper around rwc that provides + // io.LimitedReader-style limiting (while reading request headers) + // and functionality to support CloseNotifier. See *connReader docs. + r *connReader + + // bufr reads from r. + bufr *bufio.Reader + + // bufw writes to checkConnErrorWriter{c}, which populates werr on error. + bufw *bufio.Writer + + // lastMethod is the method of the most recent request + // on this connection, if any. + lastMethod string + + curReq atomic.Value // of *response (which has a Request in it) + + curState struct{ atomic uint64 } // packed (unixtime<<8|uint8(ConnState)) + + // mu guards hijackedv + mu sync.Mutex + + // hijackedv is whether this connection has been hijacked + // by a Handler with the Hijacker interface. + // It is guarded by mu. + hijackedv bool +} + +func (c *conn) hijacked() bool { + c.mu.Lock() + defer c.mu.Unlock() + return c.hijackedv +} + +// c.mu must be held. +func (c *conn) hijackLocked() (rwc net.Conn, buf *bufio.ReadWriter, err error) { + if c.hijackedv { + return nil, nil, ErrHijacked + } + c.r.abortPendingRead() + + c.hijackedv = true + rwc = c.rwc + rwc.SetDeadline(time.Time{}) + + buf = bufio.NewReadWriter(c.bufr, bufio.NewWriter(rwc)) + if c.r.hasByte { + if _, err := c.bufr.Peek(c.bufr.Buffered() + 1); err != nil { + return nil, nil, fmt.Errorf("unexpected Peek failure reading buffered byte: %v", err) + } + } + c.setState(rwc, StateHijacked, runHooks) + return +} + +// This should be >= 512 bytes for DetectContentType, +// but otherwise it's somewhat arbitrary. +const bufferBeforeChunkingSize = 2048 + +// chunkWriter writes to a response's conn buffer, and is the writer +// wrapped by the response.w buffered writer. +// +// chunkWriter also is responsible for finalizing the Header, including +// conditionally setting the Content-Type and setting a Content-Length +// in cases where the handler's final output is smaller than the buffer +// size. It also conditionally adds chunk headers, when in chunking mode. +// +// See the comment above (*response).Write for the entire write flow. +type chunkWriter struct { + res *response + + // header is either nil or a deep clone of res.handlerHeader + // at the time of res.writeHeader, if res.writeHeader is + // called and extra buffering is being done to calculate + // Content-Type and/or Content-Length. + header Header + + // wroteHeader tells whether the header's been written to "the + // wire" (or rather: w.conn.buf). this is unlike + // (*response).wroteHeader, which tells only whether it was + // logically written. + wroteHeader bool + + // set by the writeHeader method: + chunking bool // using chunked transfer encoding for reply body +} + +var ( + crlf = []byte("\r\n") + colonSpace = []byte(": ") +) + +func (cw *chunkWriter) Write(p []byte) (n int, err error) { + if !cw.wroteHeader { + cw.writeHeader(p) + } + if cw.res.req.Method == "HEAD" { + // Eat writes. + return len(p), nil + } + if cw.chunking { + _, err = fmt.Fprintf(cw.res.conn.bufw, "%x\r\n", len(p)) + if err != nil { + cw.res.conn.rwc.Close() + return + } + } + n, err = cw.res.conn.bufw.Write(p) + if cw.chunking && err == nil { + _, err = cw.res.conn.bufw.Write(crlf) + } + if err != nil { + cw.res.conn.rwc.Close() + } + return +} + +func (cw *chunkWriter) flush() { + if !cw.wroteHeader { + cw.writeHeader(nil) + } + cw.res.conn.bufw.Flush() +} + +func (cw *chunkWriter) close() { + if !cw.wroteHeader { + cw.writeHeader(nil) + } + if cw.chunking { + bw := cw.res.conn.bufw // conn's bufio writer + // zero chunk to mark EOF + bw.WriteString("0\r\n") + if trailers := cw.res.finalTrailers(); trailers != nil { + trailers.Write(bw) // the writer handles noting errors + } + // final blank line after the trailers (whether + // present or not) + bw.WriteString("\r\n") + } +} + +// A response represents the server side of an HTTP response. +type response struct { + conn *conn + req *Request // request for this response + reqBody io.ReadCloser + cancelCtx context.CancelFunc // when ServeHTTP exits + wroteHeader bool // a non-1xx header has been (logically) written + wroteContinue bool // 100 Continue response was written + wants10KeepAlive bool // HTTP/1.0 w/ Connection "keep-alive" + wantsClose bool // HTTP request has Connection "close" + + // canWriteContinue is a boolean value accessed as an atomic int32 + // that says whether or not a 100 Continue header can be written + // to the connection. + // writeContinueMu must be held while writing the header. + // These two fields together synchronize the body reader (the + // expectContinueReader, which wants to write 100 Continue) + // against the main writer. + canWriteContinue atomicBool + writeContinueMu sync.Mutex + + w *bufio.Writer // buffers output in chunks to chunkWriter + cw chunkWriter + + // handlerHeader is the Header that Handlers get access to, + // which may be retained and mutated even after WriteHeader. + // handlerHeader is copied into cw.header at WriteHeader + // time, and privately mutated thereafter. + handlerHeader Header + calledHeader bool // handler accessed handlerHeader via Header + + written int64 // number of bytes written in body + contentLength int64 // explicitly-declared Content-Length; or -1 + status int // status code passed to WriteHeader + + // close connection after this reply. set on request and + // updated after response from handler if there's a + // "Connection: keep-alive" response header and a + // Content-Length. + closeAfterReply bool + + // requestBodyLimitHit is set by requestTooLarge when + // maxBytesReader hits its max size. It is checked in + // WriteHeader, to make sure we don't consume the + // remaining request body to try to advance to the next HTTP + // request. Instead, when this is set, we stop reading + // subsequent requests on this connection and stop reading + // input from it. + requestBodyLimitHit bool + + // trailers are the headers to be sent after the handler + // finishes writing the body. This field is initialized from + // the Trailer response header when the response header is + // written. + trailers []string + + handlerDone atomicBool // set true when the handler exits + + // Buffers for Date, Content-Length, and status code + dateBuf [len(TimeFormat)]byte + clenBuf [10]byte + statusBuf [3]byte + + // closeNotifyCh is the channel returned by CloseNotify. + // TODO(bradfitz): this is currently (for Go 1.8) always + // non-nil. Make this lazily-created again as it used to be? + closeNotifyCh chan bool + didCloseNotify int32 // atomic (only 0->1 winner should send) +} + +// TrailerPrefix is a magic prefix for ResponseWriter.Header map keys +// that, if present, signals that the map entry is actually for +// the response trailers, and not the response headers. The prefix +// is stripped after the ServeHTTP call finishes and the values are +// sent in the trailers. +// +// This mechanism is intended only for trailers that are not known +// prior to the headers being written. If the set of trailers is fixed +// or known before the header is written, the normal Go trailers mechanism +// is preferred: +// +// https://pkg.go.dev/net/http#ResponseWriter +// https://pkg.go.dev/net/http#example-ResponseWriter-Trailers +const TrailerPrefix = "Trailer:" + +// finalTrailers is called after the Handler exits and returns a non-nil +// value if the Handler set any trailers. +func (w *response) finalTrailers() Header { + var t Header + for k, vv := range w.handlerHeader { + if strings.HasPrefix(k, TrailerPrefix) { + if t == nil { + t = make(Header) + } + t[strings.TrimPrefix(k, TrailerPrefix)] = vv + } + } + for _, k := range w.trailers { + if t == nil { + t = make(Header) + } + for _, v := range w.handlerHeader[k] { + t.Add(k, v) + } + } + return t +} + +type atomicBool int32 + +func (b *atomicBool) isSet() bool { return atomic.LoadInt32((*int32)(b)) != 0 } +func (b *atomicBool) setTrue() { atomic.StoreInt32((*int32)(b), 1) } +func (b *atomicBool) setFalse() { atomic.StoreInt32((*int32)(b), 0) } + +// declareTrailer is called for each Trailer header when the +// response header is written. It notes that a header will need to be +// written in the trailers at the end of the response. +func (w *response) declareTrailer(k string) { + k = CanonicalHeaderKey(k) + if !httpguts.ValidTrailerHeader(k) { + // Forbidden by RFC 7230, section 4.1.2 + return + } + w.trailers = append(w.trailers, k) +} + +// requestTooLarge is called by maxBytesReader when too much input has +// been read from the client. +func (w *response) requestTooLarge() { + w.closeAfterReply = true + w.requestBodyLimitHit = true + if !w.wroteHeader { + w.Header().Set("Connection", "close") + } +} + +// needsSniff reports whether a Content-Type still needs to be sniffed. +func (w *response) needsSniff() bool { + _, haveType := w.handlerHeader["Content-Type"] + return !w.cw.wroteHeader && !haveType && w.written < sniffLen +} + +// writerOnly hides an io.Writer value's optional ReadFrom method +// from io.Copy. +type writerOnly struct { + io.Writer +} + +// ReadFrom is here to optimize copying from an *os.File regular file +// to a *net.TCPConn with sendfile, or from a supported src type such +// as a *net.TCPConn on Linux with splice. +func (w *response) ReadFrom(src io.Reader) (n int64, err error) { + bufp := copyBufPool.Get().(*[]byte) + buf := *bufp + defer copyBufPool.Put(bufp) + + // Our underlying w.conn.rwc is usually a *TCPConn (with its + // own ReadFrom method). If not, just fall back to the normal + // copy method. + rf, ok := w.conn.rwc.(io.ReaderFrom) + if !ok { + return io.CopyBuffer(writerOnly{w}, src, buf) + } + + // Copy the first sniffLen bytes before switching to ReadFrom. + // This ensures we don't start writing the response before the + // source is available (see golang.org/issue/5660) and provides + // enough bytes to perform Content-Type sniffing when required. + if !w.cw.wroteHeader { + n0, err := io.CopyBuffer(writerOnly{w}, io.LimitReader(src, sniffLen), buf) + n += n0 + if err != nil || n0 < sniffLen { + return n, err + } + } + + w.w.Flush() // get rid of any previous writes + w.cw.flush() // make sure Header is written; flush data to rwc + + // Now that cw has been flushed, its chunking field is guaranteed initialized. + if !w.cw.chunking && w.bodyAllowed() { + n0, err := rf.ReadFrom(src) + n += n0 + w.written += n0 + return n, err + } + + n0, err := io.CopyBuffer(writerOnly{w}, src, buf) + n += n0 + return n, err +} + +// debugServerConnections controls whether all server connections are wrapped +// with a verbose logging wrapper. +const debugServerConnections = false + +// Create new connection from rwc. +func (srv *Server) newConn(rwc net.Conn) *conn { + c := &conn{ + server: srv, + rwc: rwc, + } + if debugServerConnections { + c.rwc = newLoggingConn("server", c.rwc) + } + return c +} + +type readResult struct { + _ incomparable + n int + err error + b byte // byte read, if n == 1 +} + +// connReader is the io.Reader wrapper used by *conn. It combines a +// selectively-activated io.LimitedReader (to bound request header +// read sizes) with support for selectively keeping an io.Reader.Read +// call blocked in a background goroutine to wait for activity and +// trigger a CloseNotifier channel. +type connReader struct { + conn *conn + + mu sync.Mutex // guards following + hasByte bool + byteBuf [1]byte + cond *sync.Cond + inRead bool + aborted bool // set true before conn.rwc deadline is set to past + remain int64 // bytes remaining +} + +func (cr *connReader) lock() { + cr.mu.Lock() + if cr.cond == nil { + cr.cond = sync.NewCond(&cr.mu) + } +} + +func (cr *connReader) unlock() { cr.mu.Unlock() } + +func (cr *connReader) startBackgroundRead() { + cr.lock() + defer cr.unlock() + if cr.inRead { + panic("invalid concurrent Body.Read call") + } + if cr.hasByte { + return + } + cr.inRead = true + cr.conn.rwc.SetReadDeadline(time.Time{}) + go cr.backgroundRead() +} + +func (cr *connReader) backgroundRead() { + n, err := cr.conn.rwc.Read(cr.byteBuf[:]) + cr.lock() + if n == 1 { + cr.hasByte = true + // We were past the end of the previous request's body already + // (since we wouldn't be in a background read otherwise), so + // this is a pipelined HTTP request. Prior to Go 1.11 we used to + // send on the CloseNotify channel and cancel the context here, + // but the behavior was documented as only "may", and we only + // did that because that's how CloseNotify accidentally behaved + // in very early Go releases prior to context support. Once we + // added context support, people used a Handler's + // Request.Context() and passed it along. Having that context + // cancel on pipelined HTTP requests caused problems. + // Fortunately, almost nothing uses HTTP/1.x pipelining. + // Unfortunately, apt-get does, or sometimes does. + // New Go 1.11 behavior: don't fire CloseNotify or cancel + // contexts on pipelined requests. Shouldn't affect people, but + // fixes cases like Issue 23921. This does mean that a client + // closing their TCP connection after sending a pipelined + // request won't cancel the context, but we'll catch that on any + // write failure (in checkConnErrorWriter.Write). + // If the server never writes, yes, there are still contrived + // server & client behaviors where this fails to ever cancel the + // context, but that's kinda why HTTP/1.x pipelining died + // anyway. + } + if ne, ok := err.(net.Error); ok && cr.aborted && ne.Timeout() { + // Ignore this error. It's the expected error from + // another goroutine calling abortPendingRead. + } else if err != nil { + cr.handleReadError(err) + } + cr.aborted = false + cr.inRead = false + cr.unlock() + cr.cond.Broadcast() +} + +func (cr *connReader) abortPendingRead() { + cr.lock() + defer cr.unlock() + if !cr.inRead { + return + } + cr.aborted = true + cr.conn.rwc.SetReadDeadline(aLongTimeAgo) + for cr.inRead { + cr.cond.Wait() + } + cr.conn.rwc.SetReadDeadline(time.Time{}) +} + +func (cr *connReader) setReadLimit(remain int64) { cr.remain = remain } +func (cr *connReader) setInfiniteReadLimit() { cr.remain = maxInt64 } +func (cr *connReader) hitReadLimit() bool { return cr.remain <= 0 } + +// handleReadError is called whenever a Read from the client returns a +// non-nil error. +// +// The provided non-nil err is almost always io.EOF or a "use of +// closed network connection". In any case, the error is not +// particularly interesting, except perhaps for debugging during +// development. Any error means the connection is dead and we should +// down its context. +// +// It may be called from multiple goroutines. +func (cr *connReader) handleReadError(_ error) { + cr.conn.cancelCtx() + cr.closeNotify() +} + +// may be called from multiple goroutines. +func (cr *connReader) closeNotify() { + res, _ := cr.conn.curReq.Load().(*response) + if res != nil && atomic.CompareAndSwapInt32(&res.didCloseNotify, 0, 1) { + res.closeNotifyCh <- true + } +} + +func (cr *connReader) Read(p []byte) (n int, err error) { + cr.lock() + if cr.inRead { + cr.unlock() + if cr.conn.hijacked() { + panic("invalid Body.Read call. After hijacked, the original Request must not be used") + } + panic("invalid concurrent Body.Read call") + } + if cr.hitReadLimit() { + cr.unlock() + return 0, io.EOF + } + if len(p) == 0 { + cr.unlock() + return 0, nil + } + if int64(len(p)) > cr.remain { + p = p[:cr.remain] + } + if cr.hasByte { + p[0] = cr.byteBuf[0] + cr.hasByte = false + cr.unlock() + return 1, nil + } + cr.inRead = true + cr.unlock() + n, err = cr.conn.rwc.Read(p) + + cr.lock() + cr.inRead = false + if err != nil { + cr.handleReadError(err) + } + cr.remain -= int64(n) + cr.unlock() + + cr.cond.Broadcast() + return n, err +} + +var ( + bufioReaderPool sync.Pool + bufioWriter2kPool sync.Pool + bufioWriter4kPool sync.Pool +) + +var copyBufPool = sync.Pool{ + New: func() any { + b := make([]byte, 32*1024) + return &b + }, +} + +func bufioWriterPool(size int) *sync.Pool { + switch size { + case 2 << 10: + return &bufioWriter2kPool + case 4 << 10: + return &bufioWriter4kPool + } + return nil +} + +func newBufioReader(r io.Reader) *bufio.Reader { + if v := bufioReaderPool.Get(); v != nil { + br := v.(*bufio.Reader) + br.Reset(r) + return br + } + // Note: if this reader size is ever changed, update + // TestHandlerBodyClose's assumptions. + return bufio.NewReader(r) +} + +func putBufioReader(br *bufio.Reader) { + br.Reset(nil) + bufioReaderPool.Put(br) +} + +func newBufioWriterSize(w io.Writer, size int) *bufio.Writer { + pool := bufioWriterPool(size) + if pool != nil { + if v := pool.Get(); v != nil { + bw := v.(*bufio.Writer) + bw.Reset(w) + return bw + } + } + return bufio.NewWriterSize(w, size) +} + +func putBufioWriter(bw *bufio.Writer) { + bw.Reset(nil) + if pool := bufioWriterPool(bw.Available()); pool != nil { + pool.Put(bw) + } +} + +// DefaultMaxHeaderBytes is the maximum permitted size of the headers +// in an HTTP request. +// This can be overridden by setting Server.MaxHeaderBytes. +// TINYGO: dropped default from 1 << 20 // 1 MB +const DefaultMaxHeaderBytes = 1 << 12 // 4 KB + +func (srv *Server) maxHeaderBytes() int { + if srv.MaxHeaderBytes > 0 { + return srv.MaxHeaderBytes + } + return DefaultMaxHeaderBytes +} + +func (srv *Server) initialReadLimitSize() int64 { + return int64(srv.maxHeaderBytes()) + 4096 // bufio slop +} + +// tlsHandshakeTimeout returns the time limit permitted for the TLS +// handshake, or zero for unlimited. +// +// It returns the minimum of any positive ReadHeaderTimeout, +// ReadTimeout, or WriteTimeout. +func (srv *Server) tlsHandshakeTimeout() time.Duration { + var ret time.Duration + for _, v := range [...]time.Duration{ + srv.ReadHeaderTimeout, + srv.ReadTimeout, + srv.WriteTimeout, + } { + if v <= 0 { + continue + } + if ret == 0 || v < ret { + ret = v + } + } + return ret +} + +// wrapper around io.ReadCloser which on first read, sends an +// HTTP/1.1 100 Continue header +type expectContinueReader struct { + resp *response + readCloser io.ReadCloser + closed atomicBool + sawEOF atomicBool +} + +func (ecr *expectContinueReader) Read(p []byte) (n int, err error) { + if ecr.closed.isSet() { + return 0, ErrBodyReadAfterClose + } + w := ecr.resp + if !w.wroteContinue && w.canWriteContinue.isSet() && !w.conn.hijacked() { + w.wroteContinue = true + w.writeContinueMu.Lock() + if w.canWriteContinue.isSet() { + w.conn.bufw.WriteString("HTTP/1.1 100 Continue\r\n\r\n") + w.conn.bufw.Flush() + w.canWriteContinue.setFalse() + } + w.writeContinueMu.Unlock() + } + n, err = ecr.readCloser.Read(p) + if err == io.EOF { + ecr.sawEOF.setTrue() + } + return +} + +func (ecr *expectContinueReader) Close() error { + ecr.closed.setTrue() + return ecr.readCloser.Close() +} + +// TimeFormat is the time format to use when generating times in HTTP +// headers. It is like time.RFC1123 but hard-codes GMT as the time +// zone. The time being formatted must be in UTC for Format to +// generate the correct format. +// +// For parsing this time format, see ParseTime. +const TimeFormat = "Mon, 02 Jan 2006 15:04:05 GMT" + +// appendTime is a non-allocating version of []byte(t.UTC().Format(TimeFormat)) +func appendTime(b []byte, t time.Time) []byte { + const days = "SunMonTueWedThuFriSat" + const months = "JanFebMarAprMayJunJulAugSepOctNovDec" + + t = t.UTC() + yy, mm, dd := t.Date() + hh, mn, ss := t.Clock() + day := days[3*t.Weekday():] + mon := months[3*(mm-1):] + + return append(b, + day[0], day[1], day[2], ',', ' ', + byte('0'+dd/10), byte('0'+dd%10), ' ', + mon[0], mon[1], mon[2], ' ', + byte('0'+yy/1000), byte('0'+(yy/100)%10), byte('0'+(yy/10)%10), byte('0'+yy%10), ' ', + byte('0'+hh/10), byte('0'+hh%10), ':', + byte('0'+mn/10), byte('0'+mn%10), ':', + byte('0'+ss/10), byte('0'+ss%10), ' ', + 'G', 'M', 'T') +} + +var errTooLarge = errors.New("http: request too large") + +// Read next request from connection. +func (c *conn) readRequest(ctx context.Context) (w *response, err error) { + if c.hijacked() { + return nil, ErrHijacked + } + + var ( + wholeReqDeadline time.Time // or zero if none + hdrDeadline time.Time // or zero if none + ) + t0 := time.Now() + if d := c.server.readHeaderTimeout(); d > 0 { + hdrDeadline = t0.Add(d) + } + if d := c.server.ReadTimeout; d > 0 { + wholeReqDeadline = t0.Add(d) + } + c.rwc.SetReadDeadline(hdrDeadline) + if d := c.server.WriteTimeout; d > 0 { + defer func() { + c.rwc.SetWriteDeadline(time.Now().Add(d)) + }() + } + + c.r.setReadLimit(c.server.initialReadLimitSize()) + if c.lastMethod == "POST" { + // RFC 7230 section 3 tolerance for old buggy clients. + peek, _ := c.bufr.Peek(4) // ReadRequest will get err below + c.bufr.Discard(numLeadingCRorLF(peek)) + } + req, err := readRequest(c.bufr) + if err != nil { + if c.r.hitReadLimit() { + return nil, errTooLarge + } + return nil, err + } + + if !http1ServerSupportsRequest(req) { + return nil, statusError{StatusHTTPVersionNotSupported, "unsupported protocol version"} + } + + c.lastMethod = req.Method + c.r.setInfiniteReadLimit() + + hosts, haveHost := req.Header["Host"] + isH2Upgrade := req.isH2Upgrade() + if req.ProtoAtLeast(1, 1) && (!haveHost || len(hosts) == 0) && !isH2Upgrade && req.Method != "CONNECT" { + return nil, badRequestError("missing required Host header") + } + if len(hosts) == 1 && !httpguts.ValidHostHeader(hosts[0]) { + return nil, badRequestError("malformed Host header") + } + for k, vv := range req.Header { + if !httpguts.ValidHeaderFieldName(k) { + return nil, badRequestError("invalid header name") + } + for _, v := range vv { + if !httpguts.ValidHeaderFieldValue(v) { + return nil, badRequestError("invalid header value") + } + } + } + delete(req.Header, "Host") + + ctx, cancelCtx := context.WithCancel(ctx) + req.ctx = ctx + req.RemoteAddr = c.remoteAddr + req.TLS = c.tlsState + if body, ok := req.Body.(*body); ok { + body.doEarlyClose = true + } + + // Adjust the read deadline if necessary. + if !hdrDeadline.Equal(wholeReqDeadline) { + c.rwc.SetReadDeadline(wholeReqDeadline) + } + + w = &response{ + conn: c, + cancelCtx: cancelCtx, + req: req, + reqBody: req.Body, + handlerHeader: make(Header), + contentLength: -1, + closeNotifyCh: make(chan bool, 1), + + // We populate these ahead of time so we're not + // reading from req.Header after their Handler starts + // and maybe mutates it (Issue 14940) + wants10KeepAlive: req.wantsHttp10KeepAlive(), + wantsClose: req.wantsClose(), + } + if isH2Upgrade { + w.closeAfterReply = true + } + w.cw.res = w + w.w = newBufioWriterSize(&w.cw, bufferBeforeChunkingSize) + return w, nil +} + +// http1ServerSupportsRequest reports whether Go's HTTP/1.x server +// supports the given request. +func http1ServerSupportsRequest(req *Request) bool { + if req.ProtoMajor == 1 { + return true + } + // Accept "PRI * HTTP/2.0" upgrade requests, so Handlers can + // wire up their own HTTP/2 upgrades. + if req.ProtoMajor == 2 && req.ProtoMinor == 0 && + req.Method == "PRI" && req.RequestURI == "*" { + return true + } + // Reject HTTP/0.x, and all other HTTP/2+ requests (which + // aren't encoded in ASCII anyway). + return false +} + +func (w *response) Header() Header { + if w.cw.header == nil && w.wroteHeader && !w.cw.wroteHeader { + // Accessing the header between logically writing it + // and physically writing it means we need to allocate + // a clone to snapshot the logically written state. + w.cw.header = w.handlerHeader.Clone() + } + w.calledHeader = true + return w.handlerHeader +} + +// maxPostHandlerReadBytes is the max number of Request.Body bytes not +// consumed by a handler that the server will read from the client +// in order to keep a connection alive. If there are more bytes than +// this then the server to be paranoid instead sends a "Connection: +// close" response. +// +// This number is approximately what a typical machine's TCP buffer +// size is anyway. (if we have the bytes on the machine, we might as +// well read them) +const maxPostHandlerReadBytes = 256 << 10 + +func checkWriteHeaderCode(code int) { + // Issue 22880: require valid WriteHeader status codes. + // For now we only enforce that it's three digits. + // In the future we might block things over 599 (600 and above aren't defined + // at https://httpwg.org/specs/rfc7231.html#status.codes). + // But for now any three digits. + // + // We used to send "HTTP/1.1 000 0" on the wire in responses but there's + // no equivalent bogus thing we can realistically send in HTTP/2, + // so we'll consistently panic instead and help people find their bugs + // early. (We can't return an error from WriteHeader even if we wanted to.) + if code < 100 || code > 999 { + panic(fmt.Sprintf("invalid WriteHeader code %v", code)) + } +} + +// relevantCaller searches the call stack for the first function outside of net/http. +// The purpose of this function is to provide more helpful error messages. +func relevantCaller() runtime.Frame { + pc := make([]uintptr, 16) + n := runtime.Callers(1, pc) + frames := runtime.CallersFrames(pc[:n]) + var frame runtime.Frame + for { + frame, more := frames.Next() + if !strings.HasPrefix(frame.Function, "net/http.") { + return frame + } + if !more { + break + } + } + return frame +} + +func (w *response) WriteHeader(code int) { + if w.conn.hijacked() { + caller := relevantCaller() + w.conn.server.logf("http: response.WriteHeader on hijacked connection from %s (%s:%d)", caller.Function, path.Base(caller.File), caller.Line) + return + } + if w.wroteHeader { + caller := relevantCaller() + w.conn.server.logf("http: superfluous response.WriteHeader call from %s (%s:%d)", caller.Function, path.Base(caller.File), caller.Line) + return + } + checkWriteHeaderCode(code) + + // Handle informational headers + if code >= 100 && code <= 199 { + // Prevent a potential race with an automatically-sent 100 Continue triggered by Request.Body.Read() + if code == 100 && w.canWriteContinue.isSet() { + w.writeContinueMu.Lock() + w.canWriteContinue.setFalse() + w.writeContinueMu.Unlock() + } + + writeStatusLine(w.conn.bufw, w.req.ProtoAtLeast(1, 1), code, w.statusBuf[:]) + + // Per RFC 8297 we must not clear the current header map + w.handlerHeader.WriteSubset(w.conn.bufw, excludedHeadersNoBody) + w.conn.bufw.Write(crlf) + w.conn.bufw.Flush() + + return + } + + w.wroteHeader = true + w.status = code + + if w.calledHeader && w.cw.header == nil { + w.cw.header = w.handlerHeader.Clone() + } + + if cl := w.handlerHeader.get("Content-Length"); cl != "" { + v, err := strconv.ParseInt(cl, 10, 64) + if err == nil && v >= 0 { + w.contentLength = v + } else { + w.conn.server.logf("http: invalid Content-Length of %q", cl) + w.handlerHeader.Del("Content-Length") + } + } +} + +// extraHeader is the set of headers sometimes added by chunkWriter.writeHeader. +// This type is used to avoid extra allocations from cloning and/or populating +// the response Header map and all its 1-element slices. +type extraHeader struct { + contentType string + connection string + transferEncoding string + date []byte // written if not nil + contentLength []byte // written if not nil +} + +// Sorted the same as extraHeader.Write's loop. +var extraHeaderKeys = [][]byte{ + []byte("Content-Type"), + []byte("Connection"), + []byte("Transfer-Encoding"), +} + +var ( + headerContentLength = []byte("Content-Length: ") + headerDate = []byte("Date: ") +) + +// Write writes the headers described in h to w. +// +// This method has a value receiver, despite the somewhat large size +// of h, because it prevents an allocation. The escape analysis isn't +// smart enough to realize this function doesn't mutate h. +func (h extraHeader) Write(w *bufio.Writer) { + if h.date != nil { + w.Write(headerDate) + w.Write(h.date) + w.Write(crlf) + } + if h.contentLength != nil { + w.Write(headerContentLength) + w.Write(h.contentLength) + w.Write(crlf) + } + for i, v := range []string{h.contentType, h.connection, h.transferEncoding} { + if v != "" { + w.Write(extraHeaderKeys[i]) + w.Write(colonSpace) + w.WriteString(v) + w.Write(crlf) + } + } +} + +// writeHeader finalizes the header sent to the client and writes it +// to cw.res.conn.bufw. +// +// p is not written by writeHeader, but is the first chunk of the body +// that will be written. It is sniffed for a Content-Type if none is +// set explicitly. It's also used to set the Content-Length, if the +// total body size was small and the handler has already finished +// running. +func (cw *chunkWriter) writeHeader(p []byte) { + if cw.wroteHeader { + return + } + cw.wroteHeader = true + + w := cw.res + keepAlivesEnabled := w.conn.server.doKeepAlives() + isHEAD := w.req.Method == "HEAD" + + // header is written out to w.conn.buf below. Depending on the + // state of the handler, we either own the map or not. If we + // don't own it, the exclude map is created lazily for + // WriteSubset to remove headers. The setHeader struct holds + // headers we need to add. + header := cw.header + owned := header != nil + if !owned { + header = w.handlerHeader + } + var excludeHeader map[string]bool + delHeader := func(key string) { + if owned { + header.Del(key) + return + } + if _, ok := header[key]; !ok { + return + } + if excludeHeader == nil { + excludeHeader = make(map[string]bool) + } + excludeHeader[key] = true + } + var setHeader extraHeader + + // Don't write out the fake "Trailer:foo" keys. See TrailerPrefix. + trailers := false + for k := range cw.header { + if strings.HasPrefix(k, TrailerPrefix) { + if excludeHeader == nil { + excludeHeader = make(map[string]bool) + } + excludeHeader[k] = true + trailers = true + } + } + for _, v := range cw.header["Trailer"] { + trailers = true + foreachHeaderElement(v, cw.res.declareTrailer) + } + + te := header.get("Transfer-Encoding") + hasTE := te != "" + + // If the handler is done but never sent a Content-Length + // response header and this is our first (and last) write, set + // it, even to zero. This helps HTTP/1.0 clients keep their + // "keep-alive" connections alive. + // Exceptions: 304/204/1xx responses never get Content-Length, and if + // it was a HEAD request, we don't know the difference between + // 0 actual bytes and 0 bytes because the handler noticed it + // was a HEAD request and chose not to write anything. So for + // HEAD, the handler should either write the Content-Length or + // write non-zero bytes. If it's actually 0 bytes and the + // handler never looked at the Request.Method, we just don't + // send a Content-Length header. + // Further, we don't send an automatic Content-Length if they + // set a Transfer-Encoding, because they're generally incompatible. + if w.handlerDone.isSet() && !trailers && !hasTE && bodyAllowedForStatus(w.status) && header.get("Content-Length") == "" && (!isHEAD || len(p) > 0) { + w.contentLength = int64(len(p)) + setHeader.contentLength = strconv.AppendInt(cw.res.clenBuf[:0], int64(len(p)), 10) + } + + // If this was an HTTP/1.0 request with keep-alive and we sent a + // Content-Length back, we can make this a keep-alive response ... + if w.wants10KeepAlive && keepAlivesEnabled { + sentLength := header.get("Content-Length") != "" + if sentLength && header.get("Connection") == "keep-alive" { + w.closeAfterReply = false + } + } + + // Check for an explicit (and valid) Content-Length header. + hasCL := w.contentLength != -1 + + if w.wants10KeepAlive && (isHEAD || hasCL || !bodyAllowedForStatus(w.status)) { + _, connectionHeaderSet := header["Connection"] + if !connectionHeaderSet { + setHeader.connection = "keep-alive" + } + } else if !w.req.ProtoAtLeast(1, 1) || w.wantsClose { + w.closeAfterReply = true + } + + if header.get("Connection") == "close" || !keepAlivesEnabled { + w.closeAfterReply = true + } + + // If the client wanted a 100-continue but we never sent it to + // them (or, more strictly: we never finished reading their + // request body), don't reuse this connection because it's now + // in an unknown state: we might be sending this response at + // the same time the client is now sending its request body + // after a timeout. (Some HTTP clients send Expect: + // 100-continue but knowing that some servers don't support + // it, the clients set a timer and send the body later anyway) + // If we haven't seen EOF, we can't skip over the unread body + // because we don't know if the next bytes on the wire will be + // the body-following-the-timer or the subsequent request. + // See Issue 11549. + if ecr, ok := w.req.Body.(*expectContinueReader); ok && !ecr.sawEOF.isSet() { + w.closeAfterReply = true + } + + // Per RFC 2616, we should consume the request body before + // replying, if the handler hasn't already done so. But we + // don't want to do an unbounded amount of reading here for + // DoS reasons, so we only try up to a threshold. + // TODO(bradfitz): where does RFC 2616 say that? See Issue 15527 + // about HTTP/1.x Handlers concurrently reading and writing, like + // HTTP/2 handlers can do. Maybe this code should be relaxed? + if w.req.ContentLength != 0 && !w.closeAfterReply { + var discard, tooBig bool + + switch bdy := w.req.Body.(type) { + case *expectContinueReader: + if bdy.resp.wroteContinue { + discard = true + } + case *body: + bdy.mu.Lock() + switch { + case bdy.closed: + if !bdy.sawEOF { + // Body was closed in handler with non-EOF error. + w.closeAfterReply = true + } + case bdy.unreadDataSizeLocked() >= maxPostHandlerReadBytes: + tooBig = true + default: + discard = true + } + bdy.mu.Unlock() + default: + discard = true + } + + if discard { + _, err := io.CopyN(io.Discard, w.reqBody, maxPostHandlerReadBytes+1) + switch err { + case nil: + // There must be even more data left over. + tooBig = true + case ErrBodyReadAfterClose: + // Body was already consumed and closed. + case io.EOF: + // The remaining body was just consumed, close it. + err = w.reqBody.Close() + if err != nil { + w.closeAfterReply = true + } + default: + // Some other kind of error occurred, like a read timeout, or + // corrupt chunked encoding. In any case, whatever remains + // on the wire must not be parsed as another HTTP request. + w.closeAfterReply = true + } + } + + if tooBig { + w.requestTooLarge() + delHeader("Connection") + setHeader.connection = "close" + } + } + + code := w.status + if bodyAllowedForStatus(code) { + // If no content type, apply sniffing algorithm to body. + _, haveType := header["Content-Type"] + + // If the Content-Encoding was set and is non-blank, + // we shouldn't sniff the body. See Issue 31753. + ce := header.Get("Content-Encoding") + hasCE := len(ce) > 0 + if !hasCE && !haveType && !hasTE && len(p) > 0 { + setHeader.contentType = DetectContentType(p) + } + } else { + for _, k := range suppressedHeaders(code) { + delHeader(k) + } + } + + if !header.has("Date") { + setHeader.date = appendTime(cw.res.dateBuf[:0], time.Now()) + } + + if hasCL && hasTE && te != "identity" { + // TODO: return an error if WriteHeader gets a return parameter + // For now just ignore the Content-Length. + w.conn.server.logf("http: WriteHeader called with both Transfer-Encoding of %q and a Content-Length of %d", + te, w.contentLength) + delHeader("Content-Length") + hasCL = false + } + + if w.req.Method == "HEAD" || !bodyAllowedForStatus(code) || code == StatusNoContent { + // Response has no body. + delHeader("Transfer-Encoding") + } else if hasCL { + // Content-Length has been provided, so no chunking is to be done. + delHeader("Transfer-Encoding") + } else if w.req.ProtoAtLeast(1, 1) { + // HTTP/1.1 or greater: Transfer-Encoding has been set to identity, and no + // content-length has been provided. The connection must be closed after the + // reply is written, and no chunking is to be done. This is the setup + // recommended in the Server-Sent Events candidate recommendation 11, + // section 8. + if hasTE && te == "identity" { + cw.chunking = false + w.closeAfterReply = true + delHeader("Transfer-Encoding") + } else { + // HTTP/1.1 or greater: use chunked transfer encoding + // to avoid closing the connection at EOF. + cw.chunking = true + setHeader.transferEncoding = "chunked" + if hasTE && te == "chunked" { + // We will send the chunked Transfer-Encoding header later. + delHeader("Transfer-Encoding") + } + } + } else { + // HTTP version < 1.1: cannot do chunked transfer + // encoding and we don't know the Content-Length so + // signal EOF by closing connection. + w.closeAfterReply = true + delHeader("Transfer-Encoding") // in case already set + } + + // Cannot use Content-Length with non-identity Transfer-Encoding. + if cw.chunking { + delHeader("Content-Length") + } + if !w.req.ProtoAtLeast(1, 0) { + return + } + + // Only override the Connection header if it is not a successful + // protocol switch response and if KeepAlives are not enabled. + // See https://golang.org/issue/36381. + delConnectionHeader := w.closeAfterReply && + (!keepAlivesEnabled || !hasToken(cw.header.get("Connection"), "close")) && + !isProtocolSwitchResponse(w.status, header) + if delConnectionHeader { + delHeader("Connection") + if w.req.ProtoAtLeast(1, 1) { + setHeader.connection = "close" + } + } + + writeStatusLine(w.conn.bufw, w.req.ProtoAtLeast(1, 1), code, w.statusBuf[:]) + cw.header.WriteSubset(w.conn.bufw, excludeHeader) + setHeader.Write(w.conn.bufw) + w.conn.bufw.Write(crlf) +} + +// foreachHeaderElement splits v according to the "#rule" construction +// in RFC 7230 section 7 and calls fn for each non-empty element. +func foreachHeaderElement(v string, fn func(string)) { + v = textproto.TrimString(v) + if v == "" { + return + } + if !strings.Contains(v, ",") { + fn(v) + return + } + for _, f := range strings.Split(v, ",") { + if f = textproto.TrimString(f); f != "" { + fn(f) + } + } +} + +// writeStatusLine writes an HTTP/1.x Status-Line (RFC 7230 Section 3.1.2) +// to bw. is11 is whether the HTTP request is HTTP/1.1. false means HTTP/1.0. +// code is the response status code. +// scratch is an optional scratch buffer. If it has at least capacity 3, it's used. +func writeStatusLine(bw *bufio.Writer, is11 bool, code int, scratch []byte) { + if is11 { + bw.WriteString("HTTP/1.1 ") + } else { + bw.WriteString("HTTP/1.0 ") + } + if text := StatusText(code); text != "" { + bw.Write(strconv.AppendInt(scratch[:0], int64(code), 10)) + bw.WriteByte(' ') + bw.WriteString(text) + bw.WriteString("\r\n") + } else { + // don't worry about performance + fmt.Fprintf(bw, "%03d status code %d\r\n", code, code) + } +} + +// bodyAllowed reports whether a Write is allowed for this response type. +// It's illegal to call this before the header has been flushed. +func (w *response) bodyAllowed() bool { + if !w.wroteHeader { + panic("") + } + return bodyAllowedForStatus(w.status) +} + +// The Life Of A Write is like this: +// +// Handler starts. No header has been sent. The handler can either +// write a header, or just start writing. Writing before sending a header +// sends an implicitly empty 200 OK header. +// +// If the handler didn't declare a Content-Length up front, we either +// go into chunking mode or, if the handler finishes running before +// the chunking buffer size, we compute a Content-Length and send that +// in the header instead. +// +// Likewise, if the handler didn't set a Content-Type, we sniff that +// from the initial chunk of output. +// +// The Writers are wired together like: +// +// 1. *response (the ResponseWriter) -> +// 2. (*response).w, a *bufio.Writer of bufferBeforeChunkingSize bytes -> +// 3. chunkWriter.Writer (whose writeHeader finalizes Content-Length/Type) +// and which writes the chunk headers, if needed -> +// 4. conn.bufw, a *bufio.Writer of default (4kB) bytes, writing to -> +// 5. checkConnErrorWriter{c}, which notes any non-nil error on Write +// and populates c.werr with it if so, but otherwise writes to -> +// 6. the rwc, the net.Conn. +// +// TODO(bradfitz): short-circuit some of the buffering when the +// initial header contains both a Content-Type and Content-Length. +// Also short-circuit in (1) when the header's been sent and not in +// chunking mode, writing directly to (4) instead, if (2) has no +// buffered data. More generally, we could short-circuit from (1) to +// (3) even in chunking mode if the write size from (1) is over some +// threshold and nothing is in (2). The answer might be mostly making +// bufferBeforeChunkingSize smaller and having bufio's fast-paths deal +// with this instead. +func (w *response) Write(data []byte) (n int, err error) { + return w.write(len(data), data, "") +} + +func (w *response) WriteString(data string) (n int, err error) { + return w.write(len(data), nil, data) +} + +// either dataB or dataS is non-zero. +func (w *response) write(lenData int, dataB []byte, dataS string) (n int, err error) { + if w.conn.hijacked() { + if lenData > 0 { + caller := relevantCaller() + w.conn.server.logf("http: response.Write on hijacked connection from %s (%s:%d)", caller.Function, path.Base(caller.File), caller.Line) + } + return 0, ErrHijacked + } + + if w.canWriteContinue.isSet() { + // Body reader wants to write 100 Continue but hasn't yet. + // Tell it not to. The store must be done while holding the lock + // because the lock makes sure that there is not an active write + // this very moment. + w.writeContinueMu.Lock() + w.canWriteContinue.setFalse() + w.writeContinueMu.Unlock() + } + + if !w.wroteHeader { + w.WriteHeader(StatusOK) + } + if lenData == 0 { + return 0, nil + } + if !w.bodyAllowed() { + return 0, ErrBodyNotAllowed + } + + w.written += int64(lenData) // ignoring errors, for errorKludge + if w.contentLength != -1 && w.written > w.contentLength { + return 0, ErrContentLength + } + if dataB != nil { + return w.w.Write(dataB) + } else { + return w.w.WriteString(dataS) + } +} + +func (w *response) finishRequest() { + w.handlerDone.setTrue() + + if !w.wroteHeader { + w.WriteHeader(StatusOK) + } + + w.w.Flush() + putBufioWriter(w.w) + w.cw.close() + w.conn.bufw.Flush() + + w.conn.r.abortPendingRead() + + // Close the body (regardless of w.closeAfterReply) so we can + // re-use its bufio.Reader later safely. + w.reqBody.Close() + + if w.req.MultipartForm != nil { + w.req.MultipartForm.RemoveAll() + } +} + +// shouldReuseConnection reports whether the underlying TCP connection can be reused. +// It must only be called after the handler is done executing. +func (w *response) shouldReuseConnection() bool { + if w.closeAfterReply { + // The request or something set while executing the + // handler indicated we shouldn't reuse this + // connection. + return false + } + + if w.req.Method != "HEAD" && w.contentLength != -1 && w.bodyAllowed() && w.contentLength != w.written { + // Did not write enough. Avoid getting out of sync. + return false + } + + // There was some error writing to the underlying connection + // during the request, so don't re-use this conn. + if w.conn.werr != nil { + return false + } + + if w.closedRequestBodyEarly() { + return false + } + + return true +} + +func (w *response) closedRequestBodyEarly() bool { + body, ok := w.req.Body.(*body) + return ok && body.didEarlyClose() +} + +func (w *response) Flush() { + if !w.wroteHeader { + w.WriteHeader(StatusOK) + } + w.w.Flush() + w.cw.flush() +} + +func (c *conn) finalFlush() { + if c.bufr != nil { + // Steal the bufio.Reader (~4KB worth of memory) and its associated + // reader for a future connection. + putBufioReader(c.bufr) + c.bufr = nil + } + + if c.bufw != nil { + c.bufw.Flush() + // Steal the bufio.Writer (~4KB worth of memory) and its associated + // writer for a future connection. + putBufioWriter(c.bufw) + c.bufw = nil + } +} + +// Close the connection. +func (c *conn) close() { + c.finalFlush() + c.rwc.Close() +} + +// rstAvoidanceDelay is the amount of time we sleep after closing the +// write side of a TCP connection before closing the entire socket. +// By sleeping, we increase the chances that the client sees our FIN +// and processes its final data before they process the subsequent RST +// from closing a connection with known unread data. +// This RST seems to occur mostly on BSD systems. (And Windows?) +// This timeout is somewhat arbitrary (~latency around the planet). +const rstAvoidanceDelay = 500 * time.Millisecond + +type closeWriter interface { + CloseWrite() error +} + +var _ closeWriter = (*net.TCPConn)(nil) + +// closeWrite flushes any outstanding data and sends a FIN packet (if +// client is connected via TCP), signaling that we're done. We then +// pause for a bit, hoping the client processes it before any +// subsequent RST. +// +// See https://golang.org/issue/3595 +func (c *conn) closeWriteAndWait() { + c.finalFlush() + if tcp, ok := c.rwc.(closeWriter); ok { + tcp.CloseWrite() + } + time.Sleep(rstAvoidanceDelay) +} + +// validNextProto reports whether the proto is a valid ALPN protocol name. +// Everything is valid except the empty string and built-in protocol types, +// so that those can't be overridden with alternate implementations. +func validNextProto(proto string) bool { + switch proto { + case "", "http/1.1", "http/1.0": + return false + } + return true +} + +const ( + runHooks = true + skipHooks = false +) + +func (c *conn) setState(nc net.Conn, state ConnState, runHook bool) { + srv := c.server + switch state { + case StateNew: + srv.trackConn(c, true) + case StateHijacked, StateClosed: + srv.trackConn(c, false) + } + if state > 0xff || state < 0 { + panic("internal error") + } + packedState := uint64(time.Now().Unix()<<8) | uint64(state) + atomic.StoreUint64(&c.curState.atomic, packedState) + if !runHook { + return + } + if hook := srv.ConnState; hook != nil { + hook(nc, state) + } +} + +func (c *conn) getState() (state ConnState, unixSec int64) { + packedState := atomic.LoadUint64(&c.curState.atomic) + return ConnState(packedState & 0xff), int64(packedState >> 8) +} + +// badRequestError is a literal string (used by in the server in HTML, +// unescaped) to tell the user why their request was bad. It should +// be plain text without user info or other embedded errors. +func badRequestError(e string) error { return statusError{StatusBadRequest, e} } + +// statusError is an error used to respond to a request with an HTTP status. +// The text should be plain text without user info or other embedded errors. +type statusError struct { + code int + text string +} + +func (e statusError) Error() string { return StatusText(e.code) + ": " + e.text } + +// ErrAbortHandler is a sentinel panic value to abort a handler. +// While any panic from ServeHTTP aborts the response to the client, +// panicking with ErrAbortHandler also suppresses logging of a stack +// trace to the server's error log. +var ErrAbortHandler = errors.New("net/http: abort Handler") + +// isCommonNetReadError reports whether err is a common error +// encountered during reading a request off the network when the +// client has gone away or had its read fail somehow. This is used to +// determine which logs are interesting enough to log about. +func isCommonNetReadError(err error) bool { + if err == io.EOF { + return true + } + if neterr, ok := err.(net.Error); ok && neterr.Timeout() { + return true + } + if oe, ok := err.(*net.OpError); ok && oe.Op == "read" { + return true + } + return false +} + +// Serve a new connection. +func (c *conn) serve(ctx context.Context) { + c.remoteAddr = c.rwc.RemoteAddr().String() + ctx = context.WithValue(ctx, LocalAddrContextKey, c.rwc.LocalAddr()) + var inFlightResponse *response + defer func() { + if err := recover(); err != nil && err != ErrAbortHandler { + const size = 64 << 10 + buf := make([]byte, size) + buf = buf[:runtime.Stack(buf, false)] + c.server.logf("http: panic serving %v: %v\n%s", c.remoteAddr, err, buf) + } + if inFlightResponse != nil { + inFlightResponse.cancelCtx() + } + if !c.hijacked() { + if inFlightResponse != nil { + inFlightResponse.conn.r.abortPendingRead() + inFlightResponse.reqBody.Close() + } + c.close() + c.setState(c.rwc, StateClosed, runHooks) + } + }() + + // TINYGO: Removed TLS conn check + + // HTTP/1.x from here on. + + ctx, cancelCtx := context.WithCancel(ctx) + c.cancelCtx = cancelCtx + defer cancelCtx() + + c.r = &connReader{conn: c} + c.bufr = newBufioReader(c.r) + c.bufw = newBufioWriterSize(checkConnErrorWriter{c}, 4<<10) + + for { + w, err := c.readRequest(ctx) + if c.r.remain != c.server.initialReadLimitSize() { + // If we read any bytes off the wire, we're active. + c.setState(c.rwc, StateActive, runHooks) + } + if err != nil { + const errorHeaders = "\r\nContent-Type: text/plain; charset=utf-8\r\nConnection: close\r\n\r\n" + + switch { + case err == errTooLarge: + // Their HTTP client may or may not be + // able to read this if we're + // responding to them and hanging up + // while they're still writing their + // request. Undefined behavior. + const publicErr = "431 Request Header Fields Too Large" + fmt.Fprintf(c.rwc, "HTTP/1.1 "+publicErr+errorHeaders+publicErr) + c.closeWriteAndWait() + return + + case isUnsupportedTEError(err): + // Respond as per RFC 7230 Section 3.3.1 which says, + // A server that receives a request message with a + // transfer coding it does not understand SHOULD + // respond with 501 (Unimplemented). + code := StatusNotImplemented + + // We purposefully aren't echoing back the transfer-encoding's value, + // so as to mitigate the risk of cross side scripting by an attacker. + fmt.Fprintf(c.rwc, "HTTP/1.1 %d %s%sUnsupported transfer encoding", code, StatusText(code), errorHeaders) + return + + case isCommonNetReadError(err): + return // don't reply + + default: + if v, ok := err.(statusError); ok { + fmt.Fprintf(c.rwc, "HTTP/1.1 %d %s: %s%s%d %s: %s", v.code, StatusText(v.code), v.text, errorHeaders, v.code, StatusText(v.code), v.text) + return + } + publicErr := "400 Bad Request" + fmt.Fprintf(c.rwc, "HTTP/1.1 "+publicErr+errorHeaders+publicErr) + return + } + } + + // Expect 100 Continue support + req := w.req + if req.expectsContinue() { + if req.ProtoAtLeast(1, 1) && req.ContentLength != 0 { + // Wrap the Body reader with one that replies on the connection + req.Body = &expectContinueReader{readCloser: req.Body, resp: w} + w.canWriteContinue.setTrue() + } + } else if req.Header.get("Expect") != "" { + w.sendExpectationFailed() + return + } + + c.curReq.Store(w) + + if requestBodyRemains(req.Body) { + registerOnHitEOF(req.Body, w.conn.r.startBackgroundRead) + } else { + w.conn.r.startBackgroundRead() + } + + // HTTP cannot have multiple simultaneous active requests.[*] + // Until the server replies to this request, it can't read another, + // so we might as well run the handler in this goroutine. + // [*] Not strictly true: HTTP pipelining. We could let them all process + // in parallel even if their responses need to be serialized. + // But we're not going to implement HTTP pipelining because it + // was never deployed in the wild and the answer is HTTP/2. + inFlightResponse = w + serverHandler{c.server}.ServeHTTP(w, w.req) + inFlightResponse = nil + w.cancelCtx() + if c.hijacked() { + return + } + w.finishRequest() + if !w.shouldReuseConnection() { + if w.requestBodyLimitHit || w.closedRequestBodyEarly() { + c.closeWriteAndWait() + } + return + } + c.setState(c.rwc, StateIdle, runHooks) + c.curReq.Store((*response)(nil)) + + if !w.conn.server.doKeepAlives() { + // We're in shutdown mode. We might've replied + // to the user without "Connection: close" and + // they might think they can send another + // request, but such is life with HTTP/1.1. + return + } + + if d := c.server.idleTimeout(); d != 0 { + c.rwc.SetReadDeadline(time.Now().Add(d)) + if _, err := c.bufr.Peek(4); err != nil { + return + } + } + c.rwc.SetReadDeadline(time.Time{}) + } +} + +func (w *response) sendExpectationFailed() { + // TODO(bradfitz): let ServeHTTP handlers handle + // requests with non-standard expectation[s]? Seems + // theoretical at best, and doesn't fit into the + // current ServeHTTP model anyway. We'd need to + // make the ResponseWriter an optional + // "ExpectReplier" interface or something. + // + // For now we'll just obey RFC 7231 5.1.1 which says + // "A server that receives an Expect field-value other + // than 100-continue MAY respond with a 417 (Expectation + // Failed) status code to indicate that the unexpected + // expectation cannot be met." + w.Header().Set("Connection", "close") + w.WriteHeader(StatusExpectationFailed) + w.finishRequest() +} + +// Hijack implements the Hijacker.Hijack method. Our response is both a ResponseWriter +// and a Hijacker. +func (w *response) Hijack() (rwc net.Conn, buf *bufio.ReadWriter, err error) { + if w.handlerDone.isSet() { + panic("net/http: Hijack called after ServeHTTP finished") + } + if w.wroteHeader { + w.cw.flush() + } + + c := w.conn + c.mu.Lock() + defer c.mu.Unlock() + + // Release the bufioWriter that writes to the chunk writer, it is not + // used after a connection has been hijacked. + rwc, buf, err = c.hijackLocked() + if err == nil { + putBufioWriter(w.w) + w.w = nil + } + return rwc, buf, err +} + +func (w *response) CloseNotify() <-chan bool { + if w.handlerDone.isSet() { + panic("net/http: CloseNotify called after ServeHTTP finished") + } + return w.closeNotifyCh +} + +func registerOnHitEOF(rc io.ReadCloser, fn func()) { + switch v := rc.(type) { + case *expectContinueReader: + registerOnHitEOF(v.readCloser, fn) + case *body: + v.registerOnHitEOF(fn) + default: + panic("unexpected type " + fmt.Sprintf("%T", rc)) + } +} + +// requestBodyRemains reports whether future calls to Read +// on rc might yield more data. +func requestBodyRemains(rc io.ReadCloser) bool { + if rc == NoBody { + return false + } + switch v := rc.(type) { + case *expectContinueReader: + return requestBodyRemains(v.readCloser) + case *body: + return v.bodyRemains() + default: + panic("unexpected type " + fmt.Sprintf("%T", rc)) + } +} + +// The HandlerFunc type is an adapter to allow the use of +// ordinary functions as HTTP handlers. If f is a function +// with the appropriate signature, HandlerFunc(f) is a +// Handler that calls f. +type HandlerFunc func(ResponseWriter, *Request) + +// ServeHTTP calls f(w, r). +func (f HandlerFunc) ServeHTTP(w ResponseWriter, r *Request) { + f(w, r) +} + +// Helper handlers + +// Error replies to the request with the specified error message and HTTP code. +// It does not otherwise end the request; the caller should ensure no further +// writes are done to w. +// The error message should be plain text. +func Error(w ResponseWriter, error string, code int) { + w.Header().Set("Content-Type", "text/plain; charset=utf-8") + w.Header().Set("X-Content-Type-Options", "nosniff") + w.WriteHeader(code) + fmt.Fprintln(w, error) +} + +// NotFound replies to the request with an HTTP 404 not found error. +func NotFound(w ResponseWriter, r *Request) { Error(w, "404 page not found", StatusNotFound) } + +// NotFoundHandler returns a simple request handler +// that replies to each request with a “404 page not found” reply. +func NotFoundHandler() Handler { return HandlerFunc(NotFound) } + +// StripPrefix returns a handler that serves HTTP requests by removing the +// given prefix from the request URL's Path (and RawPath if set) and invoking +// the handler h. StripPrefix handles a request for a path that doesn't begin +// with prefix by replying with an HTTP 404 not found error. The prefix must +// match exactly: if the prefix in the request contains escaped characters +// the reply is also an HTTP 404 not found error. +func StripPrefix(prefix string, h Handler) Handler { + if prefix == "" { + return h + } + return HandlerFunc(func(w ResponseWriter, r *Request) { + p := strings.TrimPrefix(r.URL.Path, prefix) + rp := strings.TrimPrefix(r.URL.RawPath, prefix) + if len(p) < len(r.URL.Path) && (r.URL.RawPath == "" || len(rp) < len(r.URL.RawPath)) { + r2 := new(Request) + *r2 = *r + r2.URL = new(url.URL) + *r2.URL = *r.URL + r2.URL.Path = p + r2.URL.RawPath = rp + h.ServeHTTP(w, r2) + } else { + NotFound(w, r) + } + }) +} + +// Redirect replies to the request with a redirect to url, +// which may be a path relative to the request path. +// +// The provided code should be in the 3xx range and is usually +// StatusMovedPermanently, StatusFound or StatusSeeOther. +// +// If the Content-Type header has not been set, Redirect sets it +// to "text/html; charset=utf-8" and writes a small HTML body. +// Setting the Content-Type header to any value, including nil, +// disables that behavior. +func Redirect(w ResponseWriter, r *Request, url string, code int) { + if u, err := urlpkg.Parse(url); err == nil { + // If url was relative, make its path absolute by + // combining with request path. + // The client would probably do this for us, + // but doing it ourselves is more reliable. + // See RFC 7231, section 7.1.2 + if u.Scheme == "" && u.Host == "" { + oldpath := r.URL.Path + if oldpath == "" { // should not happen, but avoid a crash if it does + oldpath = "/" + } + + // no leading http://server + if url == "" || url[0] != '/' { + // make relative path absolute + olddir, _ := path.Split(oldpath) + url = olddir + url + } + + var query string + if i := strings.Index(url, "?"); i != -1 { + url, query = url[:i], url[i:] + } + + // clean up but preserve trailing slash + trailing := strings.HasSuffix(url, "/") + url = path.Clean(url) + if trailing && !strings.HasSuffix(url, "/") { + url += "/" + } + url += query + } + } + + h := w.Header() + + // RFC 7231 notes that a short HTML body is usually included in + // the response because older user agents may not understand 301/307. + // Do it only if the request didn't already have a Content-Type header. + _, hadCT := h["Content-Type"] + + h.Set("Location", hexEscapeNonASCII(url)) + if !hadCT && (r.Method == "GET" || r.Method == "HEAD") { + h.Set("Content-Type", "text/html; charset=utf-8") + } + w.WriteHeader(code) + + // Shouldn't send the body for POST or HEAD; that leaves GET. + if !hadCT && r.Method == "GET" { + body := "" + StatusText(code) + ".\n" + fmt.Fprintln(w, body) + } +} + +var htmlReplacer = strings.NewReplacer( + "&", "&", + "<", "<", + ">", ">", + // """ is shorter than """. + `"`, """, + // "'" is shorter than "'" and apos was not in HTML until HTML5. + "'", "'", +) + +func htmlEscape(s string) string { + return htmlReplacer.Replace(s) +} + +// Redirect to a fixed URL +type redirectHandler struct { + url string + code int +} + +func (rh *redirectHandler) ServeHTTP(w ResponseWriter, r *Request) { + Redirect(w, r, rh.url, rh.code) +} + +// RedirectHandler returns a request handler that redirects +// each request it receives to the given url using the given +// status code. +// +// The provided code should be in the 3xx range and is usually +// StatusMovedPermanently, StatusFound or StatusSeeOther. +func RedirectHandler(url string, code int) Handler { + return &redirectHandler{url, code} +} + +// ServeMux is an HTTP request multiplexer. +// It matches the URL of each incoming request against a list of registered +// patterns and calls the handler for the pattern that +// most closely matches the URL. +// +// Patterns name fixed, rooted paths, like "/favicon.ico", +// or rooted subtrees, like "/images/" (note the trailing slash). +// Longer patterns take precedence over shorter ones, so that +// if there are handlers registered for both "/images/" +// and "/images/thumbnails/", the latter handler will be +// called for paths beginning "/images/thumbnails/" and the +// former will receive requests for any other paths in the +// "/images/" subtree. +// +// Note that since a pattern ending in a slash names a rooted subtree, +// the pattern "/" matches all paths not matched by other registered +// patterns, not just the URL with Path == "/". +// +// If a subtree has been registered and a request is received naming the +// subtree root without its trailing slash, ServeMux redirects that +// request to the subtree root (adding the trailing slash). This behavior can +// be overridden with a separate registration for the path without +// the trailing slash. For example, registering "/images/" causes ServeMux +// to redirect a request for "/images" to "/images/", unless "/images" has +// been registered separately. +// +// Patterns may optionally begin with a host name, restricting matches to +// URLs on that host only. Host-specific patterns take precedence over +// general patterns, so that a handler might register for the two patterns +// "/codesearch" and "codesearch.google.com/" without also taking over +// requests for "http://www.google.com/". +// +// ServeMux also takes care of sanitizing the URL request path and the Host +// header, stripping the port number and redirecting any request containing . or +// .. elements or repeated slashes to an equivalent, cleaner URL. +type ServeMux struct { + mu sync.RWMutex + m map[string]muxEntry + es []muxEntry // slice of entries sorted from longest to shortest. + hosts bool // whether any patterns contain hostnames +} + +type muxEntry struct { + h Handler + pattern string +} + +// NewServeMux allocates and returns a new ServeMux. +func NewServeMux() *ServeMux { return new(ServeMux) } + +// DefaultServeMux is the default ServeMux used by Serve. +var DefaultServeMux = &defaultServeMux + +var defaultServeMux ServeMux + +// cleanPath returns the canonical path for p, eliminating . and .. elements. +func cleanPath(p string) string { + if p == "" { + return "/" + } + if p[0] != '/' { + p = "/" + p + } + np := path.Clean(p) + // path.Clean removes trailing slash except for root; + // put the trailing slash back if necessary. + if p[len(p)-1] == '/' && np != "/" { + // Fast path for common case of p being the string we want: + if len(p) == len(np)+1 && strings.HasPrefix(p, np) { + np = p + } else { + np += "/" + } + } + return np +} + +// stripHostPort returns h without any trailing ":". +func stripHostPort(h string) string { + // If no port on host, return unchanged + if !strings.Contains(h, ":") { + return h + } + host, _, err := net.SplitHostPort(h) + if err != nil { + return h // on error, return unchanged + } + return host +} + +// Find a handler on a handler map given a path string. +// Most-specific (longest) pattern wins. +func (mux *ServeMux) match(path string) (h Handler, pattern string) { + // Check for exact match first. + v, ok := mux.m[path] + if ok { + return v.h, v.pattern + } + + // Check for longest valid match. mux.es contains all patterns + // that end in / sorted from longest to shortest. + for _, e := range mux.es { + if strings.HasPrefix(path, e.pattern) { + return e.h, e.pattern + } + } + return nil, "" +} + +// redirectToPathSlash determines if the given path needs appending "/" to it. +// This occurs when a handler for path + "/" was already registered, but +// not for path itself. If the path needs appending to, it creates a new +// URL, setting the path to u.Path + "/" and returning true to indicate so. +func (mux *ServeMux) redirectToPathSlash(host, path string, u *url.URL) (*url.URL, bool) { + mux.mu.RLock() + shouldRedirect := mux.shouldRedirectRLocked(host, path) + mux.mu.RUnlock() + if !shouldRedirect { + return u, false + } + path = path + "/" + u = &url.URL{Path: path, RawQuery: u.RawQuery} + return u, true +} + +// shouldRedirectRLocked reports whether the given path and host should be redirected to +// path+"/". This should happen if a handler is registered for path+"/" but +// not path -- see comments at ServeMux. +func (mux *ServeMux) shouldRedirectRLocked(host, path string) bool { + p := []string{path, host + path} + + for _, c := range p { + if _, exist := mux.m[c]; exist { + return false + } + } + + n := len(path) + if n == 0 { + return false + } + for _, c := range p { + if _, exist := mux.m[c+"/"]; exist { + return path[n-1] != '/' + } + } + + return false +} + +// Handler returns the handler to use for the given request, +// consulting r.Method, r.Host, and r.URL.Path. It always returns +// a non-nil handler. If the path is not in its canonical form, the +// handler will be an internally-generated handler that redirects +// to the canonical path. If the host contains a port, it is ignored +// when matching handlers. +// +// The path and host are used unchanged for CONNECT requests. +// +// Handler also returns the registered pattern that matches the +// request or, in the case of internally-generated redirects, +// the pattern that will match after following the redirect. +// +// If there is no registered handler that applies to the request, +// Handler returns a “page not found” handler and an empty pattern. +func (mux *ServeMux) Handler(r *Request) (h Handler, pattern string) { + + // CONNECT requests are not canonicalized. + if r.Method == "CONNECT" { + // If r.URL.Path is /tree and its handler is not registered, + // the /tree -> /tree/ redirect applies to CONNECT requests + // but the path canonicalization does not. + if u, ok := mux.redirectToPathSlash(r.URL.Host, r.URL.Path, r.URL); ok { + return RedirectHandler(u.String(), StatusMovedPermanently), u.Path + } + + return mux.handler(r.Host, r.URL.Path) + } + + // All other requests have any port stripped and path cleaned + // before passing to mux.handler. + host := stripHostPort(r.Host) + path := cleanPath(r.URL.Path) + + // If the given path is /tree and its handler is not registered, + // redirect for /tree/. + if u, ok := mux.redirectToPathSlash(host, path, r.URL); ok { + return RedirectHandler(u.String(), StatusMovedPermanently), u.Path + } + + if path != r.URL.Path { + _, pattern = mux.handler(host, path) + u := &url.URL{Path: path, RawQuery: r.URL.RawQuery} + return RedirectHandler(u.String(), StatusMovedPermanently), pattern + } + + return mux.handler(host, r.URL.Path) +} + +// handler is the main implementation of Handler. +// The path is known to be in canonical form, except for CONNECT methods. +func (mux *ServeMux) handler(host, path string) (h Handler, pattern string) { + mux.mu.RLock() + defer mux.mu.RUnlock() + + // Host-specific pattern takes precedence over generic ones + if mux.hosts { + h, pattern = mux.match(host + path) + } + if h == nil { + h, pattern = mux.match(path) + } + if h == nil { + h, pattern = NotFoundHandler(), "" + } + return +} + +// ServeHTTP dispatches the request to the handler whose +// pattern most closely matches the request URL. +func (mux *ServeMux) ServeHTTP(w ResponseWriter, r *Request) { + if r.RequestURI == "*" { + if r.ProtoAtLeast(1, 1) { + w.Header().Set("Connection", "close") + } + w.WriteHeader(StatusBadRequest) + return + } + h, _ := mux.Handler(r) + h.ServeHTTP(w, r) +} + +// Handle registers the handler for the given pattern. +// If a handler already exists for pattern, Handle panics. +func (mux *ServeMux) Handle(pattern string, handler Handler) { + mux.mu.Lock() + defer mux.mu.Unlock() + + if pattern == "" { + panic("http: invalid pattern") + } + if handler == nil { + panic("http: nil handler") + } + if _, exist := mux.m[pattern]; exist { + panic("http: multiple registrations for " + pattern) + } + + if mux.m == nil { + mux.m = make(map[string]muxEntry) + } + e := muxEntry{h: handler, pattern: pattern} + mux.m[pattern] = e + if pattern[len(pattern)-1] == '/' { + mux.es = appendSorted(mux.es, e) + } + + if pattern[0] != '/' { + mux.hosts = true + } +} + +func appendSorted(es []muxEntry, e muxEntry) []muxEntry { + n := len(es) + i := sort.Search(n, func(i int) bool { + return len(es[i].pattern) < len(e.pattern) + }) + if i == n { + return append(es, e) + } + // we now know that i points at where we want to insert + es = append(es, muxEntry{}) // try to grow the slice in place, any entry works. + copy(es[i+1:], es[i:]) // Move shorter entries down + es[i] = e + return es +} + +// HandleFunc registers the handler function for the given pattern. +func (mux *ServeMux) HandleFunc(pattern string, handler func(ResponseWriter, *Request)) { + if handler == nil { + panic("http: nil handler") + } + mux.Handle(pattern, HandlerFunc(handler)) +} + +// Handle registers the handler for the given pattern +// in the DefaultServeMux. +// The documentation for ServeMux explains how patterns are matched. +func Handle(pattern string, handler Handler) { DefaultServeMux.Handle(pattern, handler) } + +// HandleFunc registers the handler function for the given pattern +// in the DefaultServeMux. +// The documentation for ServeMux explains how patterns are matched. +func HandleFunc(pattern string, handler func(ResponseWriter, *Request)) { + DefaultServeMux.HandleFunc(pattern, handler) +} + +// Serve accepts incoming HTTP connections on the listener l, +// creating a new service goroutine for each. The service goroutines +// read requests and then call handler to reply to them. +// +// The handler is typically nil, in which case the DefaultServeMux is used. +// +// HTTP/2 support is only enabled if the Listener returns *tls.Conn +// connections and they were configured with "h2" in the TLS +// Config.NextProtos. +// +// Serve always returns a non-nil error. +func Serve(l net.Listener, handler Handler) error { + srv := &Server{Handler: handler} + return srv.Serve(l) +} + +// A Server defines parameters for running an HTTP server. +// The zero value for Server is a valid configuration. +type Server struct { + // Addr optionally specifies the TCP address for the server to listen on, + // in the form "host:port". If empty, ":http" (port 80) is used. + // The service names are defined in RFC 6335 and assigned by IANA. + // See net.Dial for details of the address format. + Addr string + + Handler Handler // handler to invoke, http.DefaultServeMux if nil + + // TLSConfig optionally provides a TLS configuration for use + // by ServeTLS and ListenAndServeTLS. Note that this value is + // cloned by ServeTLS and ListenAndServeTLS, so it's not + // possible to modify the configuration with methods like + // tls.Config.SetSessionTicketKeys. To use + // SetSessionTicketKeys, use Server.Serve with a TLS Listener + // instead. + TLSConfig *tls.Config + + // ReadTimeout is the maximum duration for reading the entire + // request, including the body. A zero or negative value means + // there will be no timeout. + // + // Because ReadTimeout does not let Handlers make per-request + // decisions on each request body's acceptable deadline or + // upload rate, most users will prefer to use + // ReadHeaderTimeout. It is valid to use them both. + ReadTimeout time.Duration + + // ReadHeaderTimeout is the amount of time allowed to read + // request headers. The connection's read deadline is reset + // after reading the headers and the Handler can decide what + // is considered too slow for the body. If ReadHeaderTimeout + // is zero, the value of ReadTimeout is used. If both are + // zero, there is no timeout. + ReadHeaderTimeout time.Duration + + // WriteTimeout is the maximum duration before timing out + // writes of the response. It is reset whenever a new + // request's header is read. Like ReadTimeout, it does not + // let Handlers make decisions on a per-request basis. + // A zero or negative value means there will be no timeout. + WriteTimeout time.Duration + + // IdleTimeout is the maximum amount of time to wait for the + // next request when keep-alives are enabled. If IdleTimeout + // is zero, the value of ReadTimeout is used. If both are + // zero, there is no timeout. + IdleTimeout time.Duration + + // MaxHeaderBytes controls the maximum number of bytes the + // server will read parsing the request header's keys and + // values, including the request line. It does not limit the + // size of the request body. + // If zero, DefaultMaxHeaderBytes is used. + MaxHeaderBytes int + + // ConnState specifies an optional callback function that is + // called when a client connection changes state. See the + // ConnState type and associated constants for details. + ConnState func(net.Conn, ConnState) + + // ErrorLog specifies an optional logger for errors accepting + // connections, unexpected behavior from handlers, and + // underlying FileSystem errors. + // If nil, logging is done via the log package's standard logger. + ErrorLog *log.Logger + + // BaseContext optionally specifies a function that returns + // the base context for incoming requests on this server. + // The provided Listener is the specific Listener that's + // about to start accepting requests. + // If BaseContext is nil, the default is context.Background(). + // If non-nil, it must return a non-nil context. + BaseContext func(net.Listener) context.Context + + // ConnContext optionally specifies a function that modifies + // the context used for a new connection c. The provided ctx + // is derived from the base context and has a ServerContextKey + // value. + ConnContext func(ctx context.Context, c net.Conn) context.Context + + inShutdown atomicBool // true when server is in shutdown + + disableKeepAlives int32 // accessed atomically. + nextProtoOnce sync.Once // guards setupHTTP2_* init + nextProtoErr error // result of http2.ConfigureServer if used + + mu sync.Mutex + listeners map[*net.Listener]struct{} + activeConn map[*conn]struct{} + doneChan chan struct{} + onShutdown []func() + + listenerGroup sync.WaitGroup +} + +func (s *Server) getDoneChan() <-chan struct{} { + s.mu.Lock() + defer s.mu.Unlock() + return s.getDoneChanLocked() +} + +func (s *Server) getDoneChanLocked() chan struct{} { + if s.doneChan == nil { + s.doneChan = make(chan struct{}) + } + return s.doneChan +} + +func (s *Server) closeDoneChanLocked() { + ch := s.getDoneChanLocked() + select { + case <-ch: + // Already closed. Don't close again. + default: + // Safe to close here. We're the only closer, guarded + // by s.mu. + close(ch) + } +} + +// Close immediately closes all active net.Listeners and any +// connections in state StateNew, StateActive, or StateIdle. For a +// graceful shutdown, use Shutdown. +// +// Close does not attempt to close (and does not even know about) +// any hijacked connections, such as WebSockets. +// +// Close returns any error returned from closing the Server's +// underlying Listener(s). +func (srv *Server) Close() error { + srv.inShutdown.setTrue() + srv.mu.Lock() + defer srv.mu.Unlock() + srv.closeDoneChanLocked() + err := srv.closeListenersLocked() + + // Unlock srv.mu while waiting for listenerGroup. + // The group Add and Done calls are made with srv.mu held, + // to avoid adding a new listener in the window between + // us setting inShutdown above and waiting here. + srv.mu.Unlock() + srv.listenerGroup.Wait() + srv.mu.Lock() + + for c := range srv.activeConn { + c.rwc.Close() + delete(srv.activeConn, c) + } + return err +} + +// shutdownPollIntervalMax is the max polling interval when checking +// quiescence during Server.Shutdown. Polling starts with a small +// interval and backs off to the max. +// Ideally we could find a solution that doesn't involve polling, +// but which also doesn't have a high runtime cost (and doesn't +// involve any contentious mutexes), but that is left as an +// exercise for the reader. +const shutdownPollIntervalMax = 500 * time.Millisecond + +// Shutdown gracefully shuts down the server without interrupting any +// active connections. Shutdown works by first closing all open +// listeners, then closing all idle connections, and then waiting +// indefinitely for connections to return to idle and then shut down. +// If the provided context expires before the shutdown is complete, +// Shutdown returns the context's error, otherwise it returns any +// error returned from closing the Server's underlying Listener(s). +// +// When Shutdown is called, Serve, ListenAndServe, and +// ListenAndServeTLS immediately return ErrServerClosed. Make sure the +// program doesn't exit and waits instead for Shutdown to return. +// +// Shutdown does not attempt to close nor wait for hijacked +// connections such as WebSockets. The caller of Shutdown should +// separately notify such long-lived connections of shutdown and wait +// for them to close, if desired. See RegisterOnShutdown for a way to +// register shutdown notification functions. +// +// Once Shutdown has been called on a server, it may not be reused; +// future calls to methods such as Serve will return ErrServerClosed. +func (srv *Server) Shutdown(ctx context.Context) error { + srv.inShutdown.setTrue() + + srv.mu.Lock() + lnerr := srv.closeListenersLocked() + srv.closeDoneChanLocked() + for _, f := range srv.onShutdown { + go f() + } + srv.mu.Unlock() + srv.listenerGroup.Wait() + + pollIntervalBase := time.Millisecond + nextPollInterval := func() time.Duration { + // Add 10% jitter. + interval := pollIntervalBase + time.Duration(rand.Intn(int(pollIntervalBase/10))) + // Double and clamp for next time. + pollIntervalBase *= 2 + if pollIntervalBase > shutdownPollIntervalMax { + pollIntervalBase = shutdownPollIntervalMax + } + return interval + } + + timer := time.NewTimer(nextPollInterval()) + defer timer.Stop() + for { + if srv.closeIdleConns() { + return lnerr + } + select { + case <-ctx.Done(): + return ctx.Err() + case <-timer.C: + timer.Reset(nextPollInterval()) + } + } +} + +// RegisterOnShutdown registers a function to call on Shutdown. +// This can be used to gracefully shutdown connections that have +// undergone ALPN protocol upgrade or that have been hijacked. +// This function should start protocol-specific graceful shutdown, +// but should not wait for shutdown to complete. +func (srv *Server) RegisterOnShutdown(f func()) { + srv.mu.Lock() + srv.onShutdown = append(srv.onShutdown, f) + srv.mu.Unlock() +} + +// closeIdleConns closes all idle connections and reports whether the +// server is quiescent. +func (s *Server) closeIdleConns() bool { + s.mu.Lock() + defer s.mu.Unlock() + quiescent := true + for c := range s.activeConn { + st, unixSec := c.getState() + // Issue 22682: treat StateNew connections as if + // they're idle if we haven't read the first request's + // header in over 5 seconds. + if st == StateNew && unixSec < time.Now().Unix()-5 { + st = StateIdle + } + if st != StateIdle || unixSec == 0 { + // Assume unixSec == 0 means it's a very new + // connection, without state set yet. + quiescent = false + continue + } + c.rwc.Close() + delete(s.activeConn, c) + } + return quiescent +} + +func (s *Server) closeListenersLocked() error { + var err error + for ln := range s.listeners { + if cerr := (*ln).Close(); cerr != nil && err == nil { + err = cerr + } + } + return err +} + +// A ConnState represents the state of a client connection to a server. +// It's used by the optional Server.ConnState hook. +type ConnState int + +const ( + // StateNew represents a new connection that is expected to + // send a request immediately. Connections begin at this + // state and then transition to either StateActive or + // StateClosed. + StateNew ConnState = iota + + // StateActive represents a connection that has read 1 or more + // bytes of a request. The Server.ConnState hook for + // StateActive fires before the request has entered a handler + // and doesn't fire again until the request has been + // handled. After the request is handled, the state + // transitions to StateClosed, StateHijacked, or StateIdle. + // For HTTP/2, StateActive fires on the transition from zero + // to one active request, and only transitions away once all + // active requests are complete. That means that ConnState + // cannot be used to do per-request work; ConnState only notes + // the overall state of the connection. + StateActive + + // StateIdle represents a connection that has finished + // handling a request and is in the keep-alive state, waiting + // for a new request. Connections transition from StateIdle + // to either StateActive or StateClosed. + StateIdle + + // StateHijacked represents a hijacked connection. + // This is a terminal state. It does not transition to StateClosed. + StateHijacked + + // StateClosed represents a closed connection. + // This is a terminal state. Hijacked connections do not + // transition to StateClosed. + StateClosed +) + +var stateName = map[ConnState]string{ + StateNew: "new", + StateActive: "active", + StateIdle: "idle", + StateHijacked: "hijacked", + StateClosed: "closed", +} + +func (c ConnState) String() string { + return stateName[c] +} + +// serverHandler delegates to either the server's Handler or +// DefaultServeMux and also handles "OPTIONS *" requests. +type serverHandler struct { + srv *Server +} + +func (sh serverHandler) ServeHTTP(rw ResponseWriter, req *Request) { + handler := sh.srv.Handler + if handler == nil { + handler = DefaultServeMux + } + if req.RequestURI == "*" && req.Method == "OPTIONS" { + handler = globalOptionsHandler{} + } + + if req.URL != nil && strings.Contains(req.URL.RawQuery, ";") { + var allowQuerySemicolonsInUse int32 + req = req.WithContext(context.WithValue(req.Context(), silenceSemWarnContextKey, func() { + atomic.StoreInt32(&allowQuerySemicolonsInUse, 1) + })) + defer func() { + if atomic.LoadInt32(&allowQuerySemicolonsInUse) == 0 { + sh.srv.logf("http: URL query contains semicolon, which is no longer a supported separator; parts of the query may be stripped when parsed; see golang.org/issue/25192") + } + }() + } + + handler.ServeHTTP(rw, req) +} + +var silenceSemWarnContextKey = &contextKey{"silence-semicolons"} + +// AllowQuerySemicolons returns a handler that serves requests by converting any +// unescaped semicolons in the URL query to ampersands, and invoking the handler h. +// +// This restores the pre-Go 1.17 behavior of splitting query parameters on both +// semicolons and ampersands. (See golang.org/issue/25192). Note that this +// behavior doesn't match that of many proxies, and the mismatch can lead to +// security issues. +// +// AllowQuerySemicolons should be invoked before Request.ParseForm is called. +func AllowQuerySemicolons(h Handler) Handler { + return HandlerFunc(func(w ResponseWriter, r *Request) { + if silenceSemicolonsWarning, ok := r.Context().Value(silenceSemWarnContextKey).(func()); ok { + silenceSemicolonsWarning() + } + if strings.Contains(r.URL.RawQuery, ";") { + r2 := new(Request) + *r2 = *r + r2.URL = new(url.URL) + *r2.URL = *r.URL + r2.URL.RawQuery = strings.ReplaceAll(r.URL.RawQuery, ";", "&") + h.ServeHTTP(w, r2) + } else { + h.ServeHTTP(w, r) + } + }) +} + +// ListenAndServe listens on the TCP network address srv.Addr and then +// calls Serve to handle requests on incoming connections. +// Accepted connections are configured to enable TCP keep-alives. +// +// If srv.Addr is blank, ":http" is used. +// +// ListenAndServe always returns a non-nil error. After Shutdown or Close, +// the returned error is ErrServerClosed. +func (srv *Server) ListenAndServe() error { + if srv.shuttingDown() { + return ErrServerClosed + } + addr := srv.Addr + if addr == "" { + addr = ":http" + } + ln, err := net.Listen("tcp", addr) + if err != nil { + return err + } + return srv.Serve(ln) +} + +var testHookServerServe func(*Server, net.Listener) // used if non-nil + +// ErrServerClosed is returned by the Server's Serve, ServeTLS, ListenAndServe, +// and ListenAndServeTLS methods after a call to Shutdown or Close. +var ErrServerClosed = errors.New("http: Server closed") + +// Serve accepts incoming connections on the Listener l, creating a +// new service goroutine for each. The service goroutines read requests and +// then call srv.Handler to reply to them. +// +// HTTP/2 support is only enabled if the Listener returns *tls.Conn +// connections and they were configured with "h2" in the TLS +// Config.NextProtos. +// +// Serve always returns a non-nil error and closes l. +// After Shutdown or Close, the returned error is ErrServerClosed. +func (srv *Server) Serve(l net.Listener) error { + if fn := testHookServerServe; fn != nil { + fn(srv, l) // call hook with unwrapped listener + } + + origListener := l + l = &onceCloseListener{Listener: l} + defer l.Close() + + if !srv.trackListener(&l, true) { + return ErrServerClosed + } + defer srv.trackListener(&l, false) + + baseCtx := context.Background() + if srv.BaseContext != nil { + baseCtx = srv.BaseContext(origListener) + if baseCtx == nil { + panic("BaseContext returned a nil context") + } + } + + var tempDelay time.Duration // how long to sleep on accept failure + + ctx := context.WithValue(baseCtx, ServerContextKey, srv) + for { + rw, err := l.Accept() + if err != nil { + select { + case <-srv.getDoneChan(): + return ErrServerClosed + default: + } + if ne, ok := err.(net.Error); ok && ne.Temporary() { + if tempDelay == 0 { + tempDelay = 5 * time.Millisecond + } else { + tempDelay *= 2 + } + if max := 1 * time.Second; tempDelay > max { + tempDelay = max + } + srv.logf("http: Accept error: %v; retrying in %v", err, tempDelay) + time.Sleep(tempDelay) + continue + } + return err + } + connCtx := ctx + if cc := srv.ConnContext; cc != nil { + connCtx = cc(connCtx, rw) + if connCtx == nil { + panic("ConnContext returned nil") + } + } + tempDelay = 0 + c := srv.newConn(rw) + c.setState(c.rwc, StateNew, runHooks) // before Serve can return + go c.serve(connCtx) + } +} + +// trackListener adds or removes a net.Listener to the set of tracked +// listeners. +// +// We store a pointer to interface in the map set, in case the +// net.Listener is not comparable. This is safe because we only call +// trackListener via Serve and can track+defer untrack the same +// pointer to local variable there. We never need to compare a +// Listener from another caller. +// +// It reports whether the server is still up (not Shutdown or Closed). +func (s *Server) trackListener(ln *net.Listener, add bool) bool { + s.mu.Lock() + defer s.mu.Unlock() + if s.listeners == nil { + s.listeners = make(map[*net.Listener]struct{}) + } + if add { + if s.shuttingDown() { + return false + } + s.listeners[ln] = struct{}{} + s.listenerGroup.Add(1) + } else { + delete(s.listeners, ln) + s.listenerGroup.Done() + } + return true +} + +func (s *Server) trackConn(c *conn, add bool) { + s.mu.Lock() + defer s.mu.Unlock() + if s.activeConn == nil { + s.activeConn = make(map[*conn]struct{}) + } + if add { + s.activeConn[c] = struct{}{} + } else { + delete(s.activeConn, c) + } +} + +func (s *Server) idleTimeout() time.Duration { + if s.IdleTimeout != 0 { + return s.IdleTimeout + } + return s.ReadTimeout +} + +func (s *Server) readHeaderTimeout() time.Duration { + if s.ReadHeaderTimeout != 0 { + return s.ReadHeaderTimeout + } + return s.ReadTimeout +} + +func (s *Server) doKeepAlives() bool { + return atomic.LoadInt32(&s.disableKeepAlives) == 0 && !s.shuttingDown() +} + +func (s *Server) shuttingDown() bool { + return s.inShutdown.isSet() +} + +// SetKeepAlivesEnabled controls whether HTTP keep-alives are enabled. +// By default, keep-alives are always enabled. Only very +// resource-constrained environments or servers in the process of +// shutting down should disable them. +func (srv *Server) SetKeepAlivesEnabled(v bool) { + if v { + atomic.StoreInt32(&srv.disableKeepAlives, 0) + return + } + atomic.StoreInt32(&srv.disableKeepAlives, 1) + + // Close idle HTTP/1 conns: + srv.closeIdleConns() + + // TODO: Issue 26303: close HTTP/2 conns as soon as they become idle. +} + +func (s *Server) logf(format string, args ...any) { + if s.ErrorLog != nil { + s.ErrorLog.Printf(format, args...) + } else { + log.Printf(format, args...) + } +} + +// logf prints to the ErrorLog of the *Server associated with request r +// via ServerContextKey. If there's no associated server, or if ErrorLog +// is nil, logging is done via the log package's standard logger. +func logf(r *Request, format string, args ...any) { + s, _ := r.Context().Value(ServerContextKey).(*Server) + if s != nil && s.ErrorLog != nil { + s.ErrorLog.Printf(format, args...) + } else { + log.Printf(format, args...) + } +} + +// ListenAndServe listens on the TCP network address addr and then calls +// Serve with handler to handle requests on incoming connections. +// Accepted connections are configured to enable TCP keep-alives. +// +// The handler is typically nil, in which case the DefaultServeMux is used. +// +// ListenAndServe always returns a non-nil error. +func ListenAndServe(addr string, handler Handler) error { + server := &Server{Addr: addr, Handler: handler} + return server.ListenAndServe() +} + +// onceCloseListener wraps a net.Listener, protecting it from +// multiple Close calls. +type onceCloseListener struct { + net.Listener + once sync.Once + closeErr error +} + +func (oc *onceCloseListener) Close() error { + oc.once.Do(oc.close) + return oc.closeErr +} + +func (oc *onceCloseListener) close() { oc.closeErr = oc.Listener.Close() } + +// globalOptionsHandler responds to "OPTIONS *" requests. +type globalOptionsHandler struct{} + +func (globalOptionsHandler) ServeHTTP(w ResponseWriter, r *Request) { + w.Header().Set("Content-Length", "0") + if r.ContentLength != 0 { + // Read up to 4KB of OPTIONS body (as mentioned in the + // spec as being reserved for future use), but anything + // over that is considered a waste of server resources + // (or an attack) and we abort and close the connection, + // courtesy of MaxBytesReader's EOF behavior. + mb := MaxBytesReader(w, r.Body, 4<<10) + io.Copy(io.Discard, mb) + } +} + +// loggingConn is used for debugging. +type loggingConn struct { + name string + net.Conn +} + +var ( + uniqNameMu sync.Mutex + uniqNameNext = make(map[string]int) +) + +func newLoggingConn(baseName string, c net.Conn) net.Conn { + uniqNameMu.Lock() + defer uniqNameMu.Unlock() + uniqNameNext[baseName]++ + return &loggingConn{ + name: fmt.Sprintf("%s-%d", baseName, uniqNameNext[baseName]), + Conn: c, + } +} + +func (c *loggingConn) Write(p []byte) (n int, err error) { + log.Printf("%s.Write(%d) = ....", c.name, len(p)) + n, err = c.Conn.Write(p) + log.Printf("%s.Write(%d) = %d, %v", c.name, len(p), n, err) + return +} + +func (c *loggingConn) Read(p []byte) (n int, err error) { + log.Printf("%s.Read(%d) = ....", c.name, len(p)) + n, err = c.Conn.Read(p) + log.Printf("%s.Read(%d) = %d, %v", c.name, len(p), n, err) + return +} + +func (c *loggingConn) Close() (err error) { + log.Printf("%s.Close() = ...", c.name) + err = c.Conn.Close() + log.Printf("%s.Close() = %v", c.name, err) + return +} + +// checkConnErrorWriter writes to c.rwc and records any write errors to c.werr. +// It only contains one field (and a pointer field at that), so it +// fits in an interface value without an extra allocation. +type checkConnErrorWriter struct { + c *conn +} + +func (w checkConnErrorWriter) Write(p []byte) (n int, err error) { + n, err = w.c.rwc.Write(p) + if err != nil && w.c.werr == nil { + w.c.werr = err + w.c.cancelCtx() + } + return +} + +func numLeadingCRorLF(v []byte) (n int) { + for _, b := range v { + if b == '\r' || b == '\n' { + n++ + continue + } + break + } + return + +} + +func strSliceContains(ss []string, s string) bool { + for _, v := range ss { + if v == s { + return true + } + } + return false +} + +// tlsRecordHeaderLooksLikeHTTP reports whether a TLS record header +// looks like it might've been a misdirected plaintext HTTP request. +func tlsRecordHeaderLooksLikeHTTP(hdr [5]byte) bool { + switch string(hdr[:]) { + case "GET /", "HEAD ", "POST ", "PUT /", "OPTIO": + return true + } + return false +} + +// MaxBytesHandler returns a Handler that runs h with its ResponseWriter and Request.Body wrapped by a MaxBytesReader. +func MaxBytesHandler(h Handler, n int64) Handler { + return HandlerFunc(func(w ResponseWriter, r *Request) { + r2 := *r + r2.Body = MaxBytesReader(w, r.Body, n) + h.ServeHTTP(w, &r2) + }) +} diff --git a/src/net/http/sniff.go b/src/net/http/sniff.go new file mode 100644 index 0000000000..3fee91284b --- /dev/null +++ b/src/net/http/sniff.go @@ -0,0 +1,306 @@ +// TINYGO: The following is copied from Go 1.19.3 official implementation. + +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package http + +import ( + "bytes" + "encoding/binary" +) + +// The algorithm uses at most sniffLen bytes to make its decision. +const sniffLen = 512 + +// DetectContentType implements the algorithm described +// at https://mimesniff.spec.whatwg.org/ to determine the +// Content-Type of the given data. It considers at most the +// first 512 bytes of data. DetectContentType always returns +// a valid MIME type: if it cannot determine a more specific one, it +// returns "application/octet-stream". +func DetectContentType(data []byte) string { + if len(data) > sniffLen { + data = data[:sniffLen] + } + + // Index of the first non-whitespace byte in data. + firstNonWS := 0 + for ; firstNonWS < len(data) && isWS(data[firstNonWS]); firstNonWS++ { + } + + for _, sig := range sniffSignatures { + if ct := sig.match(data, firstNonWS); ct != "" { + return ct + } + } + + return "application/octet-stream" // fallback +} + +// isWS reports whether the provided byte is a whitespace byte (0xWS) +// as defined in https://mimesniff.spec.whatwg.org/#terminology. +func isWS(b byte) bool { + switch b { + case '\t', '\n', '\x0c', '\r', ' ': + return true + } + return false +} + +// isTT reports whether the provided byte is a tag-terminating byte (0xTT) +// as defined in https://mimesniff.spec.whatwg.org/#terminology. +func isTT(b byte) bool { + switch b { + case ' ', '>': + return true + } + return false +} + +type sniffSig interface { + // match returns the MIME type of the data, or "" if unknown. + match(data []byte, firstNonWS int) string +} + +// Data matching the table in section 6. +var sniffSignatures = []sniffSig{ + htmlSig("= 0 || t.Body == nil { // redundant checks; caller did them + return false + } + if t.Method == "CONNECT" { + return false + } + if requestMethodUsuallyLacksBody(t.Method) { + // Only probe the Request.Body for GET/HEAD/DELETE/etc + // requests, because it's only those types of requests + // that confuse servers. + t.probeRequestBody() // adjusts t.Body, t.ContentLength + return t.Body != nil + } + // For all other request types (PUT, POST, PATCH, or anything + // made-up we've never heard of), assume it's normal and the server + // can deal with a chunked request body. Maybe we'll adjust this + // later. + return true +} + +// probeRequestBody reads a byte from t.Body to see whether it's empty +// (returns io.EOF right away). +// +// But because we've had problems with this blocking users in the past +// (issue 17480) when the body is a pipe (perhaps waiting on the response +// headers before the pipe is fed data), we need to be careful and bound how +// long we wait for it. This delay will only affect users if all the following +// are true: +// - the request body blocks +// - the content length is not set (or set to -1) +// - the method doesn't usually have a body (GET, HEAD, DELETE, ...) +// - there is no transfer-encoding=chunked already set. +// +// In other words, this delay will not normally affect anybody, and there +// are workarounds if it does. +func (t *transferWriter) probeRequestBody() { + t.ByteReadCh = make(chan readResult, 1) + go func(body io.Reader) { + var buf [1]byte + var rres readResult + rres.n, rres.err = body.Read(buf[:]) + if rres.n == 1 { + rres.b = buf[0] + } + t.ByteReadCh <- rres + close(t.ByteReadCh) + }(t.Body) + timer := time.NewTimer(200 * time.Millisecond) + select { + case rres := <-t.ByteReadCh: + timer.Stop() + if rres.n == 0 && rres.err == io.EOF { + // It was empty. + t.Body = nil + t.ContentLength = 0 + } else if rres.n == 1 { + if rres.err != nil { + t.Body = io.MultiReader(&byteReader{b: rres.b}, errorReader{rres.err}) + } else { + t.Body = io.MultiReader(&byteReader{b: rres.b}, t.Body) + } + } else if rres.err != nil { + t.Body = errorReader{rres.err} + } + case <-timer.C: + // Too slow. Don't wait. Read it later, and keep + // assuming that this is ContentLength == -1 + // (unknown), which means we'll send a + // "Transfer-Encoding: chunked" header. + t.Body = io.MultiReader(finishAsyncByteRead{t}, t.Body) + // Request that Request.Write flush the headers to the + // network before writing the body, since our body may not + // become readable until it's seen the response headers. + t.FlushHeaders = true + } +} + +func noResponseBodyExpected(requestMethod string) bool { + return requestMethod == "HEAD" +} + +func (t *transferWriter) shouldSendContentLength() bool { + if chunked(t.TransferEncoding) { + return false + } + if t.ContentLength > 0 { + return true + } + if t.ContentLength < 0 { + return false + } + // Many servers expect a Content-Length for these methods + if t.Method == "POST" || t.Method == "PUT" || t.Method == "PATCH" { + return true + } + if t.ContentLength == 0 && isIdentity(t.TransferEncoding) { + if t.Method == "GET" || t.Method == "HEAD" { + return false + } + return true + } + + return false +} + +func (t *transferWriter) writeHeader(w io.Writer) error { + if t.Close && !hasToken(t.Header.get("Connection"), "close") { + if _, err := io.WriteString(w, "Connection: close\r\n"); err != nil { + return err + } + } + + // Write Content-Length and/or Transfer-Encoding whose values are a + // function of the sanitized field triple (Body, ContentLength, + // TransferEncoding) + if t.shouldSendContentLength() { + if _, err := io.WriteString(w, "Content-Length: "); err != nil { + return err + } + if _, err := io.WriteString(w, strconv.FormatInt(t.ContentLength, 10)+"\r\n"); err != nil { + return err + } + } else if chunked(t.TransferEncoding) { + if _, err := io.WriteString(w, "Transfer-Encoding: chunked\r\n"); err != nil { + return err + } + } + + // Write Trailer header + if t.Trailer != nil { + keys := make([]string, 0, len(t.Trailer)) + for k := range t.Trailer { + k = CanonicalHeaderKey(k) + switch k { + case "Transfer-Encoding", "Trailer", "Content-Length": + return badStringError("invalid Trailer key", k) + } + keys = append(keys, k) + } + if len(keys) > 0 { + sort.Strings(keys) + // TODO: could do better allocation-wise here, but trailers are rare, + // so being lazy for now. + if _, err := io.WriteString(w, "Trailer: "+strings.Join(keys, ",")+"\r\n"); err != nil { + return err + } + } + } + + return nil +} + +// always closes t.BodyCloser +func (t *transferWriter) writeBody(w io.Writer) (err error) { + var ncopy int64 + closed := false + defer func() { + if closed || t.BodyCloser == nil { + return + } + if closeErr := t.BodyCloser.Close(); closeErr != nil && err == nil { + err = closeErr + } + }() + + // Write body. We "unwrap" the body first if it was wrapped in a + // nopCloser or readTrackingBody. This is to ensure that we can take advantage of + // OS-level optimizations in the event that the body is an + // *os.File. + if t.Body != nil { + var body = t.unwrapBody() + if chunked(t.TransferEncoding) { + if bw, ok := w.(*bufio.Writer); ok && !t.IsResponse { + w = &internal.FlushAfterChunkWriter{Writer: bw} + } + cw := internal.NewChunkedWriter(w) + _, err = t.doBodyCopy(cw, body) + if err == nil { + err = cw.Close() + } + } else if t.ContentLength == -1 { + dst := w + if t.Method == "CONNECT" { + dst = bufioFlushWriter{dst} + } + ncopy, err = t.doBodyCopy(dst, body) + } else { + ncopy, err = t.doBodyCopy(w, io.LimitReader(body, t.ContentLength)) + if err != nil { + return err + } + var nextra int64 + nextra, err = t.doBodyCopy(io.Discard, body) + ncopy += nextra + } + if err != nil { + return err + } + } + if t.BodyCloser != nil { + closed = true + if err := t.BodyCloser.Close(); err != nil { + return err + } + } + + if !t.ResponseToHEAD && t.ContentLength != -1 && t.ContentLength != ncopy { + return fmt.Errorf("http: ContentLength=%d with Body length %d", + t.ContentLength, ncopy) + } + + if chunked(t.TransferEncoding) { + // Write Trailer header + if t.Trailer != nil { + if err := t.Trailer.Write(w); err != nil { + return err + } + } + // Last chunk, empty trailer + _, err = io.WriteString(w, "\r\n") + } + return err +} + +// doBodyCopy wraps a copy operation, with any resulting error also +// being saved in bodyReadError. +// +// This function is only intended for use in writeBody. +func (t *transferWriter) doBodyCopy(dst io.Writer, src io.Reader) (n int64, err error) { + n, err = io.Copy(dst, src) + if err != nil && err != io.EOF { + t.bodyReadError = err + } + return +} + +// unwrapBodyReader unwraps the body's inner reader if it's a +// nopCloser. This is to ensure that body writes sourced from local +// files (*os.File types) are properly optimized. +// +// This function is only intended for use in writeBody. +func (t *transferWriter) unwrapBody() io.Reader { + if r, ok := unwrapNopCloser(t.Body); ok { + return r + } + if r, ok := t.Body.(*readTrackingBody); ok { + r.didRead = true + return r.ReadCloser + } + return t.Body +} + +type transferReader struct { + // Input + Header Header + StatusCode int + RequestMethod string + ProtoMajor int + ProtoMinor int + // Output + Body io.ReadCloser + ContentLength int64 + Chunked bool + Close bool + Trailer Header +} + +func (t *transferReader) protoAtLeast(m, n int) bool { + return t.ProtoMajor > m || (t.ProtoMajor == m && t.ProtoMinor >= n) +} + +// bodyAllowedForStatus reports whether a given response status code +// permits a body. See RFC 7230, section 3.3. +func bodyAllowedForStatus(status int) bool { + switch { + case status >= 100 && status <= 199: + return false + case status == 204: + return false + case status == 304: + return false + } + return true +} + +var ( + suppressedHeaders304 = []string{"Content-Type", "Content-Length", "Transfer-Encoding"} + suppressedHeadersNoBody = []string{"Content-Length", "Transfer-Encoding"} + excludedHeadersNoBody = map[string]bool{"Content-Length": true, "Transfer-Encoding": true} +) + +func suppressedHeaders(status int) []string { + switch { + case status == 304: + // RFC 7232 section 4.1 + return suppressedHeaders304 + case !bodyAllowedForStatus(status): + return suppressedHeadersNoBody + } + return nil +} + +// msg is *Request or *Response. +func readTransfer(msg any, r *bufio.Reader, onEOF func()) (err error) { + t := &transferReader{RequestMethod: "GET"} + + // TINYGO: Added onEOF func to be called when response body is closed + // TINYGO: so we can clean up the connection (r) + + // Unify input + isResponse := false + switch rr := msg.(type) { + case *Response: + t.Header = rr.Header + t.StatusCode = rr.StatusCode + t.ProtoMajor = rr.ProtoMajor + t.ProtoMinor = rr.ProtoMinor + t.Close = shouldClose(t.ProtoMajor, t.ProtoMinor, t.Header, true) + isResponse = true + if rr.Request != nil { + t.RequestMethod = rr.Request.Method + } + case *Request: + t.Header = rr.Header + t.RequestMethod = rr.Method + t.ProtoMajor = rr.ProtoMajor + t.ProtoMinor = rr.ProtoMinor + // Transfer semantics for Requests are exactly like those for + // Responses with status code 200, responding to a GET method + t.StatusCode = 200 + t.Close = rr.Close + default: + panic("unexpected type") + } + + // Default to HTTP/1.1 + if t.ProtoMajor == 0 && t.ProtoMinor == 0 { + t.ProtoMajor, t.ProtoMinor = 1, 1 + } + + // Transfer-Encoding: chunked, and overriding Content-Length. + if err := t.parseTransferEncoding(); err != nil { + return err + } + + realLength, err := fixLength(isResponse, t.StatusCode, t.RequestMethod, t.Header, t.Chunked) + if err != nil { + return err + } + if isResponse && t.RequestMethod == "HEAD" { + if n, err := parseContentLength(t.Header.get("Content-Length")); err != nil { + return err + } else { + t.ContentLength = n + } + } else { + t.ContentLength = realLength + } + + // Trailer + t.Trailer, err = fixTrailer(t.Header, t.Chunked) + if err != nil { + return err + } + + // If there is no Content-Length or chunked Transfer-Encoding on a *Response + // and the status is not 1xx, 204 or 304, then the body is unbounded. + // See RFC 7230, section 3.3. + switch msg.(type) { + case *Response: + if realLength == -1 && !t.Chunked && bodyAllowedForStatus(t.StatusCode) { + // Unbounded body. + t.Close = true + } + } + + // Prepare body reader. ContentLength < 0 means chunked encoding + // or close connection when finished, since multipart is not supported yet + switch { + case t.Chunked: + if noResponseBodyExpected(t.RequestMethod) || !bodyAllowedForStatus(t.StatusCode) { + t.Body = NoBody + } else { + t.Body = &body{src: internal.NewChunkedReader(r), hdr: msg, r: r, closing: t.Close, onHitEOF: onEOF} + } + case realLength == 0: + t.Body = NoBody + case realLength > 0: + t.Body = &body{src: io.LimitReader(r, realLength), closing: t.Close, onHitEOF: onEOF} + default: + // realLength < 0, i.e. "Content-Length" not mentioned in header + if t.Close { + // Close semantics (i.e. HTTP/1.0) + t.Body = &body{src: r, closing: t.Close, onHitEOF: onEOF} + } else { + // Persistent connection (i.e. HTTP/1.1) + t.Body = NoBody + } + } + + // Unify output + switch rr := msg.(type) { + case *Request: + rr.Body = t.Body + rr.ContentLength = t.ContentLength + if t.Chunked { + rr.TransferEncoding = []string{"chunked"} + } + rr.Close = t.Close + rr.Trailer = t.Trailer + case *Response: + rr.Body = t.Body + rr.ContentLength = t.ContentLength + if t.Chunked { + rr.TransferEncoding = []string{"chunked"} + } + rr.Close = t.Close + rr.Trailer = t.Trailer + } + + return nil +} + +// Checks whether chunked is part of the encodings stack +func chunked(te []string) bool { return len(te) > 0 && te[0] == "chunked" } + +// Checks whether the encoding is explicitly "identity". +func isIdentity(te []string) bool { return len(te) == 1 && te[0] == "identity" } + +// unsupportedTEError reports unsupported transfer-encodings. +type unsupportedTEError struct { + err string +} + +func (uste *unsupportedTEError) Error() string { + return uste.err +} + +// isUnsupportedTEError checks if the error is of type +// unsupportedTEError. It is usually invoked with a non-nil err. +func isUnsupportedTEError(err error) bool { + _, ok := err.(*unsupportedTEError) + return ok +} + +// parseTransferEncoding sets t.Chunked based on the Transfer-Encoding header. +func (t *transferReader) parseTransferEncoding() error { + raw, present := t.Header["Transfer-Encoding"] + if !present { + return nil + } + delete(t.Header, "Transfer-Encoding") + + // Issue 12785; ignore Transfer-Encoding on HTTP/1.0 requests. + if !t.protoAtLeast(1, 1) { + return nil + } + + // Like nginx, we only support a single Transfer-Encoding header field, and + // only if set to "chunked". This is one of the most security sensitive + // surfaces in HTTP/1.1 due to the risk of request smuggling, so we keep it + // strict and simple. + if len(raw) != 1 { + return &unsupportedTEError{fmt.Sprintf("too many transfer encodings: %q", raw)} + } + if !ascii.EqualFold(raw[0], "chunked") { + return &unsupportedTEError{fmt.Sprintf("unsupported transfer encoding: %q", raw[0])} + } + + // RFC 7230 3.3.2 says "A sender MUST NOT send a Content-Length header field + // in any message that contains a Transfer-Encoding header field." + // + // but also: "If a message is received with both a Transfer-Encoding and a + // Content-Length header field, the Transfer-Encoding overrides the + // Content-Length. Such a message might indicate an attempt to perform + // request smuggling (Section 9.5) or response splitting (Section 9.4) and + // ought to be handled as an error. A sender MUST remove the received + // Content-Length field prior to forwarding such a message downstream." + // + // Reportedly, these appear in the wild. + delete(t.Header, "Content-Length") + + t.Chunked = true + return nil +} + +// Determine the expected body length, using RFC 7230 Section 3.3. This +// function is not a method, because ultimately it should be shared by +// ReadResponse and ReadRequest. +func fixLength(isResponse bool, status int, requestMethod string, header Header, chunked bool) (int64, error) { + isRequest := !isResponse + contentLens := header["Content-Length"] + + // Hardening against HTTP request smuggling + if len(contentLens) > 1 { + // Per RFC 7230 Section 3.3.2, prevent multiple + // Content-Length headers if they differ in value. + // If there are dups of the value, remove the dups. + // See Issue 16490. + first := textproto.TrimString(contentLens[0]) + for _, ct := range contentLens[1:] { + if first != textproto.TrimString(ct) { + return 0, fmt.Errorf("http: message cannot contain multiple Content-Length headers; got %q", contentLens) + } + } + + // deduplicate Content-Length + header.Del("Content-Length") + header.Add("Content-Length", first) + + contentLens = header["Content-Length"] + } + + // Logic based on response type or status + if noResponseBodyExpected(requestMethod) { + // For HTTP requests, as part of hardening against request + // smuggling (RFC 7230), don't allow a Content-Length header for + // methods which don't permit bodies. As an exception, allow + // exactly one Content-Length header if its value is "0". + if isRequest && len(contentLens) > 0 && !(len(contentLens) == 1 && contentLens[0] == "0") { + return 0, fmt.Errorf("http: method cannot contain a Content-Length; got %q", contentLens) + } + return 0, nil + } + if status/100 == 1 { + return 0, nil + } + switch status { + case 204, 304: + return 0, nil + } + + // Logic based on Transfer-Encoding + if chunked { + return -1, nil + } + + // Logic based on Content-Length + var cl string + if len(contentLens) == 1 { + cl = textproto.TrimString(contentLens[0]) + } + if cl != "" { + n, err := parseContentLength(cl) + if err != nil { + return -1, err + } + return n, nil + } + header.Del("Content-Length") + + if isRequest { + // RFC 7230 neither explicitly permits nor forbids an + // entity-body on a GET request so we permit one if + // declared, but we default to 0 here (not -1 below) + // if there's no mention of a body. + // Likewise, all other request methods are assumed to have + // no body if neither Transfer-Encoding chunked nor a + // Content-Length are set. + return 0, nil + } + + // Body-EOF logic based on other methods (like closing, or chunked coding) + return -1, nil +} + +// Determine whether to hang up after sending a request and body, or +// receiving a response and body +// 'header' is the request headers. +func shouldClose(major, minor int, header Header, removeCloseHeader bool) bool { + if major < 1 { + return true + } + + conv := header["Connection"] + hasClose := httpguts.HeaderValuesContainsToken(conv, "close") + if major == 1 && minor == 0 { + return hasClose || !httpguts.HeaderValuesContainsToken(conv, "keep-alive") + } + + if hasClose && removeCloseHeader { + header.Del("Connection") + } + + return hasClose +} + +// Parse the trailer header. +func fixTrailer(header Header, chunked bool) (Header, error) { + vv, ok := header["Trailer"] + if !ok { + return nil, nil + } + if !chunked { + // Trailer and no chunking: + // this is an invalid use case for trailer header. + // Nevertheless, no error will be returned and we + // let users decide if this is a valid HTTP message. + // The Trailer header will be kept in Response.Header + // but not populate Response.Trailer. + // See issue #27197. + return nil, nil + } + header.Del("Trailer") + + trailer := make(Header) + var err error + for _, v := range vv { + foreachHeaderElement(v, func(key string) { + key = CanonicalHeaderKey(key) + switch key { + case "Transfer-Encoding", "Trailer", "Content-Length": + if err == nil { + err = badStringError("bad trailer key", key) + return + } + } + trailer[key] = nil + }) + } + if err != nil { + return nil, err + } + if len(trailer) == 0 { + return nil, nil + } + return trailer, nil +} + +// body turns a Reader into a ReadCloser. +// Close ensures that the body has been fully read +// and then reads the trailer if necessary. +type body struct { + src io.Reader + hdr any // non-nil (Response or Request) value means read trailer + r *bufio.Reader // underlying wire-format reader for the trailer + closing bool // is the connection to be closed after reading body? + doEarlyClose bool // whether Close should stop early + + mu sync.Mutex // guards following, and calls to Read and Close + sawEOF bool + closed bool + earlyClose bool // Close called and we didn't read to the end of src + onHitEOF func() // if non-nil, func to call when EOF is Read +} + +// ErrBodyReadAfterClose is returned when reading a Request or Response +// Body after the body has been closed. This typically happens when the body is +// read after an HTTP Handler calls WriteHeader or Write on its +// ResponseWriter. +var ErrBodyReadAfterClose = errors.New("http: invalid Read on closed Body") + +func (b *body) Read(p []byte) (n int, err error) { + b.mu.Lock() + defer b.mu.Unlock() + if b.closed { + return 0, ErrBodyReadAfterClose + } + return b.readLocked(p) +} + +// Must hold b.mu. +func (b *body) readLocked(p []byte) (n int, err error) { + if b.sawEOF { + return 0, io.EOF + } + n, err = b.src.Read(p) + + if err == io.EOF { + b.sawEOF = true + // Chunked case. Read the trailer. + if b.hdr != nil { + if e := b.readTrailer(); e != nil { + err = e + // Something went wrong in the trailer, we must not allow any + // further reads of any kind to succeed from body, nor any + // subsequent requests on the server connection. See + // golang.org/issue/12027 + b.sawEOF = false + b.closed = true + } + b.hdr = nil + } else { + // If the server declared the Content-Length, our body is a LimitedReader + // and we need to check whether this EOF arrived early. + if lr, ok := b.src.(*io.LimitedReader); ok && lr.N > 0 { + err = io.ErrUnexpectedEOF + } + } + } + + // If we can return an EOF here along with the read data, do + // so. This is optional per the io.Reader contract, but doing + // so helps the HTTP transport code recycle its connection + // earlier (since it will see this EOF itself), even if the + // client doesn't do future reads or Close. + if err == nil && n > 0 { + if lr, ok := b.src.(*io.LimitedReader); ok && lr.N == 0 { + err = io.EOF + b.sawEOF = true + } + } + + if b.sawEOF && b.onHitEOF != nil { + b.onHitEOF() + } + + return n, err +} + +var ( + singleCRLF = []byte("\r\n") + doubleCRLF = []byte("\r\n\r\n") +) + +func seeUpcomingDoubleCRLF(r *bufio.Reader) bool { + for peekSize := 4; ; peekSize++ { + // This loop stops when Peek returns an error, + // which it does when r's buffer has been filled. + buf, err := r.Peek(peekSize) + if bytes.HasSuffix(buf, doubleCRLF) { + return true + } + if err != nil { + break + } + } + return false +} + +var errTrailerEOF = errors.New("http: unexpected EOF reading trailer") + +func (b *body) readTrailer() error { + // The common case, since nobody uses trailers. + buf, err := b.r.Peek(2) + if bytes.Equal(buf, singleCRLF) { + b.r.Discard(2) + return nil + } + if len(buf) < 2 { + return errTrailerEOF + } + if err != nil { + return err + } + + // Make sure there's a header terminator coming up, to prevent + // a DoS with an unbounded size Trailer. It's not easy to + // slip in a LimitReader here, as textproto.NewReader requires + // a concrete *bufio.Reader. Also, we can't get all the way + // back up to our conn's LimitedReader that *might* be backing + // this bufio.Reader. Instead, a hack: we iteratively Peek up + // to the bufio.Reader's max size, looking for a double CRLF. + // This limits the trailer to the underlying buffer size, typically 4kB. + if !seeUpcomingDoubleCRLF(b.r) { + return errors.New("http: suspiciously long trailer after chunked body") + } + + hdr, err := textproto.NewReader(b.r).ReadMIMEHeader() + if err != nil { + if err == io.EOF { + return errTrailerEOF + } + return err + } + switch rr := b.hdr.(type) { + case *Request: + mergeSetHeader(&rr.Trailer, Header(hdr)) + case *Response: + mergeSetHeader(&rr.Trailer, Header(hdr)) + } + return nil +} + +func mergeSetHeader(dst *Header, src Header) { + if *dst == nil { + *dst = src + return + } + for k, vv := range src { + (*dst)[k] = vv + } +} + +// unreadDataSizeLocked returns the number of bytes of unread input. +// It returns -1 if unknown. +// b.mu must be held. +func (b *body) unreadDataSizeLocked() int64 { + if lr, ok := b.src.(*io.LimitedReader); ok { + return lr.N + } + return -1 +} + +func (b *body) Close() error { + b.mu.Lock() + defer b.mu.Unlock() + if b.closed { + return nil + } + var err error + switch { + case b.sawEOF: + // Already saw EOF, so no need going to look for it. + case b.hdr == nil && b.closing: + // no trailer and closing the connection next. + // no point in reading to EOF. + case b.doEarlyClose: + // Read up to maxPostHandlerReadBytes bytes of the body, looking + // for EOF (and trailers), so we can re-use this connection. + if lr, ok := b.src.(*io.LimitedReader); ok && lr.N > maxPostHandlerReadBytes { + // There was a declared Content-Length, and we have more bytes remaining + // than our maxPostHandlerReadBytes tolerance. So, give up. + b.earlyClose = true + } else { + var n int64 + // Consume the body, or, which will also lead to us reading + // the trailer headers after the body, if present. + n, err = io.CopyN(io.Discard, bodyLocked{b}, maxPostHandlerReadBytes) + if err == io.EOF { + err = nil + } + if n == maxPostHandlerReadBytes { + b.earlyClose = true + } + } + default: + // Fully consume the body, which will also lead to us reading + // the trailer headers after the body, if present. + _, err = io.Copy(io.Discard, bodyLocked{b}) + } + b.closed = true + return err +} + +func (b *body) didEarlyClose() bool { + b.mu.Lock() + defer b.mu.Unlock() + return b.earlyClose +} + +// bodyRemains reports whether future Read calls might +// yield data. +func (b *body) bodyRemains() bool { + b.mu.Lock() + defer b.mu.Unlock() + return !b.sawEOF +} + +func (b *body) registerOnHitEOF(fn func()) { + b.mu.Lock() + defer b.mu.Unlock() + b.onHitEOF = fn +} + +// bodyLocked is an io.Reader reading from a *body when its mutex is +// already held. +type bodyLocked struct { + b *body +} + +func (bl bodyLocked) Read(p []byte) (n int, err error) { + if bl.b.closed { + return 0, ErrBodyReadAfterClose + } + return bl.b.readLocked(p) +} + +// parseContentLength trims whitespace from s and returns -1 if no value +// is set, or the value if it's >= 0. +func parseContentLength(cl string) (int64, error) { + cl = textproto.TrimString(cl) + if cl == "" { + return -1, nil + } + n, err := strconv.ParseUint(cl, 10, 63) + if err != nil { + return 0, badStringError("bad Content-Length", cl) + } + return int64(n), nil + +} + +// finishAsyncByteRead finishes reading the 1-byte sniff +// from the ContentLength==0, Body!=nil case. +type finishAsyncByteRead struct { + tw *transferWriter +} + +func (fr finishAsyncByteRead) Read(p []byte) (n int, err error) { + if len(p) == 0 { + return + } + rres := <-fr.tw.ByteReadCh + n, err = rres.n, rres.err + if n == 1 { + p[0] = rres.b + } + if err == nil { + err = io.EOF + } + return +} + +var nopCloserType = reflect.TypeOf(io.NopCloser(nil)) +var nopCloserWriterToType = reflect.TypeOf(io.NopCloser(struct { + io.Reader + io.WriterTo +}{})) + +// unwrapNopCloser return the underlying reader and true if r is a NopCloser +// else it return false. +func unwrapNopCloser(r io.Reader) (underlyingReader io.Reader, isNopCloser bool) { + switch reflect.TypeOf(r) { + case nopCloserType, nopCloserWriterToType: + return reflect.ValueOf(r).Field(0).Interface().(io.Reader), true + default: + return nil, false + } +} + +// isKnownInMemoryReader reports whether r is a type known to not +// block on Read. Its caller uses this as an optional optimization to +// send fewer TCP packets. +func isKnownInMemoryReader(r io.Reader) bool { + switch r.(type) { + case *bytes.Reader, *bytes.Buffer, *strings.Reader: + return true + } + if r, ok := unwrapNopCloser(r); ok { + return isKnownInMemoryReader(r) + } + if r, ok := r.(*readTrackingBody); ok { + return isKnownInMemoryReader(r.ReadCloser) + } + return false +} + +// bufioFlushWriter is an io.Writer wrapper that flushes all writes +// on its wrapped writer if it's a *bufio.Writer. +type bufioFlushWriter struct{ w io.Writer } + +func (fw bufioFlushWriter) Write(p []byte) (n int, err error) { + n, err = fw.w.Write(p) + if bw, ok := fw.w.(*bufio.Writer); n > 0 && ok { + ferr := bw.Flush() + if ferr != nil && err == nil { + err = ferr + } + } + return +} diff --git a/src/net/http/transport.go b/src/net/http/transport.go new file mode 100644 index 0000000000..a5f0f49346 --- /dev/null +++ b/src/net/http/transport.go @@ -0,0 +1,22 @@ +// TINYGO: The following is copied and modified from Go 1.19.3 official implementation. + +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// HTTP client implementation. See RFC 7230 through 7235. +// +// This is the low-level Transport implementation of RoundTripper. +// The high-level interface is in client.go. + +package http + +import ( + "io" +) + +type readTrackingBody struct { + io.ReadCloser + didRead bool + didClose bool +} diff --git a/src/net/interface.go b/src/net/interface.go deleted file mode 100644 index 32206f78fc..0000000000 --- a/src/net/interface.go +++ /dev/null @@ -1,253 +0,0 @@ -// The following is copied from Go 1.16 official implementation. - -// Copyright 2011 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package net - -import ( - "errors" - "internal/itoa" - "sync" - "time" -) - -var ( - errInvalidInterface = errors.New("invalid network interface") - errInvalidInterfaceIndex = errors.New("invalid network interface index") - errInvalidInterfaceName = errors.New("invalid network interface name") - errNoSuchInterface = errors.New("no such network interface") - errNoSuchMulticastInterface = errors.New("no such multicast network interface") -) - -// Interface represents a mapping between network interface name -// and index. It also represents network interface facility -// information. -type Interface struct { - Index int // positive integer that starts at one, zero is never used - MTU int // maximum transmission unit - Name string // e.g., "en0", "lo0", "eth0.100" - HardwareAddr HardwareAddr // IEEE MAC-48, EUI-48 and EUI-64 form - Flags Flags // e.g., FlagUp, FlagLoopback, FlagMulticast -} - -type Flags uint - -const ( - FlagUp Flags = 1 << iota // interface is up - FlagBroadcast // interface supports broadcast access capability - FlagLoopback // interface is a loopback interface - FlagPointToPoint // interface belongs to a point-to-point link - FlagMulticast // interface supports multicast access capability -) - -var flagNames = []string{ - "up", - "broadcast", - "loopback", - "pointtopoint", - "multicast", -} - -func (f Flags) String() string { - s := "" - for i, name := range flagNames { - if f&(1<", if ip has length 0 // - dotted decimal ("192.0.2.1"), if ip is an IPv4 or IP4-mapped IPv6 address -// - IPv6 ("2001:db8::1"), if ip is a valid IPv6 address +// - IPv6 conforming to RFC 5952 ("2001:db8::1"), if ip is a valid IPv6 address // - the hexadecimal form of ip, without punctuation, if no other cases apply func (ip IP) String() string { p := ip @@ -528,6 +547,9 @@ func (n *IPNet) Network() string { return "ip+net" } // character and a mask expressed as hexadecimal form with no // punctuation like "198.51.100.0/c000ff00". func (n *IPNet) String() string { + if n == nil { + return "" + } nn, m := networkNumberAndMask(n) if nn == nil || m == nil { return "" @@ -557,6 +579,10 @@ func parseIPv4(s string) IP { if !ok || n > 0xFF { return nil } + if c > 1 && s[0] == '0' { + // Reject non-zero components with leading zeroes. + return nil + } s = s[c:] p[i] = byte(n) } diff --git a/src/net/iprawsock.go b/src/net/iprawsock.go index 8fac379160..8f82ec8e34 100644 --- a/src/net/iprawsock.go +++ b/src/net/iprawsock.go @@ -1,4 +1,4 @@ -// The following is copied from Go 1.16 official implementation. +// TINYGO: The following is copied and modified from Go 1.19.3 official implementation. // Copyright 2010 The Go Authors. All rights reserved. // Use of this source code is governed by a BSD-style @@ -6,8 +6,54 @@ package net +// BUG(mikio): On every POSIX platform, reads from the "ip4" network +// using the ReadFrom or ReadFromIP method might not return a complete +// IPv4 packet, including its header, even if there is space +// available. This can occur even in cases where Read or ReadMsgIP +// could return a complete packet. For this reason, it is recommended +// that you do not use these methods if it is important to receive a +// full packet. +// +// The Go 1 compatibility guidelines make it impossible for us to +// change the behavior of these methods; use Read or ReadMsgIP +// instead. + +// BUG(mikio): On JS and Plan 9, methods and functions related +// to IPConn are not implemented. + +// BUG(mikio): On Windows, the File method of IPConn is not +// implemented. + // IPAddr represents the address of an IP end point. type IPAddr struct { IP IP Zone string // IPv6 scoped addressing zone } + +// Network returns the address's network name, "ip". +func (a *IPAddr) Network() string { return "ip" } + +func (a *IPAddr) String() string { + if a == nil { + return "" + } + ip := ipEmptyString(a.IP) + if a.Zone != "" { + return ip + "%" + a.Zone + } + return ip +} + +func (a *IPAddr) isWildcard() bool { + if a == nil || a.IP == nil { + return true + } + return a.IP.IsUnspecified() +} + +func (a *IPAddr) opAddr() Addr { + if a == nil { + return nil + } + return a +} diff --git a/src/net/ipsock.go b/src/net/ipsock.go index 57ceaebf09..52d1f7dc2a 100644 --- a/src/net/ipsock.go +++ b/src/net/ipsock.go @@ -1,4 +1,4 @@ -// The following is copied from Go 1.16 official implementation. +// TINYGO: The following is copied and modified from Go 1.19.3 official implementation. // Copyright 2009 The Go Authors. All rights reserved. // Use of this source code is governed by a BSD-style @@ -6,7 +6,9 @@ package net -import "internal/bytealg" +import ( + "internal/bytealg" +) // SplitHostPort splits a network address of the form "host:port", // "host%zone:port", "[host]:port" or "[host%zone]:port" into host or diff --git a/src/net/mac.go b/src/net/mac.go index 2bad98c462..320b209d64 100644 --- a/src/net/mac.go +++ b/src/net/mac.go @@ -1,4 +1,4 @@ -// The following is copied from Go 1.16 official implementation. +// TINYGO: The following is copied from Go 1.19.3 official implementation. // Copyright 2011 The Go Authors. All rights reserved. // Use of this source code is governed by a BSD-style @@ -40,49 +40,49 @@ func (a HardwareAddr) String() string { // 0000.0000.fe80.0000.0000.0000.0200.5e10.0000.0001 func ParseMAC(s string) (hw HardwareAddr, err error) { if len(s) < 14 { - goto err + goto error } if s[2] == ':' || s[2] == '-' { if (len(s)+1)%3 != 0 { - goto err + goto error } n := (len(s) + 1) / 3 if n != 6 && n != 8 && n != 20 { - goto err + goto error } hw = make(HardwareAddr, n) for x, i := 0, 0; i < n; i++ { var ok bool if hw[i], ok = xtoi2(s[x:], s[2]); !ok { - goto err + goto error } x += 3 } } else if s[4] == '.' { if (len(s)+1)%5 != 0 { - goto err + goto error } n := 2 * (len(s) + 1) / 5 if n != 6 && n != 8 && n != 20 { - goto err + goto error } hw = make(HardwareAddr, n) for x, i := 0, 0; i < n; i += 2 { var ok bool if hw[i], ok = xtoi2(s[x:x+2], 0); !ok { - goto err + goto error } if hw[i+1], ok = xtoi2(s[x+2:], s[4]); !ok { - goto err + goto error } x += 5 } } else { - goto err + goto error } return hw, nil -err: +error: return nil, &AddrError{Err: "invalid MAC address", Addr: s} } diff --git a/src/net/mac_test.go b/src/net/mac_test.go new file mode 100644 index 0000000000..cad884fcf5 --- /dev/null +++ b/src/net/mac_test.go @@ -0,0 +1,109 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package net + +import ( + "reflect" + "strings" + "testing" +) + +var parseMACTests = []struct { + in string + out HardwareAddr + err string +}{ + // See RFC 7042, Section 2.1.1. + {"00:00:5e:00:53:01", HardwareAddr{0x00, 0x00, 0x5e, 0x00, 0x53, 0x01}, ""}, + {"00-00-5e-00-53-01", HardwareAddr{0x00, 0x00, 0x5e, 0x00, 0x53, 0x01}, ""}, + {"0000.5e00.5301", HardwareAddr{0x00, 0x00, 0x5e, 0x00, 0x53, 0x01}, ""}, + + // See RFC 7042, Section 2.2.2. + {"02:00:5e:10:00:00:00:01", HardwareAddr{0x02, 0x00, 0x5e, 0x10, 0x00, 0x00, 0x00, 0x01}, ""}, + {"02-00-5e-10-00-00-00-01", HardwareAddr{0x02, 0x00, 0x5e, 0x10, 0x00, 0x00, 0x00, 0x01}, ""}, + {"0200.5e10.0000.0001", HardwareAddr{0x02, 0x00, 0x5e, 0x10, 0x00, 0x00, 0x00, 0x01}, ""}, + + // See RFC 4391, Section 9.1.1. + { + "00:00:00:00:fe:80:00:00:00:00:00:00:02:00:5e:10:00:00:00:01", + HardwareAddr{ + 0x00, 0x00, 0x00, 0x00, + 0xfe, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x02, 0x00, 0x5e, 0x10, 0x00, 0x00, 0x00, 0x01, + }, + "", + }, + { + "00-00-00-00-fe-80-00-00-00-00-00-00-02-00-5e-10-00-00-00-01", + HardwareAddr{ + 0x00, 0x00, 0x00, 0x00, + 0xfe, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x02, 0x00, 0x5e, 0x10, 0x00, 0x00, 0x00, 0x01, + }, + "", + }, + { + "0000.0000.fe80.0000.0000.0000.0200.5e10.0000.0001", + HardwareAddr{ + 0x00, 0x00, 0x00, 0x00, + 0xfe, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x02, 0x00, 0x5e, 0x10, 0x00, 0x00, 0x00, 0x01, + }, + "", + }, + + {"ab:cd:ef:AB:CD:EF", HardwareAddr{0xab, 0xcd, 0xef, 0xab, 0xcd, 0xef}, ""}, + {"ab:cd:ef:AB:CD:EF:ab:cd", HardwareAddr{0xab, 0xcd, 0xef, 0xab, 0xcd, 0xef, 0xab, 0xcd}, ""}, + { + "ab:cd:ef:AB:CD:EF:ab:cd:ef:AB:CD:EF:ab:cd:ef:AB:CD:EF:ab:cd", + HardwareAddr{ + 0xab, 0xcd, 0xef, 0xab, + 0xcd, 0xef, 0xab, 0xcd, 0xef, 0xab, 0xcd, 0xef, + 0xab, 0xcd, 0xef, 0xab, 0xcd, 0xef, 0xab, 0xcd, + }, + "", + }, + + {"01.02.03.04.05.06", nil, "invalid MAC address"}, + {"01:02:03:04:05:06:", nil, "invalid MAC address"}, + {"x1:02:03:04:05:06", nil, "invalid MAC address"}, + {"01002:03:04:05:06", nil, "invalid MAC address"}, + {"01:02003:04:05:06", nil, "invalid MAC address"}, + {"01:02:03004:05:06", nil, "invalid MAC address"}, + {"01:02:03:04005:06", nil, "invalid MAC address"}, + {"01:02:03:04:05006", nil, "invalid MAC address"}, + {"01-02:03:04:05:06", nil, "invalid MAC address"}, + {"01:02-03-04-05-06", nil, "invalid MAC address"}, + {"0123:4567:89AF", nil, "invalid MAC address"}, + {"0123-4567-89AF", nil, "invalid MAC address"}, +} + +func TestParseMAC(t *testing.T) { + match := func(err error, s string) bool { + if s == "" { + return err == nil + } + return err != nil && strings.Contains(err.Error(), s) + } + + for i, tt := range parseMACTests { + out, err := ParseMAC(tt.in) + if !reflect.DeepEqual(out, tt.out) || !match(err, tt.err) { + t.Errorf("ParseMAC(%q) = %v, %v, want %v, %v", tt.in, out, err, tt.out, tt.err) + } + if tt.err == "" { + // Verify that serialization works too, and that it round-trips. + s := out.String() + out2, err := ParseMAC(s) + if err != nil { + t.Errorf("%d. ParseMAC(%q) = %v", i, s, err) + continue + } + if !reflect.DeepEqual(out2, out) { + t.Errorf("%d. ParseMAC(%q) = %v, want %v", i, s, out2, out) + } + } + } +} diff --git a/src/net/net.go b/src/net/net.go index db4d8f117f..2e7f9054c9 100644 --- a/src/net/net.go +++ b/src/net/net.go @@ -1,4 +1,4 @@ -// The following is copied from Go 1.18 official implementation. +// TINYGO: The following is copied and modified from Go 1.19.3 official implementation. // Copyright 2009 The Go Authors. All rights reserved. // Use of this source code is governed by a BSD-style @@ -7,7 +7,6 @@ package net import ( - "io" "time" ) @@ -81,10 +80,6 @@ type Conn interface { SetWriteDeadline(t time.Time) error } -type conn struct { - // -} - // A Listener is a generic network listener for stream-oriented protocols. // // Multiple goroutines may invoke methods on a Listener simultaneously. @@ -193,87 +188,3 @@ func (e *AddrError) Error() string { } return s } - -func (e *AddrError) Timeout() bool { return false } -func (e *AddrError) Temporary() bool { return false } - -// ErrClosed is the error returned by an I/O call on a network -// connection that has already been closed, or that is closed by -// another goroutine before the I/O is completed. This may be wrapped -// in another error, and should normally be tested using -// errors.Is(err, net.ErrClosed). -var ErrClosed = errClosed - -// buffersWriter is the interface implemented by Conns that support a -// "writev"-like batch write optimization. -// writeBuffers should fully consume and write all chunks from the -// provided Buffers, else it should report a non-nil error. -type buffersWriter interface { - writeBuffers(*Buffers) (int64, error) -} - -// Buffers contains zero or more runs of bytes to write. -// -// On certain machines, for certain types of connections, this is -// optimized into an OS-specific batch write operation (such as -// "writev"). -type Buffers [][]byte - -var ( - _ io.WriterTo = (*Buffers)(nil) - _ io.Reader = (*Buffers)(nil) -) - -// WriteTo writes contents of the buffers to w. -// -// WriteTo implements io.WriterTo for Buffers. -// -// WriteTo modifies the slice v as well as v[i] for 0 <= i < len(v), -// but does not modify v[i][j] for any i, j. -func (v *Buffers) WriteTo(w io.Writer) (n int64, err error) { - if wv, ok := w.(buffersWriter); ok { - return wv.writeBuffers(v) - } - for _, b := range *v { - nb, err := w.Write(b) - n += int64(nb) - if err != nil { - v.consume(n) - return n, err - } - } - v.consume(n) - return n, nil -} - -// Read from the buffers. -// -// Read implements io.Reader for Buffers. -// -// Read modifies the slice v as well as v[i] for 0 <= i < len(v), -// but does not modify v[i][j] for any i, j. -func (v *Buffers) Read(p []byte) (n int, err error) { - for len(p) > 0 && len(*v) > 0 { - n0 := copy(p, (*v)[0]) - v.consume(int64(n0)) - p = p[n0:] - n += n0 - } - if len(*v) == 0 { - err = io.EOF - } - return -} - -func (v *Buffers) consume(n int64) { - for len(*v) > 0 { - ln0 := int64(len((*v)[0])) - if ln0 > n { - (*v)[0] = (*v)[0][n:] - return - } - n -= ln0 - (*v)[0] = nil - *v = (*v)[1:] - } -} diff --git a/src/net/netdev.go b/src/net/netdev.go new file mode 100644 index 0000000000..2d294c15b3 --- /dev/null +++ b/src/net/netdev.go @@ -0,0 +1,46 @@ +package net + +import ( + "time" +) + +// netdev is the current netdev, set by the application with useNetdev() +var netdev netdever + +// (useNetdev is go:linkname'd from tinygo/drivers package) +func useNetdev(dev netdever) { + netdev = dev +} + +// Netdev is TinyGo's network device driver model. Network drivers implement +// the netdever interface, providing a common network I/O interface to TinyGo's +// "net" package. The interface is modeled after the BSD socket interface. +// net.Conn implementations (TCPConn, UDPConn, and TLSConn) use the netdev +// interface for device I/O access. +// +// A netdever is passed to the "net" package using net.useNetdev(). +// +// Just like a net.Conn, multiple goroutines may invoke methods on a netdever +// simultaneously. +// +// NOTE: The netdever interface is mirrored in drivers/netdev.go. +// NOTE: If making changes to this interface, mirror the changes in +// NOTE: drivers/netdev.go, and visa-versa. + +type netdever interface { + + // GetHostByName returns the IP address of either a hostname or IPv4 + // address in standard dot notation + GetHostByName(name string) (IP, error) + + // Berkely Sockets-like interface, Go-ified. See man page for socket(2), etc. + Socket(domain int, stype int, protocol int) (int, error) + Bind(sockfd int, ip IP, port int) error + Connect(sockfd int, host string, ip IP, port int) error + Listen(sockfd int, backlog int) error + Accept(sockfd int, ip IP, port int) (int, error) + Send(sockfd int, buf []byte, flags int, timeout time.Duration) (int, error) + Recv(sockfd int, buf []byte, flags int, timeout time.Duration) (int, error) + Close(sockfd int) error + SetSockOpt(sockfd int, level int, opt int, value interface{}) error +} diff --git a/src/net/parse.go b/src/net/parse.go index f1f2ccb5c9..b263271fc7 100644 --- a/src/net/parse.go +++ b/src/net/parse.go @@ -1,11 +1,121 @@ -// The following is copied from Go 1.16 official implementation. +// TINYGO: The following is copied from Go 1.19.3 official implementation. // Copyright 2009 The Go Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. +// Simple file i/o and string manipulation, to avoid +// depending on strconv and bufio and strings. + package net +import ( + "internal/bytealg" + "io" + "os" + "time" +) + +type file struct { + file *os.File + data []byte + atEOF bool +} + +func (f *file) close() { f.file.Close() } + +func (f *file) getLineFromData() (s string, ok bool) { + data := f.data + i := 0 + for i = 0; i < len(data); i++ { + if data[i] == '\n' { + s = string(data[0:i]) + ok = true + // move data + i++ + n := len(data) - i + copy(data[0:], data[i:]) + f.data = data[0:n] + return + } + } + if f.atEOF && len(f.data) > 0 { + // EOF, return all we have + s = string(data) + f.data = f.data[0:0] + ok = true + } + return +} + +func (f *file) readLine() (s string, ok bool) { + if s, ok = f.getLineFromData(); ok { + return + } + if len(f.data) < cap(f.data) { + ln := len(f.data) + n, err := io.ReadFull(f.file, f.data[ln:cap(f.data)]) + if n >= 0 { + f.data = f.data[0 : ln+n] + } + if err == io.EOF || err == io.ErrUnexpectedEOF { + f.atEOF = true + } + } + s, ok = f.getLineFromData() + return +} + +func open(name string) (*file, error) { + fd, err := os.Open(name) + if err != nil { + return nil, err + } + return &file{fd, make([]byte, 0, 64*1024), false}, nil +} + +func stat(name string) (mtime time.Time, size int64, err error) { + st, err := os.Stat(name) + if err != nil { + return time.Time{}, 0, err + } + return st.ModTime(), st.Size(), nil +} + +// Count occurrences in s of any bytes in t. +func countAnyByte(s string, t string) int { + n := 0 + for i := 0; i < len(s); i++ { + if bytealg.IndexByteString(t, s[i]) >= 0 { + n++ + } + } + return n +} + +// Split s at any bytes in t. +func splitAtBytes(s string, t string) []string { + a := make([]string, 1+countAnyByte(s, t)) + n := 0 + last := 0 + for i := 0; i < len(s); i++ { + if bytealg.IndexByteString(t, s[i]) >= 0 { + if last < i { + a[n] = s[last:i] + n++ + } + last = i + 1 + } + } + if last < len(s) { + a[n] = s[last:] + n++ + } + return a[0:n] +} + +func getFields(s string) []string { return splitAtBytes(s, " \r\t\n") } + // Bigger than we need, not too big to worry about overflow const big = 0xFFFFFF @@ -78,6 +188,17 @@ func appendHex(dst []byte, i uint32) []byte { return dst } +// Number of occurrences of b in s. +func count(s string, b byte) int { + n := 0 + for i := 0; i < len(s); i++ { + if s[i] == b { + n++ + } + } + return n +} + // Index of rightmost occurrence of b in s. func last(s string, b byte) int { i := len(s) @@ -88,3 +209,137 @@ func last(s string, b byte) int { } return i } + +// hasUpperCase tells whether the given string contains at least one upper-case. +func hasUpperCase(s string) bool { + for i := range s { + if 'A' <= s[i] && s[i] <= 'Z' { + return true + } + } + return false +} + +// lowerASCIIBytes makes x ASCII lowercase in-place. +func lowerASCIIBytes(x []byte) { + for i, b := range x { + if 'A' <= b && b <= 'Z' { + x[i] += 'a' - 'A' + } + } +} + +// lowerASCII returns the ASCII lowercase version of b. +func lowerASCII(b byte) byte { + if 'A' <= b && b <= 'Z' { + return b + ('a' - 'A') + } + return b +} + +// trimSpace returns x without any leading or trailing ASCII whitespace. +func trimSpace(x []byte) []byte { + for len(x) > 0 && isSpace(x[0]) { + x = x[1:] + } + for len(x) > 0 && isSpace(x[len(x)-1]) { + x = x[:len(x)-1] + } + return x +} + +// isSpace reports whether b is an ASCII space character. +func isSpace(b byte) bool { + return b == ' ' || b == '\t' || b == '\n' || b == '\r' +} + +// removeComment returns line, removing any '#' byte and any following +// bytes. +func removeComment(line []byte) []byte { + if i := bytealg.IndexByte(line, '#'); i != -1 { + return line[:i] + } + return line +} + +// foreachLine runs fn on each line of x. +// Each line (except for possibly the last) ends in '\n'. +// It returns the first non-nil error returned by fn. +func foreachLine(x []byte, fn func(line []byte) error) error { + for len(x) > 0 { + nl := bytealg.IndexByte(x, '\n') + if nl == -1 { + return fn(x) + } + line := x[:nl+1] + x = x[nl+1:] + if err := fn(line); err != nil { + return err + } + } + return nil +} + +// foreachField runs fn on each non-empty run of non-space bytes in x. +// It returns the first non-nil error returned by fn. +func foreachField(x []byte, fn func(field []byte) error) error { + x = trimSpace(x) + for len(x) > 0 { + sp := bytealg.IndexByte(x, ' ') + if sp == -1 { + return fn(x) + } + if field := trimSpace(x[:sp]); len(field) > 0 { + if err := fn(field); err != nil { + return err + } + } + x = trimSpace(x[sp+1:]) + } + return nil +} + +// stringsHasSuffix is strings.HasSuffix. It reports whether s ends in +// suffix. +func stringsHasSuffix(s, suffix string) bool { + return len(s) >= len(suffix) && s[len(s)-len(suffix):] == suffix +} + +// stringsHasSuffixFold reports whether s ends in suffix, +// ASCII-case-insensitively. +func stringsHasSuffixFold(s, suffix string) bool { + return len(s) >= len(suffix) && stringsEqualFold(s[len(s)-len(suffix):], suffix) +} + +// stringsHasPrefix is strings.HasPrefix. It reports whether s begins with prefix. +func stringsHasPrefix(s, prefix string) bool { + return len(s) >= len(prefix) && s[:len(prefix)] == prefix +} + +// stringsEqualFold is strings.EqualFold, ASCII only. It reports whether s and t +// are equal, ASCII-case-insensitively. +func stringsEqualFold(s, t string) bool { + if len(s) != len(t) { + return false + } + for i := 0; i < len(s); i++ { + if lowerASCII(s[i]) != lowerASCII(t[i]) { + return false + } + } + return true +} + +func readFull(r io.Reader) (all []byte, err error) { + buf := make([]byte, 1024) + for { + n, err := r.Read(buf) + all = append(all, buf[:n]...) + if err == io.EOF { + return all, nil + } + if err != nil { + return nil, err + } + } +} diff --git a/src/net/pipe.go b/src/net/pipe.go index 02dd07cf9a..238da0c727 100644 --- a/src/net/pipe.go +++ b/src/net/pipe.go @@ -1,4 +1,4 @@ -// The following is copied from Go 1.19.2 official implementation. +// The following is copied from Go 1.19.3 official implementation. // Copyright 2010 The Go Authors. All rights reserved. // Use of this source code is governed by a BSD-style diff --git a/src/net/pipe_test.go b/src/net/pipe_test.go deleted file mode 100644 index 7978fc6aa0..0000000000 --- a/src/net/pipe_test.go +++ /dev/null @@ -1,48 +0,0 @@ -// The following is copied from Go 1.19.2 official implementation. - -// Copyright 2010 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package net - -import ( - "io" - "testing" - "time" -) - -func TestPipe(t *testing.T) { - testConn(t, func() (c1, c2 Conn, stop func(), err error) { - c1, c2 = Pipe() - stop = func() { - c1.Close() - c2.Close() - } - return - }) -} - -func TestPipeCloseError(t *testing.T) { - c1, c2 := Pipe() - c1.Close() - - if _, err := c1.Read(nil); err != io.ErrClosedPipe { - t.Errorf("c1.Read() = %v, want io.ErrClosedPipe", err) - } - if _, err := c1.Write(nil); err != io.ErrClosedPipe { - t.Errorf("c1.Write() = %v, want io.ErrClosedPipe", err) - } - if err := c1.SetDeadline(time.Time{}); err != io.ErrClosedPipe { - t.Errorf("c1.SetDeadline() = %v, want io.ErrClosedPipe", err) - } - if _, err := c2.Read(nil); err != io.EOF { - t.Errorf("c2.Read() = %v, want io.EOF", err) - } - if _, err := c2.Write(nil); err != io.ErrClosedPipe { - t.Errorf("c2.Write() = %v, want io.ErrClosedPipe", err) - } - if err := c2.SetDeadline(time.Time{}); err != io.ErrClosedPipe { - t.Errorf("c2.SetDeadline() = %v, want io.ErrClosedPipe", err) - } -} diff --git a/src/net/tcpsock.go b/src/net/tcpsock.go index 4af06a857e..f5b1b65eff 100644 --- a/src/net/tcpsock.go +++ b/src/net/tcpsock.go @@ -1,11 +1,294 @@ +// TINYGO: The following is copied and modified from Go 1.19.3 official implementation. + +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + package net +import ( + "fmt" + "internal/itoa" + "net/netip" + "strconv" + "syscall" + "time" +) + +// TCPAddr represents the address of a TCP end point. +type TCPAddr struct { + IP IP + Port int + Zone string // IPv6 scoped addressing zone +} + +// AddrPort returns the TCPAddr a as a netip.AddrPort. +// +// If a.Port does not fit in a uint16, it's silently truncated. +// +// If a is nil, a zero value is returned. +func (a *TCPAddr) AddrPort() netip.AddrPort { + if a == nil { + return netip.AddrPort{} + } + na, _ := netip.AddrFromSlice(a.IP) + na = na.WithZone(a.Zone) + return netip.AddrPortFrom(na, uint16(a.Port)) +} + +// Network returns the address's network name, "tcp". +func (a *TCPAddr) Network() string { return "tcp" } + +func (a *TCPAddr) String() string { + if a == nil { + return "" + } + ip := ipEmptyString(a.IP) + if a.Zone != "" { + return JoinHostPort(ip+"%"+a.Zone, itoa.Itoa(a.Port)) + } + return JoinHostPort(ip, itoa.Itoa(a.Port)) +} + +func (a *TCPAddr) isWildcard() bool { + if a == nil || a.IP == nil { + return true + } + return a.IP.IsUnspecified() +} + +func (a *TCPAddr) opAddr() Addr { + if a == nil { + return nil + } + return a +} + +// ResolveTCPAddr returns an address of TCP end point. +// +// The network must be a TCP network name. +// +// If the host in the address parameter is not a literal IP address or +// the port is not a literal port number, ResolveTCPAddr resolves the +// address to an address of TCP end point. +// Otherwise, it parses the address as a pair of literal IP address +// and port number. +// The address parameter can use a host name, but this is not +// recommended, because it will return at most one of the host name's +// IP addresses. +// +// See func Dial for a description of the network and address +// parameters. +func ResolveTCPAddr(network, address string) (*TCPAddr, error) { + + switch network { + case "tcp", "tcp4": + default: + return nil, fmt.Errorf("Network '%s' not supported", network) + } + + // TINYGO: Use netdev resolver + + host, sport, err := SplitHostPort(address) + if err != nil { + return nil, err + } + + port, err := strconv.Atoi(sport) + if err != nil { + return nil, fmt.Errorf("Error parsing port '%s' in address: %s", + sport, err) + } + + if host == "" { + return &TCPAddr{Port: port}, nil + } + + ip, err := netdev.GetHostByName(host) + if err != nil { + return nil, fmt.Errorf("Lookup of host name '%s' failed: %s", host, err) + } + + return &TCPAddr{IP: ip, Port: port}, nil +} + // TCPConn is an implementation of the Conn interface for TCP network // connections. type TCPConn struct { - conn + fd int + laddr *TCPAddr + raddr *TCPAddr + readDeadline time.Time + writeDeadline time.Time +} + +// DialTCP acts like Dial for TCP networks. +// +// The network must be a TCP network name; see func Dial for details. +// +// If laddr is nil, a local address is automatically chosen. +// If the IP field of raddr is nil or an unspecified IP address, the +// local system is assumed. +func DialTCP(network string, laddr, raddr *TCPAddr) (*TCPConn, error) { + + switch network { + case "tcp", "tcp4": + default: + return nil, fmt.Errorf("Network '%s' not supported", network) + } + + // TINYGO: Use netdev to create TCP socket and connect + + if raddr == nil { + raddr = &TCPAddr{} + } + + if raddr.IP.IsUnspecified() { + return nil, fmt.Errorf("Sorry, localhost isn't available on Tinygo") + } + + fd, err := netdev.Socket(syscall.AF_INET, syscall.SOCK_STREAM, syscall.IPPROTO_TCP) + if err != nil { + return nil, err + } + + if err = netdev.Connect(fd, "", raddr.IP, raddr.Port); err != nil { + netdev.Close(fd) + return nil, err + } + + return &TCPConn{ + fd: fd, + laddr: laddr, + raddr: raddr, + }, nil +} + +// TINYGO: Use netdev for Conn methods: Read = Recv, Write = Send, etc. + +func (c *TCPConn) Read(b []byte) (int, error) { + var timeout time.Duration + + now := time.Now() + + if !c.readDeadline.IsZero() { + if c.readDeadline.Before(now) { + return 0, fmt.Errorf("Read deadline expired") + } else { + timeout = c.readDeadline.Sub(now) + } + } + + n, err := netdev.Recv(c.fd, b, 0, timeout) + // Turn the -1 socket error into 0 and let err speak for error + if n < 0 { + n = 0 + } + return n, err +} + +func (c *TCPConn) Write(b []byte) (int, error) { + var timeout time.Duration + + now := time.Now() + + if !c.writeDeadline.IsZero() { + if c.writeDeadline.Before(now) { + return 0, fmt.Errorf("Write deadline expired") + } else { + timeout = c.writeDeadline.Sub(now) + } + } + + n, err := netdev.Send(c.fd, b, 0, timeout) + // Turn the -1 socket error into 0 and let err speak for error + if n < 0 { + n = 0 + } + return n, err +} + +func (c *TCPConn) Close() error { + return netdev.Close(c.fd) +} + +func (c *TCPConn) LocalAddr() Addr { + return c.laddr +} + +func (c *TCPConn) RemoteAddr() Addr { + return c.raddr +} + +func (c *TCPConn) SetDeadline(t time.Time) error { + c.readDeadline = t + c.writeDeadline = t + return nil +} + +func (c *TCPConn) SetKeepAlive(keepalive bool) error { + return netdev.SetSockOpt(c.fd, syscall.SOL_SOCKET, syscall.SO_KEEPALIVE, keepalive) +} + +func (c *TCPConn) SetKeepAlivePeriod(d time.Duration) error { + // Units are 1/2 seconds + return netdev.SetSockOpt(c.fd, syscall.SOL_TCP, syscall.TCP_KEEPINTVL, 2*d.Seconds()) +} + +func (c *TCPConn) SetReadDeadline(t time.Time) error { + c.readDeadline = t + return nil +} + +func (c *TCPConn) SetWriteDeadline(t time.Time) error { + c.writeDeadline = t + return nil } func (c *TCPConn) CloseWrite() error { - return &OpError{"close", "", nil, nil, ErrNotImplemented} + return fmt.Errorf("CloseWrite not implemented") +} + +type listener struct { + fd int + laddr *TCPAddr +} + +func (l *listener) Accept() (Conn, error) { + fd, err := netdev.Accept(l.fd, IP{}, 0) + if err != nil { + return nil, err + } + + return &TCPConn{ + fd: fd, + laddr: l.laddr, + }, nil +} + +func (l *listener) Close() error { + return netdev.Close(l.fd) +} + +func (l *listener) Addr() Addr { + return l.laddr +} + +func listenTCP(laddr *TCPAddr) (Listener, error) { + fd, err := netdev.Socket(syscall.AF_INET, syscall.SOCK_STREAM, syscall.IPPROTO_TCP) + if err != nil { + return nil, err + } + + err = netdev.Bind(fd, laddr.IP, laddr.Port) + if err != nil { + return nil, err + } + + err = netdev.Listen(fd, 5) + if err != nil { + return nil, err + } + + return &listener{fd: fd, laddr: laddr}, nil } diff --git a/src/net/tlssock.go b/src/net/tlssock.go new file mode 100644 index 0000000000..b5653edd6a --- /dev/null +++ b/src/net/tlssock.go @@ -0,0 +1,154 @@ +// TINYGO: The following is copied and modified from Go 1.19.3 official implementation. + +// Copyright 2010 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// TLS low level connection and record layer + +package net + +import ( + "fmt" + "strconv" + "syscall" + "time" +) + +func DialTLS(addr string) (*TLSConn, error) { + + host, sport, err := SplitHostPort(addr) + if err != nil { + return nil, err + } + + port, err := strconv.Atoi(sport) + if err != nil { + return nil, err + } + + if port == 0 { + port = 443 + } + + fd, err := netdev.Socket(syscall.AF_INET, syscall.SOCK_STREAM, syscall.IPPROTO_TLS) + if err != nil { + return nil, err + } + + if err = netdev.Connect(fd, host, IP{}, port); err != nil { + netdev.Close(fd) + return nil, err + } + + return &TLSConn{ + fd: fd, + }, nil +} + +// A TLSConn represents a secured connection. +// It implements the net.Conn interface. +type TLSConn struct { + fd int + readDeadline time.Time + writeDeadline time.Time +} + +// Access to net.Conn methods. +// Cannot just embed net.Conn because that would +// export the struct field too. + +// LocalAddr returns the local network address. +func (c *TLSConn) LocalAddr() Addr { + // TODO + return nil +} + +// RemoteAddr returns the remote network address. +func (c *TLSConn) RemoteAddr() Addr { + // TODO + return nil +} + +// SetDeadline sets the read and write deadlines associated with the connection. +// A zero value for t means Read and Write will not time out. +// After a Write has timed out, the TLS state is corrupt and all future writes will return the same error. +func (c *TLSConn) SetDeadline(t time.Time) error { + c.readDeadline = t + c.writeDeadline = t + return nil +} + +// SetReadDeadline sets the read deadline on the underlying connection. +// A zero value for t means Read will not time out. +func (c *TLSConn) SetReadDeadline(t time.Time) error { + c.readDeadline = t + return nil +} + +// SetWriteDeadline sets the write deadline on the underlying connection. +// A zero value for t means Write will not time out. +// After a Write has timed out, the TLS state is corrupt and all future writes will return the same error. +func (c *TLSConn) SetWriteDeadline(t time.Time) error { + c.writeDeadline = t + return nil +} + +func (c *TLSConn) Read(b []byte) (int, error) { + var timeout time.Duration + + now := time.Now() + + if !c.readDeadline.IsZero() { + if c.readDeadline.Before(now) { + return 0, fmt.Errorf("Read deadline expired") + } else { + timeout = c.readDeadline.Sub(now) + } + } + + n, err := netdev.Recv(c.fd, b, 0, timeout) + // Turn the -1 socket error into 0 and let err speak for error + if n < 0 { + n = 0 + } + return n, err +} + +func (c *TLSConn) Write(b []byte) (int, error) { + var timeout time.Duration + + now := time.Now() + + if !c.writeDeadline.IsZero() { + if c.writeDeadline.Before(now) { + return 0, fmt.Errorf("Write deadline expired") + } else { + timeout = c.writeDeadline.Sub(now) + } + } + + n, err := netdev.Send(c.fd, b, 0, timeout) + // Turn the -1 socket error into 0 and let err speak for error + if n < 0 { + n = 0 + } + return n, err +} + +func (c *TLSConn) Close() error { + return netdev.Close(c.fd) +} + +// Handshake runs the client or server handshake +// protocol if it has not yet been run. +// +// Most uses of this package need not call Handshake explicitly: the +// first Read or Write will call it automatically. +// +// For control over canceling or setting a timeout on a handshake, use +// HandshakeContext or the Dialer's DialContext method instead. +func (c *TLSConn) Handshake() error { + panic("TLSConn.Handshake() not implemented") + return nil +} diff --git a/src/net/udpsock.go b/src/net/udpsock.go new file mode 100644 index 0000000000..5ffe697e7f --- /dev/null +++ b/src/net/udpsock.go @@ -0,0 +1,266 @@ +// TINYGO: The following is copied and modified from Go 1.19.3 official implementation. + +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package net + +import ( + "fmt" + "internal/itoa" + "net/netip" + "strconv" + "syscall" + "time" +) + +// UDPAddr represents the address of a UDP end point. +type UDPAddr struct { + IP IP + Port int + Zone string // IPv6 scoped addressing zone +} + +// AddrPort returns the UDPAddr a as a netip.AddrPort. +// +// If a.Port does not fit in a uint16, it's silently truncated. +// +// If a is nil, a zero value is returned. +func (a *UDPAddr) AddrPort() netip.AddrPort { + if a == nil { + return netip.AddrPort{} + } + na, _ := netip.AddrFromSlice(a.IP) + na = na.WithZone(a.Zone) + return netip.AddrPortFrom(na, uint16(a.Port)) +} + +// Network returns the address's network name, "udp". +func (a *UDPAddr) Network() string { return "udp" } + +func (a *UDPAddr) String() string { + if a == nil { + return "" + } + ip := ipEmptyString(a.IP) + if a.Zone != "" { + return JoinHostPort(ip+"%"+a.Zone, itoa.Itoa(a.Port)) + } + return JoinHostPort(ip, itoa.Itoa(a.Port)) +} + +func (a *UDPAddr) isWildcard() bool { + if a == nil || a.IP == nil { + return true + } + return a.IP.IsUnspecified() +} + +func (a *UDPAddr) opAddr() Addr { + if a == nil { + return nil + } + return a +} + +// ResolveUDPAddr returns an address of UDP end point. +// +// The network must be a UDP network name. +// +// If the host in the address parameter is not a literal IP address or +// the port is not a literal port number, ResolveUDPAddr resolves the +// address to an address of UDP end point. +// Otherwise, it parses the address as a pair of literal IP address +// and port number. +// The address parameter can use a host name, but this is not +// recommended, because it will return at most one of the host name's +// IP addresses. +// +// See func Dial for a description of the network and address +// parameters. +func ResolveUDPAddr(network, address string) (*UDPAddr, error) { + + switch network { + case "udp", "udp4": + default: + return nil, fmt.Errorf("Network '%s' not supported", network) + } + + // TINYGO: Use netdev resolver + + host, sport, err := SplitHostPort(address) + if err != nil { + return nil, err + } + + port, err := strconv.Atoi(sport) + if err != nil { + return nil, fmt.Errorf("Error parsing port '%s' in address: %s", + sport, err) + } + + if host == "" { + return &UDPAddr{Port: port}, nil + } + + ip, err := netdev.GetHostByName(host) + if err != nil { + return nil, fmt.Errorf("Lookup of host name '%s' failed: %s", host, err) + } + + return &UDPAddr{IP: ip, Port: port}, nil +} + +// UDPConn is the implementation of the Conn and PacketConn interfaces +// for UDP network connections. +type UDPConn struct { + fd int + laddr *UDPAddr + raddr *UDPAddr + readDeadline time.Time + writeDeadline time.Time +} + +// Use IANA RFC 6335 port range 49152–65535 for ephemeral (dynamic) ports +var eport = int32(49151) + +func ephemeralPort() int { + // TODO: this is racy, if concurrent DialUDPs; use atomic? + if eport == int32(65535) { + eport = int32(49151) + } else { + eport++ + } + return int(eport) +} + +// DialUDP acts like Dial for UDP networks. +// +// The network must be a UDP network name; see func Dial for details. +// +// If laddr is nil, a local address is automatically chosen. +// If the IP field of raddr is nil or an unspecified IP address, the +// local system is assumed. +func DialUDP(network string, laddr, raddr *UDPAddr) (*UDPConn, error) { + switch network { + case "udp", "udp4": + default: + return nil, fmt.Errorf("Network '%s' not supported", network) + } + + // TINYGO: Use netdev to create UDP socket and connect + + if laddr == nil { + laddr = &UDPAddr{} + } + + if raddr == nil { + raddr = &UDPAddr{} + } + + if raddr.IP.IsUnspecified() { + return nil, fmt.Errorf("Sorry, localhost isn't available on Tinygo") + } + + // If no port was given, grab an ephemeral port + if laddr.Port == 0 { + laddr.Port = ephemeralPort() + } + + fd, err := netdev.Socket(syscall.AF_INET, syscall.SOCK_DGRAM, syscall.IPPROTO_UDP) + if err != nil { + return nil, err + } + + // Local bind + err = netdev.Bind(fd, laddr.IP, laddr.Port) + if err != nil { + netdev.Close(fd) + return nil, err + } + + // Remote connect + if err = netdev.Connect(fd, "", raddr.IP, raddr.Port); err != nil { + netdev.Close(fd) + return nil, err + } + + return &UDPConn{ + fd: fd, + laddr: laddr, + raddr: raddr, + }, nil +} + +// TINYGO: Use netdev for Conn methods: Read = Recv, Write = Send, etc. + +func (c *UDPConn) Read(b []byte) (int, error) { + var timeout time.Duration + + now := time.Now() + + if !c.readDeadline.IsZero() { + if c.readDeadline.Before(now) { + return 0, fmt.Errorf("Read deadline expired") + } else { + timeout = c.readDeadline.Sub(now) + } + } + + n, err := netdev.Recv(c.fd, b, 0, timeout) + // Turn the -1 socket error into 0 and let err speak for error + if n < 0 { + n = 0 + } + return n, err +} + +func (c *UDPConn) Write(b []byte) (int, error) { + var timeout time.Duration + + now := time.Now() + + if !c.writeDeadline.IsZero() { + if c.writeDeadline.Before(now) { + return 0, fmt.Errorf("Write deadline expired") + } else { + timeout = c.writeDeadline.Sub(now) + } + } + + n, err := netdev.Send(c.fd, b, 0, timeout) + // Turn the -1 socket error into 0 and let err speak for error + if n < 0 { + n = 0 + } + return n, err +} + +func (c *UDPConn) Close() error { + return netdev.Close(c.fd) +} + +func (c *UDPConn) LocalAddr() Addr { + return c.laddr +} + +func (c *UDPConn) RemoteAddr() Addr { + return c.raddr +} + +func (c *UDPConn) SetDeadline(t time.Time) error { + c.readDeadline = t + c.writeDeadline = t + return nil +} + +func (c *UDPConn) SetReadDeadline(t time.Time) error { + c.readDeadline = t + return nil +} + +func (c *UDPConn) SetWriteDeadline(t time.Time) error { + c.writeDeadline = t + return nil +} diff --git a/src/net/writev_test.go b/src/net/writev_test.go deleted file mode 100644 index 3a2c3efa3c..0000000000 --- a/src/net/writev_test.go +++ /dev/null @@ -1,132 +0,0 @@ -// The following is copied from Go 1.17 official implementation and -// modified to accommodate TinyGo. - -// Copyright 2016 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package net - -import ( - "bytes" - "fmt" - "io" - "reflect" - "testing" -) - -func TestBuffers_read(t *testing.T) { - const story = "once upon a time in Gopherland ... " - buffers := Buffers{ - []byte("once "), - []byte("upon "), - []byte("a "), - []byte("time "), - []byte("in "), - []byte("Gopherland ... "), - } - got, err := io.ReadAll(&buffers) - if err != nil { - t.Fatal(err) - } - if string(got) != story { - t.Errorf("read %q; want %q", got, story) - } - if len(buffers) != 0 { - t.Errorf("len(buffers) = %d; want 0", len(buffers)) - } -} - -func TestBuffers_consume(t *testing.T) { - tests := []struct { - in Buffers - consume int64 - want Buffers - }{ - { - in: Buffers{[]byte("foo"), []byte("bar")}, - consume: 0, - want: Buffers{[]byte("foo"), []byte("bar")}, - }, - { - in: Buffers{[]byte("foo"), []byte("bar")}, - consume: 2, - want: Buffers{[]byte("o"), []byte("bar")}, - }, - { - in: Buffers{[]byte("foo"), []byte("bar")}, - consume: 3, - want: Buffers{[]byte("bar")}, - }, - { - in: Buffers{[]byte("foo"), []byte("bar")}, - consume: 4, - want: Buffers{[]byte("ar")}, - }, - { - in: Buffers{nil, nil, nil, []byte("bar")}, - consume: 1, - want: Buffers{[]byte("ar")}, - }, - { - in: Buffers{nil, nil, nil, []byte("foo")}, - consume: 0, - want: Buffers{[]byte("foo")}, - }, - { - in: Buffers{nil, nil, nil}, - consume: 0, - want: Buffers{}, - }, - } - for i, tt := range tests { - in := tt.in - in.consume(tt.consume) - if !reflect.DeepEqual(in, tt.want) { - t.Errorf("%d. after consume(%d) = %+v, want %+v", i, tt.consume, in, tt.want) - } - } -} - -func TestBuffers_WriteTo(t *testing.T) { - for _, name := range []string{"WriteTo", "Copy"} { - for _, size := range []int{0, 10, 1023, 1024, 1025} { - t.Run(fmt.Sprintf("%s/%d", name, size), func(t *testing.T) { - testBuffer_writeTo(t, size, name == "Copy") - }) - } - } -} - -func testBuffer_writeTo(t *testing.T, chunks int, useCopy bool) { - var want bytes.Buffer - for i := 0; i < chunks; i++ { - want.WriteByte(byte(i)) - } - - var b bytes.Buffer - buffers := make(Buffers, chunks) - for i := range buffers { - buffers[i] = want.Bytes()[i : i+1] - } - var n int64 - var err error - if useCopy { - n, err = io.Copy(&b, &buffers) - } else { - n, err = buffers.WriteTo(&b) - } - if err != nil { - t.Fatal(err) - } - if len(buffers) != 0 { - t.Fatal(fmt.Errorf("len(buffers) = %d; want 0", len(buffers))) - } - if n != int64(want.Len()) { - t.Fatal(fmt.Errorf("Buffers.WriteTo returned %d; want %d", n, want.Len())) - } - all, err := io.ReadAll(&b) - if !bytes.Equal(all, want.Bytes()) || err != nil { - t.Fatal(fmt.Errorf("read %q, %v; want %q, nil", all, err, want.Bytes())) - } -} diff --git a/src/os/errors.go b/src/os/errors.go index 74c77c902d..23b738c97c 100644 --- a/src/os/errors.go +++ b/src/os/errors.go @@ -92,6 +92,11 @@ func IsPermission(err error) bool { return underlyingErrorIs(err, ErrPermission) } +func IsTimeout(err error) bool { + terr, ok := underlyingError(err).(timeout) + return ok && terr.Timeout() +} + func underlyingErrorIs(err, target error) bool { // Note that this function is not errors.Is: // underlyingError only unwraps the specific error-wrapping types diff --git a/src/os/file_other.go b/src/os/file_other.go index 12b21b838c..68bb114e53 100644 --- a/src/os/file_other.go +++ b/src/os/file_other.go @@ -37,6 +37,14 @@ func NewFile(fd uintptr, name string) *File { return &File{&file{stdioFileHandle(fd), name}} } +// Rename renames (moves) oldpath to newpath. +// If newpath already exists and is not a directory, Rename replaces it. +// OS-specific restrictions may apply when oldpath and newpath are in different directories. +// If there is an error, it will be of type *LinkError. +func Rename(oldpath, newpath string) error { + return ErrNotImplemented +} + // Read reads up to len(b) bytes from machine.Serial. // It returns the number of bytes read and any error encountered. func (f stdioFileHandle) Read(b []byte) (n int, err error) { diff --git a/src/os/stat_darwin.go b/src/os/stat_darwin.go index a27a3b6636..74214cefa4 100644 --- a/src/os/stat_darwin.go +++ b/src/os/stat_darwin.go @@ -12,7 +12,7 @@ import ( func fillFileStatFromSys(fs *fileStat, name string) { fs.name = basename(name) fs.size = fs.sys.Size - fs.modTime = timespecToTime(fs.sys.Mtim) + fs.modTime = timespecToTime(fs.sys.Mtimespec) fs.mode = FileMode(fs.sys.Mode & 0777) switch fs.sys.Mode & syscall.S_IFMT { case syscall.S_IFBLK, syscall.S_IFWHT: @@ -47,5 +47,5 @@ func timespecToTime(ts syscall.Timespec) time.Time { // For testing. func atime(fi FileInfo) time.Time { - return timespecToTime(fi.Sys().(*syscall.Stat_t).Atim) + return timespecToTime(fi.Sys().(*syscall.Stat_t).Atimespec) } diff --git a/src/reflect/all_test.go b/src/reflect/all_test.go index a2dc268a30..f85dba27b3 100644 --- a/src/reflect/all_test.go +++ b/src/reflect/all_test.go @@ -71,7 +71,7 @@ var deepEqualTests = []DeepEqualTest{ {&[3]int{1, 2, 3}, &[3]int{1, 2, 3}, true}, {Basic{1, 0.5}, Basic{1, 0.5}, true}, {error(nil), error(nil), true}, - //{map[int]string{1: "one", 2: "two"}, map[int]string{2: "two", 1: "one"}, true}, + {map[int]string{1: "one", 2: "two"}, map[int]string{2: "two", 1: "one"}, true}, {fn1, fn2, true}, {[]byte{1, 2, 3}, []byte{1, 2, 3}, true}, {[]MyByte{1, 2, 3}, []MyByte{1, 2, 3}, true}, @@ -87,10 +87,10 @@ var deepEqualTests = []DeepEqualTest{ {&[3]int{1, 2, 3}, &[3]int{1, 2, 4}, false}, {Basic{1, 0.5}, Basic{1, 0.6}, false}, {Basic{1, 0}, Basic{2, 0}, false}, - //{map[int]string{1: "one", 3: "two"}, map[int]string{2: "two", 1: "one"}, false}, - //{map[int]string{1: "one", 2: "txo"}, map[int]string{2: "two", 1: "one"}, false}, - //{map[int]string{1: "one"}, map[int]string{2: "two", 1: "one"}, false}, - //{map[int]string{2: "two", 1: "one"}, map[int]string{1: "one"}, false}, + {map[int]string{1: "one", 3: "two"}, map[int]string{2: "two", 1: "one"}, false}, + {map[int]string{1: "one", 2: "txo"}, map[int]string{2: "two", 1: "one"}, false}, + {map[int]string{1: "one"}, map[int]string{2: "two", 1: "one"}, false}, + {map[int]string{2: "two", 1: "one"}, map[int]string{1: "one"}, false}, {nil, 1, false}, {1, nil, false}, {fn1, fn3, false}, @@ -104,16 +104,16 @@ var deepEqualTests = []DeepEqualTest{ {&[1]float64{math.NaN()}, self{}, true}, {[]float64{math.NaN()}, []float64{math.NaN()}, false}, {[]float64{math.NaN()}, self{}, true}, - //{map[float64]float64{math.NaN(): 1}, map[float64]float64{1: 2}, false}, - //{map[float64]float64{math.NaN(): 1}, self{}, true}, + {map[float64]float64{math.NaN(): 1}, map[float64]float64{1: 2}, false}, + {map[float64]float64{math.NaN(): 1}, self{}, true}, // Nil vs empty: not the same. {[]int{}, []int(nil), false}, {[]int{}, []int{}, true}, {[]int(nil), []int(nil), true}, - //{map[int]int{}, map[int]int(nil), false}, - //{map[int]int{}, map[int]int{}, true}, - //{map[int]int(nil), map[int]int(nil), true}, + {map[int]int{}, map[int]int(nil), false}, + {map[int]int{}, map[int]int{}, true}, + {map[int]int(nil), map[int]int(nil), true}, // Mismatched types {1, 1.0, false}, @@ -130,8 +130,8 @@ var deepEqualTests = []DeepEqualTest{ // Possible loops. {&loopy1, &loopy1, true}, {&loopy1, &loopy2, true}, - //{&cycleMap1, &cycleMap2, true}, - //{&cycleMap1, &cycleMap3, false}, + {&cycleMap1, &cycleMap2, true}, + {&cycleMap1, &cycleMap3, false}, } func TestDeepEqual(t *testing.T) { diff --git a/src/reflect/deepequal.go b/src/reflect/deepequal.go index f84ddc8b5e..18a728458c 100644 --- a/src/reflect/deepequal.go +++ b/src/reflect/deepequal.go @@ -15,7 +15,7 @@ import "unsafe" type visit struct { a1 unsafe.Pointer a2 unsafe.Pointer - typ rawType + typ *rawType } // Tests for deep equality using reflected types. The map argument tracks diff --git a/src/reflect/sidetables.go b/src/reflect/sidetables.go deleted file mode 100644 index ea26ff767f..0000000000 --- a/src/reflect/sidetables.go +++ /dev/null @@ -1,61 +0,0 @@ -package reflect - -import ( - "unsafe" -) - -// This stores a varint for each named type. Named types are identified by their -// name instead of by their type. The named types stored in this struct are -// non-basic types: pointer, struct, and channel. -// -//go:extern reflect.namedNonBasicTypesSidetable -var namedNonBasicTypesSidetable uintptr - -//go:extern reflect.structTypesSidetable -var structTypesSidetable byte - -//go:extern reflect.structNamesSidetable -var structNamesSidetable byte - -//go:extern reflect.arrayTypesSidetable -var arrayTypesSidetable byte - -// readStringSidetable reads a string from the given table (like -// structNamesSidetable) and returns this string. No heap allocation is -// necessary because it makes the string point directly to the raw bytes of the -// table. -func readStringSidetable(table unsafe.Pointer, index uintptr) string { - nameLen, namePtr := readVarint(unsafe.Pointer(uintptr(table) + index)) - return *(*string)(unsafe.Pointer(&stringHeader{ - data: namePtr, - len: nameLen, - })) -} - -// readVarint decodes a varint as used in the encoding/binary package. -// It has an input pointer and returns the read varint and the pointer -// incremented to the next field in the data structure, just after the varint. -// -// Details: -// https://github.com/golang/go/blob/e37a1b1c/src/encoding/binary/varint.go#L7-L25 -func readVarint(buf unsafe.Pointer) (uintptr, unsafe.Pointer) { - var n uintptr - shift := uintptr(0) - for { - // Read the next byte in the buffer. - c := *(*byte)(buf) - - // Decode the bits from this byte and add them to the output number. - n |= uintptr(c&0x7f) << shift - shift += 7 - - // Increment the buf pointer (pointer arithmetic!). - buf = unsafe.Pointer(uintptr(buf) + 1) - - // Check whether this is the last byte of this varint. The upper bit - // (msb) indicates whether any bytes follow. - if c>>7 == 0 { - return n, buf - } - } -} diff --git a/src/reflect/swapper.go b/src/reflect/swapper.go index 82842de96c..a2fa44cef0 100644 --- a/src/reflect/swapper.go +++ b/src/reflect/swapper.go @@ -31,8 +31,8 @@ func Swapper(slice interface{}) func(i, j int) { if uint(i) >= uint(header.len) || uint(j) >= uint(header.len) { panic("reflect: slice index out of range") } - val1 := unsafe.Pointer(uintptr(header.data) + uintptr(i)*size) - val2 := unsafe.Pointer(uintptr(header.data) + uintptr(j)*size) + val1 := unsafe.Add(header.data, uintptr(i)*size) + val2 := unsafe.Add(header.data, uintptr(j)*size) memcpy(tmp, val1, size) memcpy(val1, val2, size) memcpy(val2, tmp, size) diff --git a/src/reflect/type.go b/src/reflect/type.go index a5e63f3558..139ef9a47b 100644 --- a/src/reflect/type.go +++ b/src/reflect/type.go @@ -2,36 +2,85 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. +// Type information of an interface is stored as a pointer to a global in the +// interface type (runtime._interface). This is called a type struct. +// It always starts with a byte that contains both the type kind and a few +// flags. In most cases it also contains a pointer to another type struct +// (ptrTo), that is the pointer type of the current type (for example, type int +// also has a pointer to the type *int). The exception is pointer types, to +// avoid infinite recursion. +// +// The layouts specifically look like this: +// - basic types (Bool..UnsafePointer): +// meta uint8 // actually: kind + flags +// ptrTo *typeStruct +// - channels and slices (see elemType): +// meta uint8 +// nmethods uint16 (0) +// ptrTo *typeStruct +// elementType *typeStruct // the type that you get with .Elem() +// - pointer types (see ptrType, this doesn't include chan, map, etc): +// meta uint8 +// nmethods uint16 +// elementType *typeStruct +// - array types (see arrayType) +// meta uint8 +// nmethods uint16 (0) +// ptrTo *typeStruct +// elem *typeStruct // element type of the array +// arrayLen uintptr // length of the array (this is part of the type) +// - map types (this is still missing the key and element types) +// meta uint8 +// nmethods uint16 (0) +// ptrTo *typeStruct +// elem *typeStruct +// key *typeStruct +// - struct types (see structType): +// meta uint8 +// nmethods uint16 +// ptrTo *typeStruct +// pkgpath *byte // package path; null terminated +// numField uint16 +// fields [...]structField // the remaining fields are all of type structField +// - interface types (this is missing the interface methods): +// meta uint8 +// ptrTo *typeStruct +// - signature types (this is missing input and output parameters): +// meta uint8 +// ptrTo *typeStruct +// - named types +// meta uint8 +// nmethods uint16 // number of methods +// ptrTo *typeStruct +// elem *typeStruct // underlying type +// pkgpath *byte // pkgpath; null terminated +// name [1]byte // actual name; null terminated +// +// The type struct is essentially a union of all the above types. Which it is, +// can be determined by looking at the meta byte. + package reflect import ( + "internal/itoa" "unsafe" ) -// The compiler uses a compact encoding to store type information. Unlike the -// main Go compiler, most of the types are stored directly in the type code. -// -// Type code bit allocation: -// xxxxx0: basic types, where xxxxx is the basic type number (never 0). -// The higher bits indicate the named type, if any. -// nxxx1: complex types, where n indicates whether this is a named type (named -// if set) and xxx contains the type kind number: -// 0 (0001): Chan -// 1 (0011): Interface -// 2 (0101): Pointer -// 3 (0111): Slice -// 4 (1001): Array -// 5 (1011): Func -// 6 (1101): Map -// 7 (1111): Struct -// The higher bits are either the contents of the type depending on the -// type (if n is clear) or indicate the number of the named type (if n -// is set). - -type Kind uintptr +// Flags stored in the first byte of the struct field byte array. Must be kept +// up to date with compiler/interface.go. +const ( + structFieldFlagAnonymous = 1 << iota + structFieldFlagHasTag + structFieldFlagIsExported + structFieldFlagIsEmbedded +) + +type Kind uint8 // Copied from reflect/type.go // https://golang.org/src/reflect/type.go?s=8302:8316#L217 +// These constants must match basicTypes and the typeKind* constants in +// compiler/interface.go const ( Invalid Kind = iota Bool @@ -124,11 +173,6 @@ func (k Kind) String() string { } } -// basicType returns a new Type for this kind if Kind is a basic type. -func (k Kind) basicType() rawType { - return rawType(k << 1) -} - // Copied from reflect/type.go // https://go.dev/src/reflect/type.go?#L348 @@ -195,7 +239,7 @@ type Type interface { // // Only exported methods are accessible and they are sorted in // lexicographic order. - //Method(int) Method + Method(int) Method // MethodByName returns the method with that name in the type's // method set and a boolean indicating if the method was found. @@ -346,80 +390,204 @@ type Type interface { Out(i int) Type } -// The typecode as used in an interface{}. -type rawType uintptr +// Constants for the 'meta' byte. +const ( + kindMask = 31 // mask to apply to the meta byte to get the Kind value + flagNamed = 32 // flag that is set if this is a named type +) + +// The base type struct. All type structs start with this. +type rawType struct { + meta uint8 // metadata byte, contains kind and flags (see contants above) +} + +// All types that have an element type: named, chan, slice, array, map (but not +// pointer because it doesn't have ptrTo). +type elemType struct { + rawType + numMethod uint16 + ptrTo *rawType + elem *rawType +} + +type ptrType struct { + rawType + numMethod uint16 + elem *rawType +} + +type arrayType struct { + rawType + numMethod uint16 + ptrTo *rawType + elem *rawType + arrayLen uintptr +} + +type mapType struct { + rawType + numMethod uint16 + ptrTo *rawType + elem *rawType + key *rawType +} + +type namedType struct { + rawType + numMethod uint16 + ptrTo *rawType + elem *rawType + pkg *byte + name [1]byte +} + +// Type for struct types. The numField value is intentionally put before ptrTo +// for better struct packing on 32-bit and 64-bit architectures. On these +// architectures, the ptrTo field still has the same offset as in all the other +// type structs. +// The fields array isn't necessarily 1 structField long, instead it is as long +// as numFields. The array is given a length of 1 to satisfy the Go type +// checker. +type structType struct { + rawType + numMethod uint16 + ptrTo *rawType + pkgpath *byte + numField uint16 + fields [1]structField // the remaining fields are all of type structField +} + +type structField struct { + fieldType *rawType + data unsafe.Pointer // various bits of information, packed in a byte array +} + +// Equivalent to (go/types.Type).Underlying(): if this is a named type return +// the underlying type, else just return the type itself. +func (t *rawType) underlying() *rawType { + if t.isNamed() { + return (*elemType)(unsafe.Pointer(t)).elem + } + return t +} + +func (t *rawType) isNamed() bool { + return t.meta&flagNamed != 0 +} func TypeOf(i interface{}) Type { - return ValueOf(i).typecode + if i == nil { + return nil + } + typecode, _ := decomposeInterface(i) + return (*rawType)(typecode) } func PtrTo(t Type) Type { return PointerTo(t) } func PointerTo(t Type) Type { - if t.Kind() == Pointer { - panic("reflect: cannot make **T type") + return pointerTo(t.(*rawType)) +} + +func pointerTo(t *rawType) *rawType { + if t.isNamed() { + return (*elemType)(unsafe.Pointer(t)).ptrTo } - ptrType := t.(rawType)<<5 | 5 // 0b0101 == 5 - if ptrType>>5 != t { - panic("reflect: PointerTo type does not fit") + + switch t.Kind() { + case Pointer: + // TODO(dgryski): This is blocking https://github.com/tinygo-org/tinygo/issues/3131 + // We need to be able to create types that match existing types to prevent typecode equality. + panic("reflect: cannot make **T type") + case Struct: + return (*structType)(unsafe.Pointer(t)).ptrTo + default: + return (*elemType)(unsafe.Pointer(t)).ptrTo } - return ptrType } -func (t rawType) String() string { - return "T" +func (t *rawType) String() string { + if t.isNamed() { + s := t.name() + if s[0] == '.' { + return s[1:] + } + return s + } + + switch t.Kind() { + case Chan: + return "chan " + t.elem().String() + case Pointer: + return "*" + t.elem().String() + case Slice: + return "[]" + t.elem().String() + case Array: + return "[" + itoa.Itoa(t.Len()) + "]" + t.elem().String() + case Map: + return "map[" + t.key().String() + "]" + t.elem().String() + case Struct: + numField := t.NumField() + if numField == 0 { + return "struct {}" + } + s := "struct {" + for i := 0; i < numField; i++ { + f := t.rawField(i) + s += " " + f.Name + " " + f.Type.String() + // every field except the last needs a semicolon + if i < numField-1 { + s += ";" + } + } + s += " }" + return s + case Interface: + // TODO(dgryski): Needs actual method set info + return "interface {}" + default: + return t.Kind().String() + } + + return t.Kind().String() } -func (t rawType) Kind() Kind { - if t%2 == 0 { - // basic type - return Kind((t >> 1) % 32) - } else { - return Kind(t>>1)%8 + 19 +func (t *rawType) Kind() Kind { + if t == nil { + return Invalid } + return Kind(t.meta & kindMask) } // Elem returns the element type for channel, slice and array types, the // pointed-to value for pointer types, and the key type for map types. -func (t rawType) Elem() Type { +func (t *rawType) Elem() Type { return t.elem() } -func (t rawType) elem() rawType { - switch t.Kind() { - case Chan, Pointer, Slice: - return t.stripPrefix() - case Array: - index := t.stripPrefix() - elem, _ := readVarint(unsafe.Pointer(uintptr(unsafe.Pointer(&arrayTypesSidetable)) + uintptr(index))) - return rawType(elem) - default: // not implemented: Map - panic("unimplemented: (reflect.Type).Elem()") +func (t *rawType) elem() *rawType { + underlying := t.underlying() + switch underlying.Kind() { + case Pointer: + return (*ptrType)(unsafe.Pointer(underlying)).elem + case Chan, Slice, Array, Map: + return (*elemType)(unsafe.Pointer(underlying)).elem + default: + panic(&TypeError{"Elem"}) } } -// stripPrefix removes the "prefix" (the low 5 bits of the type code) from -// the type code. If this is a named type, it will resolve the underlying type -// (which is the data for this named type). If it is not, the lower bits are -// simply shifted off. -// -// The behavior is only defined for non-basic types. -func (t rawType) stripPrefix() rawType { - // Look at the 'n' bit in the type code (see the top of this file) to see - // whether this is a named type. - if (t>>4)%2 != 0 { - // This is a named type. The data is stored in a sidetable. - namedTypeNum := t >> 5 - n := *(*uintptr)(unsafe.Pointer(uintptr(unsafe.Pointer(&namedNonBasicTypesSidetable)) + uintptr(namedTypeNum)*unsafe.Sizeof(uintptr(0)))) - return rawType(n) +func (t *rawType) key() *rawType { + underlying := t.underlying() + if underlying.Kind() != Map { + panic(&TypeError{"Key"}) } - // Not a named type, so the value is stored directly in the type code. - return t >> 5 + return (*mapType)(unsafe.Pointer(underlying)).key } // Field returns the type of the i'th field of this struct type. It panics if t // is not a struct type. -func (t rawType) Field(i int) StructField { +func (t *rawType) Field(i int) StructField { field := t.rawField(i) return StructField{ Name: field.Name, @@ -428,6 +596,38 @@ func (t rawType) Field(i int) StructField { Tag: field.Tag, Anonymous: field.Anonymous, Offset: field.Offset, + Index: []int{i}, + } +} + +func rawStructFieldFromPointer(descriptor *structType, fieldType *rawType, data unsafe.Pointer, flagsByte uint8, name string, offset uintptr) rawStructField { + // Read the field tag, if there is one. + var tag string + if flagsByte&structFieldFlagHasTag != 0 { + data = unsafe.Add(data, 1) // C: data+1 + tagLen := uintptr(*(*byte)(data)) + data = unsafe.Add(data, 1) // C: data+1 + tag = *(*string)(unsafe.Pointer(&stringHeader{ + data: data, + len: tagLen, + })) + } + + // Set the PkgPath to some (arbitrary) value if the package path is not + // exported. + pkgPath := "" + if flagsByte&structFieldFlagIsExported == 0 { + // This field is unexported. + pkgPath = readStringZ(unsafe.Pointer(descriptor.pkgpath)) + } + + return rawStructField{ + Name: name, + PkgPath: pkgPath, + Type: fieldType, + Tag: StructTag(tag), + Anonymous: flagsByte&structFieldFlagAnonymous != 0, + Offset: offset, } } @@ -435,82 +635,146 @@ func (t rawType) Field(i int) StructField { // Type member to an interface. // // For internal use only. -func (t rawType) rawField(i int) rawStructField { +func (t *rawType) rawField(n int) rawStructField { if t.Kind() != Struct { panic(&TypeError{"Field"}) } - structIdentifier := t.stripPrefix() - - numField, p := readVarint(unsafe.Pointer(uintptr(unsafe.Pointer(&structTypesSidetable)) + uintptr(structIdentifier))) - if uint(i) >= uint(numField) { + descriptor := (*structType)(unsafe.Pointer(t.underlying())) + if uint(n) >= uint(descriptor.numField) { panic("reflect: field index out of range") } - // Iterate over every field in the struct and update the StructField each - // time, until the target field has been reached. This is very much not - // efficient, but it is easy to implement. - // Adding a jump table at the start to jump to the field directly would - // make this much faster, but that would also impact code size. - field := rawStructField{} - offset := uintptr(0) - for fieldNum := 0; fieldNum <= i; fieldNum++ { - // Read some flags of this field, like whether the field is an - // embedded field. - flagsByte := *(*uint8)(p) - p = unsafe.Pointer(uintptr(p) + 1) - - // Read the type of this struct field. - var fieldTypeVal uintptr - fieldTypeVal, p = readVarint(p) - fieldType := rawType(fieldTypeVal) - field.Type = fieldType - - // Move Offset forward to align it to this field's alignment. - // Assume alignment is a power of two. - offset = align(offset, uintptr(fieldType.Align())) - field.Offset = offset - offset += fieldType.Size() // starting (unaligned) offset for next field - - // Read the field name. - var nameNum uintptr - nameNum, p = readVarint(p) - field.Name = readStringSidetable(unsafe.Pointer(&structNamesSidetable), nameNum) - - // The first bit in the flagsByte indicates whether this is an embedded - // field. - field.Anonymous = flagsByte&1 != 0 - - // The second bit indicates whether there is a tag. - if flagsByte&2 != 0 { - // There is a tag. - var tagNum uintptr - tagNum, p = readVarint(p) - field.Tag = StructTag(readStringSidetable(unsafe.Pointer(&structNamesSidetable), tagNum)) - } else { - // There is no tag. - field.Tag = "" + // Iterate over all the fields to calculate the offset. + // This offset could have been stored directly in the array (to make the + // lookup faster), but by calculating it on-the-fly a bit of storage can be + // saved. + field := &descriptor.fields[0] + var offset uintptr = 0 + for i := 0; i < n; i++ { + offset += field.fieldType.Size() + + // Increment pointer to the next field. + field = (*structField)(unsafe.Add(unsafe.Pointer(field), unsafe.Sizeof(structField{}))) + + // Align the offset for the next field. + offset = align(offset, uintptr(field.fieldType.Align())) + } + + data := field.data + + // Read some flags of this field, like whether the field is an embedded + // field. See structFieldFlagAnonymous and similar flags. + flagsByte := *(*byte)(data) + data = unsafe.Add(data, 1) + + name := readStringZ(data) + data = unsafe.Add(data, len(name)) + + return rawStructFieldFromPointer(descriptor, field.fieldType, data, flagsByte, name, offset) +} + +// rawFieldByName returns nearly the same value as FieldByName but without converting the +// Type member to an interface. +// +// For internal use only. +func (t *rawType) rawFieldByName(n string) (rawStructField, []int, bool) { + if t.Kind() != Struct { + panic(&TypeError{"Field"}) + } + + type fieldWalker struct { + t *rawType + index []int + } + + queue := make([]fieldWalker, 0, 4) + queue = append(queue, fieldWalker{t, nil}) + + for len(queue) > 0 { + type result struct { + r rawStructField + index []int + } + + var found []result + var nextlevel []fieldWalker + + // For all the structs at this level.. + for _, ll := range queue { + // Iterate over all the fields looking for the matching name + // Also calculate field offset. + + descriptor := (*structType)(unsafe.Pointer(ll.t.underlying())) + var offset uintptr + field := &descriptor.fields[0] + + for i := uint16(0); i < descriptor.numField; i++ { + data := field.data + + // Read some flags of this field, like whether the field is an embedded + // field. See structFieldFlagAnonymous and similar flags. + flagsByte := *(*byte)(data) + data = unsafe.Add(data, 1) + + name := readStringZ(data) + data = unsafe.Add(data, len(name)) + if name == n { + found = append(found, result{ + rawStructFieldFromPointer(descriptor, field.fieldType, data, flagsByte, name, offset), + append(ll.index, int(i)), + }) + } + + structOrPtrToStruct := field.fieldType.Kind() == Struct || (field.fieldType.Kind() == Pointer && field.fieldType.elem().Kind() == Struct) + if flagsByte&structFieldFlagIsEmbedded == structFieldFlagIsEmbedded && structOrPtrToStruct { + embedded := field.fieldType + if embedded.Kind() == Pointer { + embedded = embedded.elem() + } + + nextlevel = append(nextlevel, fieldWalker{ + t: embedded, + index: append(ll.index, int(i)), + }) + } + + offset += field.fieldType.Size() + + // update offset/field pointer if there *is* a next field + if i < descriptor.numField-1 { + + // Increment pointer to the next field. + field = (*structField)(unsafe.Add(unsafe.Pointer(field), unsafe.Sizeof(structField{}))) + + // Align the offset for the next field. + offset = align(offset, uintptr(field.fieldType.Align())) + } + } + } + + // found multiple hits at this level + if len(found) > 1 { + return rawStructField{}, nil, false } - // The third bit indicates whether this field is exported. - if flagsByte&4 != 0 { - // This field is exported. - field.PkgPath = "" - } else { - // This field is unexported. - // TODO: list the real package path here. Storing it should not - // significantly impact binary size as there is only a limited - // number of packages in any program. - field.PkgPath = "" + // found the field we were looking for + if len(found) == 1 { + r := found[0] + return r.r, r.index, true } + + // else len(found) == 0, move on to the next level + queue = append(queue[:0], nextlevel...) } - return field + // didn't find it + return rawStructField{}, nil, false } // Bits returns the number of bits that this type uses. It is only valid for // arithmetic types (integers, floats, and complex numbers). For other types, it // will panic. -func (t rawType) Bits() int { +func (t *rawType) Bits() int { kind := t.Kind() if kind >= Int && kind <= Complex128 { return int(t.Size()) * 8 @@ -520,34 +784,26 @@ func (t rawType) Bits() int { // Len returns the number of elements in this array. It panics of the type kind // is not Array. -func (t rawType) Len() int { +func (t *rawType) Len() int { if t.Kind() != Array { panic(TypeError{"Len"}) } - // skip past the element type - arrayIdentifier := t.stripPrefix() - _, p := readVarint(unsafe.Pointer(uintptr(unsafe.Pointer(&arrayTypesSidetable)) + uintptr(arrayIdentifier))) - - // Read the array length. - arrayLen, _ := readVarint(p) - return int(arrayLen) + return int((*arrayType)(unsafe.Pointer(t.underlying())).arrayLen) } // NumField returns the number of fields of a struct type. It panics for other // type kinds. -func (t rawType) NumField() int { +func (t *rawType) NumField() int { if t.Kind() != Struct { panic(&TypeError{"NumField"}) } - structIdentifier := t.stripPrefix() - n, _ := readVarint(unsafe.Pointer(uintptr(unsafe.Pointer(&structTypesSidetable)) + uintptr(structIdentifier))) - return int(n) + return int((*structType)(unsafe.Pointer(t.underlying())).numField) } // Size returns the size in bytes of a given type. It is similar to // unsafe.Sizeof. -func (t rawType) Size() uintptr { +func (t *rawType) Size() uintptr { switch t.Kind() { case Bool, Int8, Uint8: return 1 @@ -596,7 +852,7 @@ func (t rawType) Size() uintptr { // Align returns the alignment of this type. It is similar to calling // unsafe.Alignof. -func (t rawType) Align() int { +func (t *rawType) Align() int { switch t.Kind() { case Bool, Int8, Uint8: return int(unsafe.Alignof(int8(0))) @@ -648,23 +904,28 @@ func (t rawType) Align() int { // FieldAlign returns the alignment if this type is used in a struct field. It // is currently an alias for Align() but this might change in the future. -func (t rawType) FieldAlign() int { +func (t *rawType) FieldAlign() int { return t.Align() } // AssignableTo returns whether a value of type t can be assigned to a variable // of type u. -func (t rawType) AssignableTo(u Type) bool { - if t == u.(rawType) { +func (t *rawType) AssignableTo(u Type) bool { + if t == u.(*rawType) { + return true + } + + if u.Kind() == Interface && u.NumMethod() == 0 { return true } + if u.Kind() == Interface { panic("reflect: unimplemented: AssignableTo with interface") } return false } -func (t rawType) Implements(u Type) bool { +func (t *rawType) Implements(u Type) bool { if u.Kind() != Interface { panic("reflect: non-interface type passed to Type.Implements") } @@ -672,7 +933,7 @@ func (t rawType) Implements(u Type) bool { } // Comparable returns whether values of this type can be compared to each other. -func (t rawType) Comparable() bool { +func (t *rawType) Comparable() bool { switch t.Kind() { case Bool, Int, Int8, Int16, Int32, Int64, Uint, Uint8, Uint16, Uint32, Uint64, Uintptr: return true @@ -709,36 +970,104 @@ func (t rawType) Comparable() bool { } } +// isbinary() returns if the hashmapAlgorithmBinary functions can be used on this type +func (t *rawType) isBinary() bool { + switch t.Kind() { + case Bool, Int, Int8, Int16, Int32, Int64, Uint, Uint8, Uint16, Uint32, Uint64, Uintptr: + return true + case Pointer: + return true + case Array: + return t.elem().isBinary() + case Struct: + numField := t.NumField() + for i := 0; i < numField; i++ { + if !t.rawField(i).Type.isBinary() { + return false + } + } + return true + } + return false +} + func (t rawType) ChanDir() ChanDir { panic("unimplemented: (reflect.Type).ChanDir()") } -func (t rawType) ConvertibleTo(u Type) bool { +func (t *rawType) ConvertibleTo(u Type) bool { panic("unimplemented: (reflect.Type).ConvertibleTo()") } -func (t rawType) IsVariadic() bool { +func (t *rawType) IsVariadic() bool { panic("unimplemented: (reflect.Type).IsVariadic()") } -func (t rawType) NumIn() int { +func (t *rawType) NumIn() int { panic("unimplemented: (reflect.Type).NumIn()") } -func (t rawType) NumOut() int { +func (t *rawType) NumOut() int { panic("unimplemented: (reflect.Type).NumOut()") } -func (t rawType) NumMethod() int { - panic("unimplemented: (reflect.Type).NumMethod()") +func (t *rawType) NumMethod() int { + + if t.isNamed() { + return int((*namedType)(unsafe.Pointer(t)).numMethod) + } + + switch t.Kind() { + case Pointer: + return int((*ptrType)(unsafe.Pointer(t)).numMethod) + case Struct: + return int((*structType)(unsafe.Pointer(t)).numMethod) + } + + // Other types have no methods attached. Note we don't panic here. + return 0 } -func (t rawType) Name() string { - panic("unimplemented: (reflect.Type).Name()") +// Read and return a null terminated string starting from data. +func readStringZ(data unsafe.Pointer) string { + start := data + var len uintptr + for *(*byte)(data) != 0 { + len++ + data = unsafe.Add(data, 1) // C: data++ + } + + return *(*string)(unsafe.Pointer(&stringHeader{ + data: start, + len: len, + })) } -func (t rawType) Key() Type { - panic("unimplemented: (reflect.Type).Key()") +func (t *rawType) name() string { + ntype := (*namedType)(unsafe.Pointer(t)) + return readStringZ(unsafe.Pointer(&ntype.name[0])) +} + +func (t *rawType) Name() string { + if t.isNamed() { + name := t.name() + for i := 0; i < len(name); i++ { + if name[i] == '.' { + return name[i+1:] + } + } + panic("corrupt name data") + } + + if t.Kind() <= UnsafePointer { + return t.Kind().String() + } + + return "" +} + +func (t *rawType) Key() Type { + return t.key() } func (t rawType) In(i int) Type { @@ -749,20 +1078,71 @@ func (t rawType) Out(i int) Type { panic("unimplemented: (reflect.Type).Out()") } +func (t rawType) Method(i int) Method { + panic("unimplemented: (reflect.Type).Method()") +} + func (t rawType) MethodByName(name string) (Method, bool) { panic("unimplemented: (reflect.Type).MethodByName()") } -func (t rawType) PkgPath() string { - panic("unimplemented: (reflect.Type).PkgPath()") +func (t *rawType) PkgPath() string { + if t.isNamed() { + ntype := (*namedType)(unsafe.Pointer(t)) + return readStringZ(unsafe.Pointer(ntype.pkg)) + } + + return "" } -func (t rawType) FieldByName(name string) (StructField, bool) { - panic("unimplemented: (reflect.Type).FieldByName()") +func (t *rawType) FieldByName(name string) (StructField, bool) { + if t.Kind() != Struct { + panic(TypeError{"FieldByName"}) + } + + field, index, ok := t.rawFieldByName(name) + if !ok { + return StructField{}, false + } + + return StructField{ + Name: field.Name, + PkgPath: field.PkgPath, + Type: field.Type, // note: converts rawType to Type + Tag: field.Tag, + Anonymous: field.Anonymous, + Offset: field.Offset, + Index: index, + }, true } -func (t rawType) FieldByIndex(index []int) StructField { - panic("unimplemented: (reflect.Type).FieldByIndex()") +func (t *rawType) FieldByIndex(index []int) StructField { + ftype := t + var field rawStructField + + for _, n := range index { + structOrPtrToStruct := ftype.Kind() == Struct || (ftype.Kind() == Pointer && ftype.elem().Kind() == Struct) + if !structOrPtrToStruct { + panic(&TypeError{"FieldByIndex:" + ftype.Kind().String()}) + } + + if ftype.Kind() == Pointer { + ftype = ftype.elem() + } + + field = ftype.rawField(n) + ftype = field.Type + } + + return StructField{ + Name: field.Name, + PkgPath: field.PkgPath, + Type: field.Type, // note: converts rawType to Type + Tag: field.Tag, + Anonymous: field.Anonymous, + Offset: field.Offset, + Index: index, + } } // A StructField describes a single field in a struct. @@ -776,9 +1156,9 @@ type StructField struct { Type Type Tag StructTag // field tag string - Anonymous bool Offset uintptr Index []int // index sequence for Type.FieldByIndex + Anonymous bool } // IsExported reports whether the field is exported. @@ -792,10 +1172,10 @@ func (f StructField) IsExported() bool { type rawStructField struct { Name string PkgPath string - Type rawType + Type *rawType Tag StructTag - Anonymous bool Offset uintptr + Anonymous bool } // A StructTag is the tag string in a struct field. @@ -879,3 +1259,15 @@ func align(offset uintptr, alignment uintptr) uintptr { func SliceOf(t Type) Type { panic("unimplemented: reflect.SliceOf()") } + +func ArrayOf(n int, t Type) Type { + panic("unimplemented: reflect.ArrayOf()") +} + +func StructOf([]StructField) Type { + panic("unimplemented: reflect.StructOf()") +} + +func MapOf(key, value Type) Type { + panic("unimplemented: reflect.MapOf()") +} diff --git a/src/reflect/value.go b/src/reflect/value.go index fee3217212..f63a503ec9 100644 --- a/src/reflect/value.go +++ b/src/reflect/value.go @@ -17,7 +17,7 @@ const ( ) type Value struct { - typecode rawType + typecode *rawType value unsafe.Pointer flags valueFlags } @@ -44,15 +44,15 @@ func Indirect(v Value) Value { } //go:linkname composeInterface runtime.composeInterface -func composeInterface(rawType, unsafe.Pointer) interface{} +func composeInterface(unsafe.Pointer, unsafe.Pointer) interface{} //go:linkname decomposeInterface runtime.decomposeInterface -func decomposeInterface(i interface{}) (rawType, unsafe.Pointer) +func decomposeInterface(i interface{}) (unsafe.Pointer, unsafe.Pointer) func ValueOf(i interface{}) Value { typecode, value := decomposeInterface(i) return Value{ - typecode: typecode, + typecode: (*rawType)(typecode), value: value, flags: valueFlagExported, } @@ -81,11 +81,11 @@ func valueInterfaceUnsafe(v Value) interface{} { // value. var value uintptr for j := v.typecode.Size(); j != 0; j-- { - value = (value << 8) | uintptr(*(*uint8)(unsafe.Pointer(uintptr(v.value) + j - 1))) + value = (value << 8) | uintptr(*(*uint8)(unsafe.Add(v.value, j-1))) } v.value = unsafe.Pointer(value) } - return composeInterface(v.typecode, v.value) + return composeInterface(unsafe.Pointer(v.typecode), v.value) } func (v Value) Type() Type { @@ -128,7 +128,7 @@ func (v Value) IsZero() bool { default: // This should never happens, but will act as a safeguard for // later, as a default value doesn't makes sense here. - panic(&ValueError{"reflect.Value.IsZero", v.Kind()}) + panic(&ValueError{Method: "reflect.Value.IsZero", Kind: v.Kind()}) } } @@ -136,7 +136,7 @@ func (v Value) IsZero() bool { // // RawType returns the raw, underlying type code. It is used in the runtime // package and needs to be exported for the runtime package to access it. -func (v Value) RawType() rawType { +func (v Value) RawType() *rawType { return v.typecode } @@ -163,13 +163,10 @@ func (v Value) IsNil() bool { slice := (*sliceHeader)(v.value) return slice.data == nil case Interface: - if v.value == nil { - return true - } - _, val := decomposeInterface(*(*interface{})(v.value)) + val := *(*interface{})(v.value) return val == nil default: - panic(&ValueError{Method: "IsNil"}) + panic(&ValueError{Method: "IsNil", Kind: v.Kind()}) } } @@ -189,9 +186,13 @@ func (v Value) UnsafePointer() unsafe.Pointer { slice := (*sliceHeader)(v.value) return slice.data case Func: - panic("unimplemented: (reflect.Value).UnsafePointer()") - default: // not implemented: Func - panic(&ValueError{Method: "UnsafePointer"}) + fn := (*funcHeader)(v.value) + if fn.Context != nil { + return fn.Context + } + return fn.Code + default: + panic(&ValueError{Method: "UnsafePointer", Kind: v.Kind()}) } } @@ -205,7 +206,7 @@ func (v Value) pointer() unsafe.Pointer { } func (v Value) IsValid() bool { - return v.typecode != 0 + return v.typecode != nil } func (v Value) CanInterface() bool { @@ -217,7 +218,19 @@ func (v Value) CanAddr() bool { } func (v Value) Addr() Value { - panic("unimplemented: (reflect.Value).Addr()") + if !v.CanAddr() { + panic("reflect.Value.Addr of unaddressable value") + } + + return Value{ + typecode: pointerTo(v.typecode), + value: v.value, + flags: v.flags ^ valueFlagIndirect, + } +} + +func (v Value) UnsafeAddr() uintptr { + return uintptr(v.Addr().UnsafePointer()) } func (v Value) CanSet() bool { @@ -233,7 +246,7 @@ func (v Value) Bool() bool { return uintptr(v.value) != 0 } default: - panic(&ValueError{Method: "Bool"}) + panic(&ValueError{Method: "Bool", Kind: v.Kind()}) } } @@ -270,7 +283,7 @@ func (v Value) Int() int64 { return int64(int64(uintptr(v.value))) } default: - panic(&ValueError{Method: "Int"}) + panic(&ValueError{Method: "Int", Kind: v.Kind()}) } } @@ -313,7 +326,7 @@ func (v Value) Uint() uint64 { return uint64(uintptr(v.value)) } default: - panic(&ValueError{Method: "Uint"}) + panic(&ValueError{Method: "Uint", Kind: v.Kind()}) } } @@ -339,7 +352,7 @@ func (v Value) Float() float64 { return *(*float64)(unsafe.Pointer(&v.value)) } default: - panic(&ValueError{Method: "Float"}) + panic(&ValueError{Method: "Float", Kind: v.Kind()}) } } @@ -361,7 +374,7 @@ func (v Value) Complex() complex128 { // architectures with 128-bit pointers, however. return *(*complex128)(v.value) default: - panic(&ValueError{Method: "Complex"}) + panic(&ValueError{Method: "Complex", Kind: v.Kind()}) } } @@ -373,19 +386,111 @@ func (v Value) String() string { return *(*string)(v.value) default: // Special case because of the special treatment of .String() in Go. - return "" + return "<" + v.typecode.String() + " Value>" } } func (v Value) Bytes() []byte { - panic("unimplemented: (reflect.Value).Bytes()") + switch v.Kind() { + case Slice: + if v.typecode.elem().Kind() != Uint8 { + panic(&ValueError{Method: "Bytes", Kind: v.Kind()}) + } + return *(*[]byte)(v.value) + + case Array: + v.checkAddressable() + + if v.typecode.elem().Kind() != Uint8 { + panic(&ValueError{Method: "Bytes", Kind: v.Kind()}) + } + + // Small inline arrays are not addressable, so we only have to + // handle addressable arrays which will be stored as pointers + // in v.value + return unsafe.Slice((*byte)(v.value), v.Len()) + } + + panic(&ValueError{Method: "Bytes", Kind: v.Kind()}) } func (v Value) Slice(i, j int) Value { - panic("unimplemented: (reflect.Value).Slice()") + switch v.Kind() { + case Slice: + hdr := *(*sliceHeader)(v.value) + i, j := uintptr(i), uintptr(j) + + if j < i || hdr.cap < j { + slicePanic() + } + + elemSize := v.typecode.underlying().elem().Size() + + hdr.len = j - i + hdr.cap = hdr.cap - i + hdr.data = unsafe.Add(hdr.data, i*elemSize) + + return Value{ + typecode: v.typecode, + value: unsafe.Pointer(&hdr), + flags: v.flags, + } + + case Array: + // TODO(dgryski): can't do this yet because the resulting value needs type slice of v.elem(), not array of v.elem(). + // need to be able to look up this "new" type so pointer equality of types still works + + case String: + i, j := uintptr(i), uintptr(j) + str := *(*stringHeader)(v.value) + + if j < i || str.len < j { + slicePanic() + } + + hdr := stringHeader{ + data: unsafe.Add(str.data, i), + len: j - i, + } + + return Value{ + typecode: v.typecode, + value: unsafe.Pointer(&hdr), + flags: v.flags, + } + } + + panic(&ValueError{Method: "Slice", Kind: v.Kind()}) } func (v Value) Slice3(i, j, k int) Value { + switch v.Kind() { + case Slice: + hdr := *(*sliceHeader)(v.value) + i, j, k := uintptr(i), uintptr(j), uintptr(k) + + if j < i || k < j || hdr.len < k { + slicePanic() + } + + elemSize := v.typecode.underlying().elem().Size() + + hdr.len = j - i + hdr.cap = k - i + hdr.data = unsafe.Add(hdr.data, i*elemSize) + + return Value{ + typecode: v.typecode, + value: unsafe.Pointer(&hdr), + flags: v.flags, + } + + case Array: + // TODO(dgryski): can't do this yet because the resulting value needs type v.elem(), not array of v.elem(). + // need to be able to look up this "new" type so pointer equality of types still works + + } + panic("unimplemented: (reflect.Value).Slice3()") } @@ -410,7 +515,7 @@ func (v Value) Len() int { case String: return int((*stringHeader)(v.value).len) default: - panic(&ValueError{Method: "Len"}) + panic(&ValueError{Method: "Len", Kind: v.Kind()}) } } @@ -428,7 +533,7 @@ func (v Value) Cap() int { case Slice: return int((*sliceHeader)(v.value).cap) default: - panic(&ValueError{Method: "Cap"}) + panic(&ValueError{Method: "Cap", Kind: v.Kind()}) } } @@ -453,17 +558,20 @@ func (v Value) Elem() Value { case Interface: typecode, value := decomposeInterface(*(*interface{})(v.value)) return Value{ - typecode: typecode, + typecode: (*rawType)(typecode), value: value, flags: v.flags &^ valueFlagIndirect, } default: - panic(&ValueError{Method: "Elem"}) + panic(&ValueError{Method: "Elem", Kind: v.Kind()}) } } // Field returns the value of the i'th field of this struct. func (v Value) Field(i int) Value { + if v.Kind() != Struct { + panic(&ValueError{Method: "Field", Kind: v.Kind()}) + } structField := v.typecode.rawField(i) flags := v.flags if structField.PkgPath != "" { @@ -481,7 +589,7 @@ func (v Value) Field(i int) Value { return Value{ flags: flags, typecode: fieldType, - value: unsafe.Pointer(uintptr(v.value) + structField.Offset), + value: unsafe.Add(v.value, structField.Offset), } } @@ -496,7 +604,7 @@ func (v Value) Field(i int) Value { return Value{ flags: flags, typecode: fieldType, - value: unsafe.Pointer(uintptr(0)), + value: unsafe.Pointer(nil), } } @@ -504,7 +612,7 @@ func (v Value) Field(i int) Value { // The value was not stored in the interface before but will be // afterwards, so load the value (from the correct offset) and return // it. - ptr := unsafe.Pointer(uintptr(v.value) + structField.Offset) + ptr := unsafe.Add(v.value, structField.Offset) value := unsafe.Pointer(loadValue(ptr, fieldSize)) return Value{ flags: flags &^ valueFlagIndirect, @@ -523,6 +631,8 @@ func (v Value) Field(i int) Value { } } +var uint8Type = TypeOf(uint8(0)).(*rawType) + func (v Value) Index(i int) Value { switch v.Kind() { case Slice: @@ -535,8 +645,7 @@ func (v Value) Index(i int) Value { typecode: v.typecode.elem(), flags: v.flags | valueFlagIndirect, } - addr := uintptr(slice.data) + elem.typecode.Size()*uintptr(i) // pointer to new value - elem.value = unsafe.Pointer(addr) + elem.value = unsafe.Add(slice.data, elem.typecode.Size()*uintptr(i)) // pointer to new value return elem case String: // Extract a character from a string. @@ -550,8 +659,8 @@ func (v Value) Index(i int) Value { panic("reflect: string index out of range") } return Value{ - typecode: Uint8.basicType(), - value: unsafe.Pointer(uintptr(*(*uint8)(unsafe.Pointer(uintptr(s.data) + uintptr(i))))), + typecode: uint8Type, + value: unsafe.Pointer(uintptr(*(*uint8)(unsafe.Add(s.data, i)))), flags: v.flags & valueFlagExported, } case Array: @@ -571,18 +680,18 @@ func (v Value) Index(i int) Value { // indirect. Also, because size != 0 this implies that the array // length must be != 0, and thus that the total size is at least // elemSize. - addr := uintptr(v.value) + elemSize*uintptr(i) // pointer to new value + addr := unsafe.Add(v.value, elemSize*uintptr(i)) // pointer to new value return Value{ typecode: v.typecode.elem(), flags: v.flags, - value: unsafe.Pointer(addr), + value: addr, } } if size > unsafe.Sizeof(uintptr(0)) || v.isIndirect() { // The element fits in a pointer, but the array is not stored in the pointer directly. // Load the value from the pointer. - addr := unsafe.Pointer(uintptr(v.value) + elemSize*uintptr(i)) // pointer to new value + addr := unsafe.Add(v.value, elemSize*uintptr(i)) // pointer to new value value := addr if !v.isIndirect() { // Use a pointer to the value (don't load the value) if the @@ -606,7 +715,7 @@ func (v Value) Index(i int) Value { value: unsafe.Pointer(value), } default: - panic(&ValueError{Method: "Index"}) + panic(&ValueError{Method: "Index", Kind: v.Kind()}) } } @@ -619,7 +728,7 @@ func loadValue(ptr unsafe.Pointer, size uintptr) uintptr { for i := uintptr(0); i < size; i++ { loadedValue |= uintptr(*(*byte)(ptr)) << shift shift += 8 - ptr = unsafe.Pointer(uintptr(ptr) + 1) + ptr = unsafe.Add(ptr, 1) } return loadedValue } @@ -634,35 +743,146 @@ func (v Value) NumMethod() int { return v.typecode.NumMethod() } +// OverflowFloat reports whether the float64 x cannot be represented by v's type. +// It panics if v's Kind is not Float32 or Float64. func (v Value) OverflowFloat(x float64) bool { - panic("unimplemented: (reflect.Value).OverflowFloat()") + k := v.Kind() + switch k { + case Float32: + return overflowFloat32(x) + case Float64: + return false + } + panic(&ValueError{Method: "reflect.Value.OverflowFloat", Kind: v.Kind()}) +} + +func overflowFloat32(x float64) bool { + if x < 0 { + x = -x + } + return math.MaxFloat32 < x && x <= math.MaxFloat64 } func (v Value) MapKeys() []Value { - panic("unimplemented: (reflect.Value).MapKeys()") + if v.Kind() != Map { + panic(&ValueError{Method: "MapKeys", Kind: v.Kind()}) + } + + // empty map + if v.Len() == 0 { + return nil + } + + keys := make([]Value, 0, v.Len()) + + it := hashmapNewIterator() + k := New(v.typecode.Key()) + e := New(v.typecode.Elem()) + + for hashmapNext(v.pointer(), it, k.value, e.value) { + keys = append(keys, k.Elem()) + k = New(v.typecode.Key()) + } + + return keys } +//go:linkname hashmapStringGet runtime.hashmapStringGetUnsafePointer +func hashmapStringGet(m unsafe.Pointer, key string, value unsafe.Pointer, valueSize uintptr) bool + +//go:linkname hashmapBinaryGet runtime.hashmapBinaryGetUnsafePointer +func hashmapBinaryGet(m unsafe.Pointer, key, value unsafe.Pointer, valueSize uintptr) bool + +//go:linkname hashmapInterfaceGet runtime.hashmapInterfaceGetUnsafePointer +func hashmapInterfaceGet(m unsafe.Pointer, key interface{}, value unsafe.Pointer, valueSize uintptr) bool + func (v Value) MapIndex(key Value) Value { - panic("unimplemented: (reflect.Value).MapIndex()") + if v.Kind() != Map { + panic(&ValueError{Method: "MapIndex", Kind: v.Kind()}) + } + + // compare key type with actual key type of map + if key.typecode != v.typecode.key() { + // type error? + panic("reflect.Value.MapIndex: incompatible types for key") + } + + elemType := v.typecode.Elem() + elem := New(elemType) + + if key.Kind() == String { + if ok := hashmapStringGet(v.pointer(), *(*string)(key.value), elem.value, elemType.Size()); !ok { + return Value{} + } + return elem.Elem() + } else if key.typecode.isBinary() { + var keyptr unsafe.Pointer + if key.isIndirect() || key.typecode.Size() > unsafe.Sizeof(uintptr(0)) { + keyptr = key.value + } else { + keyptr = unsafe.Pointer(&key.value) + } + //TODO(dgryski): zero out padding bytes in key, if any + if ok := hashmapBinaryGet(v.pointer(), keyptr, elem.value, elemType.Size()); !ok { + return Value{} + } + return elem.Elem() + } else { + if ok := hashmapInterfaceGet(v.pointer(), key.Interface(), elem.value, elemType.Size()); !ok { + return Value{} + } + return elem.Elem() + } } +//go:linkname hashmapNewIterator runtime.hashmapNewIterator +func hashmapNewIterator() unsafe.Pointer + +//go:linkname hashmapNext runtime.hashmapNextUnsafePointer +func hashmapNext(m unsafe.Pointer, it unsafe.Pointer, key, value unsafe.Pointer) bool + func (v Value) MapRange() *MapIter { - panic("unimplemented: (reflect.Value).MapRange()") + if v.Kind() != Map { + panic(&ValueError{Method: "MapRange", Kind: v.Kind()}) + } + + return &MapIter{ + m: v, + it: hashmapNewIterator(), + } } type MapIter struct { + m Value + it unsafe.Pointer + key Value + val Value + + valid bool } func (it *MapIter) Key() Value { - panic("unimplemented: (*reflect.MapIter).Key()") + if !it.valid { + panic("reflect.MapIter.Key called on invalid iterator") + } + + return it.key.Elem() } func (it *MapIter) Value() Value { - panic("unimplemented: (*reflect.MapIter).Value()") + if !it.valid { + panic("reflect.MapIter.Value called on invalid iterator") + } + + return it.val.Elem() } func (it *MapIter) Next() bool { - panic("unimplemented: (*reflect.MapIter).Next()") + it.key = New(it.m.typecode.Key()) + it.val = New(it.m.typecode.Elem()) + + it.valid = hashmapNext(it.m.pointer(), it.it, it.key.value, it.val.value) + return it.valid } func (v Value) Set(x Value) { @@ -685,7 +905,7 @@ func (v Value) SetBool(x bool) { case Bool: *(*bool)(v.value) = x default: - panic(&ValueError{Method: "SetBool"}) + panic(&ValueError{Method: "SetBool", Kind: v.Kind()}) } } @@ -703,7 +923,7 @@ func (v Value) SetInt(x int64) { case Int64: *(*int64)(v.value) = x default: - panic(&ValueError{Method: "SetInt"}) + panic(&ValueError{Method: "SetInt", Kind: v.Kind()}) } } @@ -723,7 +943,7 @@ func (v Value) SetUint(x uint64) { case Uintptr: *(*uintptr)(v.value) = uintptr(x) default: - panic(&ValueError{Method: "SetUint"}) + panic(&ValueError{Method: "SetUint", Kind: v.Kind()}) } } @@ -735,7 +955,7 @@ func (v Value) SetFloat(x float64) { case Float64: *(*float64)(v.value) = x default: - panic(&ValueError{Method: "SetFloat"}) + panic(&ValueError{Method: "SetFloat", Kind: v.Kind()}) } } @@ -747,7 +967,7 @@ func (v Value) SetComplex(x complex128) { case Complex128: *(*complex128)(v.value) = x default: - panic(&ValueError{Method: "SetComplex"}) + panic(&ValueError{Method: "SetComplex", Kind: v.Kind()}) } } @@ -757,12 +977,18 @@ func (v Value) SetString(x string) { case String: *(*string)(v.value) = x default: - panic(&ValueError{Method: "SetString"}) + panic(&ValueError{Method: "SetString", Kind: v.Kind()}) } } func (v Value) SetBytes(x []byte) { - panic("unimplemented: (reflect.Value).SetBytes()") + v.checkAddressable() + if v.typecode.Kind() != Slice || v.typecode.elem().Kind() != Uint8 { + panic("reflect.Value.SetBytes called on not []byte") + } + + // copy the header contents over + *(*[]byte)(v.value) = x } func (v Value) SetCap(n int) { @@ -770,7 +996,15 @@ func (v Value) SetCap(n int) { } func (v Value) SetLen(n int) { - panic("unimplemented: (reflect.Value).SetLen()") + if v.typecode.Kind() != Slice { + panic(&ValueError{Method: "reflect.Value.SetLen", Kind: v.Kind()}) + } + + hdr := (*sliceHeader)(v.value) + if int(uintptr(n)) != n || uintptr(n) > hdr.cap { + panic("reflect.Value.SetLen: slice length out of range") + } + hdr.len = uintptr(n) } func (v Value) checkAddressable() { @@ -779,31 +1013,115 @@ func (v Value) checkAddressable() { } } +// OverflowInt reports whether the int64 x cannot be represented by v's type. +// It panics if v's Kind is not Int, Int8, Int16, Int32, or Int64. func (v Value) OverflowInt(x int64) bool { - panic("unimplemented: reflect.OverflowInt()") + switch v.Kind() { + case Int, Int8, Int16, Int32, Int64: + bitSize := v.typecode.Size() * 8 + trunc := (x << (64 - bitSize)) >> (64 - bitSize) + return x != trunc + } + panic(&ValueError{Method: "reflect.Value.OverflowInt", Kind: v.Kind()}) } +// OverflowUint reports whether the uint64 x cannot be represented by v's type. +// It panics if v's Kind is not Uint, Uintptr, Uint8, Uint16, Uint32, or Uint64. func (v Value) OverflowUint(x uint64) bool { - panic("unimplemented: reflect.OverflowUint()") + k := v.Kind() + switch k { + case Uint, Uintptr, Uint8, Uint16, Uint32, Uint64: + bitSize := v.typecode.Size() * 8 + trunc := (x << (64 - bitSize)) >> (64 - bitSize) + return x != trunc + } + panic(&ValueError{Method: "reflect.Value.OverflowUint", Kind: v.Kind()}) +} + +func (v Value) CanConvert(t Type) bool { + panic("unimplemented: (reflect.Value).CanConvert()") } func (v Value) Convert(t Type) Value { panic("unimplemented: (reflect.Value).Convert()") } +//go:linkname slicePanic runtime.slicePanic +func slicePanic() + func MakeSlice(typ Type, len, cap int) Value { - panic("unimplemented: reflect.MakeSlice()") + if typ.Kind() != Slice { + panic("reflect.MakeSlice of non-slice type") + } + + rtype := typ.(*rawType) + + ulen := uint(len) + ucap := uint(cap) + maxSize := (^uintptr(0)) / 2 + elementSize := rtype.elem().Size() + if elementSize > 1 { + maxSize /= uintptr(elementSize) + } + if ulen > ucap || ucap > uint(maxSize) { + slicePanic() + } + + // This can't overflow because of the above checks. + size := uintptr(ucap) * elementSize + + var slice sliceHeader + slice.cap = uintptr(ucap) + slice.len = uintptr(ulen) + slice.data = alloc(size, nil) + + return Value{ + typecode: rtype, + value: unsafe.Pointer(&slice), + flags: valueFlagExported, + } +} + +var zerobuffer unsafe.Pointer + +const zerobufferLen = 32 + +func init() { + // 32 characters of zero bytes + zerobufferStr := "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" + s := (*stringHeader)(unsafe.Pointer(&zerobufferStr)) + zerobuffer = s.data } func Zero(typ Type) Value { - panic("unimplemented: reflect.Zero()") + if typ.Size() <= unsafe.Sizeof(uintptr(0)) { + return Value{ + typecode: typ.(*rawType), + value: nil, + flags: valueFlagExported, + } + } + + if typ.Size() <= zerobufferLen { + return Value{ + typecode: typ.(*rawType), + value: unsafe.Pointer(zerobuffer), + flags: valueFlagExported, + } + } + + return Value{ + typecode: typ.(*rawType), + value: alloc(typ.Size(), nil), + flags: valueFlagExported, + } } // New is the reflect equivalent of the new(T) keyword, returning a pointer to a // new value of the given type. func New(typ Type) Value { return Value{ - typecode: PtrTo(typ).(rawType), + typecode: pointerTo(typ.(*rawType)), value: alloc(typ.Size(), nil), flags: valueFlagExported, } @@ -861,16 +1179,112 @@ func alloc(size uintptr, layout unsafe.Pointer) unsafe.Pointer //go:linkname sliceAppend runtime.sliceAppend func sliceAppend(srcBuf, elemsBuf unsafe.Pointer, srcLen, srcCap, elemsLen uintptr, elemSize uintptr) (unsafe.Pointer, uintptr, uintptr) +//go:linkname sliceCopy runtime.sliceCopy +func sliceCopy(dst, src unsafe.Pointer, dstLen, srcLen uintptr, elemSize uintptr) int + // Copy copies the contents of src into dst until either // dst has been filled or src has been exhausted. func Copy(dst, src Value) int { - panic("unimplemented: reflect.Copy()") + compatibleTypes := false || + // dst and src are both slices or arrays with equal types + ((dst.typecode.Kind() == Slice || dst.typecode.Kind() == Array) && + (src.typecode.Kind() == Slice || src.typecode.Kind() == Array) && + (dst.typecode.elem() == src.typecode.elem())) || + // dst is array or slice of uint8 and src is string + ((dst.typecode.Kind() == Slice || dst.typecode.Kind() == Array) && + dst.typecode.elem().Kind() == Uint8 && + src.typecode.Kind() == String) + + if !compatibleTypes { + panic("Copy: type mismatch: " + dst.typecode.String() + "/" + src.typecode.String()) + } + + // Can read from an unaddressable array but not write to one. + if dst.typecode.Kind() == Array && !dst.isIndirect() { + panic("reflect.Copy: unaddressable array value") + } + + dstbuf, dstlen := buflen(dst) + srcbuf, srclen := buflen(src) + + return sliceCopy(dstbuf, srcbuf, dstlen, srclen, dst.typecode.elem().Size()) +} + +func buflen(v Value) (unsafe.Pointer, uintptr) { + var buf unsafe.Pointer + var len uintptr + switch v.typecode.Kind() { + case Slice: + hdr := (*sliceHeader)(v.value) + buf = hdr.data + len = hdr.len + case Array: + if v.isIndirect() { + buf = v.value + } else { + buf = unsafe.Pointer(&v.value) + } + len = uintptr(v.Len()) + case String: + hdr := (*stringHeader)(v.value) + buf = hdr.data + len = hdr.len + default: + // This shouldn't happen + panic("reflect.Copy: not slice or array or string") + } + + return buf, len +} + +//go:linkname sliceGrow runtime.sliceGrow +func sliceGrow(buf unsafe.Pointer, oldLen, oldCap, newCap, elemSize uintptr) (unsafe.Pointer, uintptr, uintptr) + +// extend slice to hold n new elements +func (v *Value) extendSlice(n int) { + if v.Kind() != Slice { + panic(&ValueError{Method: "extendSlice", Kind: v.Kind()}) + } + + var old sliceHeader + if v.value != nil { + old = *(*sliceHeader)(v.value) + } + + var nbuf unsafe.Pointer + var nlen, ncap uintptr + + if old.len+uintptr(n) > old.cap { + // we need to grow the slice + nbuf, nlen, ncap = sliceGrow(old.data, old.len, old.cap, old.cap+uintptr(n), v.typecode.elem().Size()) + } else { + // we can reuse the slice we have + nbuf = old.data + nlen = old.len + ncap = old.cap + } + + newslice := sliceHeader{ + data: nbuf, + len: nlen + uintptr(n), + cap: ncap, + } + + v.value = (unsafe.Pointer)(&newslice) } // Append appends the values x to a slice s and returns the resulting slice. // As in Go, each x's value must be assignable to the slice's element type. -func Append(s Value, x ...Value) Value { - panic("unimplemented: reflect.Append()") +func Append(v Value, x ...Value) Value { + if v.Kind() != Slice { + panic(&ValueError{Method: "Append", Kind: v.Kind()}) + } + oldLen := v.Len() + v.extendSlice(len(x)) + for i, xx := range x { + v.Index(oldLen + i).Set(xx) + } + return v } // AppendSlice appends a slice t to a slice s and returns the resulting slice. @@ -901,13 +1315,109 @@ func AppendSlice(s, t Value) Value { } } +//go:linkname hashmapStringSet runtime.hashmapStringSetUnsafePointer +func hashmapStringSet(m unsafe.Pointer, key string, value unsafe.Pointer) + +//go:linkname hashmapBinarySet runtime.hashmapBinarySetUnsafePointer +func hashmapBinarySet(m unsafe.Pointer, key, value unsafe.Pointer) + +//go:linkname hashmapInterfaceSet runtime.hashmapInterfaceSetUnsafePointer +func hashmapInterfaceSet(m unsafe.Pointer, key interface{}, value unsafe.Pointer) + +//go:linkname hashmapStringDelete runtime.hashmapStringDeleteUnsafePointer +func hashmapStringDelete(m unsafe.Pointer, key string) + +//go:linkname hashmapBinaryDelete runtime.hashmapBinaryDeleteUnsafePointer +func hashmapBinaryDelete(m unsafe.Pointer, key unsafe.Pointer) + +//go:linkname hashmapInterfaceDelete runtime.hashmapInterfaceDeleteUnsafePointer +func hashmapInterfaceDelete(m unsafe.Pointer, key interface{}) + func (v Value) SetMapIndex(key, elem Value) { - panic("unimplemented: (reflect.Value).SetMapIndex()") + if v.Kind() != Map { + panic(&ValueError{Method: "SetMapIndex", Kind: v.Kind()}) + } + + // compare key type with actual key type of map + if key.typecode != v.typecode.key() { + panic("reflect.Value.SetMapIndex: incompatible types for key") + } + + // if elem is the zero Value, it means delete + del := elem == Value{} + + if !del && elem.typecode != v.typecode.elem() { + panic("reflect.Value.SetMapIndex: incompatible types for value") + } + + if key.Kind() == String { + if del { + hashmapStringDelete(v.pointer(), *(*string)(key.value)) + } else { + var elemptr unsafe.Pointer + if elem.isIndirect() || elem.typecode.Size() > unsafe.Sizeof(uintptr(0)) { + elemptr = elem.value + } else { + elemptr = unsafe.Pointer(&elem.value) + } + hashmapStringSet(v.pointer(), *(*string)(key.value), elemptr) + } + + } else if key.typecode.isBinary() { + var keyptr unsafe.Pointer + if key.isIndirect() || key.typecode.Size() > unsafe.Sizeof(uintptr(0)) { + keyptr = key.value + } else { + keyptr = unsafe.Pointer(&key.value) + } + + if del { + hashmapBinaryDelete(v.pointer(), keyptr) + } else { + var elemptr unsafe.Pointer + if elem.isIndirect() || elem.typecode.Size() > unsafe.Sizeof(uintptr(0)) { + elemptr = elem.value + } else { + elemptr = unsafe.Pointer(&elem.value) + } + hashmapBinarySet(v.pointer(), keyptr, elemptr) + } + } else { + if del { + hashmapInterfaceDelete(v.pointer(), key.Interface()) + } else { + var elemptr unsafe.Pointer + if elem.isIndirect() || elem.typecode.Size() > unsafe.Sizeof(uintptr(0)) { + elemptr = elem.value + } else { + elemptr = unsafe.Pointer(&elem.value) + } + + hashmapInterfaceSet(v.pointer(), key.Interface(), elemptr) + } + } } // FieldByIndex returns the nested field corresponding to index. func (v Value) FieldByIndex(index []int) Value { - panic("unimplemented: (reflect.Value).FieldByIndex()") + if len(index) == 1 { + return v.Field(index[0]) + } + if v.Kind() != Struct { + panic(&ValueError{"FieldByIndex", v.Kind()}) + } + for i, x := range index { + if i > 0 { + if v.Kind() == Pointer && v.typecode.elem().Kind() == Struct { + if v.IsNil() { + panic("reflect: indirection through nil pointer to embedded struct") + } + v = v.Elem() + } + } + v = v.Field(x) + } + return v } // FieldByIndexErr returns the nested field corresponding to index. @@ -916,18 +1426,73 @@ func (v Value) FieldByIndexErr(index []int) (Value, error) { } func (v Value) FieldByName(name string) Value { - panic("unimplemented: (reflect.Value).FieldByName()") + if v.Kind() != Struct { + panic(&ValueError{"FieldByName", v.Kind()}) + } + + if field, ok := v.typecode.FieldByName(name); ok { + return v.FieldByIndex(field.Index) + } + return Value{} +} + +//go:linkname hashmapMake runtime.hashmapMakeUnsafePointer +func hashmapMake(keySize, valueSize uintptr, sizeHint uintptr, alg uint8) unsafe.Pointer + +// MakeMapWithSize creates a new map with the specified type and initial space +// for approximately n elements. +func MakeMapWithSize(typ Type, n int) Value { + + // TODO(dgryski): deduplicate these? runtime and reflect both need them. + const ( + hashmapAlgorithmBinary uint8 = iota + hashmapAlgorithmString + hashmapAlgorithmInterface + ) + + if typ.Kind() != Map { + panic(&ValueError{Method: "MakeMap", Kind: typ.Kind()}) + } + + if n < 0 { + panic("reflect.MakeMapWithSize: negative size hint") + } + + key := typ.Key().(*rawType) + val := typ.Elem().(*rawType) + + var alg uint8 + + if key.Kind() == String { + alg = hashmapAlgorithmString + } else if key.isBinary() { + alg = hashmapAlgorithmBinary + } else { + alg = hashmapAlgorithmInterface + } + + m := hashmapMake(key.Size(), val.Size(), uintptr(n), alg) + + return Value{ + typecode: typ.(*rawType), + value: m, + flags: valueFlagExported, + } } // MakeMap creates a new map with the specified type. func MakeMap(typ Type) Value { - panic("unimplemented: reflect.MakeMap()") + return MakeMapWithSize(typ, 8) } func (v Value) Call(in []Value) []Value { panic("unimplemented: (reflect.Value).Call()") } +func (v Value) Method(i int) Value { + panic("unimplemented: (reflect.Value).Method()") +} + func (v Value) MethodByName(name string) Value { panic("unimplemented: (reflect.Value).MethodByName()") } diff --git a/src/reflect/value_test.go b/src/reflect/value_test.go index 5698ede557..2bc8f92756 100644 --- a/src/reflect/value_test.go +++ b/src/reflect/value_test.go @@ -1,7 +1,9 @@ package reflect_test import ( + "encoding/base64" . "reflect" + "sort" "testing" ) @@ -30,3 +32,410 @@ func TestIndirectPointers(t *testing.T) { t.Errorf("bad indirect array index via reflect") } } + +func TestMap(t *testing.T) { + + m := make(map[string]int) + + mtyp := TypeOf(m) + + if got, want := mtyp.Key().Kind().String(), "string"; got != want { + t.Errorf("m.Type().Key().String()=%q, want %q", got, want) + } + + if got, want := mtyp.Elem().Kind().String(), "int"; got != want { + t.Errorf("m.Elem().String()=%q, want %q", got, want) + } + + m["foo"] = 2 + + mref := ValueOf(m) + two := mref.MapIndex(ValueOf("foo")) + + if got, want := two.Interface().(int), 2; got != want { + t.Errorf("MapIndex(`foo`)=%v, want %v", got, want) + } + + m["bar"] = 3 + m["baz"] = 4 + m["qux"] = 5 + + it := mref.MapRange() + + var gotKeys []string + for it.Next() { + k := it.Key() + v := it.Value() + + kstr := k.Interface().(string) + vint := v.Interface().(int) + + gotKeys = append(gotKeys, kstr) + + if m[kstr] != vint { + t.Errorf("m[%v]=%v, want %v", kstr, vint, m[kstr]) + } + } + var wantKeys []string + for k := range m { + wantKeys = append(wantKeys, k) + } + sort.Strings(gotKeys) + sort.Strings(wantKeys) + + if !equal(gotKeys, wantKeys) { + t.Errorf("MapRange return unexpected keys: got %v, want %v", gotKeys, wantKeys) + } + + refMapKeys := mref.MapKeys() + gotKeys = gotKeys[:0] + for _, v := range refMapKeys { + gotKeys = append(gotKeys, v.Interface().(string)) + } + + sort.Strings(gotKeys) + if !equal(gotKeys, wantKeys) { + t.Errorf("MapKeys return unexpected keys: got %v, want %v", gotKeys, wantKeys) + } + + mref.SetMapIndex(ValueOf("bar"), Value{}) + if _, ok := m["bar"]; ok { + t.Errorf("SetMapIndex failed to delete `bar`") + } + + mref.SetMapIndex(ValueOf("baz"), ValueOf(6)) + if got, want := m["baz"], 6; got != want { + t.Errorf("SetMapIndex(bar, 6) got %v, want %v", got, want) + } + + m2ref := MakeMap(mref.Type()) + m2ref.SetMapIndex(ValueOf("foo"), ValueOf(2)) + + m2 := m2ref.Interface().(map[string]int) + + if m2["foo"] != 2 { + t.Errorf("MakeMap failed to create map") + } + + type stringint struct { + s string + i int + } + + simap := make(map[stringint]int) + + refsimap := MakeMap(TypeOf(simap)) + + refsimap.SetMapIndex(ValueOf(stringint{"hello", 4}), ValueOf(6)) + + six := refsimap.MapIndex(ValueOf(stringint{"hello", 4})) + + if six.Interface().(int) != 6 { + t.Errorf("m[hello, 4]=%v, want 6", six) + } +} + +func TestSlice(t *testing.T) { + s := []int{0, 10, 20} + refs := ValueOf(s) + + for i := 3; i < 10; i++ { + refs = Append(refs, ValueOf(i*10)) + } + + s = refs.Interface().([]int) + + for i := 0; i < 10; i++ { + if s[i] != i*10 { + t.Errorf("s[%d]=%d, want %d", i, s[i], i*10) + } + } + + s28 := s[2:8] + s28ref := refs.Slice(2, 8) + + if len(s28) != s28ref.Len() || cap(s28) != s28ref.Cap() { + t.Errorf("Slice: len(s28)=%d s28ref.Len()=%d cap(s28)=%d s28ref.Cap()=%d\n", len(s28), s28ref.Len(), cap(s28), s28ref.Cap()) + } + + for i, got := range s28 { + want := int(s28ref.Index(i).Int()) + if got != want { + t.Errorf("s28[%d]=%d, want %d", i, got, want) + } + } + + s268 := s[2:6:8] + s268ref := refs.Slice3(2, 6, 8) + + if len(s268) != s268ref.Len() || cap(s268) != s268ref.Cap() { + t.Errorf("Slice3: len(s268)=%d s268ref.Len()=%d cap(s268)=%d s268ref.Cap()=%d\n", len(s268), s268ref.Len(), cap(s268), s268ref.Cap()) + } + + for i, got := range s268 { + want := int(s268ref.Index(i).Int()) + if got != want { + t.Errorf("s268[%d]=%d, want %d", i, got, want) + } + } + + // should be equivalent to s28 now, except for the capacity which doesn't change + s268ref.SetLen(6) + if len(s28) != s268ref.Len() || cap(s268) != s268ref.Cap() { + t.Errorf("SetLen: len(s268)=%d s268ref.Len()=%d cap(s268)=%d s268ref.Cap()=%d\n", len(s28), s268ref.Len(), cap(s268), s268ref.Cap()) + } + + for i, got := range s28 { + want := int(s268ref.Index(i).Int()) + if got != want { + t.Errorf("s28[%d]=%d, want %d", i, got, want) + } + } + + refs = MakeSlice(TypeOf(s), 5, 10) + s = refs.Interface().([]int) + + if len(s) != refs.Len() || cap(s) != refs.Cap() { + t.Errorf("len(s)=%v refs.Len()=%v cap(s)=%v refs.Cap()=%v", len(s), refs.Len(), cap(s), refs.Cap()) + } +} + +func TestBytes(t *testing.T) { + s := []byte("abcde") + refs := ValueOf(s) + + s2 := refs.Bytes() + + if !equal(s, s2) { + t.Errorf("Failed to get Bytes(): %v != %v", s, s2) + } + + Copy(refs, ValueOf("12345")) + + if string(s) != "12345" { + t.Errorf("Copy()=%q, want `12345`", string(s)) + } + + // test small arrays that fit in a pointer + a := [3]byte{10, 20, 30} + v := ValueOf(&a) + vslice := v.Elem().Bytes() + if len(vslice) != 3 || cap(vslice) != 3 { + t.Errorf("len(vslice)=%v, cap(vslice)=%v", len(vslice), cap(vslice)) + } + + for i, got := range vslice { + if want := (byte(i) + 1) * 10; got != want { + t.Errorf("vslice[%d]=%d, want %d", i, got, want) + } + } +} + +func TestNamedTypes(t *testing.T) { + type namedString string + + named := namedString("foo") + if got, want := TypeOf(named).Name(), "namedString"; got != want { + t.Errorf("TypeOf.Name()=%v, want %v", got, want) + } + + if got, want := TypeOf(named).String(), "reflect_test.namedString"; got != want { + t.Errorf("TypeOf.String()=%v, want %v", got, want) + } + + errorType := TypeOf((*error)(nil)).Elem() + if s := errorType.String(); s != "error" { + t.Errorf("error type = %v, want error", s) + } + + m := make(map[[4]uint16]string) + + if got, want := TypeOf(m).String(), "map[[4]uint16]string"; got != want { + t.Errorf("Type.String()=%v, want %v", got, want) + } + + s := struct { + a int8 + b int8 + c int8 + d int8 + e int8 + f int32 + }{} + + if got, want := TypeOf(s).String(), "struct { a int8; b int8; c int8; d int8; e int8; f int32 }"; got != want { + t.Errorf("Type.String()=%v, want %v", got, want) + } + + if got, want := ValueOf(m).String(), ""; got != want { + t.Errorf("Value.String()=%v, want %v", got, want) + } + + if got, want := TypeOf(base64.Encoding{}).String(), "base64.Encoding"; got != want { + t.Errorf("Type.String(base64.Encoding{})=%v, want %v", got, want) + } + + type Repository struct { + RoleName *string `json:"role_name,omitempty"` + } + + var repo *Repository + v := ValueOf(&repo).Elem() + n := New(v.Type().Elem()) + v.Set(n) +} + +func TestStruct(t *testing.T) { + type barStruct struct { + QuxString string + BazInt int + } + + type foobar struct { + Foo string `foo:"struct tag"` + Bar barStruct + } + + var fb foobar + fb.Bar.QuxString = "qux" + + reffb := TypeOf(fb) + + q := reffb.FieldByIndex([]int{1, 0}) + if want := "QuxString"; q.Name != want { + t.Errorf("FieldByIndex=%v, want %v", q.Name, want) + } + + var ok bool + q, ok = reffb.FieldByName("Foo") + if q.Name != "Foo" || !ok { + t.Errorf("FieldByName(Foo)=%v,%v, want Foo, true") + } + + if got, want := q.Tag, `foo:"struct tag"`; string(got) != want { + t.Errorf("StrucTag for Foo=%v, want %v", got, want) + } + + q, ok = reffb.FieldByName("Snorble") + if q.Name != "" || ok { + t.Errorf("FieldByName(Snorble)=%v,%v, want ``, false") + } +} + +func TestZero(t *testing.T) { + s := "hello, world" + var sptr *string = &s + v := ValueOf(&sptr).Elem() + v.Set(Zero(v.Type())) + + sptr = v.Interface().(*string) + + if sptr != nil { + t.Errorf("failed to set a nil string pointer") + } + + sl := []int{1, 2, 3} + v = ValueOf(&sl).Elem() + v.Set(Zero(v.Type())) + sl = v.Interface().([]int) + + if sl != nil { + t.Errorf("failed to set a nil slice") + } +} + +func addrDecode(body interface{}) { + vbody := ValueOf(body) + ptr := vbody.Elem() + pptr := ptr.Addr() + addrSetInt(pptr.Interface()) +} + +func addrSetInt(intf interface{}) { + ptr := intf.(*uint64) + *ptr = 112358 +} + +func TestAddr(t *testing.T) { + var n uint64 + addrDecode(&n) + if n != 112358 { + t.Errorf("Failed to set t=112358, got %v", n) + } + + v := ValueOf(&n) + if got, want := v.Elem().Addr().CanAddr(), false; got != want { + t.Errorf("Elem.Addr.CanAddr=%v, want %v", got, want) + } +} + +func TestNilType(t *testing.T) { + var a any = nil + typ := TypeOf(a) + if typ != nil { + t.Errorf("Type of any{nil} is not nil") + } +} + +func TestSetBytes(t *testing.T) { + var b []byte + refb := ValueOf(&b).Elem() + s := []byte("hello") + refb.SetBytes(s) + s[0] = 'b' + + refbSlice := refb.Interface().([]byte) + + if len(refbSlice) != len(s) || b[0] != s[0] || refbSlice[0] != s[0] { + t.Errorf("SetBytes(): reflection slice mismatch") + } +} + +type methodStruct struct { + i int +} + +func (m methodStruct) valueMethod1() int { + return m.i +} + +func (m methodStruct) valueMethod2() int { + return m.i +} + +func (m *methodStruct) pointerMethod1() int { + return m.i +} + +func (m *methodStruct) pointerMethod2() int { + return m.i +} + +func (m *methodStruct) pointerMethod3() int { + return m.i +} + +func TestNumMethods(t *testing.T) { + refptrt := TypeOf(&methodStruct{}) + if got, want := refptrt.NumMethod(), 2+3; got != want { + t.Errorf("Pointer Methods=%v, want %v", got, want) + } + + reft := refptrt.Elem() + if got, want := reft.NumMethod(), 2; got != want { + t.Errorf("Value Methods=%v, want %v", got, want) + } +} + +func equal[T comparable](a, b []T) bool { + if len(a) != len(b) { + return false + } + + for i, aa := range a { + if b[i] != aa { + return false + } + } + return true +} diff --git a/src/reflect/visiblefields.go b/src/reflect/visiblefields.go new file mode 100644 index 0000000000..9375faa110 --- /dev/null +++ b/src/reflect/visiblefields.go @@ -0,0 +1,105 @@ +// Copyright 2021 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package reflect + +// VisibleFields returns all the visible fields in t, which must be a +// struct type. A field is defined as visible if it's accessible +// directly with a FieldByName call. The returned fields include fields +// inside anonymous struct members and unexported fields. They follow +// the same order found in the struct, with anonymous fields followed +// immediately by their promoted fields. +// +// For each element e of the returned slice, the corresponding field +// can be retrieved from a value v of type t by calling v.FieldByIndex(e.Index). +func VisibleFields(t Type) []StructField { + if t == nil { + panic("reflect: VisibleFields(nil)") + } + if t.Kind() != Struct { + panic("reflect.VisibleFields of non-struct type") + } + w := &visibleFieldsWalker{ + byName: make(map[string]int), + visiting: make(map[Type]bool), + fields: make([]StructField, 0, t.NumField()), + index: make([]int, 0, 2), + } + w.walk(t) + // Remove all the fields that have been hidden. + // Use an in-place removal that avoids copying in + // the common case that there are no hidden fields. + j := 0 + for i := range w.fields { + f := &w.fields[i] + if f.Name == "" { + continue + } + if i != j { + // A field has been removed. We need to shuffle + // all the subsequent elements up. + w.fields[j] = *f + } + j++ + } + return w.fields[:j] +} + +type visibleFieldsWalker struct { + byName map[string]int + visiting map[Type]bool + fields []StructField + index []int +} + +// walk walks all the fields in the struct type t, visiting +// fields in index preorder and appending them to w.fields +// (this maintains the required ordering). +// Fields that have been overridden have their +// Name field cleared. +func (w *visibleFieldsWalker) walk(t Type) { + if w.visiting[t] { + return + } + w.visiting[t] = true + for i := 0; i < t.NumField(); i++ { + f := t.Field(i) + w.index = append(w.index, i) + add := true + if oldIndex, ok := w.byName[f.Name]; ok { + old := &w.fields[oldIndex] + if len(w.index) == len(old.Index) { + // Fields with the same name at the same depth + // cancel one another out. Set the field name + // to empty to signify that has happened, and + // there's no need to add this field. + old.Name = "" + add = false + } else if len(w.index) < len(old.Index) { + // The old field loses because it's deeper than the new one. + old.Name = "" + } else { + // The old field wins because it's shallower than the new one. + add = false + } + } + if add { + // Copy the index so that it's not overwritten + // by the other appends. + f.Index = append([]int(nil), w.index...) + w.byName[f.Name] = len(w.fields) + w.fields = append(w.fields, f) + } + if f.Anonymous { + if f.Type.Kind() == Pointer { + f.Type = f.Type.Elem() + } + if f.Type.Kind() == Struct { + w.walk(f.Type) + } + } + w.index = w.index[:len(w.index)-1] + } + delete(w.visiting, t) +} diff --git a/src/reflect/visiblefields_test.go b/src/reflect/visiblefields_test.go new file mode 100644 index 0000000000..e03198584e --- /dev/null +++ b/src/reflect/visiblefields_test.go @@ -0,0 +1,352 @@ +// Copyright 2021 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package reflect_test + +import ( + . "reflect" + "strings" + "testing" +) + +type structField struct { + name string + index []int +} + +var fieldsTests = []struct { + testName string + val any + expect []structField +}{{ + testName: "SimpleStruct", + val: struct { + A int + B string + C bool + }{}, + expect: []structField{{ + name: "A", + index: []int{0}, + }, { + name: "B", + index: []int{1}, + }, { + name: "C", + index: []int{2}, + }}, +}, { + testName: "NonEmbeddedStructMember", + val: struct { + A struct { + X int + } + }{}, + expect: []structField{{ + name: "A", + index: []int{0}, + }}, +}, { + testName: "EmbeddedExportedStruct", + val: struct { + SFG + }{}, + expect: []structField{{ + name: "SFG", + index: []int{0}, + }, { + name: "F", + index: []int{0, 0}, + }, { + name: "G", + index: []int{0, 1}, + }}, +}, { + testName: "EmbeddedUnexportedStruct", + val: struct { + sFG + }{}, + expect: []structField{{ + name: "sFG", + index: []int{0}, + }, { + name: "F", + index: []int{0, 0}, + }, { + name: "G", + index: []int{0, 1}, + }}, +}, { + testName: "TwoEmbeddedStructsWithCancelingMembers", + val: struct { + SFG + SF + }{}, + expect: []structField{{ + name: "SFG", + index: []int{0}, + }, { + name: "G", + index: []int{0, 1}, + }, { + name: "SF", + index: []int{1}, + }}, +}, { + testName: "EmbeddedStructsWithSameFieldsAtDifferentDepths", + val: struct { + SFGH3 + SG1 + SFG2 + SF2 + L int + }{}, + expect: []structField{{ + name: "SFGH3", + index: []int{0}, + }, { + name: "SFGH2", + index: []int{0, 0}, + }, { + name: "SFGH1", + index: []int{0, 0, 0}, + }, { + name: "SFGH", + index: []int{0, 0, 0, 0}, + }, { + name: "H", + index: []int{0, 0, 0, 0, 2}, + }, { + name: "SG1", + index: []int{1}, + }, { + name: "SG", + index: []int{1, 0}, + }, { + name: "G", + index: []int{1, 0, 0}, + }, { + name: "SFG2", + index: []int{2}, + }, { + name: "SFG1", + index: []int{2, 0}, + }, { + name: "SFG", + index: []int{2, 0, 0}, + }, { + name: "SF2", + index: []int{3}, + }, { + name: "SF1", + index: []int{3, 0}, + }, { + name: "SF", + index: []int{3, 0, 0}, + }, { + name: "L", + index: []int{4}, + }}, +}, { + testName: "EmbeddedPointerStruct", + val: struct { + *SF + }{}, + expect: []structField{{ + name: "SF", + index: []int{0}, + }, { + name: "F", + index: []int{0, 0}, + }}, +}, { + testName: "EmbeddedNotAPointer", + val: struct { + M + }{}, + expect: []structField{{ + name: "M", + index: []int{0}, + }}, +}, { + testName: "RecursiveEmbedding", + val: Rec1{}, + expect: []structField{{ + name: "Rec2", + index: []int{0}, + }, { + name: "F", + index: []int{0, 0}, + }, { + name: "Rec1", + index: []int{0, 1}, + }}, +}, { + testName: "RecursiveEmbedding2", + val: Rec2{}, + expect: []structField{{ + name: "F", + index: []int{0}, + }, { + name: "Rec1", + index: []int{1}, + }, { + name: "Rec2", + index: []int{1, 0}, + }}, +}, { + testName: "RecursiveEmbedding3", + val: RS3{}, + expect: []structField{{ + name: "RS2", + index: []int{0}, + }, { + name: "RS1", + index: []int{1}, + }, { + name: "i", + index: []int{1, 0}, + }}, +}} + +type SFG struct { + F int + G int +} + +type SFG1 struct { + SFG +} + +type SFG2 struct { + SFG1 +} + +type SFGH struct { + F int + G int + H int +} + +type SFGH1 struct { + SFGH +} + +type SFGH2 struct { + SFGH1 +} + +type SFGH3 struct { + SFGH2 +} + +type SF struct { + F int +} + +type SF1 struct { + SF +} + +type SF2 struct { + SF1 +} + +type SG struct { + G int +} + +type SG1 struct { + SG +} + +type sFG struct { + F int + G int +} + +type RS1 struct { + i int +} + +type RS2 struct { + RS1 +} + +type RS3 struct { + RS2 + RS1 +} + +type M map[string]any + +type Rec1 struct { + *Rec2 +} + +type Rec2 struct { + F string + *Rec1 +} + +func TestFields(t *testing.T) { + for _, test := range fieldsTests { + test := test + t.Run(test.testName, func(t *testing.T) { + typ := TypeOf(test.val) + fields := VisibleFields(typ) + if got, want := len(fields), len(test.expect); got != want { + t.Fatalf("unexpected field count; got %d want %d", got, want) + } + + for j, field := range fields { + expect := test.expect[j] + t.Logf("field %d: %s", j, expect.name) + gotField := typ.FieldByIndex(field.Index) + // Unfortunately, FieldByIndex does not return + // a field with the same index that we passed in, + // so we set it to the expected value so that + // it can be compared later with the result of FieldByName. + gotField.Index = field.Index + expectField := typ.FieldByIndex(expect.index) + // ditto. + expectField.Index = expect.index + if !DeepEqual(gotField, expectField) { + t.Fatalf("unexpected field result\ngot %#v\nwant %#v", gotField, expectField) + } + + // Sanity check that we can actually access the field by the + // expected name. + gotField1, ok := typ.FieldByName(expect.name) + if !ok { + t.Fatalf("field %q not accessible by name", expect.name) + } + if !DeepEqual(gotField1, expectField) { + t.Fatalf("unexpected FieldByName result; got %#v want %#v", gotField1, expectField) + } + } + }) + } +} + +// Must not panic with nil embedded pointer. +func TestFieldByIndexErr(t *testing.T) { + // TODO(dgryski): FieldByIndexErr not implemented yet -- skip + return + + type A struct { + S string + } + type B struct { + *A + } + v := ValueOf(B{}) + _, err := v.FieldByIndexErr([]int{0, 0}) + if err == nil { + t.Fatal("expected error") + } + if !strings.Contains(err.Error(), "embedded struct field A") { + t.Fatal(err) + } +} diff --git a/src/runtime/asm_arm64.S b/src/runtime/asm_arm64.S index 679f55fa18..267b639951 100644 --- a/src/runtime/asm_arm64.S +++ b/src/runtime/asm_arm64.S @@ -4,7 +4,6 @@ _tinygo_scanCurrentStack: #else .section .text.tinygo_scanCurrentStack .global tinygo_scanCurrentStack -.type tinygo_scanCurrentStack, %function tinygo_scanCurrentStack: #endif // Sources: @@ -12,12 +11,16 @@ tinygo_scanCurrentStack: // * https://godbolt.org/z/qrvrEh // Save callee-saved registers. - stp x29, x30, [sp, #-96]! + stp x29, x30, [sp, #-160]! stp x28, x27, [sp, #16] stp x26, x25, [sp, #32] stp x24, x23, [sp, #48] stp x22, x21, [sp, #64] stp x20, x19, [sp, #80] + stp d8, d9, [sp, #96] + stp d10, d11, [sp, #112] + stp d12, d13, [sp, #128] + stp d14, d15, [sp, #144] // Scan the stack. mov x0, sp @@ -28,7 +31,7 @@ tinygo_scanCurrentStack: #endif // Restore stack state and return. - ldp x29, x30, [sp], #96 + ldp x29, x30, [sp], #160 ret @@ -38,7 +41,6 @@ _tinygo_longjmp: #else .section .text.tinygo_longjmp .global tinygo_longjmp -.type tinygo_longjmp, %function tinygo_longjmp: #endif // Note: the code we jump to assumes x0 is set to a non-zero value if we diff --git a/src/runtime/asm_arm64_windows.S b/src/runtime/asm_arm64_windows.S deleted file mode 100644 index 2437789978..0000000000 --- a/src/runtime/asm_arm64_windows.S +++ /dev/null @@ -1,36 +0,0 @@ -.section .text.tinygo_scanCurrentStack,"ax" -.global tinygo_scanCurrentStack -tinygo_scanCurrentStack: - // Sources: - // * https://learn.microsoft.com/en-us/cpp/build/arm64-windows-abi-conventions?view=msvc-170 - // * https://godbolt.org/z/foc1xncvb - - // Save callee-saved registers. - stp x29, x30, [sp, #-160]! - stp x28, x27, [sp, #16] - stp x26, x25, [sp, #32] - stp x24, x23, [sp, #48] - stp x22, x21, [sp, #64] - stp x20, x19, [sp, #80] - stp d8, d9, [sp, #96] - stp d10, d11, [sp, #112] - stp d12, d13, [sp, #128] - stp d14, d15, [sp, #144] - - // Scan the stack. - mov x0, sp - bl tinygo_scanstack - - // Restore stack state and return. - ldp x29, x30, [sp], #160 - ret - - -.section .text.tinygo_longjmp,"ax" -.global tinygo_longjmp -tinygo_longjmp: - // Note: the code we jump to assumes x0 is set to a non-zero value if we - // jump from here (which is conveniently already the case). - ldp x1, x2, [x0] // jumpSP, jumpPC - mov sp, x1 - br x2 diff --git a/src/runtime/chan.go b/src/runtime/chan.go index 5bc05929dc..f1abac4d08 100644 --- a/src/runtime/chan.go +++ b/src/runtime/chan.go @@ -243,13 +243,8 @@ func (ch *channel) push(value unsafe.Pointer) bool { // copy value to buffer memcpy( - unsafe.Pointer( // pointer to the base of the buffer + offset = pointer to destination element - uintptr(ch.buf)+ - uintptr( // element size * equivalent slice index = offset - ch.elementSize* // element size (bytes) - ch.bufHead, // index of first available buffer entry - ), - ), + unsafe.Add(ch.buf, // pointer to the base of the buffer + offset = pointer to destination element + ch.elementSize*ch.bufHead), // element size * equivalent slice index = offset value, ch.elementSize, ) @@ -274,7 +269,7 @@ func (ch *channel) pop(value unsafe.Pointer) bool { } // compute address of source - addr := unsafe.Pointer(uintptr(ch.buf) + (ch.elementSize * ch.bufTail)) + addr := unsafe.Add(ch.buf, (ch.elementSize * ch.bufTail)) // copy value from buffer memcpy( diff --git a/src/runtime/debug/debug.go b/src/runtime/debug/debug.go index 3bbbf71a23..fe3650cb2c 100644 --- a/src/runtime/debug/debug.go +++ b/src/runtime/debug/debug.go @@ -45,3 +45,8 @@ type Module struct { Sum string // checksum Replace *Module // replaced by this module } + +// Not implemented. +func SetGCPercent(n int) int { + return n +} diff --git a/src/runtime/dynamic_arm64.go b/src/runtime/dynamic_arm64.go index 645c797539..e167f6f2e6 100644 --- a/src/runtime/dynamic_arm64.go +++ b/src/runtime/dynamic_arm64.go @@ -43,9 +43,9 @@ func dynamicLoader(base uintptr, dyn *dyn64) { relasz = uint64(dyn.Val) / uint64(unsafe.Sizeof(rela64{})) } - ptr := uintptr(unsafe.Pointer(dyn)) - ptr += unsafe.Sizeof(dyn64{}) - dyn = (*dyn64)(unsafe.Pointer(ptr)) + ptr := unsafe.Pointer(dyn) + ptr = unsafe.Add(ptr, unsafe.Sizeof(dyn64{})) + dyn = (*dyn64)(ptr) } if rela == nil { @@ -70,9 +70,9 @@ func dynamicLoader(base uintptr, dyn *dyn64) { } } - rptr := uintptr(unsafe.Pointer(rela)) - rptr += unsafe.Sizeof(rela64{}) - rela = (*rela64)(unsafe.Pointer(rptr)) + rptr := unsafe.Pointer(rela) + rptr = unsafe.Add(rptr, unsafe.Sizeof(rela64{})) + rela = (*rela64)(rptr) relasz-- } } diff --git a/src/runtime/gc_blocks.go b/src/runtime/gc_blocks.go index 59aebb2ecd..54c3cb9130 100644 --- a/src/runtime/gc_blocks.go +++ b/src/runtime/gc_blocks.go @@ -146,7 +146,7 @@ func (b gcBlock) findNext() gcBlock { // State returns the current block state. func (b gcBlock) state() blockState { - stateBytePtr := (*uint8)(unsafe.Pointer(uintptr(metadataStart) + uintptr(b/blocksPerStateByte))) + stateBytePtr := (*uint8)(unsafe.Add(metadataStart, b/blocksPerStateByte)) return blockState(*stateBytePtr>>((b%blocksPerStateByte)*stateBits)) & blockStateMask } @@ -154,7 +154,7 @@ func (b gcBlock) state() blockState { // bits than the current state. Allowed transitions: from free to any state and // from head to mark. func (b gcBlock) setState(newState blockState) { - stateBytePtr := (*uint8)(unsafe.Pointer(uintptr(metadataStart) + uintptr(b/blocksPerStateByte))) + stateBytePtr := (*uint8)(unsafe.Add(metadataStart, b/blocksPerStateByte)) *stateBytePtr |= uint8(newState << ((b % blocksPerStateByte) * stateBits)) if gcAsserts && b.state() != newState { runtimePanic("gc: setState() was not successful") @@ -163,7 +163,7 @@ func (b gcBlock) setState(newState blockState) { // markFree sets the block state to free, no matter what state it was in before. func (b gcBlock) markFree() { - stateBytePtr := (*uint8)(unsafe.Pointer(uintptr(metadataStart) + uintptr(b/blocksPerStateByte))) + stateBytePtr := (*uint8)(unsafe.Add(metadataStart, b/blocksPerStateByte)) *stateBytePtr &^= uint8(blockStateMask << ((b % blocksPerStateByte) * stateBits)) if gcAsserts && b.state() != blockStateFree { runtimePanic("gc: markFree() was not successful") @@ -180,7 +180,7 @@ func (b gcBlock) unmark() { runtimePanic("gc: unmark() on a block that is not marked") } clearMask := blockStateMask ^ blockStateHead // the bits to clear from the state - stateBytePtr := (*uint8)(unsafe.Pointer(uintptr(metadataStart) + uintptr(b/blocksPerStateByte))) + stateBytePtr := (*uint8)(unsafe.Add(metadataStart, b/blocksPerStateByte)) *stateBytePtr &^= uint8(clearMask << ((b % blocksPerStateByte) * stateBits)) if gcAsserts && b.state() != blockStateHead { runtimePanic("gc: unmark() was not successful") @@ -277,6 +277,10 @@ func alloc(size uintptr, layout unsafe.Pointer) unsafe.Pointer { size += align(unsafe.Sizeof(layout)) } + if interrupt.In() { + runtimePanic("alloc in interrupt") + } + gcTotalAlloc += uint64(size) gcMallocs++ @@ -687,3 +691,7 @@ func ReadMemStats(m *MemStats) { m.Frees = gcFrees m.Sys = uint64(heapEnd - heapStart) } + +func SetFinalizer(obj interface{}, finalizer interface{}) { + // Unimplemented. +} diff --git a/src/runtime/gc_custom.go b/src/runtime/gc_custom.go index 45857338f5..a34b7dce69 100644 --- a/src/runtime/gc_custom.go +++ b/src/runtime/gc_custom.go @@ -20,6 +20,7 @@ package runtime // - func free(ptr unsafe.Pointer) // - func markRoots(start, end uintptr) // - func GC() +// - func SetFinalizer(obj interface{}, finalizer interface{}) // - func ReadMemStats(ms *runtime.MemStats) // // @@ -51,6 +52,9 @@ func markRoots(start, end uintptr) // GC is called to explicitly run garbage collection. func GC() +// SetFinalizer registers a finalizer. +func SetFinalizer(obj interface{}, finalizer interface{}) + // ReadMemStats populates m with memory statistics. func ReadMemStats(ms *MemStats) diff --git a/src/runtime/gc_leaking.go b/src/runtime/gc_leaking.go index 2f2bdff17e..d99b8d125b 100644 --- a/src/runtime/gc_leaking.go +++ b/src/runtime/gc_leaking.go @@ -85,6 +85,10 @@ func GC() { // No-op. } +func SetFinalizer(obj interface{}, finalizer interface{}) { + // No-op. +} + func initHeap() { // preinit() may have moved heapStart; reset heapptr heapptr = heapStart diff --git a/src/runtime/gc_none.go b/src/runtime/gc_none.go index 859c0b5dbe..98636f5c4b 100644 --- a/src/runtime/gc_none.go +++ b/src/runtime/gc_none.go @@ -26,6 +26,10 @@ func GC() { // Unimplemented. } +func SetFinalizer(obj interface{}, finalizer interface{}) { + // Unimplemented. +} + func initHeap() { // Nothing to initialize. } diff --git a/src/runtime/hashmap.go b/src/runtime/hashmap.go index 684d9cfbf5..ab85aee876 100644 --- a/src/runtime/hashmap.go +++ b/src/runtime/hashmap.go @@ -49,6 +49,10 @@ type hashmapIterator struct { bucketIndex uint8 // current index into bucket } +func hashmapNewIterator() unsafe.Pointer { + return unsafe.Pointer(new(hashmapIterator)) +} + // Get the topmost 8 bits of the hash, without using a special value (like 0). func hashmapTopHash(hash uint32) uint8 { tophash := uint8(hash >> 24) @@ -84,6 +88,10 @@ func hashmapMake(keySize, valueSize uintptr, sizeHint uintptr, alg uint8) *hashm } } +func hashmapMakeUnsafePointer(keySize, valueSize uintptr, sizeHint uintptr, alg uint8) unsafe.Pointer { + return (unsafe.Pointer)(hashmapMake(keySize, valueSize, sizeHint, alg)) +} + func hashmapKeyEqualAlg(alg hashmapAlgorithm) func(x, y unsafe.Pointer, n uintptr) bool { switch alg { case hashmapAlgorithmBinary: @@ -142,10 +150,8 @@ func hashmapLen(m *hashmap) int { return int(m.count) } -// wrapper for use in reflect -func hashmapLenUnsafePointer(p unsafe.Pointer) int { - m := (*hashmap)(p) - return hashmapLen(m) +func hashmapLenUnsafePointer(m unsafe.Pointer) int { + return hashmapLen((*hashmap)(m)) } // Set a specified key to a given value. Grow the map if necessary. @@ -163,8 +169,7 @@ func hashmapSet(m *hashmap, key unsafe.Pointer, value unsafe.Pointer, hash uint3 numBuckets := uintptr(1) << m.bucketBits bucketNumber := (uintptr(hash) & (numBuckets - 1)) bucketSize := unsafe.Sizeof(hashmapBucket{}) + m.keySize*8 + m.valueSize*8 - bucketAddr := uintptr(m.buckets) + bucketSize*bucketNumber - bucket := (*hashmapBucket)(unsafe.Pointer(bucketAddr)) + bucket := (*hashmapBucket)(unsafe.Add(m.buckets, bucketSize*bucketNumber)) var lastBucket *hashmapBucket // See whether the key already exists somewhere. @@ -174,9 +179,9 @@ func hashmapSet(m *hashmap, key unsafe.Pointer, value unsafe.Pointer, hash uint3 for bucket != nil { for i := uintptr(0); i < 8; i++ { slotKeyOffset := unsafe.Sizeof(hashmapBucket{}) + m.keySize*uintptr(i) - slotKey := unsafe.Pointer(uintptr(unsafe.Pointer(bucket)) + slotKeyOffset) + slotKey := unsafe.Add(unsafe.Pointer(bucket), slotKeyOffset) slotValueOffset := unsafe.Sizeof(hashmapBucket{}) + m.keySize*8 + m.valueSize*uintptr(i) - slotValue := unsafe.Pointer(uintptr(unsafe.Pointer(bucket)) + slotValueOffset) + slotValue := unsafe.Add(unsafe.Pointer(bucket), slotValueOffset) if bucket.tophash[i] == 0 && emptySlotKey == nil { // Found an empty slot, store it for if we couldn't find an // existing slot. @@ -208,6 +213,10 @@ func hashmapSet(m *hashmap, key unsafe.Pointer, value unsafe.Pointer, hash uint3 *emptySlotTophash = tophash } +func hashmapSetUnsafePointer(m unsafe.Pointer, key unsafe.Pointer, value unsafe.Pointer, hash uint32) { + hashmapSet((*hashmap)(m), key, value, hash) +} + // hashmapInsertIntoNewBucket creates a new bucket, inserts the given key and // value into the bucket, and returns a pointer to this bucket. func hashmapInsertIntoNewBucket(m *hashmap, key, value unsafe.Pointer, tophash uint8) *hashmapBucket { @@ -215,9 +224,9 @@ func hashmapInsertIntoNewBucket(m *hashmap, key, value unsafe.Pointer, tophash u bucketBuf := alloc(bucketBufSize, nil) // Insert into the first slot, which is empty as it has just been allocated. slotKeyOffset := unsafe.Sizeof(hashmapBucket{}) - slotKey := unsafe.Pointer(uintptr(bucketBuf) + slotKeyOffset) + slotKey := unsafe.Add(bucketBuf, slotKeyOffset) slotValueOffset := unsafe.Sizeof(hashmapBucket{}) + m.keySize*8 - slotValue := unsafe.Pointer(uintptr(bucketBuf) + slotValueOffset) + slotValue := unsafe.Add(bucketBuf, slotValueOffset) m.count++ memcpy(slotKey, key, m.keySize) memcpy(slotValue, value, m.valueSize) @@ -266,8 +275,7 @@ func hashmapGet(m *hashmap, key, value unsafe.Pointer, valueSize uintptr, hash u numBuckets := uintptr(1) << m.bucketBits bucketNumber := (uintptr(hash) & (numBuckets - 1)) bucketSize := unsafe.Sizeof(hashmapBucket{}) + m.keySize*8 + m.valueSize*8 - bucketAddr := uintptr(m.buckets) + bucketSize*bucketNumber - bucket := (*hashmapBucket)(unsafe.Pointer(bucketAddr)) + bucket := (*hashmapBucket)(unsafe.Add(m.buckets, bucketSize*bucketNumber)) tophash := uint8(hash >> 24) if tophash < 1 { @@ -279,9 +287,9 @@ func hashmapGet(m *hashmap, key, value unsafe.Pointer, valueSize uintptr, hash u for bucket != nil { for i := uintptr(0); i < 8; i++ { slotKeyOffset := unsafe.Sizeof(hashmapBucket{}) + m.keySize*uintptr(i) - slotKey := unsafe.Pointer(uintptr(unsafe.Pointer(bucket)) + slotKeyOffset) + slotKey := unsafe.Add(unsafe.Pointer(bucket), slotKeyOffset) slotValueOffset := unsafe.Sizeof(hashmapBucket{}) + m.keySize*8 + m.valueSize*uintptr(i) - slotValue := unsafe.Pointer(uintptr(unsafe.Pointer(bucket)) + slotValueOffset) + slotValue := unsafe.Add(unsafe.Pointer(bucket), slotValueOffset) if bucket.tophash[i] == tophash { // This could be the key we're looking for. if m.keyEqual(key, slotKey, m.keySize) { @@ -299,6 +307,10 @@ func hashmapGet(m *hashmap, key, value unsafe.Pointer, valueSize uintptr, hash u return false } +func hashmapGetUnsafePointer(m unsafe.Pointer, key, value unsafe.Pointer, valueSize uintptr, hash uint32) bool { + return hashmapGet((*hashmap)(m), key, value, valueSize, hash) +} + // Delete a given key from the map. No-op when the key does not exist in the // map. // @@ -313,8 +325,7 @@ func hashmapDelete(m *hashmap, key unsafe.Pointer, hash uint32) { numBuckets := uintptr(1) << m.bucketBits bucketNumber := (uintptr(hash) & (numBuckets - 1)) bucketSize := unsafe.Sizeof(hashmapBucket{}) + m.keySize*8 + m.valueSize*8 - bucketAddr := uintptr(m.buckets) + bucketSize*bucketNumber - bucket := (*hashmapBucket)(unsafe.Pointer(bucketAddr)) + bucket := (*hashmapBucket)(unsafe.Add(m.buckets, bucketSize*bucketNumber)) tophash := uint8(hash >> 24) if tophash < 1 { @@ -326,7 +337,7 @@ func hashmapDelete(m *hashmap, key unsafe.Pointer, hash uint32) { for bucket != nil { for i := uintptr(0); i < 8; i++ { slotKeyOffset := unsafe.Sizeof(hashmapBucket{}) + m.keySize*uintptr(i) - slotKey := unsafe.Pointer(uintptr(unsafe.Pointer(bucket)) + slotKeyOffset) + slotKey := unsafe.Add(unsafe.Pointer(bucket), slotKeyOffset) if bucket.tophash[i] == tophash { // This could be the key we're looking for. if m.keyEqual(key, slotKey, m.keySize) { @@ -368,8 +379,7 @@ func hashmapNext(m *hashmap, it *hashmapIterator, key, value unsafe.Pointer) boo return false } bucketSize := unsafe.Sizeof(hashmapBucket{}) + m.keySize*8 + m.valueSize*8 - bucketAddr := uintptr(it.buckets) + bucketSize*it.bucketNumber - it.bucket = (*hashmapBucket)(unsafe.Pointer(bucketAddr)) + it.bucket = (*hashmapBucket)(unsafe.Add(it.buckets, bucketSize*it.bucketNumber)) it.bucketNumber++ // next bucket } if it.bucket.tophash[it.bucketIndex] == 0 { @@ -378,16 +388,15 @@ func hashmapNext(m *hashmap, it *hashmapIterator, key, value unsafe.Pointer) boo continue } - bucketAddr := uintptr(unsafe.Pointer(it.bucket)) slotKeyOffset := unsafe.Sizeof(hashmapBucket{}) + m.keySize*uintptr(it.bucketIndex) - slotKey := unsafe.Pointer(bucketAddr + slotKeyOffset) + slotKey := unsafe.Add(unsafe.Pointer(it.bucket), slotKeyOffset) memcpy(key, slotKey, m.keySize) if it.buckets == m.buckets { // Our view of the buckets is the same as the parent map. // Just copy the value we have slotValueOffset := unsafe.Sizeof(hashmapBucket{}) + m.keySize*8 + m.valueSize*uintptr(it.bucketIndex) - slotValue := unsafe.Pointer(bucketAddr + slotValueOffset) + slotValue := unsafe.Add(unsafe.Pointer(it.bucket), slotValueOffset) memcpy(value, slotValue, m.valueSize) it.bucketIndex++ } else { @@ -409,6 +418,10 @@ func hashmapNext(m *hashmap, it *hashmapIterator, key, value unsafe.Pointer) boo } } +func hashmapNextUnsafePointer(m unsafe.Pointer, it unsafe.Pointer, key, value unsafe.Pointer) bool { + return hashmapNext((*hashmap)(m), (*hashmapIterator)(it), key, value) +} + // Hashmap with plain binary data keys (not containing strings etc.). func hashmapBinarySet(m *hashmap, key, value unsafe.Pointer) { if m == nil { @@ -418,6 +431,10 @@ func hashmapBinarySet(m *hashmap, key, value unsafe.Pointer) { hashmapSet(m, key, value, hash) } +func hashmapBinarySetUnsafePointer(m unsafe.Pointer, key, value unsafe.Pointer) { + hashmapBinarySet((*hashmap)(m), key, value) +} + func hashmapBinaryGet(m *hashmap, key, value unsafe.Pointer, valueSize uintptr) bool { if m == nil { memzero(value, uintptr(valueSize)) @@ -427,6 +444,10 @@ func hashmapBinaryGet(m *hashmap, key, value unsafe.Pointer, valueSize uintptr) return hashmapGet(m, key, value, valueSize, hash) } +func hashmapBinaryGetUnsafePointer(m unsafe.Pointer, key, value unsafe.Pointer, valueSize uintptr) bool { + return hashmapBinaryGet((*hashmap)(m), key, value, valueSize) +} + func hashmapBinaryDelete(m *hashmap, key unsafe.Pointer) { if m == nil { return @@ -435,6 +456,10 @@ func hashmapBinaryDelete(m *hashmap, key unsafe.Pointer) { hashmapDelete(m, key, hash) } +func hashmapBinaryDeleteUnsafePointer(m unsafe.Pointer, key unsafe.Pointer) { + hashmapBinaryDelete((*hashmap)(m), key) +} + // Hashmap with string keys (a common case). func hashmapStringEqual(x, y unsafe.Pointer, n uintptr) bool { @@ -459,6 +484,10 @@ func hashmapStringSet(m *hashmap, key string, value unsafe.Pointer) { hashmapSet(m, unsafe.Pointer(&key), value, hash) } +func hashmapStringSetUnsafePointer(m unsafe.Pointer, key string, value unsafe.Pointer) { + hashmapStringSet((*hashmap)(m), key, value) +} + func hashmapStringGet(m *hashmap, key string, value unsafe.Pointer, valueSize uintptr) bool { if m == nil { memzero(value, uintptr(valueSize)) @@ -468,6 +497,10 @@ func hashmapStringGet(m *hashmap, key string, value unsafe.Pointer, valueSize ui return hashmapGet(m, unsafe.Pointer(&key), value, valueSize, hash) } +func hashmapStringGetUnsafePointer(m unsafe.Pointer, key string, value unsafe.Pointer, valueSize uintptr) bool { + return hashmapStringGet((*hashmap)(m), key, value, valueSize) +} + func hashmapStringDelete(m *hashmap, key string) { if m == nil { return @@ -476,6 +509,10 @@ func hashmapStringDelete(m *hashmap, key string) { hashmapDelete(m, unsafe.Pointer(&key), hash) } +func hashmapStringDeleteUnsafePointer(m unsafe.Pointer, key string) { + hashmapStringDelete((*hashmap)(m), key) +} + // Hashmap with interface keys (for everything else). // This is a method that is intentionally unexported in the reflect package. It @@ -506,7 +543,7 @@ func hashmapFloat64Hash(ptr unsafe.Pointer, seed uintptr) uint32 { func hashmapInterfaceHash(itf interface{}, seed uintptr) uint32 { x := reflect.ValueOf(itf) - if x.RawType() == 0 { + if x.RawType() == nil { return 0 // nil interface } @@ -532,10 +569,10 @@ func hashmapInterfaceHash(itf interface{}, seed uintptr) uint32 { case reflect.Float64: return hashmapFloat64Hash(ptr, seed) case reflect.Complex64: - rptr, iptr := ptr, unsafe.Pointer(uintptr(ptr)+4) + rptr, iptr := ptr, unsafe.Add(ptr, 4) return hashmapFloat32Hash(rptr, seed) ^ hashmapFloat32Hash(iptr, seed) case reflect.Complex128: - rptr, iptr := ptr, unsafe.Pointer(uintptr(ptr)+8) + rptr, iptr := ptr, unsafe.Add(ptr, 8) return hashmapFloat64Hash(rptr, seed) ^ hashmapFloat64Hash(iptr, seed) case reflect.String: return hashmapStringHash(x.String(), seed) @@ -563,7 +600,7 @@ func hashmapInterfaceHash(itf interface{}, seed uintptr) uint32 { } func hashmapInterfacePtrHash(iptr unsafe.Pointer, size uintptr, seed uintptr) uint32 { - _i := *(*_interface)(iptr) + _i := *(*interface{})(iptr) return hashmapInterfaceHash(_i, seed) } @@ -579,6 +616,10 @@ func hashmapInterfaceSet(m *hashmap, key interface{}, value unsafe.Pointer) { hashmapSet(m, unsafe.Pointer(&key), value, hash) } +func hashmapInterfaceSetUnsafePointer(m unsafe.Pointer, key interface{}, value unsafe.Pointer) { + hashmapInterfaceSet((*hashmap)(m), key, value) +} + func hashmapInterfaceGet(m *hashmap, key interface{}, value unsafe.Pointer, valueSize uintptr) bool { if m == nil { memzero(value, uintptr(valueSize)) @@ -588,6 +629,10 @@ func hashmapInterfaceGet(m *hashmap, key interface{}, value unsafe.Pointer, valu return hashmapGet(m, unsafe.Pointer(&key), value, valueSize, hash) } +func hashmapInterfaceGetUnsafePointer(m unsafe.Pointer, key interface{}, value unsafe.Pointer, valueSize uintptr) bool { + return hashmapInterfaceGet((*hashmap)(m), key, value, valueSize) +} + func hashmapInterfaceDelete(m *hashmap, key interface{}) { if m == nil { return @@ -595,3 +640,7 @@ func hashmapInterfaceDelete(m *hashmap, key interface{}) { hash := hashmapInterfaceHash(key, m.seed) hashmapDelete(m, unsafe.Pointer(&key), hash) } + +func hashmapInterfaceDeleteUnsafePointer(m unsafe.Pointer, key interface{}) { + hashmapInterfaceDelete((*hashmap)(m), key) +} diff --git a/src/runtime/interface.go b/src/runtime/interface.go index 63fd69ec81..8718c140a0 100644 --- a/src/runtime/interface.go +++ b/src/runtime/interface.go @@ -11,17 +11,17 @@ import ( ) type _interface struct { - typecode uintptr + typecode unsafe.Pointer value unsafe.Pointer } //go:inline -func composeInterface(typecode uintptr, value unsafe.Pointer) _interface { +func composeInterface(typecode, value unsafe.Pointer) _interface { return _interface{typecode, value} } //go:inline -func decomposeInterface(i _interface) (uintptr, unsafe.Pointer) { +func decomposeInterface(i _interface) (unsafe.Pointer, unsafe.Pointer) { return i.typecode, i.value } @@ -34,7 +34,7 @@ func reflectValueEqual(x, y reflect.Value) bool { // Note: doing a x.Type() == y.Type() comparison would not work here as that // would introduce an infinite recursion: comparing two reflect.Type values // is done with this reflectValueEqual runtime call. - if x.RawType() == 0 || y.RawType() == 0 { + if x.RawType() == nil || y.RawType() == nil { // One of them is nil. return x.RawType() == y.RawType() } @@ -94,48 +94,13 @@ func interfaceTypeAssert(ok bool) { // lowered to inline IR in the interface lowering pass. // See compiler/interface-lowering.go for details. -type interfaceMethodInfo struct { - signature *uint8 // external *i8 with a name identifying the Go function signature - funcptr uintptr // bitcast from the actual function pointer -} - -type typecodeID struct { - // Depending on the type kind of this typecodeID, this pointer is something - // different: - // * basic types: null - // * named type: the underlying type - // * interface: null - // * chan/pointer/slice/array: the element type - // * struct: bitcast of global with structField array - // * func/map: TODO - references *typecodeID - - // The array length, for array types. - length uintptr - - methodSet *interfaceMethodInfo // nil or a GEP of an array - - // The type that's a pointer to this type, nil if it is already a pointer. - // Keeping the type struct alive here is important so that values from - // reflect.New (which uses reflect.PtrTo) can be used in type asserts etc. - ptrTo *typecodeID - - // typeAssert is a ptrtoint of a declared interface assert function. - // It only exists to make the rtcalls pass easier. - typeAssert uintptr -} - -// structField is used by the compiler to pass information to the interface -// lowering pass. It is not used in the final binary. type structField struct { - typecode *typecodeID // type of this struct field - name *uint8 // pointer to char array - tag *uint8 // pointer to char array, or nil - embedded bool + typecode unsafe.Pointer // type of this struct field + data *uint8 // pointer to byte array containing name, tag, and 'embedded' flag } // Pseudo function call used during a type assert. It is used during interface // lowering, to assign the lowest type numbers to the types with the most type // asserts. Also, it is replaced with const false if this type assert can never // happen. -func typeAssert(actualType uintptr, assertedType *uint8) bool +func typeAssert(actualType unsafe.Pointer, assertedType *uint8) bool diff --git a/src/runtime/interrupt/interrupt_avr.go b/src/runtime/interrupt/interrupt_avr.go index 25c8caa023..0af71a89e5 100644 --- a/src/runtime/interrupt/interrupt_avr.go +++ b/src/runtime/interrupt/interrupt_avr.go @@ -34,3 +34,12 @@ func Restore(state State) { "state": state, }) } + +// In returns whether the system is currently in an interrupt. +// +// Warning: this always returns false on AVR, as there does not appear to be a +// reliable way to determine whether we're currently running inside an interrupt +// handler. +func In() bool { + return false +} diff --git a/src/runtime/interrupt/interrupt_cortexm.go b/src/runtime/interrupt/interrupt_cortexm.go index a7127c3ef6..a34dff7879 100644 --- a/src/runtime/interrupt/interrupt_cortexm.go +++ b/src/runtime/interrupt/interrupt_cortexm.go @@ -50,3 +50,13 @@ func Disable() (state State) { func Restore(state State) { arm.EnableInterrupts(uintptr(state)) } + +// In returns whether the system is currently in an interrupt. +func In() bool { + // The VECTACTIVE field gives the instruction vector that is currently + // active (in handler mode), or 0 if not in an interrupt. + // Documentation: + // https://developer.arm.com/documentation/dui0497/a/cortex-m0-peripherals/system-control-block/interrupt-control-and-state-register + vectactive := uint8(arm.SCB.ICSR.Get()) + return vectactive != 0 +} diff --git a/src/runtime/interrupt/interrupt_esp32c3.go b/src/runtime/interrupt/interrupt_esp32c3.go index 5a9337ed15..7d9be3937e 100644 --- a/src/runtime/interrupt/interrupt_esp32c3.go +++ b/src/runtime/interrupt/interrupt_esp32c3.go @@ -34,7 +34,7 @@ func (i Interrupt) Enable() error { esp.INTERRUPT_CORE0.CPU_INT_TYPE.SetBits(1 << i.num) // Set default threshold to defaultThreshold - reg := (*volatile.Register32)(unsafe.Pointer((uintptr(unsafe.Pointer(&esp.INTERRUPT_CORE0.CPU_INT_PRI_0)) + uintptr(i.num)*4))) + reg := (*volatile.Register32)(unsafe.Add(unsafe.Pointer(&esp.INTERRUPT_CORE0.CPU_INT_PRI_0), i.num*4)) reg.Set(defaultThreshold) // Reset interrupt before reenabling @@ -171,7 +171,7 @@ func handleInterrupt() { mepc := riscv.MEPC.Get() // Useing threshold to temporary disable this interrupts. // FYI: using CPU interrupt enable bit make runtime to loose interrupts. - reg := (*volatile.Register32)(unsafe.Pointer((uintptr(unsafe.Pointer(&esp.INTERRUPT_CORE0.CPU_INT_PRI_0)) + uintptr(interruptNumber)*4))) + reg := (*volatile.Register32)(unsafe.Add(unsafe.Pointer(&esp.INTERRUPT_CORE0.CPU_INT_PRI_0), interruptNumber*4)) thresholdSave := reg.Get() reg.Set(disableThreshold) riscv.Asm("fence") diff --git a/src/runtime/interrupt/interrupt_gameboyadvance.go b/src/runtime/interrupt/interrupt_gameboyadvance.go index 72f2a81940..13f5fbe09d 100644 --- a/src/runtime/interrupt/interrupt_gameboyadvance.go +++ b/src/runtime/interrupt/interrupt_gameboyadvance.go @@ -2,49 +2,31 @@ package interrupt -import ( - "runtime/volatile" - "unsafe" -) - -const ( - IRQ_VBLANK = 0 - IRQ_HBLANK = 1 - IRQ_VCOUNT = 2 - IRQ_TIMER0 = 3 - IRQ_TIMER1 = 4 - IRQ_TIMER2 = 5 - IRQ_TIMER3 = 6 - IRQ_COM = 7 - IRQ_DMA0 = 8 - IRQ_DMA1 = 9 - IRQ_DMA2 = 10 - IRQ_DMA3 = 11 - IRQ_KEYPAD = 12 - IRQ_GAMEPAK = 13 -) +// This is good documentation of the GBA: https://www.akkit.org/info/gbatek.htm -var ( - regInterruptEnable = (*volatile.Register16)(unsafe.Pointer(uintptr(0x4000200))) - regInterruptRequestFlags = (*volatile.Register16)(unsafe.Pointer(uintptr(0x4000202))) - regGlobalInterruptEnable = (*volatile.Register16)(unsafe.Pointer(uintptr(0x4000208))) +import ( + "device/gba" ) // Enable enables this interrupt. Right after calling this function, the // interrupt may be invoked if it was already pending. func (irq Interrupt) Enable() { - regInterruptEnable.SetBits(1 << uint(irq.num)) + gba.INTERRUPT.IE.SetBits(1 << uint(irq.num)) } +var inInterrupt bool + //export handleInterrupt func handleInterrupt() { - flags := regInterruptRequestFlags.Get() + inInterrupt = true + flags := gba.INTERRUPT.IF.Get() for i := 0; i < 14; i++ { if flags&(1<DEMCR |= CoreDebug_DEMCR_TRCENA_Msk; //DWT->CTRL |= DWT_CTRL_CYCCNTENA_Msk; + + // Disable automatic NVM write operations + sam.NVMCTRL.SetCTRLA_WMODE(sam.NVMCTRL_CTRLA_WMODE_MAN) } func initRTC() { @@ -367,6 +371,10 @@ func initADCClock() { sam.GCLK_PCHCTRL_CHEN) } +func enableCache() { + sam.CMCC.CTRL.SetBits(sam.CMCC_CTRL_CEN) +} + func waitForEvents() { arm.Asm("wfe") } diff --git a/src/runtime/runtime_avr.go b/src/runtime/runtime_avr.go index aff6f78985..bf9860ed08 100644 --- a/src/runtime/runtime_avr.go +++ b/src/runtime/runtime_avr.go @@ -52,7 +52,7 @@ func preinit() { ptr := unsafe.Pointer(&_sbss) for ptr != unsafe.Pointer(&_ebss) { *(*uint8)(ptr) = 0 - ptr = unsafe.Pointer(uintptr(ptr) + 1) + ptr = unsafe.Add(ptr, 1) } } diff --git a/src/runtime/runtime_cortexm.go b/src/runtime/runtime_cortexm.go index 55ccc7d33b..137122b8fd 100644 --- a/src/runtime/runtime_cortexm.go +++ b/src/runtime/runtime_cortexm.go @@ -26,7 +26,7 @@ func preinit() { ptr := unsafe.Pointer(&_sbss) for ptr != unsafe.Pointer(&_ebss) { *(*uint32)(ptr) = 0 - ptr = unsafe.Pointer(uintptr(ptr) + 4) + ptr = unsafe.Add(ptr, 4) } // Initialize .data: global variables initialized from flash. @@ -34,8 +34,8 @@ func preinit() { dst := unsafe.Pointer(&_sdata) for dst != unsafe.Pointer(&_edata) { *(*uint32)(dst) = *(*uint32)(src) - dst = unsafe.Pointer(uintptr(dst) + 4) - src = unsafe.Pointer(uintptr(src) + 4) + dst = unsafe.Add(dst, 4) + src = unsafe.Add(src, 4) } } diff --git a/src/runtime/runtime_esp32c3.go b/src/runtime/runtime_esp32c3.go index 561ba4bfde..8a4c40df4c 100644 --- a/src/runtime/runtime_esp32c3.go +++ b/src/runtime/runtime_esp32c3.go @@ -78,7 +78,7 @@ func interruptInit() { priReg := &esp.INTERRUPT_CORE0.CPU_INT_PRI_1 for i := 0; i < 31; i++ { priReg.Set(0) - priReg = (*volatile.Register32)(unsafe.Pointer(uintptr(unsafe.Pointer(priReg)) + uintptr(4))) + priReg = (*volatile.Register32)(unsafe.Add(unsafe.Pointer(priReg), 4)) } // default threshold for interrupts is 5 diff --git a/src/runtime/runtime_esp32xx.go b/src/runtime/runtime_esp32xx.go index 36b4996572..b2a16f872a 100644 --- a/src/runtime/runtime_esp32xx.go +++ b/src/runtime/runtime_esp32xx.go @@ -32,7 +32,7 @@ func clearbss() { ptr := unsafe.Pointer(&_sbss) for ptr != unsafe.Pointer(&_ebss) { *(*uint32)(ptr) = 0 - ptr = unsafe.Pointer(uintptr(ptr) + 4) + ptr = unsafe.Add(ptr, 4) } } diff --git a/src/runtime/runtime_esp8266.go b/src/runtime/runtime_esp8266.go index 9b1bd807af..b12a8b68fe 100644 --- a/src/runtime/runtime_esp8266.go +++ b/src/runtime/runtime_esp8266.go @@ -76,7 +76,7 @@ func preinit() { ptr := unsafe.Pointer(&_sbss) for ptr != unsafe.Pointer(&_ebss) { *(*uint32)(ptr) = 0 - ptr = unsafe.Pointer(uintptr(ptr) + 4) + ptr = unsafe.Add(ptr, 4) } } diff --git a/src/runtime/runtime_fe310.go b/src/runtime/runtime_fe310.go index 22afca4077..01cc3ba119 100644 --- a/src/runtime/runtime_fe310.go +++ b/src/runtime/runtime_fe310.go @@ -25,6 +25,12 @@ func main() { // Zero the threshold value to allow all priorities of interrupts. sifive.PLIC.THRESHOLD.Set(0) + // Zero MCAUSE, which is set to the reset reason on reset. It must be zeroed + // to make interrupt.In() work. + // This would also be a good time to save the reset reason, but that hasn't + // been implemented yet. + riscv.MCAUSE.Set(0) + // Set the interrupt address. // Note that this address must be aligned specially, otherwise the MODE bits // of MTVEC won't be zero. @@ -73,6 +79,10 @@ func handleInterrupt() { // misaligned loads). However, for now we'll just print a fatal error. handleException(code) } + + // Zero MCAUSE so that it can later be used to see whether we're in an + // interrupt or not. + riscv.MCAUSE.Set(0) } // initPeripherals configures periperhals the way the runtime expects them. diff --git a/src/runtime/runtime_k210.go b/src/runtime/runtime_k210.go index 298a420ff0..5998de69db 100644 --- a/src/runtime/runtime_k210.go +++ b/src/runtime/runtime_k210.go @@ -31,6 +31,12 @@ func main() { kendryte.PLIC.PRIORITY[i].Set(0) } + // Zero MCAUSE, which is set to the reset reason on reset. It must be zeroed + // to make interrupt.In() work. + // This would also be a good time to save the reset reason, but that hasn't + // been implemented yet. + riscv.MCAUSE.Set(0) + // Set the interrupt address. // Note that this address must be aligned specially, otherwise the MODE bits // of MTVEC won't be zero. @@ -93,6 +99,10 @@ func handleInterrupt() { // misaligned loads). However, for now we'll just print a fatal error. handleException(code) } + + // Zero MCAUSE so that it can later be used to see whether we're in an + // interrupt or not. + riscv.MCAUSE.Set(0) } // initPeripherals configures periperhals the way the runtime expects them. diff --git a/src/runtime/runtime_nintendoswitch.go b/src/runtime/runtime_nintendoswitch.go index d03d71496b..f2606023ff 100644 --- a/src/runtime/runtime_nintendoswitch.go +++ b/src/runtime/runtime_nintendoswitch.go @@ -109,7 +109,7 @@ func write(fd int32, buf *byte, count int) int { // TODO: Proper handling write for i := 0; i < count; i++ { putchar(*buf) - buf = (*byte)(unsafe.Pointer(uintptr(unsafe.Pointer(buf)) + 1)) + buf = (*byte)(unsafe.Add(unsafe.Pointer(buf), 1)) } return count } diff --git a/src/runtime/runtime_tinygoriscv.go b/src/runtime/runtime_tinygoriscv.go index f4385f3058..065cdf3f0d 100644 --- a/src/runtime/runtime_tinygoriscv.go +++ b/src/runtime/runtime_tinygoriscv.go @@ -24,7 +24,7 @@ func preinit() { ptr := unsafe.Pointer(&_sbss) for ptr != unsafe.Pointer(&_ebss) { *(*uint32)(ptr) = 0 - ptr = unsafe.Pointer(uintptr(ptr) + 4) + ptr = unsafe.Add(ptr, 4) } // Initialize .data: global variables initialized from flash. @@ -32,7 +32,7 @@ func preinit() { dst := unsafe.Pointer(&_sdata) for dst != unsafe.Pointer(&_edata) { *(*uint32)(dst) = *(*uint32)(src) - dst = unsafe.Pointer(uintptr(dst) + 4) - src = unsafe.Pointer(uintptr(src) + 4) + dst = unsafe.Add(dst, 4) + src = unsafe.Add(src, 4) } } diff --git a/src/runtime/runtime_tinygoriscv64.go b/src/runtime/runtime_tinygoriscv64.go index 44049a51b1..7162979fc1 100644 --- a/src/runtime/runtime_tinygoriscv64.go +++ b/src/runtime/runtime_tinygoriscv64.go @@ -24,7 +24,7 @@ func preinit() { ptr := unsafe.Pointer(&_sbss) for ptr != unsafe.Pointer(&_ebss) { *(*uint64)(ptr) = 0 - ptr = unsafe.Pointer(uintptr(ptr) + 8) + ptr = unsafe.Add(ptr, 8) } // Initialize .data: global variables initialized from flash. @@ -32,7 +32,7 @@ func preinit() { dst := unsafe.Pointer(&_sdata) for dst != unsafe.Pointer(&_edata) { *(*uint64)(dst) = *(*uint64)(src) - dst = unsafe.Pointer(uintptr(dst) + 8) - src = unsafe.Pointer(uintptr(src) + 8) + dst = unsafe.Add(dst, 8) + src = unsafe.Add(src, 8) } } diff --git a/src/runtime/runtime_unix.go b/src/runtime/runtime_unix.go index 39ea125c16..8af3d673c4 100644 --- a/src/runtime/runtime_unix.go +++ b/src/runtime/runtime_unix.go @@ -106,7 +106,7 @@ func os_runtime_args() []string { arg.length = length arg.ptr = (*byte)(*argv) // This is the Go equivalent of "argv++" in C. - argv = (*unsafe.Pointer)(unsafe.Pointer(uintptr(unsafe.Pointer(argv)) + unsafe.Sizeof(argv))) + argv = (*unsafe.Pointer)(unsafe.Add(unsafe.Pointer(argv), unsafe.Sizeof(argv))) } } return args @@ -129,7 +129,7 @@ func syscall_runtime_envs() []string { numEnvs := 0 for *env != nil { numEnvs++ - env = (*unsafe.Pointer)(unsafe.Pointer(uintptr(unsafe.Pointer(env)) + unsafe.Sizeof(environ))) + env = (*unsafe.Pointer)(unsafe.Add(unsafe.Pointer(env), unsafe.Sizeof(environ))) } // Create a string slice of all environment variables. @@ -144,7 +144,7 @@ func syscall_runtime_envs() []string { length: length, } envs = append(envs, *(*string)(unsafe.Pointer(&s))) - env = (*unsafe.Pointer)(unsafe.Pointer(uintptr(unsafe.Pointer(env)) + unsafe.Sizeof(environ))) + env = (*unsafe.Pointer)(unsafe.Add(unsafe.Pointer(env), unsafe.Sizeof(environ))) } return envs diff --git a/src/runtime/runtime_windows.go b/src/runtime/runtime_windows.go index 6be6c32b51..30cb00c1ba 100644 --- a/src/runtime/runtime_windows.go +++ b/src/runtime/runtime_windows.go @@ -85,7 +85,7 @@ func os_runtime_args() []string { arg.length = length arg.ptr = (*byte)(*argv) // This is the Go equivalent of "argv++" in C. - argv = (*unsafe.Pointer)(unsafe.Pointer(uintptr(unsafe.Pointer(argv)) + unsafe.Sizeof(argv))) + argv = (*unsafe.Pointer)(unsafe.Add(unsafe.Pointer(argv), unsafe.Sizeof(argv))) } } return args diff --git a/src/runtime/slice.go b/src/runtime/slice.go index b58fab360c..2269047a8c 100644 --- a/src/runtime/slice.go +++ b/src/runtime/slice.go @@ -37,7 +37,7 @@ func sliceAppend(srcBuf, elemsBuf unsafe.Pointer, srcLen, srcCap, elemsLen uintp } // The slice fits (after possibly allocating a new one), append it in-place. - memmove(unsafe.Pointer(uintptr(srcBuf)+srcLen*elemSize), elemsBuf, elemsLen*elemSize) + memmove(unsafe.Add(srcBuf, srcLen*elemSize), elemsBuf, elemsLen*elemSize) return srcBuf, srcLen + elemsLen, srcCap } @@ -51,3 +51,32 @@ func sliceCopy(dst, src unsafe.Pointer, dstLen, srcLen uintptr, elemSize uintptr memmove(dst, src, n*elemSize) return int(n) } + +// sliceGrow returns a new slice with space for at least newCap elements +func sliceGrow(oldBuf unsafe.Pointer, oldLen, oldCap, newCap, elemSize uintptr) (unsafe.Pointer, uintptr, uintptr) { + + // TODO(dgryski): sliceGrow() and sliceAppend() should be refactored to share the base growth code. + + if oldCap >= newCap { + // No need to grow, return the input slice. + return oldBuf, oldLen, oldCap + } + + // allow nil slice + if oldCap == 0 { + oldCap++ + } + + // grow capacity + for oldCap < newCap { + oldCap *= 2 + } + + buf := alloc(oldCap*elemSize, nil) + if oldLen > 0 { + // copy any data to new slice + memmove(buf, oldBuf, oldLen*elemSize) + } + + return buf, oldLen, oldCap +} diff --git a/src/runtime/string.go b/src/runtime/string.go index 2064629189..13bfcd0ed2 100644 --- a/src/runtime/string.go +++ b/src/runtime/string.go @@ -61,7 +61,7 @@ func stringConcat(x, y _string) _string { length := x.length + y.length buf := alloc(length, nil) memcpy(buf, unsafe.Pointer(x.ptr), x.length) - memcpy(unsafe.Pointer(uintptr(buf)+x.length), unsafe.Pointer(y.ptr), y.length) + memcpy(unsafe.Add(buf, x.length), unsafe.Pointer(y.ptr), y.length) return _string{ptr: (*byte)(buf), length: length} } } @@ -107,7 +107,7 @@ func stringFromRunes(runeSlice []rune) (s _string) { for _, r := range runeSlice { array, numBytes := encodeUTF8(r) for _, c := range array[:numBytes] { - *(*byte)(unsafe.Pointer(uintptr(unsafe.Pointer(s.ptr)) + index)) = c + *(*byte)(unsafe.Add(unsafe.Pointer(s.ptr), index)) = c index++ } } @@ -243,7 +243,7 @@ func isContinuation(b byte) bool { func cgo_CString(s _string) unsafe.Pointer { buf := malloc(s.length + 1) memcpy(buf, unsafe.Pointer(s.ptr), s.length) - *(*byte)(unsafe.Pointer(uintptr(buf) + s.length)) = 0 // trailing 0 byte + *(*byte)(unsafe.Add(buf, s.length)) = 0 // trailing 0 byte return buf } diff --git a/src/syscall/net.go b/src/syscall/net.go index 531fa80d8f..5f8c50da9a 100644 --- a/src/syscall/net.go +++ b/src/syscall/net.go @@ -32,3 +32,22 @@ type Conn interface { // SyscallConn returns a raw network connection. SyscallConn() (RawConn, error) } + +const ( + AF_INET = 0x2 + SOCK_STREAM = 0x1 + SOCK_DGRAM = 0x2 + SOL_SOCKET = 0x1 + SO_KEEPALIVE = 0x9 + SOL_TCP = 0x6 + TCP_KEEPINTVL = 0x5 + IPPROTO_TCP = 0x6 + IPPROTO_UDP = 0x11 + F_SETFL = 0x4 + + // TINYGO: Made up, not a real IP protocol number. This is used to + // create a TLS socket on the device, assuming the device supports mbed + // TLS. + + IPPROTO_TLS = 0xFE +) diff --git a/src/syscall/syscall_libc.go b/src/syscall/syscall_libc.go index 313ae36f78..68072faff2 100644 --- a/src/syscall/syscall_libc.go +++ b/src/syscall/syscall_libc.go @@ -258,7 +258,7 @@ func Environ() []string { for environ := libc_environ; *environ != nil; { length += libc_strlen(*environ) vars++ - environ = (*unsafe.Pointer)(unsafe.Pointer(uintptr(unsafe.Pointer(environ)) + unsafe.Sizeof(environ))) + environ = (*unsafe.Pointer)(unsafe.Add(unsafe.Pointer(environ), unsafe.Sizeof(environ))) } // allocate our backing slice for the strings @@ -287,7 +287,7 @@ func Environ() []string { // add s to our list of environment variables envs = append(envs, s) // environ++ - environ = (*unsafe.Pointer)(unsafe.Pointer(uintptr(unsafe.Pointer(environ)) + unsafe.Sizeof(environ))) + environ = (*unsafe.Pointer)(unsafe.Add(unsafe.Pointer(environ), unsafe.Sizeof(environ))) } return envs } diff --git a/src/syscall/syscall_libc_darwin.go b/src/syscall/syscall_libc_darwin.go index 2a0dd206e3..875a11dffd 100644 --- a/src/syscall/syscall_libc_darwin.go +++ b/src/syscall/syscall_libc_darwin.go @@ -53,7 +53,6 @@ const ( DT_UNKNOWN = 0x0 DT_WHT = 0xe F_GETFL = 0x3 - F_SETFL = 0x4 O_NONBLOCK = 0x4 ) @@ -147,6 +146,11 @@ type Timespec struct { Nsec int64 } +// Unix returns the time stored in ts as seconds plus nanoseconds. +func (ts *Timespec) Unix() (sec int64, nsec int64) { + return int64(ts.Sec), int64(ts.Nsec) +} + // Source: upstream ztypes_darwin_amd64.go type Dirent struct { Ino uint64 @@ -158,7 +162,6 @@ type Dirent struct { Pad_cgo_0 [3]byte } -// Go chose Linux's field names for Stat_t, see https://github.com/golang/go/issues/31735 type Stat_t struct { Dev int32 Mode uint16 @@ -168,10 +171,10 @@ type Stat_t struct { Gid uint32 Rdev int32 Pad_cgo_0 [4]byte - Atim Timespec - Mtim Timespec - Ctim Timespec - Btim Timespec + Atimespec Timespec + Mtimespec Timespec + Ctimespec Timespec + Btimespec Timespec Size int64 Blocks int64 Blksize int32 diff --git a/src/syscall/syscall_libc_wasi.go b/src/syscall/syscall_libc_wasi.go index 18118ab401..d80986a704 100644 --- a/src/syscall/syscall_libc_wasi.go +++ b/src/syscall/syscall_libc_wasi.go @@ -90,7 +90,6 @@ const ( // ../../lib/wasi-libc/expected/wasm32-wasi/predefined-macros.txt F_GETFL = 3 - F_SETFL = 4 ) // These values are needed as a stub until Go supports WASI as a full target. @@ -210,10 +209,14 @@ type Timespec struct { Nsec int64 } +// Unix returns the time stored in ts as seconds plus nanoseconds. +func (ts *Timespec) Unix() (sec int64, nsec int64) { + return int64(ts.Sec), int64(ts.Nsec) +} + // https://github.com/WebAssembly/wasi-libc/blob/main/libc-bottom-half/headers/public/__struct_stat.h // https://github.com/WebAssembly/wasi-libc/blob/main/libc-bottom-half/headers/public/__typedef_ino_t.h // etc. -// Go chose Linux's field names for Stat_t, see https://github.com/golang/go/issues/31735 type Stat_t struct { Dev uint64 Ino uint64 diff --git a/targets/arm.ld b/targets/arm.ld index e4155b9090..39b5c75ddb 100644 --- a/targets/arm.ld +++ b/targets/arm.ld @@ -42,6 +42,8 @@ SECTIONS *(.data) *(.data.*) . = ALIGN(4); + *(.ramfuncs*) /* Functions that must execute from RAM */ + . = ALIGN(4); _edata = .; /* used by startup code */ } >RAM AT>FLASH_TEXT @@ -69,3 +71,7 @@ _heap_start = _ebss; _heap_end = ORIGIN(RAM) + LENGTH(RAM); _globals_start = _sdata; _globals_end = _ebss; + +/* For the flash API */ +__flash_data_start = LOADADDR(.data) + SIZEOF(.data); +__flash_data_end = ORIGIN(FLASH_TEXT) + LENGTH(FLASH_TEXT); diff --git a/targets/cortex-m-qemu.json b/targets/cortex-m-qemu.json index 80e087aa56..5a1758dbfb 100644 --- a/targets/cortex-m-qemu.json +++ b/targets/cortex-m-qemu.json @@ -2,6 +2,7 @@ "inherits": ["cortex-m3"], "build-tags": ["qemu", "lm3s6965"], "linkerscript": "targets/lm3s6965.ld", + "default-stack-size": 4096, "extra-files": [ "targets/cortex-m-qemu.s" ], diff --git a/targets/cortex-m-qemu.s b/targets/cortex-m-qemu.s index fdbecc8fa1..685c7fd5c8 100644 --- a/targets/cortex-m-qemu.s +++ b/targets/cortex-m-qemu.s @@ -23,6 +23,7 @@ Default_Handler: .section .isr_vector, "a", %progbits .global __isr_vector +__isr_vector: // Interrupt vector as defined by Cortex-M, starting with the stack top. // On reset, SP is initialized with *0x0 and PC is loaded with *0x4, loading // _stack_top and Reset_Handler. @@ -54,3 +55,5 @@ Default_Handler: IRQ DebugMon_Handler IRQ PendSV_Handler IRQ SysTick_Handler + +.size __isr_vector, .-__isr_vector diff --git a/testdata/corpus.yaml b/testdata/corpus.yaml index 7b8e21d6ba..02fde0c585 100644 --- a/testdata/corpus.yaml +++ b/testdata/corpus.yaml @@ -281,3 +281,4 @@ - repo: github.com/russross/blackfriday version: v2 - repo: github.com/soypat/mu8 +- repo: github.com/brandondube/pctl diff --git a/testdata/gc.go b/testdata/gc.go index eb594db6cb..456d763b4c 100644 --- a/testdata/gc.go +++ b/testdata/gc.go @@ -1,5 +1,7 @@ package main +import "runtime" + var xorshift32State uint32 = 1 func xorshift32(x uint32) uint32 { @@ -17,6 +19,7 @@ func randuint32() uint32 { func main() { testNonPointerHeap() + testKeepAlive() } var scalarSlices [4][]byte @@ -64,3 +67,10 @@ func testNonPointerHeap() { } println("ok") } + +func testKeepAlive() { + // There isn't much we can test, but at least we can test that + // runtime.KeepAlive compiles correctly. + var x int + runtime.KeepAlive(&x) +} diff --git a/testdata/map.go b/testdata/map.go index d30889910e..d746cf9fc5 100644 --- a/testdata/map.go +++ b/testdata/map.go @@ -129,6 +129,8 @@ func main() { floatcmplx() mapgrow() + + interfacerehash() } func floatcmplx() { @@ -274,3 +276,35 @@ func mapgrow() { } println("done") } + +type Counter interface { + count() int +} + +type counter struct { + i int +} + +func (c *counter) count() int { + return c.i +} + +func interfacerehash() { + m := make(map[Counter]int) + + for i := 0; i < 20; i++ { + c := &counter{i} + m[c] = i + } + + var failures int + for k, v := range m { + if got := m[k]; got != v { + println("lookup failure got", got, "want", v) + failures++ + } + } + if failures == 0 { + println("no interface lookup failures") + } +} diff --git a/testdata/map.txt b/testdata/map.txt index 6bf04c80d6..d5e553b1a7 100644 --- a/testdata/map.txt +++ b/testdata/map.txt @@ -80,3 +80,4 @@ tested growing of a map 2 2 done +no interface lookup failures diff --git a/testdata/reflect.go b/testdata/reflect.go index 47359e66ea..1a92e47ab7 100644 --- a/testdata/reflect.go +++ b/testdata/reflect.go @@ -3,6 +3,7 @@ package main import ( "errors" "reflect" + "strconv" "unsafe" ) @@ -17,8 +18,8 @@ type ( Y int16 } mystruct struct { - n int `foo:"bar"` - some point + n int `foo:"bar"` + some point "some\x00tag" zero struct{} buf []byte Buf []byte @@ -480,7 +481,8 @@ func showValue(rv reflect.Value, indent string) { for i := 0; i < rv.NumField(); i++ { field := rt.Field(i) println(indent+" field:", i, field.Name) - println(indent+" tag:", field.Tag) + println(indent+" pkg:", field.PkgPath) + println(indent+" tag:", strconv.Quote(string(field.Tag))) println(indent+" embedded:", field.Anonymous) println(indent+" exported:", field.IsExported()) showValue(rv.Field(i), indent+" ") diff --git a/testdata/reflect.txt b/testdata/reflect.txt index 03a7e5e590..e4a92a5e1c 100644 --- a/testdata/reflect.txt +++ b/testdata/reflect.txt @@ -233,7 +233,8 @@ reflect type: struct reflect type: struct struct: 1 field: 0 error - tag: + pkg: main + tag: "" embedded: true exported: false reflect type: interface caninterface=false @@ -242,19 +243,22 @@ reflect type: struct reflect type: struct struct: 3 field: 0 a - tag: + pkg: main + tag: "" embedded: false exported: false reflect type: uint8 caninterface=false uint: 42 field: 1 b - tag: + pkg: main + tag: "" embedded: false exported: false reflect type: int16 caninterface=false int: 321 field: 2 c - tag: + pkg: main + tag: "" embedded: false exported: false reflect type: int8 caninterface=false @@ -262,37 +266,43 @@ reflect type: struct reflect type: struct comparable=false struct: 5 field: 0 n - tag: foo:"bar" + pkg: main + tag: "foo:\"bar\"" embedded: false exported: false reflect type: int caninterface=false int: 5 field: 1 some - tag: + pkg: main + tag: "some\x00tag" embedded: false exported: false reflect type: struct caninterface=false struct: 2 field: 0 X - tag: + pkg: + tag: "" embedded: false exported: true reflect type: int16 caninterface=false int: -5 field: 1 Y - tag: + pkg: + tag: "" embedded: false exported: true reflect type: int16 caninterface=false int: 3 field: 2 zero - tag: + pkg: main + tag: "" embedded: false exported: false reflect type: struct caninterface=false struct: 0 field: 3 buf - tag: + pkg: main + tag: "" embedded: false exported: false reflect type: slice caninterface=false comparable=false @@ -306,7 +316,8 @@ reflect type: struct comparable=false reflect type: uint8 addrable=true caninterface=false uint: 111 field: 4 Buf - tag: + pkg: + tag: "" embedded: false exported: true reflect type: slice comparable=false @@ -322,14 +333,16 @@ reflect type: ptr reflect type: struct settable=true addrable=true struct: 2 field: 0 next - tag: description:"chain" + pkg: main + tag: "description:\"chain\"" embedded: false exported: false reflect type: ptr addrable=true caninterface=false pointer: false struct nil: true field: 1 foo - tag: + pkg: main + tag: "" embedded: false exported: false reflect type: int addrable=true caninterface=false @@ -337,13 +350,15 @@ reflect type: ptr reflect type: struct struct: 2 field: 0 A - tag: + pkg: + tag: "" embedded: false exported: true reflect type: uintptr uint: 2 field: 1 B - tag: + pkg: + tag: "" embedded: false exported: true reflect type: uintptr diff --git a/testdata/slice.go b/testdata/slice.go index 18b80b7c49..fbb1d45cc5 100644 --- a/testdata/slice.go +++ b/testdata/slice.go @@ -6,6 +6,8 @@ type MySlice [32]byte type myUint8 uint8 +type RecursiveSlice []RecursiveSlice + // Indexing into slice with named type (regression test). var array = [4]int{ myUint8(2): 3, @@ -160,6 +162,10 @@ func main() { for _, c := range named { assert(c == 0) } + + // Test recursive slices. + rs := []RecursiveSlice(nil) + println("len:", len(rs)) } func printslice(name string, s []int) { diff --git a/testdata/slice.txt b/testdata/slice.txt index ea8d4491ab..d16a0bda9f 100644 --- a/testdata/slice.txt +++ b/testdata/slice.txt @@ -16,3 +16,4 @@ bytes: len=6 cap=6 data: 1 2 3 102 111 111 slice to array pointer: 1 -2 20 4 unsafe.Add array: 1 5 8 4 unsafe.Slice array: 3 3 9 15 4 +len: 0 diff --git a/tools/gen-device-svd/gen-device-svd.go b/tools/gen-device-svd/gen-device-svd.go index ce1f7ed138..8eb308c7f5 100755 --- a/tools/gen-device-svd/gen-device-svd.go +++ b/tools/gen-device-svd/gen-device-svd.go @@ -90,9 +90,10 @@ type SVDCluster struct { } type Device struct { - Metadata *Metadata - Interrupts []*Interrupt - Peripherals []*Peripheral + Metadata *Metadata + Interrupts []*Interrupt + Peripherals []*Peripheral + PeripheralDict map[string]*Peripheral } type Metadata struct { @@ -191,6 +192,142 @@ func cleanName(text string) string { return text } +func processSubCluster(p *Peripheral, cluster *SVDCluster, clusterOffset uint64, clusterName string, peripheralDict map[string]*Peripheral) []*Peripheral { + var peripheralsList []*Peripheral + clusterPrefix := clusterName + "_" + cpRegisters := []*PeripheralField{} + + for _, regEl := range cluster.Registers { + cpRegisters = append(cpRegisters, parseRegister(p.GroupName, regEl, p.BaseAddress+clusterOffset, clusterPrefix)...) + } + // handle sub-clusters of registers + for _, subClusterEl := range cluster.Clusters { + subclusterName := strings.ReplaceAll(subClusterEl.Name, "[%s]", "") + subclusterPrefix := subclusterName + "_" + subclusterOffset, err := strconv.ParseUint(subClusterEl.AddressOffset, 0, 32) + if err != nil { + panic(err) + } + subdim := *subClusterEl.Dim + subdimIncrement, err := strconv.ParseInt(subClusterEl.DimIncrement, 0, 32) + if err != nil { + panic(err) + } + + if subdim > 1 { + subcpRegisters := []*PeripheralField{} + for _, regEl := range subClusterEl.Registers { + subcpRegisters = append(subcpRegisters, parseRegister(p.GroupName, regEl, p.BaseAddress+clusterOffset+subclusterOffset, subclusterPrefix)...) + } + + cpRegisters = append(cpRegisters, &PeripheralField{ + Name: subclusterName, + Address: p.BaseAddress + clusterOffset + subclusterOffset, + Description: subClusterEl.Description, + Registers: subcpRegisters, + Array: subdim, + ElementSize: int(subdimIncrement), + ShortName: clusterPrefix + subclusterName, + }) + } else { + for _, regEl := range subClusterEl.Registers { + cpRegisters = append(cpRegisters, parseRegister(regEl.Name, regEl, p.BaseAddress+clusterOffset+subclusterOffset, subclusterPrefix)...) + } + } + } + + sort.SliceStable(cpRegisters, func(i, j int) bool { + return cpRegisters[i].Address < cpRegisters[j].Address + }) + clusterPeripheral := &Peripheral{ + Name: p.Name + "_" + clusterName, + GroupName: p.GroupName + "_" + clusterName, + Description: p.Description + " - " + clusterName, + ClusterName: clusterName, + BaseAddress: p.BaseAddress + clusterOffset, + Registers: cpRegisters, + } + peripheralsList = append(peripheralsList, clusterPeripheral) + peripheralDict[clusterPeripheral.Name] = clusterPeripheral + p.Subtypes = append(p.Subtypes, clusterPeripheral) + + return peripheralsList +} + +func processCluster(p *Peripheral, clusters []*SVDCluster, peripheralDict map[string]*Peripheral) []*Peripheral { + var peripheralsList []*Peripheral + for _, cluster := range clusters { + clusterName := strings.ReplaceAll(cluster.Name, "[%s]", "") + if cluster.DimIndex != nil { + clusterName = strings.ReplaceAll(clusterName, "%s", "") + } + clusterPrefix := clusterName + "_" + clusterOffset, err := strconv.ParseUint(cluster.AddressOffset, 0, 32) + if err != nil { + panic(err) + } + var dim, dimIncrement int + if cluster.Dim == nil { + // Nordic SVD have sub-clusters with another sub-clusters. + if clusterOffset == 0 || len(cluster.Clusters) > 0 { + peripheralsList = append(peripheralsList, processSubCluster(p, cluster, clusterOffset, clusterName, peripheralDict)...) + continue + } + dim = -1 + dimIncrement = -1 + } else { + dim = *cluster.Dim + if dim == 1 { + dimIncrement = -1 + } else { + inc, err := strconv.ParseUint(cluster.DimIncrement, 0, 32) + if err != nil { + panic(err) + } + dimIncrement = int(inc) + } + } + clusterRegisters := []*PeripheralField{} + for _, regEl := range cluster.Registers { + regName := p.GroupName + if regName == "" { + regName = p.Name + } + clusterRegisters = append(clusterRegisters, parseRegister(regName, regEl, p.BaseAddress+clusterOffset, clusterPrefix)...) + } + sort.SliceStable(clusterRegisters, func(i, j int) bool { + return clusterRegisters[i].Address < clusterRegisters[j].Address + }) + if dimIncrement == -1 && len(clusterRegisters) > 0 { + lastReg := clusterRegisters[len(clusterRegisters)-1] + lastAddress := lastReg.Address + if lastReg.Array != -1 { + lastAddress = lastReg.Address + uint64(lastReg.Array*lastReg.ElementSize) + } + firstAddress := clusterRegisters[0].Address + dimIncrement = int(lastAddress - firstAddress) + } + + if !unicode.IsUpper(rune(clusterName[0])) && !unicode.IsDigit(rune(clusterName[0])) { + clusterName = strings.ToUpper(clusterName) + } + + p.Registers = append(p.Registers, &PeripheralField{ + Name: clusterName, + Address: p.BaseAddress + clusterOffset, + Description: cluster.Description, + Registers: clusterRegisters, + Array: dim, + ElementSize: dimIncrement, + ShortName: clusterName, + }) + } + sort.SliceStable(p.Registers, func(i, j int) bool { + return p.Registers[i].Address < p.Registers[j].Address + }) + return peripheralsList +} + // Read ARM SVD files. func readSVD(path, sourceURL string) (*Device, error) { // Open the XML file. @@ -293,133 +430,7 @@ func readSVD(path, sourceURL string) (*Device, error) { } p.Registers = append(p.Registers, parseRegister(regName, register, baseAddress, "")...) } - for _, cluster := range periphEl.Clusters { - clusterName := strings.ReplaceAll(cluster.Name, "[%s]", "") - if cluster.DimIndex != nil { - clusterName = strings.ReplaceAll(clusterName, "%s", "") - } - clusterPrefix := clusterName + "_" - clusterOffset, err := strconv.ParseUint(cluster.AddressOffset, 0, 32) - if err != nil { - panic(err) - } - var dim, dimIncrement int - if cluster.Dim == nil { - if clusterOffset == 0 { - // make this a separate peripheral - cpRegisters := []*PeripheralField{} - for _, regEl := range cluster.Registers { - cpRegisters = append(cpRegisters, parseRegister(groupName, regEl, baseAddress, clusterName+"_")...) - } - // handle sub-clusters of registers - for _, subClusterEl := range cluster.Clusters { - subclusterName := strings.ReplaceAll(subClusterEl.Name, "[%s]", "") - subclusterPrefix := subclusterName + "_" - subclusterOffset, err := strconv.ParseUint(subClusterEl.AddressOffset, 0, 32) - if err != nil { - panic(err) - } - subdim := *subClusterEl.Dim - subdimIncrement, err := strconv.ParseInt(subClusterEl.DimIncrement, 0, 32) - if err != nil { - panic(err) - } - - if subdim > 1 { - subcpRegisters := []*PeripheralField{} - subregSize := 0 - for _, regEl := range subClusterEl.Registers { - size, err := strconv.ParseInt(*regEl.Size, 0, 32) - if err != nil { - panic(err) - } - subregSize += int(size) - subcpRegisters = append(subcpRegisters, parseRegister(groupName, regEl, baseAddress+subclusterOffset, subclusterPrefix)...) - } - cpRegisters = append(cpRegisters, &PeripheralField{ - Name: subclusterName, - Address: baseAddress + subclusterOffset, - Description: subClusterEl.Description, - Registers: subcpRegisters, - Array: subdim, - ElementSize: int(subdimIncrement), - ShortName: clusterPrefix + subclusterName, - }) - } else { - for _, regEl := range subClusterEl.Registers { - cpRegisters = append(cpRegisters, parseRegister(regEl.Name, regEl, baseAddress+subclusterOffset, subclusterPrefix)...) - } - } - } - - sort.SliceStable(cpRegisters, func(i, j int) bool { - return cpRegisters[i].Address < cpRegisters[j].Address - }) - clusterPeripheral := &Peripheral{ - Name: periphEl.Name + "_" + clusterName, - GroupName: groupName + "_" + clusterName, - Description: description + " - " + clusterName, - ClusterName: clusterName, - BaseAddress: baseAddress, - Registers: cpRegisters, - } - peripheralsList = append(peripheralsList, clusterPeripheral) - peripheralDict[clusterPeripheral.Name] = clusterPeripheral - p.Subtypes = append(p.Subtypes, clusterPeripheral) - continue - } - dim = -1 - dimIncrement = -1 - } else { - dim = *cluster.Dim - if dim == 1 { - dimIncrement = -1 - } else { - inc, err := strconv.ParseUint(cluster.DimIncrement, 0, 32) - if err != nil { - panic(err) - } - dimIncrement = int(inc) - } - } - clusterRegisters := []*PeripheralField{} - for _, regEl := range cluster.Registers { - regName := groupName - if regName == "" { - regName = periphEl.Name - } - clusterRegisters = append(clusterRegisters, parseRegister(regName, regEl, baseAddress+clusterOffset, clusterPrefix)...) - } - sort.SliceStable(clusterRegisters, func(i, j int) bool { - return clusterRegisters[i].Address < clusterRegisters[j].Address - }) - if dimIncrement == -1 && len(clusterRegisters) > 0 { - lastReg := clusterRegisters[len(clusterRegisters)-1] - lastAddress := lastReg.Address - if lastReg.Array != -1 { - lastAddress = lastReg.Address + uint64(lastReg.Array*lastReg.ElementSize) - } - firstAddress := clusterRegisters[0].Address - dimIncrement = int(lastAddress - firstAddress) - } - - if !unicode.IsUpper(rune(clusterName[0])) && !unicode.IsDigit(rune(clusterName[0])) { - clusterName = strings.ToUpper(clusterName) - } - - p.Registers = append(p.Registers, &PeripheralField{ - Name: clusterName, - Address: baseAddress + clusterOffset, - Description: cluster.Description, - Registers: clusterRegisters, - Array: dim, - ElementSize: dimIncrement, - ShortName: clusterName, - }) - } - sort.SliceStable(p.Registers, func(i, j int) bool { - return p.Registers[i].Address < p.Registers[j].Address - }) + peripheralsList = append(peripheralsList, processCluster(p, periphEl.Clusters, peripheralDict)...) } // Make a sorted list of interrupts. @@ -459,9 +470,10 @@ func readSVD(path, sourceURL string) (*Device, error) { metadata.NVICPrioBits = device.CPU.NVICPrioBits } return &Device{ - Metadata: metadata, - Interrupts: interruptList, - Peripherals: peripheralsList, + Metadata: metadata, + Interrupts: interruptList, + Peripherals: peripheralsList, + PeripheralDict: peripheralDict, }, nil } @@ -979,10 +991,11 @@ var ( address := peripheral.BaseAddress type clusterInfo struct { - name string - address uint64 - size uint64 - registers []*PeripheralField + name string + description string + address uint64 + size uint64 + registers []*PeripheralField } clusters := []clusterInfo{} for _, register := range peripheral.Registers { @@ -1024,7 +1037,7 @@ var ( if register.Registers != nil { // This is a cluster, not a register. Create the cluster type. regType = peripheral.GroupName + "_" + register.Name - clusters = append(clusters, clusterInfo{regType, register.Address, uint64(register.ElementSize), register.Registers}) + clusters = append(clusters, clusterInfo{regType, register.Description, register.Address, uint64(register.ElementSize), register.Registers}) regType = regType + "_Type" subaddress := register.Address for _, subregister := range register.Registers { @@ -1075,7 +1088,16 @@ var ( continue } + if _, ok := device.PeripheralDict[cluster.name]; ok { + continue + } + fmt.Fprintln(w) + if cluster.description != "" { + for _, l := range splitLine(cluster.description) { + fmt.Fprintf(w, "// %s\n", l) + } + } fmt.Fprintf(w, "type %s_Type struct {\n", cluster.name) address := cluster.address @@ -1116,7 +1138,7 @@ var ( if register.Registers != nil { // This is a cluster, not a register. Create the cluster type. regType = peripheral.GroupName + "_" + register.Name - clusters = append(clusters, clusterInfo{regType, register.Address, uint64(register.ElementSize), register.Registers}) + clusters = append(clusters, clusterInfo{regType, register.Description, register.Address, uint64(register.ElementSize), register.Registers}) regType = regType + "_Type" subaddress := register.Address @@ -1400,6 +1422,9 @@ __isr_vector: for _, intr := range device.Interrupts { fmt.Fprintf(w, " IRQ %s_IRQHandler\n", intr.Name) } + w.WriteString(` +.size __isr_vector, .-__isr_vector +`) return w.Flush() } diff --git a/transform/globals.go b/transform/globals.go deleted file mode 100644 index cc506f0244..0000000000 --- a/transform/globals.go +++ /dev/null @@ -1,20 +0,0 @@ -package transform - -import "tinygo.org/x/go-llvm" - -// This file implements small transformations on globals (functions and global -// variables) for specific ABIs/architectures. - -// ApplyFunctionSections puts every function in a separate section. This makes -// it possible for the linker to remove dead code. It is the equivalent of -// passing -ffunction-sections to a C compiler. -func ApplyFunctionSections(mod llvm.Module) { - llvmFn := mod.FirstFunction() - for !llvmFn.IsNil() { - if !llvmFn.IsDeclaration() && llvmFn.Section() == "" { - name := llvmFn.Name() - llvmFn.SetSection(".text." + name) - } - llvmFn = llvm.NextFunction(llvmFn) - } -} diff --git a/transform/globals_test.go b/transform/globals_test.go deleted file mode 100644 index 1b6f243b8c..0000000000 --- a/transform/globals_test.go +++ /dev/null @@ -1,15 +0,0 @@ -package transform_test - -import ( - "testing" - - "github.com/tinygo-org/tinygo/transform" - "tinygo.org/x/go-llvm" -) - -func TestApplyFunctionSections(t *testing.T) { - t.Parallel() - testTransform(t, "testdata/globals-function-sections", func(mod llvm.Module) { - transform.ApplyFunctionSections(mod) - }) -} diff --git a/transform/interface-lowering.go b/transform/interface-lowering.go index 55d1af39b7..ebd47ff8f6 100644 --- a/transform/interface-lowering.go +++ b/transform/interface-lowering.go @@ -13,8 +13,7 @@ package transform // // typeAssert: // Replaced with an icmp instruction so it can be directly used in a type -// switch. This is very easy to optimize for LLVM: it will often translate a -// type switch into a regular switch statement. +// switch. // // interface type assert: // These functions are defined by creating a big type switch over all the @@ -54,10 +53,11 @@ type methodInfo struct { // typeInfo describes a single concrete Go type, which can be a basic or a named // type. If it is a named type, it may have methods. type typeInfo struct { - name string - typecode llvm.Value - methodSet llvm.Value - methods []*methodInfo + name string + typecode llvm.Value + typecodeGEP llvm.Value + methodSet llvm.Value + methods []*methodInfo } // getMethod looks up the method on this type with the given signature and @@ -91,6 +91,8 @@ type lowerInterfacesPass struct { difiles map[string]llvm.Metadata ctx llvm.Context uintptrType llvm.Type + targetData llvm.TargetData + i8ptrType llvm.Type types map[string]*typeInfo signatures map[string]*signatureInfo interfaces map[string]*interfaceInfo @@ -101,14 +103,17 @@ type lowerInterfacesPass struct { // before LLVM can work on them. This is done so that a few cleanup passes can // run before assigning the final type codes. func LowerInterfaces(mod llvm.Module, config *compileopts.Config) error { + ctx := mod.Context() targetData := llvm.NewTargetData(mod.DataLayout()) defer targetData.Dispose() p := &lowerInterfacesPass{ mod: mod, config: config, - builder: mod.Context().NewBuilder(), - ctx: mod.Context(), + builder: ctx.NewBuilder(), + ctx: ctx, + targetData: targetData, uintptrType: mod.Context().IntType(targetData.PointerSize() * 8), + i8ptrType: llvm.PointerType(ctx.Int8Type(), 0), types: make(map[string]*typeInfo), signatures: make(map[string]*signatureInfo), interfaces: make(map[string]*interfaceInfo), @@ -151,11 +156,26 @@ func (p *lowerInterfacesPass) run() error { } p.types[name] = t initializer := global.Initializer() - if initializer.IsNil() { - continue + firstField := p.builder.CreateExtractValue(initializer, 0, "") + if firstField.Type() != p.ctx.Int8Type() { + // This type has a method set at index 0. Change the GEP to + // point to index 1 (the meta byte). + t.typecodeGEP = llvm.ConstGEP(global.GlobalValueType(), global, []llvm.Value{ + llvm.ConstInt(p.ctx.Int32Type(), 0, false), + llvm.ConstInt(p.ctx.Int32Type(), 1, false), + }) + methodSet := stripPointerCasts(firstField) + if !strings.HasSuffix(methodSet.Name(), "$methodset") { + panic("expected method set") + } + p.addTypeMethods(t, methodSet) + } else { + // This type has no method set. + t.typecodeGEP = llvm.ConstGEP(global.GlobalValueType(), global, []llvm.Value{ + llvm.ConstInt(p.ctx.Int32Type(), 0, false), + llvm.ConstInt(p.ctx.Int32Type(), 0, false), + }) } - methodSet := p.builder.CreateExtractValue(initializer, 2, "") - p.addTypeMethods(t, methodSet) } } } @@ -266,10 +286,10 @@ func (p *lowerInterfacesPass) run() error { actualType := use.Operand(0) name := strings.TrimPrefix(use.Operand(1).Name(), "reflect/types.typeid:") if t, ok := p.types[name]; ok { - // The type exists in the program, so lower to a regular integer + // The type exists in the program, so lower to a regular pointer // comparison. p.builder.SetInsertPointBefore(use) - commaOk := p.builder.CreateICmp(llvm.IntEQ, llvm.ConstPtrToInt(t.typecode, p.uintptrType), actualType, "typeassert.ok") + commaOk := p.builder.CreateICmp(llvm.IntEQ, t.typecodeGEP, actualType, "typeassert.ok") use.ReplaceAllUsesWith(commaOk) } else { // The type does not exist in the program, so lower to a constant @@ -282,16 +302,54 @@ func (p *lowerInterfacesPass) run() error { use.EraseFromParentAsInstruction() } + // Create a sorted list of type names, for predictable iteration. + var typeNames []string + for name := range p.types { + typeNames = append(typeNames, name) + } + sort.Strings(typeNames) + // Remove all method sets, which are now unnecessary and inhibit later - // optimizations if they are left in place. Also remove references to the - // interface type assert functions just to be sure. - zeroUintptr := llvm.ConstNull(p.uintptrType) - for _, t := range p.types { - initializer := t.typecode.Initializer() - methodSet := p.builder.CreateExtractValue(initializer, 2, "") - initializer = p.builder.CreateInsertValue(initializer, llvm.ConstNull(methodSet.Type()), 2, "") - initializer = p.builder.CreateInsertValue(initializer, zeroUintptr, 4, "") - t.typecode.SetInitializer(initializer) + // optimizations if they are left in place. + zero := llvm.ConstInt(p.ctx.Int32Type(), 0, false) + for _, name := range typeNames { + t := p.types[name] + if !t.methodSet.IsNil() { + initializer := t.typecode.Initializer() + var newInitializerFields []llvm.Value + for i := 1; i < initializer.Type().StructElementTypesCount(); i++ { + newInitializerFields = append(newInitializerFields, p.builder.CreateExtractValue(initializer, i, "")) + } + newInitializer := p.ctx.ConstStruct(newInitializerFields, false) + typecodeName := t.typecode.Name() + newGlobal := llvm.AddGlobal(p.mod, newInitializer.Type(), typecodeName+".tmp") + newGlobal.SetInitializer(newInitializer) + newGlobal.SetLinkage(t.typecode.Linkage()) + newGlobal.SetGlobalConstant(true) + newGlobal.SetAlignment(t.typecode.Alignment()) + for _, use := range getUses(t.typecode) { + if !use.IsAConstantExpr().IsNil() { + opcode := use.Opcode() + if opcode == llvm.GetElementPtr && use.OperandsCount() == 3 { + if use.Operand(1).ZExtValue() == 0 && use.Operand(2).ZExtValue() == 1 { + gep := p.builder.CreateInBoundsGEP(newGlobal.GlobalValueType(), newGlobal, []llvm.Value{zero, zero}, "") + use.ReplaceAllUsesWith(gep) + } + } + } + } + // Fallback. + if hasUses(t.typecode) { + bitcast := llvm.ConstBitCast(newGlobal, p.i8ptrType) + negativeOffset := -int64(p.targetData.TypeAllocSize(p.i8ptrType)) + gep := p.builder.CreateInBoundsGEP(p.ctx.Int8Type(), bitcast, []llvm.Value{llvm.ConstInt(p.ctx.Int32Type(), uint64(negativeOffset), true)}, "") + bitcast2 := llvm.ConstBitCast(gep, t.typecode.Type()) + t.typecode.ReplaceAllUsesWith(bitcast2) + } + t.typecode.EraseFromParentAsGlobal() + newGlobal.SetName(typecodeName) + t.typecode = newGlobal + } } return nil @@ -301,22 +359,22 @@ func (p *lowerInterfacesPass) run() error { // retrieves the signatures and the references to the method functions // themselves for later type<->interface matching. func (p *lowerInterfacesPass) addTypeMethods(t *typeInfo, methodSet llvm.Value) { - if !t.methodSet.IsNil() || methodSet.IsNull() { + if !t.methodSet.IsNil() { // no methods or methods already read return } - if !methodSet.IsAConstantExpr().IsNil() && methodSet.Opcode() == llvm.GetElementPtr { - methodSet = methodSet.Operand(0) // get global from GEP, for LLVM 14 (non-opaque pointers) - } // This type has methods, collect all methods of this type. t.methodSet = methodSet set := methodSet.Initializer() // get value from global - for i := 0; i < set.Type().ArrayLength(); i++ { - methodData := p.builder.CreateExtractValue(set, i, "") - signatureGlobal := p.builder.CreateExtractValue(methodData, 0, "") + signatures := p.builder.CreateExtractValue(set, 1, "") + wrappers := p.builder.CreateExtractValue(set, 2, "") + numMethods := signatures.Type().ArrayLength() + for i := 0; i < numMethods; i++ { + signatureGlobal := p.builder.CreateExtractValue(signatures, i, "") + function := p.builder.CreateExtractValue(wrappers, i, "") + function = stripPointerCasts(function) // strip bitcasts signatureName := signatureGlobal.Name() - function := p.builder.CreateExtractValue(methodData, 1, "").Operand(0) signature := p.getSignature(signatureName) method := &methodInfo{ function: function, @@ -401,7 +459,7 @@ func (p *lowerInterfacesPass) defineInterfaceImplementsFunc(fn llvm.Value, itf * actualType := fn.Param(0) for _, typ := range itf.types { nextBlock := p.ctx.AddBasicBlock(fn, typ.name+".next") - cmp := p.builder.CreateICmp(llvm.IntEQ, actualType, llvm.ConstPtrToInt(typ.typecode, p.uintptrType), typ.name+".icmp") + cmp := p.builder.CreateICmp(llvm.IntEQ, actualType, typ.typecodeGEP, typ.name+".icmp") p.builder.CreateCondBr(cmp, thenBlock, nextBlock) p.builder.SetInsertPointAtEnd(nextBlock) } @@ -440,7 +498,7 @@ func (p *lowerInterfacesPass) defineInterfaceMethodFunc(fn llvm.Value, itf *inte params[i] = fn.Param(i + 1) } params = append(params, - llvm.Undef(llvm.PointerType(p.ctx.Int8Type(), 0)), + llvm.Undef(p.i8ptrType), ) // Start chain in the entry block. @@ -472,7 +530,7 @@ func (p *lowerInterfacesPass) defineInterfaceMethodFunc(fn llvm.Value, itf *inte // Create type check (if/else). bb := p.ctx.AddBasicBlock(fn, typ.name) next := p.ctx.AddBasicBlock(fn, typ.name+".next") - cmp := p.builder.CreateICmp(llvm.IntEQ, actualType, llvm.ConstPtrToInt(typ.typecode, p.uintptrType), typ.name+".icmp") + cmp := p.builder.CreateICmp(llvm.IntEQ, actualType, typ.typecodeGEP, typ.name+".icmp") p.builder.CreateCondBr(cmp, bb, next) // The function we will redirect to when the interface has this type. @@ -522,7 +580,7 @@ func (p *lowerInterfacesPass) defineInterfaceMethodFunc(fn llvm.Value, itf *inte // method on a nil interface. nilPanic := p.mod.NamedFunction("runtime.nilPanic") p.builder.CreateCall(nilPanic.GlobalValueType(), nilPanic, []llvm.Value{ - llvm.Undef(llvm.PointerType(p.ctx.Int8Type(), 0)), + llvm.Undef(p.i8ptrType), }, "") p.builder.CreateUnreachable() } diff --git a/transform/optimizer.go b/transform/optimizer.go index 80be631913..20258ef4fe 100644 --- a/transform/optimizer.go +++ b/transform/optimizer.go @@ -116,7 +116,6 @@ func Optimize(mod llvm.Module, config *compileopts.Config, optLevel, sizeLevel i goPasses.Run(mod) // Run TinyGo-specific interprocedural optimizations. - LowerReflect(mod) OptimizeAllocs(mod, config.Options.PrintAllocs, func(pos token.Position, msg string) { fmt.Fprintln(os.Stderr, pos.String()+": "+msg) }) @@ -129,7 +128,6 @@ func Optimize(mod llvm.Module, config *compileopts.Config, optLevel, sizeLevel i if err != nil { return []error{err} } - LowerReflect(mod) errs := LowerInterrupts(mod) if len(errs) > 0 { return errs diff --git a/transform/reflect.go b/transform/reflect.go deleted file mode 100644 index b994df61c2..0000000000 --- a/transform/reflect.go +++ /dev/null @@ -1,567 +0,0 @@ -package transform - -// This file has some compiler support for run-time reflection using the reflect -// package. In particular, it encodes type information in type codes in such a -// way that the reflect package can decode the type from this information. -// Where needed, it also adds some side tables for looking up more information -// about a type, when that information cannot be stored directly in the type -// code. -// -// Go has 26 different type kinds. -// -// Type kinds are subdivided in basic types (see the list of basicTypes below) -// that are mostly numeric literals and non-basic (or "complex") types that are -// more difficult to encode. These non-basic types come in two forms: -// * Prefix types (pointer, slice, interface, channel): these just add -// something to an existing type. For example, a pointer like *int just adds -// the fact that it's a pointer to an existing type (int). -// These are encoded efficiently by adding a prefix to a type code. -// * Types with multiple fields (struct, array, func, map). All of these have -// multiple fields contained within. Most obviously structs can contain many -// types as fields. Also arrays contain not just the element type but also -// the length parameter which can be any arbitrary number and thus may not -// fit in a type code. -// These types are encoded using side tables. -// -// This distinction is also important for how named types are encoded. At the -// moment, named basic type just get a unique number assigned while named -// non-basic types have their underlying type stored in a sidetable. - -import ( - "encoding/binary" - "go/ast" - "math/big" - "sort" - "strings" - - "tinygo.org/x/go-llvm" -) - -// A list of basic types and their numbers. This list should be kept in sync -// with the list of Kind constants of type.go in the reflect package. -var basicTypes = map[string]int64{ - "bool": 1, - "int": 2, - "int8": 3, - "int16": 4, - "int32": 5, - "int64": 6, - "uint": 7, - "uint8": 8, - "uint16": 9, - "uint32": 10, - "uint64": 11, - "uintptr": 12, - "float32": 13, - "float64": 14, - "complex64": 15, - "complex128": 16, - "string": 17, - "unsafe.Pointer": 18, -} - -// A list of non-basic types. Adding 19 to this number will give the Kind as -// used in src/reflect/types.go, and it must be kept in sync with that list. -var nonBasicTypes = map[string]int64{ - "chan": 0, - "interface": 1, - "pointer": 2, - "slice": 3, - "array": 4, - "func": 5, - "map": 6, - "struct": 7, -} - -// typeCodeAssignmentState keeps some global state around for type code -// assignments, used to assign one unique type code to each Go type. -type typeCodeAssignmentState struct { - // Builder used purely for constant operations (because LLVM 15 removed many - // llvm.Const* functions). - builder llvm.Builder - - // An integer that's incremented each time it's used to give unique IDs to - // type codes that are not yet fully supported otherwise by the reflect - // package (or are simply unused in the compiled program). - fallbackIndex int - - // This is the length of an uintptr. Only used occasionally to know whether - // a given number can be encoded as a varint. - uintptrLen int - - // Map of named types to their type code. It is important that named types - // get unique IDs for each type. - namedBasicTypes map[string]int - namedNonBasicTypes map[string]int - - // Map of array types to their type code. - arrayTypes map[string]int - arrayTypesSidetable []byte - needsArrayTypesSidetable bool - - // Map of struct types to their type code. - structTypes map[string]int - structTypesSidetable []byte - needsStructNamesSidetable bool - - // Map of struct names and tags to their name string. - structNames map[string]int - structNamesSidetable []byte - needsStructTypesSidetable bool - - // This byte array is stored in reflect.namedNonBasicTypesSidetable and is - // used at runtime to get details about a named non-basic type. - // Entries are varints (see makeVarint below and readVarint in - // reflect/sidetables.go for the encoding): one varint per entry. The - // integers in namedNonBasicTypes are indices into this array. Because these - // are varints, most type codes are really small (just one byte). - // - // Note that this byte buffer is not created when it is not needed - // (reflect.namedNonBasicTypesSidetable has no uses), see - // needsNamedTypesSidetable. - namedNonBasicTypesSidetable []uint64 - - // This indicates whether namedNonBasicTypesSidetable needs to be created at - // all. If it is false, namedNonBasicTypesSidetable will contain simple - // monotonically increasing numbers. - needsNamedNonBasicTypesSidetable bool -} - -// LowerReflect is used to assign a type code to each type in the program -// that is ever stored in an interface. It tries to use the smallest possible -// numbers to make the code that works with interfaces as small as possible. -func LowerReflect(mod llvm.Module) { - // if reflect were not used, we could skip generating the sidetable - // this does not help in practice, and is difficult to do correctly - - // Obtain slice of all types in the program. - type typeInfo struct { - typecode llvm.Value - name string - numUses int - } - var types []*typeInfo - for global := mod.FirstGlobal(); !global.IsNil(); global = llvm.NextGlobal(global) { - if strings.HasPrefix(global.Name(), "reflect/types.type:") { - types = append(types, &typeInfo{ - typecode: global, - name: global.Name(), - numUses: len(getUses(global)), - }) - } - } - - // Sort the slice in a way that often used types are assigned a type code - // first. - sort.Slice(types, func(i, j int) bool { - if types[i].numUses != types[j].numUses { - return types[i].numUses < types[j].numUses - } - // It would make more sense to compare the name in the other direction, - // but for some reason that increases binary size. Could be a fluke, but - // could also have some good reason (and possibly hint at a small - // optimization). - return types[i].name > types[j].name - }) - - // Assign typecodes the way the reflect package expects. - targetData := llvm.NewTargetData(mod.DataLayout()) - defer targetData.Dispose() - uintptrType := mod.Context().IntType(targetData.PointerSize() * 8) - state := typeCodeAssignmentState{ - builder: mod.Context().NewBuilder(), - fallbackIndex: 1, - uintptrLen: targetData.PointerSize() * 8, - namedBasicTypes: make(map[string]int), - namedNonBasicTypes: make(map[string]int), - arrayTypes: make(map[string]int), - structTypes: make(map[string]int), - structNames: make(map[string]int), - needsNamedNonBasicTypesSidetable: len(getUses(mod.NamedGlobal("reflect.namedNonBasicTypesSidetable"))) != 0, - needsStructTypesSidetable: len(getUses(mod.NamedGlobal("reflect.structTypesSidetable"))) != 0, - needsStructNamesSidetable: len(getUses(mod.NamedGlobal("reflect.structNamesSidetable"))) != 0, - needsArrayTypesSidetable: len(getUses(mod.NamedGlobal("reflect.arrayTypesSidetable"))) != 0, - } - defer state.builder.Dispose() - for _, t := range types { - num := state.getTypeCodeNum(t.typecode) - if num.BitLen() > state.uintptrLen || !num.IsUint64() { - // TODO: support this in some way, using a side table for example. - // That's less efficient but better than not working at all. - // Particularly important on systems with 16-bit pointers (e.g. - // AVR). - panic("compiler: could not store type code number inside interface type code") - } - - // Replace each use of the type code global with the constant type code. - for _, use := range getUses(t.typecode) { - if use.IsAConstantExpr().IsNil() { - continue - } - typecode := llvm.ConstInt(uintptrType, num.Uint64(), false) - switch use.Opcode() { - case llvm.PtrToInt: - // Already of the correct type. - case llvm.BitCast: - // Could happen when stored in an interface (which is of type - // i8*). - typecode = llvm.ConstIntToPtr(typecode, use.Type()) - default: - panic("unexpected constant expression") - } - use.ReplaceAllUsesWith(typecode) - } - } - - // Only create this sidetable when it is necessary. - if state.needsNamedNonBasicTypesSidetable { - global := replaceGlobalIntWithArray(mod, "reflect.namedNonBasicTypesSidetable", state.namedNonBasicTypesSidetable) - global.SetLinkage(llvm.InternalLinkage) - global.SetUnnamedAddr(true) - global.SetGlobalConstant(true) - } - if state.needsArrayTypesSidetable { - global := replaceGlobalIntWithArray(mod, "reflect.arrayTypesSidetable", state.arrayTypesSidetable) - global.SetLinkage(llvm.InternalLinkage) - global.SetUnnamedAddr(true) - global.SetGlobalConstant(true) - } - if state.needsStructTypesSidetable { - global := replaceGlobalIntWithArray(mod, "reflect.structTypesSidetable", state.structTypesSidetable) - global.SetLinkage(llvm.InternalLinkage) - global.SetUnnamedAddr(true) - global.SetGlobalConstant(true) - } - if state.needsStructNamesSidetable { - global := replaceGlobalIntWithArray(mod, "reflect.structNamesSidetable", state.structNamesSidetable) - global.SetLinkage(llvm.InternalLinkage) - global.SetUnnamedAddr(true) - global.SetGlobalConstant(true) - } - - // Remove most objects created for interface and reflect lowering. - // They would normally be removed anyway in later passes, but not always. - // It also cleans up the IR for testing. - for _, typ := range types { - initializer := typ.typecode.Initializer() - references := state.builder.CreateExtractValue(initializer, 0, "") - typ.typecode.SetInitializer(llvm.ConstNull(initializer.Type())) - if strings.HasPrefix(typ.name, "reflect/types.type:struct:") { - // Structs have a 'references' field that is not a typecode but - // a pointer to a runtime.structField array and therefore a - // bitcast. This global should be erased separately, otherwise - // typecode objects cannot be erased. - structFields := references - if !structFields.IsAConstantExpr().IsNil() && structFields.Opcode() == llvm.BitCast { - structFields = structFields.Operand(0) // get global from bitcast, for LLVM 14 compatibility (non-opaque pointers) - } - structFields.EraseFromParentAsGlobal() - } - } -} - -// getTypeCodeNum returns the typecode for a given type as expected by the -// reflect package. Also see getTypeCodeName, which serializes types to a string -// based on a types.Type value for this function. -func (state *typeCodeAssignmentState) getTypeCodeNum(typecode llvm.Value) *big.Int { - // Note: see src/reflect/type.go for bit allocations. - class, value := getClassAndValueFromTypeCode(typecode) - name := "" - if class == "named" { - name = value - typecode = state.builder.CreateExtractValue(typecode.Initializer(), 0, "") - class, value = getClassAndValueFromTypeCode(typecode) - } - if class == "basic" { - // Basic types follow the following bit pattern: - // ...xxxxx0 - // where xxxxx is allocated for the 18 possible basic types and all the - // upper bits are used to indicate the named type. - num, ok := basicTypes[value] - if !ok { - panic("invalid basic type: " + value) - } - if name != "" { - // This type is named, set the upper bits to the name ID. - num |= int64(state.getBasicNamedTypeNum(name)) << 5 - } - return big.NewInt(num << 1) - } else { - // Non-baisc types use the following bit pattern: - // ...nxxx1 - // where xxx indicates the non-basic type. The upper bits contain - // whatever the type contains. Types that wrap a single other type - // (channel, interface, pointer, slice) just contain the bits of the - // wrapped type. Other types (like struct) need more fields and thus - // cannot be encoded as a simple prefix. - var classNumber int64 - if n, ok := nonBasicTypes[class]; ok { - classNumber = n - } else { - panic("unknown type kind: " + class) - } - var num *big.Int - lowBits := (classNumber << 1) + 1 // the 5 low bits of the typecode - if name == "" { - num = state.getNonBasicTypeCode(class, typecode) - } else { - // We must return a named type here. But first check whether it - // has already been defined. - if index, ok := state.namedNonBasicTypes[name]; ok { - num := big.NewInt(int64(index)) - num.Lsh(num, 5).Or(num, big.NewInt((classNumber<<1)+1+(1<<4))) - return num - } - lowBits |= 1 << 4 // set the 'n' bit (see above) - if !state.needsNamedNonBasicTypesSidetable { - // Use simple small integers in this case, to make these numbers - // smaller. - index := len(state.namedNonBasicTypes) + 1 - state.namedNonBasicTypes[name] = index - num = big.NewInt(int64(index)) - } else { - // We need to store full type information. - // First allocate a number in the named non-basic type - // sidetable. - index := len(state.namedNonBasicTypesSidetable) - state.namedNonBasicTypesSidetable = append(state.namedNonBasicTypesSidetable, 0) - state.namedNonBasicTypes[name] = index - // Get the typecode of the underlying type (which could be the - // element type in the case of pointers, for example). - num = state.getNonBasicTypeCode(class, typecode) - if num.BitLen() > state.uintptrLen || !num.IsUint64() { - panic("cannot store value in sidetable") - } - // Now update the side table with the number we just - // determined. We need this multi-step approach to avoid stack - // overflow due to adding types recursively in the case of - // linked lists (a pointer which points to a struct that - // contains that same pointer). - state.namedNonBasicTypesSidetable[index] = num.Uint64() - num = big.NewInt(int64(index)) - } - } - // Concatenate the 'num' and 'lowBits' bitstrings. - num.Lsh(num, 5).Or(num, big.NewInt(lowBits)) - return num - } -} - -// getNonBasicTypeCode is used by getTypeCodeNum. It returns the upper bits of -// the type code used there in the type code. -func (state *typeCodeAssignmentState) getNonBasicTypeCode(class string, typecode llvm.Value) *big.Int { - switch class { - case "chan", "pointer", "slice": - // Prefix-style type kinds. The upper bits contain the element type. - sub := state.builder.CreateExtractValue(typecode.Initializer(), 0, "") - return state.getTypeCodeNum(sub) - case "array": - // An array is basically a pair of (typecode, length) stored in a - // sidetable. - return big.NewInt(int64(state.getArrayTypeNum(typecode))) - case "struct": - // More complicated type kind. The upper bits contain the index to the - // struct type in the struct types sidetable. - return big.NewInt(int64(state.getStructTypeNum(typecode))) - default: - // Type has not yet been implemented, so fall back by using a unique - // number. - num := big.NewInt(int64(state.fallbackIndex)) - state.fallbackIndex++ - return num - } -} - -// getClassAndValueFromTypeCode takes a typecode (a llvm.Value of type -// runtime.typecodeID), looks at the name, and extracts the typecode class and -// value from it. For example, for a typecode with the following name: -// -// reflect/types.type:pointer:named:reflect.ValueError -// -// It extracts: -// -// class = "pointer" -// value = "named:reflect.ValueError" -func getClassAndValueFromTypeCode(typecode llvm.Value) (class, value string) { - typecodeName := typecode.Name() - const prefix = "reflect/types.type:" - if !strings.HasPrefix(typecodeName, prefix) { - panic("unexpected typecode name: " + typecodeName) - } - id := typecodeName[len(prefix):] - class = id[:strings.IndexByte(id, ':')] - value = id[len(class)+1:] - return -} - -// getBasicNamedTypeNum returns an appropriate (unique) number for the given -// named type. If the name already has a number that number is returned, else a -// new number is returned. The number is always non-zero. -func (state *typeCodeAssignmentState) getBasicNamedTypeNum(name string) int { - if num, ok := state.namedBasicTypes[name]; ok { - return num - } - num := len(state.namedBasicTypes) + 1 - state.namedBasicTypes[name] = num - return num -} - -// getArrayTypeNum returns the array type number, which is an index into the -// reflect.arrayTypesSidetable or a unique number for this type if this table is -// not used. -func (state *typeCodeAssignmentState) getArrayTypeNum(typecode llvm.Value) int { - name := typecode.Name() - if num, ok := state.arrayTypes[name]; ok { - // This array type already has an entry in the sidetable. Don't store - // it twice. - return num - } - - if !state.needsArrayTypesSidetable { - // We don't need array sidetables, so we can just assign monotonically - // increasing numbers to each array type. - num := len(state.arrayTypes) - state.arrayTypes[name] = num - return num - } - - elemTypeCode := state.builder.CreateExtractValue(typecode.Initializer(), 0, "") - elemTypeNum := state.getTypeCodeNum(elemTypeCode) - if elemTypeNum.BitLen() > state.uintptrLen || !elemTypeNum.IsUint64() { - // TODO: make this a regular error - panic("array element type has a type code that is too big") - } - - // The array side table is a sequence of {element type, array length}. - arrayLength := state.builder.CreateExtractValue(typecode.Initializer(), 1, "").ZExtValue() - buf := makeVarint(elemTypeNum.Uint64()) - buf = append(buf, makeVarint(arrayLength)...) - - index := len(state.arrayTypesSidetable) - state.arrayTypes[name] = index - state.arrayTypesSidetable = append(state.arrayTypesSidetable, buf...) - return index -} - -// getStructTypeNum returns the struct type number, which is an index into -// reflect.structTypesSidetable or an unique number for every struct if this -// sidetable is not needed in the to-be-compiled program. -func (state *typeCodeAssignmentState) getStructTypeNum(typecode llvm.Value) int { - name := typecode.Name() - if num, ok := state.structTypes[name]; ok { - // This struct already has an assigned type code. - return num - } - - if !state.needsStructTypesSidetable { - // We don't need struct sidetables, so we can just assign monotonically - // increasing numbers to each struct type. - num := len(state.structTypes) - state.structTypes[name] = num - return num - } - - // Get the fields this struct type contains. - // The struct number will be the start index of - structTypeGlobal := stripPointerCasts(state.builder.CreateExtractValue(typecode.Initializer(), 0, "")).Initializer() - numFields := structTypeGlobal.Type().ArrayLength() - - // The first data that is stored in the struct sidetable is the number of - // fields this struct contains. This is usually just a single byte because - // most structs don't contain that many fields, but make it a varint just - // to be sure. - buf := makeVarint(uint64(numFields)) - - // Iterate over every field in the struct. - // Every field is stored sequentially in the struct sidetable. Fields can - // be retrieved from this list of fields at runtime by iterating over all - // of them until the right field has been found. - // Perhaps adding some index would speed things up, but it would also make - // the sidetable bigger. - for i := 0; i < numFields; i++ { - // Collect some information about this field. - field := state.builder.CreateExtractValue(structTypeGlobal, i, "") - - nameGlobal := state.builder.CreateExtractValue(field, 1, "") - if nameGlobal == llvm.ConstPointerNull(nameGlobal.Type()) { - panic("compiler: no name for this struct field") - } - fieldNameBytes := getGlobalBytes(stripPointerCasts(nameGlobal), state.builder) - fieldNameNumber := state.getStructNameNumber(fieldNameBytes) - - // See whether this struct field has an associated tag, and if so, - // store that tag in the tags sidetable. - tagGlobal := state.builder.CreateExtractValue(field, 2, "") - hasTag := false - tagNumber := 0 - if tagGlobal != llvm.ConstPointerNull(tagGlobal.Type()) { - hasTag = true - tagBytes := getGlobalBytes(stripPointerCasts(tagGlobal), state.builder) - tagNumber = state.getStructNameNumber(tagBytes) - } - - // The 'embedded' or 'anonymous' flag for this field. - embedded := state.builder.CreateExtractValue(field, 3, "").ZExtValue() != 0 - - // The first byte in the struct types sidetable is a flags byte with - // two bits in it. - flagsByte := byte(0) - if embedded { - flagsByte |= 1 - } - if hasTag { - flagsByte |= 2 - } - if ast.IsExported(string(fieldNameBytes)) { - flagsByte |= 4 - } - buf = append(buf, flagsByte) - - // Get the type number and add it to the buffer. - // All fields have a type, so include it directly here. - typeNum := state.getTypeCodeNum(state.builder.CreateExtractValue(field, 0, "")) - if typeNum.BitLen() > state.uintptrLen || !typeNum.IsUint64() { - // TODO: make this a regular error - panic("struct field has a type code that is too big") - } - buf = append(buf, makeVarint(typeNum.Uint64())...) - - // Add the name. - buf = append(buf, makeVarint(uint64(fieldNameNumber))...) - - // Add the tag, if there is one. - if hasTag { - buf = append(buf, makeVarint(uint64(tagNumber))...) - } - } - - num := len(state.structTypesSidetable) - state.structTypes[name] = num - state.structTypesSidetable = append(state.structTypesSidetable, buf...) - return num -} - -// getStructNameNumber stores this string (name or tag) onto the struct names -// sidetable. The format is a varint of the length of the struct, followed by -// the raw bytes of the name. Multiple identical strings are stored under the -// same name for space efficiency. -func (state *typeCodeAssignmentState) getStructNameNumber(nameBytes []byte) int { - name := string(nameBytes) - if n, ok := state.structNames[name]; ok { - // This name was used before, re-use it now (for space efficiency). - return n - } - // This name is not yet in the names sidetable. Add it now. - n := len(state.structNamesSidetable) - state.structNames[name] = n - state.structNamesSidetable = append(state.structNamesSidetable, makeVarint(uint64(len(nameBytes)))...) - state.structNamesSidetable = append(state.structNamesSidetable, nameBytes...) - return n -} - -// makeVarint is a small helper function that returns the bytes of the number in -// varint encoding. -func makeVarint(n uint64) []byte { - buf := make([]byte, binary.MaxVarintLen64) - return buf[:binary.PutUvarint(buf, n)] -} diff --git a/transform/reflect_test.go b/transform/reflect_test.go deleted file mode 100644 index 242337024a..0000000000 --- a/transform/reflect_test.go +++ /dev/null @@ -1,77 +0,0 @@ -package transform_test - -import ( - "testing" - - "github.com/tinygo-org/tinygo/transform" - "tinygo.org/x/go-llvm" -) - -type reflectAssert struct { - call llvm.Value - name string - expectedNumber uint64 -} - -// Test reflect lowering. This code looks at IR like this: -// -// call void @main.assertType(i32 ptrtoint (%runtime.typecodeID* @"reflect/types.type:basic:int" to i32), i8* inttoptr (i32 3 to i8*), i32 4, i8* undef, i8* undef) -// -// and verifies that the ptrtoint constant (the first parameter of -// @main.assertType) is replaced with the correct type code. The expected -// output is this: -// -// call void @main.assertType(i32 4, i8* inttoptr (i32 3 to i8*), i32 4, i8* undef, i8* undef) -// -// The first and third parameter are compared and must match, the second -// parameter is ignored. -func TestReflect(t *testing.T) { - t.Parallel() - - mod := compileGoFileForTesting(t, "./testdata/reflect.go") - - // Run the instcombine pass, to clean up the IR a bit (especially - // insertvalue/extractvalue instructions). - pm := llvm.NewPassManager() - defer pm.Dispose() - pm.AddInstructionCombiningPass() - pm.Run(mod) - - // Get a list of all the asserts in the source code. - assertType := mod.NamedFunction("main.assertType") - var asserts []reflectAssert - for user := assertType.FirstUse(); !user.IsNil(); user = user.NextUse() { - use := user.User() - if use.IsACallInst().IsNil() { - t.Fatal("expected call use of main.assertType") - } - global := use.Operand(0).Operand(0) - expectedNumber := use.Operand(2).ZExtValue() - asserts = append(asserts, reflectAssert{ - call: use, - name: global.Name(), - expectedNumber: expectedNumber, - }) - } - - // Sanity check to show that the test is actually testing anything. - if len(asserts) < 3 { - t.Errorf("expected at least 3 test cases, got %d", len(asserts)) - } - - // Now lower the type codes. - transform.LowerReflect(mod) - - // Check whether the values are as expected. - for _, assert := range asserts { - actualNumberValue := assert.call.Operand(0) - if actualNumberValue.IsAConstantInt().IsNil() { - t.Errorf("expected to see a constant for %s, got something else", assert.name) - continue - } - actualNumber := actualNumberValue.ZExtValue() - if actualNumber != assert.expectedNumber { - t.Errorf("%s: expected number 0b%b, got 0b%b", assert.name, assert.expectedNumber, actualNumber) - } - } -} diff --git a/transform/rtcalls.go b/transform/rtcalls.go index 209e15ae13..36d2853b6c 100644 --- a/transform/rtcalls.go +++ b/transform/rtcalls.go @@ -113,11 +113,6 @@ func OptimizeReflectImplements(mod llvm.Module) { builder := mod.Context().NewBuilder() defer builder.Dispose() - // Get a few useful object for use later. - targetData := llvm.NewTargetData(mod.DataLayout()) - defer targetData.Dispose() - uintptrType := mod.Context().IntType(targetData.PointerSize() * 8) - // Look up the (reflect.Value).Implements() method. var implementsFunc llvm.Value for fn := mod.FirstFunction(); !fn.IsNil(); fn = llvm.NextFunction(fn) { @@ -141,14 +136,13 @@ func OptimizeReflectImplements(mod llvm.Module) { } interfaceType := stripPointerCasts(call.Operand(2)) if interfaceType.IsAGlobalVariable().IsNil() { - // The asserted interface is not constant, so can't optimize this - // code. + // Interface is unknown at compile time. This can't be optimized. continue } if strings.HasPrefix(interfaceType.Name(), "reflect/types.type:named:") { // Get the underlying type. - interfaceType = builder.CreateExtractValue(interfaceType.Initializer(), 0, "") + interfaceType = stripPointerCasts(builder.CreateExtractValue(interfaceType.Initializer(), 3, "")) } if !strings.HasPrefix(interfaceType.Name(), "reflect/types.type:interface:") { // This is an error. The Type passed to Implements should be of @@ -156,16 +150,15 @@ func OptimizeReflectImplements(mod llvm.Module) { // reported at runtime. continue } - if interfaceType.IsAGlobalVariable().IsNil() { - // Interface is unknown at compile time. This can't be optimized. + typeAssertFunction := mod.NamedFunction(strings.TrimPrefix(interfaceType.Name(), "reflect/types.type:") + ".$typeassert") + if typeAssertFunction.IsNil() { continue } - typeAssertFunction := builder.CreateExtractValue(interfaceType.Initializer(), 4, "").Operand(0) // Replace Implements call with the type assert call. builder.SetInsertPointBefore(call) implements := builder.CreateCall(typeAssertFunction.GlobalValueType(), typeAssertFunction, []llvm.Value{ - builder.CreatePtrToInt(call.Operand(0), uintptrType, ""), // typecode to check + call.Operand(0), // typecode to check }, "") call.ReplaceAllUsesWith(implements) call.EraseFromParentAsInstruction() diff --git a/transform/stacksize.go b/transform/stacksize.go index 3e46a5794d..2f7a6c1d6f 100644 --- a/transform/stacksize.go +++ b/transform/stacksize.go @@ -1,8 +1,11 @@ package transform import ( + "path/filepath" + "github.com/tinygo-org/tinygo/compileopts" "github.com/tinygo-org/tinygo/compiler/llvmutil" + "github.com/tinygo-org/tinygo/goenv" "tinygo.org/x/go-llvm" ) @@ -47,10 +50,49 @@ func CreateStackSizeLoads(mod llvm.Module, config *compileopts.Config) []string stackSizesGlobal.SetSection(".tinygo_stacksizes") defaultStackSizes := make([]llvm.Value, len(functions)) defaultStackSize := llvm.ConstInt(functions[0].Type(), config.StackSize(), false) + alignment := targetData.ABITypeAlignment(functions[0].Type()) for i := range defaultStackSizes { defaultStackSizes[i] = defaultStackSize } stackSizesGlobal.SetInitializer(llvm.ConstArray(functions[0].Type(), defaultStackSizes)) + stackSizesGlobal.SetAlignment(alignment) + // TODO: make this a constant. For some reason, that incrases code size though. + if config.Debug() { + dibuilder := llvm.NewDIBuilder(mod) + dibuilder.CreateCompileUnit(llvm.DICompileUnit{ + Language: 0xb, // DW_LANG_C99 (0xc, off-by-one?) + File: "", + Dir: "", + Producer: "TinyGo", + Optimized: true, + }) + ditype := dibuilder.CreateArrayType(llvm.DIArrayType{ + SizeInBits: targetData.TypeAllocSize(stackSizesGlobalType) * 8, + AlignInBits: uint32(alignment * 8), + ElementType: dibuilder.CreateBasicType(llvm.DIBasicType{ + Name: "uintptr", + SizeInBits: targetData.TypeAllocSize(functions[0].Type()) * 8, + Encoding: llvm.DW_ATE_unsigned, + }), + Subscripts: []llvm.DISubrange{ + { + Lo: 0, + Count: int64(len(functions)), + }, + }, + }) + diglobal := dibuilder.CreateGlobalVariableExpression(llvm.Metadata{}, llvm.DIGlobalVariableExpression{ + Name: "internal/task.stackSizes", + File: dibuilder.CreateFile("internal/task/task_stack.go", filepath.Join(goenv.Get("TINYGOROOT"), "src")), + Line: 1, + Type: ditype, + Expr: dibuilder.CreateExpression(nil), + }) + stackSizesGlobal.AddMetadata(0, diglobal) + + dibuilder.Finalize() + dibuilder.Destroy() + } // Add all relevant values to llvm.used (for LTO). llvmutil.AppendToGlobal(mod, "llvm.used", append([]llvm.Value{stackSizesGlobal}, functionValues...)...) diff --git a/transform/testdata/allocs.ll b/transform/testdata/allocs.ll index 58af2ea82e..1c2fdd5aa4 100644 --- a/transform/testdata/allocs.ll +++ b/transform/testdata/allocs.ll @@ -3,57 +3,51 @@ target triple = "armv7m-none-eabi" @runtime.zeroSizedAlloc = internal global i8 0, align 1 -declare nonnull i8* @runtime.alloc(i32, i8*) +declare nonnull ptr @runtime.alloc(i32, ptr) ; Test allocating a single int (i32) that should be allocated on the stack. define void @testInt() { - %1 = call i8* @runtime.alloc(i32 4, i8* null) - %2 = bitcast i8* %1 to i32* - store i32 5, i32* %2 + %alloc = call ptr @runtime.alloc(i32 4, ptr null) + store i32 5, ptr %alloc ret void } ; Test allocating an array of 3 i16 values that should be allocated on the ; stack. define i16 @testArray() { - %1 = call i8* @runtime.alloc(i32 6, i8* null) - %2 = bitcast i8* %1 to i16* - %3 = getelementptr i16, i16* %2, i32 1 - store i16 5, i16* %3 - %4 = getelementptr i16, i16* %2, i32 2 - %5 = load i16, i16* %4 - ret i16 %5 + %alloc = call ptr @runtime.alloc(i32 6, ptr null) + %alloc.1 = getelementptr i16, ptr %alloc, i32 1 + store i16 5, ptr %alloc.1 + %alloc.2 = getelementptr i16, ptr %alloc, i32 2 + %val = load i16, ptr %alloc.2 + ret i16 %val } ; Call a function that will let the pointer escape, so the heap-to-stack ; transform shouldn't be applied. define void @testEscapingCall() { - %1 = call i8* @runtime.alloc(i32 4, i8* null) - %2 = bitcast i8* %1 to i32* - %3 = call i32* @escapeIntPtr(i32* %2) + %alloc = call ptr @runtime.alloc(i32 4, ptr null) + %val = call ptr @escapeIntPtr(ptr %alloc) ret void } define void @testEscapingCall2() { - %1 = call i8* @runtime.alloc(i32 4, i8* null) - %2 = bitcast i8* %1 to i32* - %3 = call i32* @escapeIntPtrSometimes(i32* %2, i32* %2) + %alloc = call ptr @runtime.alloc(i32 4, ptr null) + %val = call ptr @escapeIntPtrSometimes(ptr %alloc, ptr %alloc) ret void } ; Call a function that doesn't let the pointer escape. define void @testNonEscapingCall() { - %1 = call i8* @runtime.alloc(i32 4, i8* null) - %2 = bitcast i8* %1 to i32* - %3 = call i32* @noescapeIntPtr(i32* %2) + %alloc = call ptr @runtime.alloc(i32 4, ptr null) + %val = call ptr @noescapeIntPtr(ptr %alloc) ret void } ; Return the allocated value, which lets it escape. -define i32* @testEscapingReturn() { - %1 = call i8* @runtime.alloc(i32 4, i8* null) - %2 = bitcast i8* %1 to i32* - ret i32* %2 +define ptr @testEscapingReturn() { + %alloc = call ptr @runtime.alloc(i32 4, ptr null) + ret ptr %alloc } ; Do a non-escaping allocation in a loop. @@ -61,25 +55,23 @@ define void @testNonEscapingLoop() { entry: br label %loop loop: - %0 = call i8* @runtime.alloc(i32 4, i8* null) - %1 = bitcast i8* %0 to i32* - %2 = call i32* @noescapeIntPtr(i32* %1) - %3 = icmp eq i32* null, %2 - br i1 %3, label %loop, label %end + %alloc = call ptr @runtime.alloc(i32 4, ptr null) + %ptr = call ptr @noescapeIntPtr(ptr %alloc) + %result = icmp eq ptr null, %ptr + br i1 %result, label %loop, label %end end: ret void } ; Test a zero-sized allocation. define void @testZeroSizedAlloc() { - %1 = call i8* @runtime.alloc(i32 0, i8* null) - %2 = bitcast i8* %1 to i32* - %3 = call i32* @noescapeIntPtr(i32* %2) + %alloc = call ptr @runtime.alloc(i32 0, ptr null) + %ptr = call ptr @noescapeIntPtr(ptr %alloc) ret void } -declare i32* @escapeIntPtr(i32*) +declare ptr @escapeIntPtr(ptr) -declare i32* @noescapeIntPtr(i32* nocapture) +declare ptr @noescapeIntPtr(ptr nocapture) -declare i32* @escapeIntPtrSometimes(i32* nocapture, i32*) +declare ptr @escapeIntPtrSometimes(ptr nocapture, ptr) diff --git a/transform/testdata/allocs.out.ll b/transform/testdata/allocs.out.ll index 48f9b7685f..d1b07e6c42 100644 --- a/transform/testdata/allocs.out.ll +++ b/transform/testdata/allocs.out.ll @@ -3,53 +3,47 @@ target triple = "armv7m-none-eabi" @runtime.zeroSizedAlloc = internal global i8 0, align 1 -declare nonnull i8* @runtime.alloc(i32, i8*) +declare nonnull ptr @runtime.alloc(i32, ptr) define void @testInt() { %stackalloc.alloca = alloca [4 x i8], align 4 - store [4 x i8] zeroinitializer, [4 x i8]* %stackalloc.alloca, align 4 - %stackalloc = bitcast [4 x i8]* %stackalloc.alloca to i32* - store i32 5, i32* %stackalloc, align 4 + store [4 x i8] zeroinitializer, ptr %stackalloc.alloca, align 4 + store i32 5, ptr %stackalloc.alloca, align 4 ret void } define i16 @testArray() { %stackalloc.alloca = alloca [6 x i8], align 2 - store [6 x i8] zeroinitializer, [6 x i8]* %stackalloc.alloca, align 2 - %stackalloc = bitcast [6 x i8]* %stackalloc.alloca to i16* - %1 = getelementptr i16, i16* %stackalloc, i32 1 - store i16 5, i16* %1, align 2 - %2 = getelementptr i16, i16* %stackalloc, i32 2 - %3 = load i16, i16* %2, align 2 - ret i16 %3 + store [6 x i8] zeroinitializer, ptr %stackalloc.alloca, align 2 + %alloc.1 = getelementptr i16, ptr %stackalloc.alloca, i32 1 + store i16 5, ptr %alloc.1, align 2 + %alloc.2 = getelementptr i16, ptr %stackalloc.alloca, i32 2 + %val = load i16, ptr %alloc.2, align 2 + ret i16 %val } define void @testEscapingCall() { - %1 = call i8* @runtime.alloc(i32 4, i8* null) - %2 = bitcast i8* %1 to i32* - %3 = call i32* @escapeIntPtr(i32* %2) + %alloc = call ptr @runtime.alloc(i32 4, ptr null) + %val = call ptr @escapeIntPtr(ptr %alloc) ret void } define void @testEscapingCall2() { - %1 = call i8* @runtime.alloc(i32 4, i8* null) - %2 = bitcast i8* %1 to i32* - %3 = call i32* @escapeIntPtrSometimes(i32* %2, i32* %2) + %alloc = call ptr @runtime.alloc(i32 4, ptr null) + %val = call ptr @escapeIntPtrSometimes(ptr %alloc, ptr %alloc) ret void } define void @testNonEscapingCall() { %stackalloc.alloca = alloca [4 x i8], align 4 - store [4 x i8] zeroinitializer, [4 x i8]* %stackalloc.alloca, align 4 - %stackalloc = bitcast [4 x i8]* %stackalloc.alloca to i32* - %1 = call i32* @noescapeIntPtr(i32* %stackalloc) + store [4 x i8] zeroinitializer, ptr %stackalloc.alloca, align 4 + %val = call ptr @noescapeIntPtr(ptr %stackalloc.alloca) ret void } -define i32* @testEscapingReturn() { - %1 = call i8* @runtime.alloc(i32 4, i8* null) - %2 = bitcast i8* %1 to i32* - ret i32* %2 +define ptr @testEscapingReturn() { + %alloc = call ptr @runtime.alloc(i32 4, ptr null) + ret ptr %alloc } define void @testNonEscapingLoop() { @@ -58,24 +52,22 @@ entry: br label %loop loop: ; preds = %loop, %entry - store [4 x i8] zeroinitializer, [4 x i8]* %stackalloc.alloca, align 4 - %stackalloc = bitcast [4 x i8]* %stackalloc.alloca to i32* - %0 = call i32* @noescapeIntPtr(i32* %stackalloc) - %1 = icmp eq i32* null, %0 - br i1 %1, label %loop, label %end + store [4 x i8] zeroinitializer, ptr %stackalloc.alloca, align 4 + %ptr = call ptr @noescapeIntPtr(ptr %stackalloc.alloca) + %result = icmp eq ptr null, %ptr + br i1 %result, label %loop, label %end end: ; preds = %loop ret void } define void @testZeroSizedAlloc() { - %1 = bitcast i8* @runtime.zeroSizedAlloc to i32* - %2 = call i32* @noescapeIntPtr(i32* %1) + %ptr = call ptr @noescapeIntPtr(ptr @runtime.zeroSizedAlloc) ret void } -declare i32* @escapeIntPtr(i32*) +declare ptr @escapeIntPtr(ptr) -declare i32* @noescapeIntPtr(i32* nocapture) +declare ptr @noescapeIntPtr(ptr nocapture) -declare i32* @escapeIntPtrSometimes(i32* nocapture, i32*) +declare ptr @escapeIntPtrSometimes(ptr nocapture, ptr) diff --git a/transform/testdata/gc-stackslots.ll b/transform/testdata/gc-stackslots.ll index c217fb9ad2..7f196e9b42 100644 --- a/transform/testdata/gc-stackslots.ll +++ b/transform/testdata/gc-stackslots.ll @@ -1,101 +1,99 @@ target datalayout = "e-m:e-p:32:32-i64:64-n32:64-S128" target triple = "wasm32-unknown-unknown-wasm" -%runtime.stackChainObject = type { %runtime.stackChainObject*, i32 } - -@runtime.stackChainStart = external global %runtime.stackChainObject* +@runtime.stackChainStart = external global ptr @someGlobal = global i8 3 -@ptrGlobal = global i8** null +@ptrGlobal = global ptr null -declare void @runtime.trackPointer(i8* nocapture readonly) +declare void @runtime.trackPointer(ptr nocapture readonly) -declare noalias nonnull i8* @runtime.alloc(i32, i8*) +declare noalias nonnull ptr @runtime.alloc(i32, ptr) ; Generic function that returns a pointer (that must be tracked). -define i8* @getPointer() { - ret i8* @someGlobal +define ptr @getPointer() { + ret ptr @someGlobal } -define i8* @needsStackSlots() { +define ptr @needsStackSlots() { ; Tracked pointer. Although, in this case the value is immediately returned ; so tracking it is not really necessary. - %ptr = call i8* @runtime.alloc(i32 4, i8* null) - call void @runtime.trackPointer(i8* %ptr) + %ptr = call ptr @runtime.alloc(i32 4, ptr null) + call void @runtime.trackPointer(ptr %ptr) call void @someArbitraryFunction() - %val = load i8, i8* @someGlobal - ret i8* %ptr + %val = load i8, ptr @someGlobal + ret ptr %ptr } ; Check some edge cases of pointer tracking. -define i8* @needsStackSlots2() { +define ptr @needsStackSlots2() { ; Only one stack slot should be created for this (but at the moment, one is ; created for each call to runtime.trackPointer). - %ptr1 = call i8* @getPointer() - call void @runtime.trackPointer(i8* %ptr1) - call void @runtime.trackPointer(i8* %ptr1) - call void @runtime.trackPointer(i8* %ptr1) + %ptr1 = call ptr @getPointer() + call void @runtime.trackPointer(ptr %ptr1) + call void @runtime.trackPointer(ptr %ptr1) + call void @runtime.trackPointer(ptr %ptr1) ; Create a pointer that does not need to be tracked (but is tracked). - %ptr2 = getelementptr i8, i8* @someGlobal, i32 0 - call void @runtime.trackPointer(i8* %ptr2) + %ptr2 = getelementptr i8, ptr @someGlobal, i32 0 + call void @runtime.trackPointer(ptr %ptr2) ; Here is finally the point where an allocation happens. - %unused = call i8* @runtime.alloc(i32 4, i8* null) - call void @runtime.trackPointer(i8* %unused) + %unused = call ptr @runtime.alloc(i32 4, ptr null) + call void @runtime.trackPointer(ptr %unused) - ret i8* %ptr1 + ret ptr %ptr1 } ; Return a pointer from a caller. Because it doesn't allocate, no stack objects ; need to be created. -define i8* @noAllocatingFunction() { - %ptr = call i8* @getPointer() - call void @runtime.trackPointer(i8* %ptr) - ret i8* %ptr +define ptr @noAllocatingFunction() { + %ptr = call ptr @getPointer() + call void @runtime.trackPointer(ptr %ptr) + ret ptr %ptr } -define i8* @fibNext(i8* %x, i8* %y) { - %x.val = load i8, i8* %x - %y.val = load i8, i8* %y +define ptr @fibNext(ptr %x, ptr %y) { + %x.val = load i8, ptr %x + %y.val = load i8, ptr %y %out.val = add i8 %x.val, %y.val - %out.alloc = call i8* @runtime.alloc(i32 1, i8* null) - call void @runtime.trackPointer(i8* %out.alloc) - store i8 %out.val, i8* %out.alloc - ret i8* %out.alloc + %out.alloc = call ptr @runtime.alloc(i32 1, ptr null) + call void @runtime.trackPointer(ptr %out.alloc) + store i8 %out.val, ptr %out.alloc + ret ptr %out.alloc } -define i8* @allocLoop() { +define ptr @allocLoop() { entry: - %entry.x = call i8* @runtime.alloc(i32 1, i8* null) - call void @runtime.trackPointer(i8* %entry.x) - %entry.y = call i8* @runtime.alloc(i32 1, i8* null) - call void @runtime.trackPointer(i8* %entry.y) - store i8 1, i8* %entry.y + %entry.x = call ptr @runtime.alloc(i32 1, ptr null) + call void @runtime.trackPointer(ptr %entry.x) + %entry.y = call ptr @runtime.alloc(i32 1, ptr null) + call void @runtime.trackPointer(ptr %entry.y) + store i8 1, ptr %entry.y br label %loop loop: - %prev.y = phi i8* [ %entry.y, %entry ], [ %prev.x, %loop ] - %prev.x = phi i8* [ %entry.x, %entry ], [ %next.x, %loop ] - call void @runtime.trackPointer(i8* %prev.x) - call void @runtime.trackPointer(i8* %prev.y) - %next.x = call i8* @fibNext(i8* %prev.x, i8* %prev.y) - call void @runtime.trackPointer(i8* %next.x) - %next.x.val = load i8, i8* %next.x + %prev.y = phi ptr [ %entry.y, %entry ], [ %prev.x, %loop ] + %prev.x = phi ptr [ %entry.x, %entry ], [ %next.x, %loop ] + call void @runtime.trackPointer(ptr %prev.x) + call void @runtime.trackPointer(ptr %prev.y) + %next.x = call ptr @fibNext(ptr %prev.x, ptr %prev.y) + call void @runtime.trackPointer(ptr %next.x) + %next.x.val = load i8, ptr %next.x %loop.done = icmp ult i8 40, %next.x.val br i1 %loop.done, label %end, label %loop end: - ret i8* %next.x + ret ptr %next.x } -declare [32 x i8]* @arrayAlloc() +declare ptr @arrayAlloc() define void @testGEPBitcast() { - %arr = call [32 x i8]* @arrayAlloc() - %arr.bitcast = getelementptr [32 x i8], [32 x i8]* %arr, i32 0, i32 0 - call void @runtime.trackPointer(i8* %arr.bitcast) - %other = call i8* @runtime.alloc(i32 1, i8* null) - call void @runtime.trackPointer(i8* %other) + %arr = call ptr @arrayAlloc() + %arr.bitcast = getelementptr [32 x i8], ptr %arr, i32 0, i32 0 + call void @runtime.trackPointer(ptr %arr.bitcast) + %other = call ptr @runtime.alloc(i32 1, ptr null) + call void @runtime.trackPointer(ptr %other) ret void } @@ -104,18 +102,17 @@ define void @someArbitraryFunction() { } define void @earlyPopRegression() { - %x.alloc = call i8* @runtime.alloc(i32 4, i8* null) - call void @runtime.trackPointer(i8* %x.alloc) - %x = bitcast i8* %x.alloc to i8** + %x.alloc = call ptr @runtime.alloc(i32 4, ptr null) + call void @runtime.trackPointer(ptr %x.alloc) ; At this point the pass used to pop the stack chain, resulting in a potential use-after-free during allocAndSave. - musttail call void @allocAndSave(i8** %x) + musttail call void @allocAndSave(ptr %x.alloc) ret void } -define void @allocAndSave(i8** %x) { - %y = call i8* @runtime.alloc(i32 4, i8* null) - call void @runtime.trackPointer(i8* %y) - store i8* %y, i8** %x - store i8** %x, i8*** @ptrGlobal +define void @allocAndSave(ptr %x) { + %y = call ptr @runtime.alloc(i32 4, ptr null) + call void @runtime.trackPointer(ptr %y) + store ptr %y, ptr %x + store ptr %x, ptr @ptrGlobal ret void -} \ No newline at end of file +} diff --git a/transform/testdata/gc-stackslots.out.ll b/transform/testdata/gc-stackslots.out.ll index 83d1c841b8..f80d0c96c5 100644 --- a/transform/testdata/gc-stackslots.out.ll +++ b/transform/testdata/gc-stackslots.out.ll @@ -1,141 +1,134 @@ target datalayout = "e-m:e-p:32:32-i64:64-n32:64-S128" target triple = "wasm32-unknown-unknown-wasm" -%runtime.stackChainObject = type { %runtime.stackChainObject*, i32 } - -@runtime.stackChainStart = internal global %runtime.stackChainObject* null +@runtime.stackChainStart = internal global ptr null @someGlobal = global i8 3 -@ptrGlobal = global i8** null +@ptrGlobal = global ptr null -declare void @runtime.trackPointer(i8* nocapture readonly) +declare void @runtime.trackPointer(ptr nocapture readonly) -declare noalias nonnull i8* @runtime.alloc(i32, i8*) +declare noalias nonnull ptr @runtime.alloc(i32, ptr) -define i8* @getPointer() { - ret i8* @someGlobal +define ptr @getPointer() { + ret ptr @someGlobal } -define i8* @needsStackSlots() { - %gc.stackobject = alloca { %runtime.stackChainObject*, i32, i8* }, align 8 - store { %runtime.stackChainObject*, i32, i8* } { %runtime.stackChainObject* null, i32 1, i8* null }, { %runtime.stackChainObject*, i32, i8* }* %gc.stackobject, align 4 - %1 = load %runtime.stackChainObject*, %runtime.stackChainObject** @runtime.stackChainStart, align 4 - %2 = getelementptr { %runtime.stackChainObject*, i32, i8* }, { %runtime.stackChainObject*, i32, i8* }* %gc.stackobject, i32 0, i32 0 - store %runtime.stackChainObject* %1, %runtime.stackChainObject** %2, align 4 - %3 = bitcast { %runtime.stackChainObject*, i32, i8* }* %gc.stackobject to %runtime.stackChainObject* - store %runtime.stackChainObject* %3, %runtime.stackChainObject** @runtime.stackChainStart, align 4 - %ptr = call i8* @runtime.alloc(i32 4, i8* null) - %4 = getelementptr { %runtime.stackChainObject*, i32, i8* }, { %runtime.stackChainObject*, i32, i8* }* %gc.stackobject, i32 0, i32 2 - store i8* %ptr, i8** %4, align 4 +define ptr @needsStackSlots() { + %gc.stackobject = alloca { ptr, i32, ptr }, align 8 + store { ptr, i32, ptr } { ptr null, i32 1, ptr null }, ptr %gc.stackobject, align 4 + %1 = load ptr, ptr @runtime.stackChainStart, align 4 + %2 = getelementptr { ptr, i32, ptr }, ptr %gc.stackobject, i32 0, i32 0 + store ptr %1, ptr %2, align 4 + store ptr %gc.stackobject, ptr @runtime.stackChainStart, align 4 + %ptr = call ptr @runtime.alloc(i32 4, ptr null) + %3 = getelementptr { ptr, i32, ptr }, ptr %gc.stackobject, i32 0, i32 2 + store ptr %ptr, ptr %3, align 4 call void @someArbitraryFunction() - %val = load i8, i8* @someGlobal, align 1 - store %runtime.stackChainObject* %1, %runtime.stackChainObject** @runtime.stackChainStart, align 4 - ret i8* %ptr + %val = load i8, ptr @someGlobal, align 1 + store ptr %1, ptr @runtime.stackChainStart, align 4 + ret ptr %ptr } -define i8* @needsStackSlots2() { - %gc.stackobject = alloca { %runtime.stackChainObject*, i32, i8*, i8*, i8*, i8*, i8* }, align 8 - store { %runtime.stackChainObject*, i32, i8*, i8*, i8*, i8*, i8* } { %runtime.stackChainObject* null, i32 5, i8* null, i8* null, i8* null, i8* null, i8* null }, { %runtime.stackChainObject*, i32, i8*, i8*, i8*, i8*, i8* }* %gc.stackobject, align 4 - %1 = load %runtime.stackChainObject*, %runtime.stackChainObject** @runtime.stackChainStart, align 4 - %2 = getelementptr { %runtime.stackChainObject*, i32, i8*, i8*, i8*, i8*, i8* }, { %runtime.stackChainObject*, i32, i8*, i8*, i8*, i8*, i8* }* %gc.stackobject, i32 0, i32 0 - store %runtime.stackChainObject* %1, %runtime.stackChainObject** %2, align 4 - %3 = bitcast { %runtime.stackChainObject*, i32, i8*, i8*, i8*, i8*, i8* }* %gc.stackobject to %runtime.stackChainObject* - store %runtime.stackChainObject* %3, %runtime.stackChainObject** @runtime.stackChainStart, align 4 - %ptr1 = call i8* @getPointer() - %4 = getelementptr { %runtime.stackChainObject*, i32, i8*, i8*, i8*, i8*, i8* }, { %runtime.stackChainObject*, i32, i8*, i8*, i8*, i8*, i8* }* %gc.stackobject, i32 0, i32 4 - store i8* %ptr1, i8** %4, align 4 - %5 = getelementptr { %runtime.stackChainObject*, i32, i8*, i8*, i8*, i8*, i8* }, { %runtime.stackChainObject*, i32, i8*, i8*, i8*, i8*, i8* }* %gc.stackobject, i32 0, i32 3 - store i8* %ptr1, i8** %5, align 4 - %6 = getelementptr { %runtime.stackChainObject*, i32, i8*, i8*, i8*, i8*, i8* }, { %runtime.stackChainObject*, i32, i8*, i8*, i8*, i8*, i8* }* %gc.stackobject, i32 0, i32 2 - store i8* %ptr1, i8** %6, align 4 - %ptr2 = getelementptr i8, i8* @someGlobal, i32 0 - %7 = getelementptr { %runtime.stackChainObject*, i32, i8*, i8*, i8*, i8*, i8* }, { %runtime.stackChainObject*, i32, i8*, i8*, i8*, i8*, i8* }* %gc.stackobject, i32 0, i32 5 - store i8* %ptr2, i8** %7, align 4 - %unused = call i8* @runtime.alloc(i32 4, i8* null) - %8 = getelementptr { %runtime.stackChainObject*, i32, i8*, i8*, i8*, i8*, i8* }, { %runtime.stackChainObject*, i32, i8*, i8*, i8*, i8*, i8* }* %gc.stackobject, i32 0, i32 6 - store i8* %unused, i8** %8, align 4 - store %runtime.stackChainObject* %1, %runtime.stackChainObject** @runtime.stackChainStart, align 4 - ret i8* %ptr1 +define ptr @needsStackSlots2() { + %gc.stackobject = alloca { ptr, i32, ptr, ptr, ptr, ptr, ptr }, align 8 + store { ptr, i32, ptr, ptr, ptr, ptr, ptr } { ptr null, i32 5, ptr null, ptr null, ptr null, ptr null, ptr null }, ptr %gc.stackobject, align 4 + %1 = load ptr, ptr @runtime.stackChainStart, align 4 + %2 = getelementptr { ptr, i32, ptr, ptr, ptr, ptr, ptr }, ptr %gc.stackobject, i32 0, i32 0 + store ptr %1, ptr %2, align 4 + store ptr %gc.stackobject, ptr @runtime.stackChainStart, align 4 + %ptr1 = call ptr @getPointer() + %3 = getelementptr { ptr, i32, ptr, ptr, ptr, ptr, ptr }, ptr %gc.stackobject, i32 0, i32 4 + store ptr %ptr1, ptr %3, align 4 + %4 = getelementptr { ptr, i32, ptr, ptr, ptr, ptr, ptr }, ptr %gc.stackobject, i32 0, i32 3 + store ptr %ptr1, ptr %4, align 4 + %5 = getelementptr { ptr, i32, ptr, ptr, ptr, ptr, ptr }, ptr %gc.stackobject, i32 0, i32 2 + store ptr %ptr1, ptr %5, align 4 + %ptr2 = getelementptr i8, ptr @someGlobal, i32 0 + %6 = getelementptr { ptr, i32, ptr, ptr, ptr, ptr, ptr }, ptr %gc.stackobject, i32 0, i32 5 + store ptr %ptr2, ptr %6, align 4 + %unused = call ptr @runtime.alloc(i32 4, ptr null) + %7 = getelementptr { ptr, i32, ptr, ptr, ptr, ptr, ptr }, ptr %gc.stackobject, i32 0, i32 6 + store ptr %unused, ptr %7, align 4 + store ptr %1, ptr @runtime.stackChainStart, align 4 + ret ptr %ptr1 } -define i8* @noAllocatingFunction() { - %ptr = call i8* @getPointer() - ret i8* %ptr +define ptr @noAllocatingFunction() { + %ptr = call ptr @getPointer() + ret ptr %ptr } -define i8* @fibNext(i8* %x, i8* %y) { - %gc.stackobject = alloca { %runtime.stackChainObject*, i32, i8* }, align 8 - store { %runtime.stackChainObject*, i32, i8* } { %runtime.stackChainObject* null, i32 1, i8* null }, { %runtime.stackChainObject*, i32, i8* }* %gc.stackobject, align 4 - %1 = load %runtime.stackChainObject*, %runtime.stackChainObject** @runtime.stackChainStart, align 4 - %2 = getelementptr { %runtime.stackChainObject*, i32, i8* }, { %runtime.stackChainObject*, i32, i8* }* %gc.stackobject, i32 0, i32 0 - store %runtime.stackChainObject* %1, %runtime.stackChainObject** %2, align 4 - %3 = bitcast { %runtime.stackChainObject*, i32, i8* }* %gc.stackobject to %runtime.stackChainObject* - store %runtime.stackChainObject* %3, %runtime.stackChainObject** @runtime.stackChainStart, align 4 - %x.val = load i8, i8* %x, align 1 - %y.val = load i8, i8* %y, align 1 +define ptr @fibNext(ptr %x, ptr %y) { + %gc.stackobject = alloca { ptr, i32, ptr }, align 8 + store { ptr, i32, ptr } { ptr null, i32 1, ptr null }, ptr %gc.stackobject, align 4 + %1 = load ptr, ptr @runtime.stackChainStart, align 4 + %2 = getelementptr { ptr, i32, ptr }, ptr %gc.stackobject, i32 0, i32 0 + store ptr %1, ptr %2, align 4 + store ptr %gc.stackobject, ptr @runtime.stackChainStart, align 4 + %x.val = load i8, ptr %x, align 1 + %y.val = load i8, ptr %y, align 1 %out.val = add i8 %x.val, %y.val - %out.alloc = call i8* @runtime.alloc(i32 1, i8* null) - %4 = getelementptr { %runtime.stackChainObject*, i32, i8* }, { %runtime.stackChainObject*, i32, i8* }* %gc.stackobject, i32 0, i32 2 - store i8* %out.alloc, i8** %4, align 4 - store i8 %out.val, i8* %out.alloc, align 1 - store %runtime.stackChainObject* %1, %runtime.stackChainObject** @runtime.stackChainStart, align 4 - ret i8* %out.alloc + %out.alloc = call ptr @runtime.alloc(i32 1, ptr null) + %3 = getelementptr { ptr, i32, ptr }, ptr %gc.stackobject, i32 0, i32 2 + store ptr %out.alloc, ptr %3, align 4 + store i8 %out.val, ptr %out.alloc, align 1 + store ptr %1, ptr @runtime.stackChainStart, align 4 + ret ptr %out.alloc } -define i8* @allocLoop() { +define ptr @allocLoop() { entry: - %gc.stackobject = alloca { %runtime.stackChainObject*, i32, i8*, i8*, i8*, i8*, i8* }, align 8 - store { %runtime.stackChainObject*, i32, i8*, i8*, i8*, i8*, i8* } { %runtime.stackChainObject* null, i32 5, i8* null, i8* null, i8* null, i8* null, i8* null }, { %runtime.stackChainObject*, i32, i8*, i8*, i8*, i8*, i8* }* %gc.stackobject, align 4 - %0 = load %runtime.stackChainObject*, %runtime.stackChainObject** @runtime.stackChainStart, align 4 - %1 = getelementptr { %runtime.stackChainObject*, i32, i8*, i8*, i8*, i8*, i8* }, { %runtime.stackChainObject*, i32, i8*, i8*, i8*, i8*, i8* }* %gc.stackobject, i32 0, i32 0 - store %runtime.stackChainObject* %0, %runtime.stackChainObject** %1, align 4 - %2 = bitcast { %runtime.stackChainObject*, i32, i8*, i8*, i8*, i8*, i8* }* %gc.stackobject to %runtime.stackChainObject* - store %runtime.stackChainObject* %2, %runtime.stackChainObject** @runtime.stackChainStart, align 4 - %entry.x = call i8* @runtime.alloc(i32 1, i8* null) - %3 = getelementptr { %runtime.stackChainObject*, i32, i8*, i8*, i8*, i8*, i8* }, { %runtime.stackChainObject*, i32, i8*, i8*, i8*, i8*, i8* }* %gc.stackobject, i32 0, i32 2 - store i8* %entry.x, i8** %3, align 4 - %entry.y = call i8* @runtime.alloc(i32 1, i8* null) - %4 = getelementptr { %runtime.stackChainObject*, i32, i8*, i8*, i8*, i8*, i8* }, { %runtime.stackChainObject*, i32, i8*, i8*, i8*, i8*, i8* }* %gc.stackobject, i32 0, i32 3 - store i8* %entry.y, i8** %4, align 4 - store i8 1, i8* %entry.y, align 1 + %gc.stackobject = alloca { ptr, i32, ptr, ptr, ptr, ptr, ptr }, align 8 + store { ptr, i32, ptr, ptr, ptr, ptr, ptr } { ptr null, i32 5, ptr null, ptr null, ptr null, ptr null, ptr null }, ptr %gc.stackobject, align 4 + %0 = load ptr, ptr @runtime.stackChainStart, align 4 + %1 = getelementptr { ptr, i32, ptr, ptr, ptr, ptr, ptr }, ptr %gc.stackobject, i32 0, i32 0 + store ptr %0, ptr %1, align 4 + store ptr %gc.stackobject, ptr @runtime.stackChainStart, align 4 + %entry.x = call ptr @runtime.alloc(i32 1, ptr null) + %2 = getelementptr { ptr, i32, ptr, ptr, ptr, ptr, ptr }, ptr %gc.stackobject, i32 0, i32 2 + store ptr %entry.x, ptr %2, align 4 + %entry.y = call ptr @runtime.alloc(i32 1, ptr null) + %3 = getelementptr { ptr, i32, ptr, ptr, ptr, ptr, ptr }, ptr %gc.stackobject, i32 0, i32 3 + store ptr %entry.y, ptr %3, align 4 + store i8 1, ptr %entry.y, align 1 br label %loop loop: ; preds = %loop, %entry - %prev.y = phi i8* [ %entry.y, %entry ], [ %prev.x, %loop ] - %prev.x = phi i8* [ %entry.x, %entry ], [ %next.x, %loop ] - %5 = getelementptr { %runtime.stackChainObject*, i32, i8*, i8*, i8*, i8*, i8* }, { %runtime.stackChainObject*, i32, i8*, i8*, i8*, i8*, i8* }* %gc.stackobject, i32 0, i32 5 - store i8* %prev.y, i8** %5, align 4 - %6 = getelementptr { %runtime.stackChainObject*, i32, i8*, i8*, i8*, i8*, i8* }, { %runtime.stackChainObject*, i32, i8*, i8*, i8*, i8*, i8* }* %gc.stackobject, i32 0, i32 4 - store i8* %prev.x, i8** %6, align 4 - %next.x = call i8* @fibNext(i8* %prev.x, i8* %prev.y) - %7 = getelementptr { %runtime.stackChainObject*, i32, i8*, i8*, i8*, i8*, i8* }, { %runtime.stackChainObject*, i32, i8*, i8*, i8*, i8*, i8* }* %gc.stackobject, i32 0, i32 6 - store i8* %next.x, i8** %7, align 4 - %next.x.val = load i8, i8* %next.x, align 1 + %prev.y = phi ptr [ %entry.y, %entry ], [ %prev.x, %loop ] + %prev.x = phi ptr [ %entry.x, %entry ], [ %next.x, %loop ] + %4 = getelementptr { ptr, i32, ptr, ptr, ptr, ptr, ptr }, ptr %gc.stackobject, i32 0, i32 5 + store ptr %prev.y, ptr %4, align 4 + %5 = getelementptr { ptr, i32, ptr, ptr, ptr, ptr, ptr }, ptr %gc.stackobject, i32 0, i32 4 + store ptr %prev.x, ptr %5, align 4 + %next.x = call ptr @fibNext(ptr %prev.x, ptr %prev.y) + %6 = getelementptr { ptr, i32, ptr, ptr, ptr, ptr, ptr }, ptr %gc.stackobject, i32 0, i32 6 + store ptr %next.x, ptr %6, align 4 + %next.x.val = load i8, ptr %next.x, align 1 %loop.done = icmp ult i8 40, %next.x.val br i1 %loop.done, label %end, label %loop end: ; preds = %loop - store %runtime.stackChainObject* %0, %runtime.stackChainObject** @runtime.stackChainStart, align 4 - ret i8* %next.x + store ptr %0, ptr @runtime.stackChainStart, align 4 + ret ptr %next.x } -declare [32 x i8]* @arrayAlloc() +declare ptr @arrayAlloc() define void @testGEPBitcast() { - %gc.stackobject = alloca { %runtime.stackChainObject*, i32, i8*, i8* }, align 8 - store { %runtime.stackChainObject*, i32, i8*, i8* } { %runtime.stackChainObject* null, i32 2, i8* null, i8* null }, { %runtime.stackChainObject*, i32, i8*, i8* }* %gc.stackobject, align 4 - %1 = load %runtime.stackChainObject*, %runtime.stackChainObject** @runtime.stackChainStart, align 4 - %2 = getelementptr { %runtime.stackChainObject*, i32, i8*, i8* }, { %runtime.stackChainObject*, i32, i8*, i8* }* %gc.stackobject, i32 0, i32 0 - store %runtime.stackChainObject* %1, %runtime.stackChainObject** %2, align 4 - %3 = bitcast { %runtime.stackChainObject*, i32, i8*, i8* }* %gc.stackobject to %runtime.stackChainObject* - store %runtime.stackChainObject* %3, %runtime.stackChainObject** @runtime.stackChainStart, align 4 - %arr = call [32 x i8]* @arrayAlloc() - %arr.bitcast = getelementptr [32 x i8], [32 x i8]* %arr, i32 0, i32 0 - %4 = getelementptr { %runtime.stackChainObject*, i32, i8*, i8* }, { %runtime.stackChainObject*, i32, i8*, i8* }* %gc.stackobject, i32 0, i32 2 - store i8* %arr.bitcast, i8** %4, align 4 - %other = call i8* @runtime.alloc(i32 1, i8* null) - %5 = getelementptr { %runtime.stackChainObject*, i32, i8*, i8* }, { %runtime.stackChainObject*, i32, i8*, i8* }* %gc.stackobject, i32 0, i32 3 - store i8* %other, i8** %5, align 4 - store %runtime.stackChainObject* %1, %runtime.stackChainObject** @runtime.stackChainStart, align 4 + %gc.stackobject = alloca { ptr, i32, ptr, ptr }, align 8 + store { ptr, i32, ptr, ptr } { ptr null, i32 2, ptr null, ptr null }, ptr %gc.stackobject, align 4 + %1 = load ptr, ptr @runtime.stackChainStart, align 4 + %2 = getelementptr { ptr, i32, ptr, ptr }, ptr %gc.stackobject, i32 0, i32 0 + store ptr %1, ptr %2, align 4 + store ptr %gc.stackobject, ptr @runtime.stackChainStart, align 4 + %arr = call ptr @arrayAlloc() + %arr.bitcast = getelementptr [32 x i8], ptr %arr, i32 0, i32 0 + %3 = getelementptr { ptr, i32, ptr, ptr }, ptr %gc.stackobject, i32 0, i32 2 + store ptr %arr.bitcast, ptr %3, align 4 + %other = call ptr @runtime.alloc(i32 1, ptr null) + %4 = getelementptr { ptr, i32, ptr, ptr }, ptr %gc.stackobject, i32 0, i32 3 + store ptr %other, ptr %4, align 4 + store ptr %1, ptr @runtime.stackChainStart, align 4 ret void } @@ -144,35 +137,32 @@ define void @someArbitraryFunction() { } define void @earlyPopRegression() { - %gc.stackobject = alloca { %runtime.stackChainObject*, i32, i8* }, align 8 - store { %runtime.stackChainObject*, i32, i8* } { %runtime.stackChainObject* null, i32 1, i8* null }, { %runtime.stackChainObject*, i32, i8* }* %gc.stackobject, align 4 - %1 = load %runtime.stackChainObject*, %runtime.stackChainObject** @runtime.stackChainStart, align 4 - %2 = getelementptr { %runtime.stackChainObject*, i32, i8* }, { %runtime.stackChainObject*, i32, i8* }* %gc.stackobject, i32 0, i32 0 - store %runtime.stackChainObject* %1, %runtime.stackChainObject** %2, align 4 - %3 = bitcast { %runtime.stackChainObject*, i32, i8* }* %gc.stackobject to %runtime.stackChainObject* - store %runtime.stackChainObject* %3, %runtime.stackChainObject** @runtime.stackChainStart, align 4 - %x.alloc = call i8* @runtime.alloc(i32 4, i8* null) - %4 = getelementptr { %runtime.stackChainObject*, i32, i8* }, { %runtime.stackChainObject*, i32, i8* }* %gc.stackobject, i32 0, i32 2 - store i8* %x.alloc, i8** %4, align 4 - %x = bitcast i8* %x.alloc to i8** - call void @allocAndSave(i8** %x) - store %runtime.stackChainObject* %1, %runtime.stackChainObject** @runtime.stackChainStart, align 4 + %gc.stackobject = alloca { ptr, i32, ptr }, align 8 + store { ptr, i32, ptr } { ptr null, i32 1, ptr null }, ptr %gc.stackobject, align 4 + %1 = load ptr, ptr @runtime.stackChainStart, align 4 + %2 = getelementptr { ptr, i32, ptr }, ptr %gc.stackobject, i32 0, i32 0 + store ptr %1, ptr %2, align 4 + store ptr %gc.stackobject, ptr @runtime.stackChainStart, align 4 + %x.alloc = call ptr @runtime.alloc(i32 4, ptr null) + %3 = getelementptr { ptr, i32, ptr }, ptr %gc.stackobject, i32 0, i32 2 + store ptr %x.alloc, ptr %3, align 4 + call void @allocAndSave(ptr %x.alloc) + store ptr %1, ptr @runtime.stackChainStart, align 4 ret void } -define void @allocAndSave(i8** %x) { - %gc.stackobject = alloca { %runtime.stackChainObject*, i32, i8* }, align 8 - store { %runtime.stackChainObject*, i32, i8* } { %runtime.stackChainObject* null, i32 1, i8* null }, { %runtime.stackChainObject*, i32, i8* }* %gc.stackobject, align 4 - %1 = load %runtime.stackChainObject*, %runtime.stackChainObject** @runtime.stackChainStart, align 4 - %2 = getelementptr { %runtime.stackChainObject*, i32, i8* }, { %runtime.stackChainObject*, i32, i8* }* %gc.stackobject, i32 0, i32 0 - store %runtime.stackChainObject* %1, %runtime.stackChainObject** %2, align 4 - %3 = bitcast { %runtime.stackChainObject*, i32, i8* }* %gc.stackobject to %runtime.stackChainObject* - store %runtime.stackChainObject* %3, %runtime.stackChainObject** @runtime.stackChainStart, align 4 - %y = call i8* @runtime.alloc(i32 4, i8* null) - %4 = getelementptr { %runtime.stackChainObject*, i32, i8* }, { %runtime.stackChainObject*, i32, i8* }* %gc.stackobject, i32 0, i32 2 - store i8* %y, i8** %4, align 4 - store i8* %y, i8** %x, align 4 - store i8** %x, i8*** @ptrGlobal, align 4 - store %runtime.stackChainObject* %1, %runtime.stackChainObject** @runtime.stackChainStart, align 4 +define void @allocAndSave(ptr %x) { + %gc.stackobject = alloca { ptr, i32, ptr }, align 8 + store { ptr, i32, ptr } { ptr null, i32 1, ptr null }, ptr %gc.stackobject, align 4 + %1 = load ptr, ptr @runtime.stackChainStart, align 4 + %2 = getelementptr { ptr, i32, ptr }, ptr %gc.stackobject, i32 0, i32 0 + store ptr %1, ptr %2, align 4 + store ptr %gc.stackobject, ptr @runtime.stackChainStart, align 4 + %y = call ptr @runtime.alloc(i32 4, ptr null) + %3 = getelementptr { ptr, i32, ptr }, ptr %gc.stackobject, i32 0, i32 2 + store ptr %y, ptr %3, align 4 + store ptr %y, ptr %x, align 4 + store ptr %x, ptr @ptrGlobal, align 4 + store ptr %1, ptr @runtime.stackChainStart, align 4 ret void } diff --git a/transform/testdata/globals-function-sections.ll b/transform/testdata/globals-function-sections.ll deleted file mode 100644 index 505ba5aa5f..0000000000 --- a/transform/testdata/globals-function-sections.ll +++ /dev/null @@ -1,8 +0,0 @@ -target datalayout = "e-m:e-p:32:32-Fi8-i64:64-v128:64:128-a:0:32-n32-S64" -target triple = "armv7em-none-eabi" - -declare void @foo() - -define void @bar() { - ret void -} diff --git a/transform/testdata/globals-function-sections.out.ll b/transform/testdata/globals-function-sections.out.ll deleted file mode 100644 index e3d03ed07b..0000000000 --- a/transform/testdata/globals-function-sections.out.ll +++ /dev/null @@ -1,8 +0,0 @@ -target datalayout = "e-m:e-p:32:32-Fi8-i64:64-v128:64:128-a:0:32-n32-S64" -target triple = "armv7em-none-eabi" - -declare void @foo() - -define void @bar() section ".text.bar" { - ret void -} diff --git a/transform/testdata/interface.ll b/transform/testdata/interface.ll index 4d8e818de9..76ed029b47 100644 --- a/transform/testdata/interface.ll +++ b/transform/testdata/interface.ll @@ -1,70 +1,70 @@ target datalayout = "e-m:e-p:32:32-i64:64-v128:64:128-a:0:32-n32-S64" target triple = "armv7m-none-eabi" -%runtime.typecodeID = type { %runtime.typecodeID*, i32, %runtime.interfaceMethodInfo*, %runtime.typecodeID*, i32 } -%runtime.interfaceMethodInfo = type { i8*, i32 } - -@"reflect/types.type:basic:uint8" = private constant %runtime.typecodeID zeroinitializer +@"reflect/types.type:basic:uint8" = linkonce_odr constant { i8, ptr } { i8 8, ptr @"reflect/types.type:pointer:basic:uint8" }, align 4 +@"reflect/types.type:pointer:basic:uint8" = linkonce_odr constant { i8, ptr } { i8 21, ptr @"reflect/types.type:basic:uint8" }, align 4 @"reflect/types.typeid:basic:uint8" = external constant i8 @"reflect/types.typeid:basic:int16" = external constant i8 -@"reflect/types.type:basic:int" = private constant %runtime.typecodeID zeroinitializer +@"reflect/types.type:basic:int" = linkonce_odr constant { i8, ptr } { i8 2, ptr @"reflect/types.type:pointer:basic:int" }, align 4 +@"reflect/types.type:pointer:basic:int" = linkonce_odr constant { i8, ptr } { i8 21, ptr @"reflect/types.type:basic:int" }, align 4 @"reflect/methods.NeverImplementedMethod()" = linkonce_odr constant i8 0 @"reflect/methods.Double() int" = linkonce_odr constant i8 0 -@"Number$methodset" = private constant [1 x %runtime.interfaceMethodInfo] [%runtime.interfaceMethodInfo { i8* @"reflect/methods.Double() int", i32 ptrtoint (i32 (i8*, i8*)* @"(Number).Double$invoke" to i32) }] -@"reflect/types.type:named:Number" = private constant %runtime.typecodeID { %runtime.typecodeID* @"reflect/types.type:basic:int", i32 0, %runtime.interfaceMethodInfo* getelementptr inbounds ([1 x %runtime.interfaceMethodInfo], [1 x %runtime.interfaceMethodInfo]* @"Number$methodset", i32 0, i32 0), %runtime.typecodeID* null, i32 0 } +@"Number$methodset" = linkonce_odr unnamed_addr constant { i32, [1 x ptr], { ptr } } { i32 1, [1 x ptr] [ptr @"reflect/methods.Double() int"], { ptr } { ptr @"(Number).Double$invoke" } } +@"reflect/types.type:named:Number" = linkonce_odr constant { ptr, i8, ptr, ptr } { ptr @"Number$methodset", i8 34, ptr @"reflect/types.type:pointer:named:Number", ptr @"reflect/types.type:basic:int" }, align 4 +@"reflect/types.type:pointer:named:Number" = linkonce_odr constant { i8, ptr } { i8 21, ptr getelementptr inbounds ({ ptr, i8, ptr, ptr }, ptr @"reflect/types.type:named:Number", i32 0, i32 1) }, align 4 -declare i1 @runtime.typeAssert(i32, i8*) +declare i1 @runtime.typeAssert(ptr, ptr) declare void @runtime.printuint8(i8) declare void @runtime.printint16(i16) declare void @runtime.printint32(i32) declare void @runtime.printptr(i32) declare void @runtime.printnl() -declare void @runtime.nilPanic(i8*) +declare void @runtime.nilPanic(ptr) define void @printInterfaces() { - call void @printInterface(i32 ptrtoint (%runtime.typecodeID* @"reflect/types.type:basic:int" to i32), i8* inttoptr (i32 5 to i8*)) - call void @printInterface(i32 ptrtoint (%runtime.typecodeID* @"reflect/types.type:basic:uint8" to i32), i8* inttoptr (i8 120 to i8*)) - call void @printInterface(i32 ptrtoint (%runtime.typecodeID* @"reflect/types.type:named:Number" to i32), i8* inttoptr (i32 3 to i8*)) + call void @printInterface(ptr @"reflect/types.type:basic:int", ptr inttoptr (i32 5 to ptr)) + call void @printInterface(ptr @"reflect/types.type:basic:uint8", ptr inttoptr (i8 120 to ptr)) + call void @printInterface(ptr getelementptr inbounds ({ ptr, i8, ptr, ptr }, ptr @"reflect/types.type:named:Number", i32 0, i32 1), ptr inttoptr (i32 3 to ptr)) ret void } -define void @printInterface(i32 %typecode, i8* %value) { - %isUnmatched = call i1 @Unmatched$typeassert(i32 %typecode) +define void @printInterface(ptr %typecode, ptr %value) { + %isUnmatched = call i1 @Unmatched$typeassert(ptr %typecode) br i1 %isUnmatched, label %typeswitch.Unmatched, label %typeswitch.notUnmatched typeswitch.Unmatched: - %unmatched = ptrtoint i8* %value to i32 + %unmatched = ptrtoint ptr %value to i32 call void @runtime.printptr(i32 %unmatched) call void @runtime.printnl() ret void typeswitch.notUnmatched: - %isDoubler = call i1 @Doubler$typeassert(i32 %typecode) + %isDoubler = call i1 @Doubler$typeassert(ptr %typecode) br i1 %isDoubler, label %typeswitch.Doubler, label %typeswitch.notDoubler typeswitch.Doubler: - %doubler.result = call i32 @"Doubler.Double$invoke"(i8* %value, i32 %typecode, i8* undef) + %doubler.result = call i32 @"Doubler.Double$invoke"(ptr %value, ptr %typecode, ptr undef) call void @runtime.printint32(i32 %doubler.result) ret void typeswitch.notDoubler: - %isByte = call i1 @runtime.typeAssert(i32 %typecode, i8* nonnull @"reflect/types.typeid:basic:uint8") + %isByte = call i1 @runtime.typeAssert(ptr %typecode, ptr nonnull @"reflect/types.typeid:basic:uint8") br i1 %isByte, label %typeswitch.byte, label %typeswitch.notByte typeswitch.byte: - %byte = ptrtoint i8* %value to i8 + %byte = ptrtoint ptr %value to i8 call void @runtime.printuint8(i8 %byte) call void @runtime.printnl() ret void typeswitch.notByte: ; this is a type assert that always fails - %isInt16 = call i1 @runtime.typeAssert(i32 %typecode, i8* nonnull @"reflect/types.typeid:basic:int16") + %isInt16 = call i1 @runtime.typeAssert(ptr %typecode, ptr nonnull @"reflect/types.typeid:basic:int16") br i1 %isInt16, label %typeswitch.int16, label %typeswitch.notInt16 typeswitch.int16: - %int16 = ptrtoint i8* %value to i16 + %int16 = ptrtoint ptr %value to i16 call void @runtime.printint16(i16 %int16) call void @runtime.printnl() ret void @@ -73,22 +73,22 @@ typeswitch.notInt16: ret void } -define i32 @"(Number).Double"(i32 %receiver, i8* %context) { +define i32 @"(Number).Double"(i32 %receiver, ptr %context) { %ret = mul i32 %receiver, 2 ret i32 %ret } -define i32 @"(Number).Double$invoke"(i8* %receiverPtr, i8* %context) { - %receiver = ptrtoint i8* %receiverPtr to i32 - %ret = call i32 @"(Number).Double"(i32 %receiver, i8* undef) +define i32 @"(Number).Double$invoke"(ptr %receiverPtr, ptr %context) { + %receiver = ptrtoint ptr %receiverPtr to i32 + %ret = call i32 @"(Number).Double"(i32 %receiver, ptr undef) ret i32 %ret } -declare i32 @"Doubler.Double$invoke"(i8* %receiver, i32 %typecode, i8* %context) #0 +declare i32 @"Doubler.Double$invoke"(ptr %receiver, ptr %typecode, ptr %context) #0 -declare i1 @Doubler$typeassert(i32 %typecode) #1 +declare i1 @Doubler$typeassert(ptr %typecode) #1 -declare i1 @Unmatched$typeassert(i32 %typecode) #2 +declare i1 @Unmatched$typeassert(ptr %typecode) #2 attributes #0 = { "tinygo-invoke"="reflect/methods.Double() int" "tinygo-methods"="reflect/methods.Double() int" } attributes #1 = { "tinygo-methods"="reflect/methods.Double() int" } diff --git a/transform/testdata/interface.out.ll b/transform/testdata/interface.out.ll index 262df21084..cb041ab1db 100644 --- a/transform/testdata/interface.out.ll +++ b/transform/testdata/interface.out.ll @@ -1,12 +1,12 @@ target datalayout = "e-m:e-p:32:32-i64:64-v128:64:128-a:0:32-n32-S64" target triple = "armv7m-none-eabi" -%runtime.typecodeID = type { %runtime.typecodeID*, i32, %runtime.interfaceMethodInfo*, %runtime.typecodeID*, i32 } -%runtime.interfaceMethodInfo = type { i8*, i32 } - -@"reflect/types.type:basic:uint8" = private constant %runtime.typecodeID zeroinitializer -@"reflect/types.type:basic:int" = private constant %runtime.typecodeID zeroinitializer -@"reflect/types.type:named:Number" = private constant %runtime.typecodeID { %runtime.typecodeID* @"reflect/types.type:basic:int", i32 0, %runtime.interfaceMethodInfo* null, %runtime.typecodeID* null, i32 0 } +@"reflect/types.type:basic:uint8" = linkonce_odr constant { i8, ptr } { i8 8, ptr @"reflect/types.type:pointer:basic:uint8" }, align 4 +@"reflect/types.type:pointer:basic:uint8" = linkonce_odr constant { i8, ptr } { i8 21, ptr @"reflect/types.type:basic:uint8" }, align 4 +@"reflect/types.type:basic:int" = linkonce_odr constant { i8, ptr } { i8 2, ptr @"reflect/types.type:pointer:basic:int" }, align 4 +@"reflect/types.type:pointer:basic:int" = linkonce_odr constant { i8, ptr } { i8 21, ptr @"reflect/types.type:basic:int" }, align 4 +@"reflect/types.type:pointer:named:Number" = linkonce_odr constant { i8, ptr } { i8 21, ptr @"reflect/types.type:named:Number" }, align 4 +@"reflect/types.type:named:Number" = linkonce_odr constant { i8, ptr, ptr } { i8 34, ptr @"reflect/types.type:pointer:named:Number", ptr @"reflect/types.type:basic:int" }, align 4 declare void @runtime.printuint8(i8) @@ -18,40 +18,40 @@ declare void @runtime.printptr(i32) declare void @runtime.printnl() -declare void @runtime.nilPanic(i8*) +declare void @runtime.nilPanic(ptr) define void @printInterfaces() { - call void @printInterface(i32 ptrtoint (%runtime.typecodeID* @"reflect/types.type:basic:int" to i32), i8* inttoptr (i32 5 to i8*)) - call void @printInterface(i32 ptrtoint (%runtime.typecodeID* @"reflect/types.type:basic:uint8" to i32), i8* inttoptr (i8 120 to i8*)) - call void @printInterface(i32 ptrtoint (%runtime.typecodeID* @"reflect/types.type:named:Number" to i32), i8* inttoptr (i32 3 to i8*)) + call void @printInterface(ptr @"reflect/types.type:basic:int", ptr inttoptr (i32 5 to ptr)) + call void @printInterface(ptr @"reflect/types.type:basic:uint8", ptr inttoptr (i8 120 to ptr)) + call void @printInterface(ptr @"reflect/types.type:named:Number", ptr inttoptr (i32 3 to ptr)) ret void } -define void @printInterface(i32 %typecode, i8* %value) { - %isUnmatched = call i1 @"Unmatched$typeassert"(i32 %typecode) +define void @printInterface(ptr %typecode, ptr %value) { + %isUnmatched = call i1 @"Unmatched$typeassert"(ptr %typecode) br i1 %isUnmatched, label %typeswitch.Unmatched, label %typeswitch.notUnmatched typeswitch.Unmatched: ; preds = %0 - %unmatched = ptrtoint i8* %value to i32 + %unmatched = ptrtoint ptr %value to i32 call void @runtime.printptr(i32 %unmatched) call void @runtime.printnl() ret void typeswitch.notUnmatched: ; preds = %0 - %isDoubler = call i1 @"Doubler$typeassert"(i32 %typecode) + %isDoubler = call i1 @"Doubler$typeassert"(ptr %typecode) br i1 %isDoubler, label %typeswitch.Doubler, label %typeswitch.notDoubler typeswitch.Doubler: ; preds = %typeswitch.notUnmatched - %doubler.result = call i32 @"Doubler.Double$invoke"(i8* %value, i32 %typecode, i8* undef) + %doubler.result = call i32 @"Doubler.Double$invoke"(ptr %value, ptr %typecode, ptr undef) call void @runtime.printint32(i32 %doubler.result) ret void typeswitch.notDoubler: ; preds = %typeswitch.notUnmatched - %typeassert.ok = icmp eq i32 ptrtoint (%runtime.typecodeID* @"reflect/types.type:basic:uint8" to i32), %typecode + %typeassert.ok = icmp eq ptr @"reflect/types.type:basic:uint8", %typecode br i1 %typeassert.ok, label %typeswitch.byte, label %typeswitch.notByte typeswitch.byte: ; preds = %typeswitch.notDoubler - %byte = ptrtoint i8* %value to i8 + %byte = ptrtoint ptr %value to i8 call void @runtime.printuint8(i8 %byte) call void @runtime.printnl() ret void @@ -60,7 +60,7 @@ typeswitch.notByte: ; preds = %typeswitch.notDoubl br i1 false, label %typeswitch.int16, label %typeswitch.notInt16 typeswitch.int16: ; preds = %typeswitch.notByte - %int16 = ptrtoint i8* %value to i16 + %int16 = ptrtoint ptr %value to i16 call void @runtime.printint16(i16 %int16) call void @runtime.printnl() ret void @@ -69,34 +69,34 @@ typeswitch.notInt16: ; preds = %typeswitch.notByte ret void } -define i32 @"(Number).Double"(i32 %receiver, i8* %context) { +define i32 @"(Number).Double"(i32 %receiver, ptr %context) { %ret = mul i32 %receiver, 2 ret i32 %ret } -define i32 @"(Number).Double$invoke"(i8* %receiverPtr, i8* %context) { - %receiver = ptrtoint i8* %receiverPtr to i32 - %ret = call i32 @"(Number).Double"(i32 %receiver, i8* undef) +define i32 @"(Number).Double$invoke"(ptr %receiverPtr, ptr %context) { + %receiver = ptrtoint ptr %receiverPtr to i32 + %ret = call i32 @"(Number).Double"(i32 %receiver, ptr undef) ret i32 %ret } -define internal i32 @"Doubler.Double$invoke"(i8* %receiver, i32 %actualType, i8* %context) unnamed_addr #0 { +define internal i32 @"Doubler.Double$invoke"(ptr %receiver, ptr %actualType, ptr %context) unnamed_addr #0 { entry: - %"named:Number.icmp" = icmp eq i32 %actualType, ptrtoint (%runtime.typecodeID* @"reflect/types.type:named:Number" to i32) + %"named:Number.icmp" = icmp eq ptr %actualType, @"reflect/types.type:named:Number" br i1 %"named:Number.icmp", label %"named:Number", label %"named:Number.next" "named:Number": ; preds = %entry - %0 = call i32 @"(Number).Double$invoke"(i8* %receiver, i8* undef) + %0 = call i32 @"(Number).Double$invoke"(ptr %receiver, ptr undef) ret i32 %0 "named:Number.next": ; preds = %entry - call void @runtime.nilPanic(i8* undef) + call void @runtime.nilPanic(ptr undef) unreachable } -define internal i1 @"Doubler$typeassert"(i32 %actualType) unnamed_addr #1 { +define internal i1 @"Doubler$typeassert"(ptr %actualType) unnamed_addr #1 { entry: - %"named:Number.icmp" = icmp eq i32 %actualType, ptrtoint (%runtime.typecodeID* @"reflect/types.type:named:Number" to i32) + %"named:Number.icmp" = icmp eq ptr %actualType, @"reflect/types.type:named:Number" br i1 %"named:Number.icmp", label %then, label %"named:Number.next" then: ; preds = %entry @@ -106,7 +106,7 @@ then: ; preds = %entry ret i1 false } -define internal i1 @"Unmatched$typeassert"(i32 %actualType) unnamed_addr #2 { +define internal i1 @"Unmatched$typeassert"(ptr %actualType) unnamed_addr #2 { entry: ret i1 false diff --git a/transform/testdata/interrupt.ll b/transform/testdata/interrupt.ll index 1436827b00..f4d2e0226d 100644 --- a/transform/testdata/interrupt.ll +++ b/transform/testdata/interrupt.ll @@ -1,39 +1,39 @@ target datalayout = "e-m:e-p:32:32-Fi8-i64:64-v128:64:128-a:0:32-n32-S64" target triple = "armv7em-none-eabi" -%machine.UART = type { %machine.RingBuffer* } +%machine.UART = type { ptr } %machine.RingBuffer = type { [128 x %"runtime/volatile.Register8"], %"runtime/volatile.Register8", %"runtime/volatile.Register8" } %"runtime/volatile.Register8" = type { i8 } -%"runtime/interrupt.handle" = type { i8*, i32, %"runtime/interrupt.Interrupt" } +%"runtime/interrupt.handle" = type { ptr, i32, %"runtime/interrupt.Interrupt" } %"runtime/interrupt.Interrupt" = type { i32 } -@"runtime/interrupt.$interrupt2" = private unnamed_addr constant %"runtime/interrupt.handle" { i8* bitcast (%machine.UART* @machine.UART0 to i8*), i32 ptrtoint (void (i32, i8*)* @"(*machine.UART).handleInterrupt$bound" to i32), %"runtime/interrupt.Interrupt" { i32 2 } } -@machine.UART0 = internal global %machine.UART { %machine.RingBuffer* @"machine$alloc.335" } +@"runtime/interrupt.$interrupt2" = private unnamed_addr constant %"runtime/interrupt.handle" { ptr @machine.UART0, i32 ptrtoint (ptr @"(*machine.UART).handleInterrupt$bound" to i32), %"runtime/interrupt.Interrupt" { i32 2 } } +@machine.UART0 = internal global %machine.UART { ptr @"machine$alloc.335" } @"machine$alloc.335" = internal global %machine.RingBuffer zeroinitializer -declare void @"runtime/interrupt.callHandlers"(i32, i8*) local_unnamed_addr +declare void @"runtime/interrupt.callHandlers"(i32, ptr) local_unnamed_addr -declare void @"device/arm.EnableIRQ"(i32, i8* nocapture readnone) +declare void @"device/arm.EnableIRQ"(i32, ptr nocapture readnone) -declare void @"device/arm.SetPriority"(i32, i32, i8* nocapture readnone) +declare void @"device/arm.SetPriority"(i32, i32, ptr nocapture readnone) declare void @"runtime/interrupt.use"(%"runtime/interrupt.Interrupt") -define void @runtime.initAll(i8* nocapture readnone) unnamed_addr { +define void @runtime.initAll(ptr nocapture readnone) unnamed_addr { entry: - call void @"device/arm.SetPriority"(i32 ptrtoint (%"runtime/interrupt.handle"* @"runtime/interrupt.$interrupt2" to i32), i32 192, i8* undef) - call void @"device/arm.EnableIRQ"(i32 ptrtoint (%"runtime/interrupt.handle"* @"runtime/interrupt.$interrupt2" to i32), i8* undef) - call void @"runtime/interrupt.use"(%"runtime/interrupt.Interrupt" { i32 ptrtoint (%"runtime/interrupt.handle"* @"runtime/interrupt.$interrupt2" to i32) }) + call void @"device/arm.SetPriority"(i32 ptrtoint (ptr @"runtime/interrupt.$interrupt2" to i32), i32 192, ptr undef) + call void @"device/arm.EnableIRQ"(i32 ptrtoint (ptr @"runtime/interrupt.$interrupt2" to i32), ptr undef) + call void @"runtime/interrupt.use"(%"runtime/interrupt.Interrupt" { i32 ptrtoint (ptr @"runtime/interrupt.$interrupt2" to i32) }) ret void } define void @UARTE0_UART0_IRQHandler() { - call void @"runtime/interrupt.callHandlers"(i32 2, i8* undef) + call void @"runtime/interrupt.callHandlers"(i32 2, ptr undef) ret void } define void @NFCT_IRQHandler() { - call void @"runtime/interrupt.callHandlers"(i32 5, i8* undef) + call void @"runtime/interrupt.callHandlers"(i32 5, ptr undef) ret void } @@ -45,22 +45,21 @@ entry: ] switch.body2: - call void @"runtime/interrupt.callHandlers"(i32 2, i8* undef) + call void @"runtime/interrupt.callHandlers"(i32 2, ptr undef) ret void switch.body5: - call void @"runtime/interrupt.callHandlers"(i32 5, i8* undef) + call void @"runtime/interrupt.callHandlers"(i32 5, ptr undef) ret void switch.done: ret void } -define internal void @"(*machine.UART).handleInterrupt$bound"(i32, i8* nocapture %context) { +define internal void @"(*machine.UART).handleInterrupt$bound"(i32, ptr nocapture %context) { entry: - %unpack.ptr = bitcast i8* %context to %machine.UART* - call void @"(*machine.UART).handleInterrupt"(%machine.UART* %unpack.ptr, i32 %0, i8* undef) + call void @"(*machine.UART).handleInterrupt"(ptr %context, i32 %0, ptr undef) ret void } -declare void @"(*machine.UART).handleInterrupt"(%machine.UART* nocapture, i32, i8* nocapture readnone) +declare void @"(*machine.UART).handleInterrupt"(ptr nocapture, i32, ptr nocapture readnone) diff --git a/transform/testdata/interrupt.out.ll b/transform/testdata/interrupt.out.ll index 7eb9f0a878..3663c0a33c 100644 --- a/transform/testdata/interrupt.out.ll +++ b/transform/testdata/interrupt.out.ll @@ -1,31 +1,31 @@ target datalayout = "e-m:e-p:32:32-Fi8-i64:64-v128:64:128-a:0:32-n32-S64" target triple = "armv7em-none-eabi" -%machine.UART = type { %machine.RingBuffer* } +%machine.UART = type { ptr } %machine.RingBuffer = type { [128 x %"runtime/volatile.Register8"], %"runtime/volatile.Register8", %"runtime/volatile.Register8" } %"runtime/volatile.Register8" = type { i8 } %"runtime/interrupt.Interrupt" = type { i32 } -@machine.UART0 = internal global %machine.UART { %machine.RingBuffer* @"machine$alloc.335" } +@machine.UART0 = internal global %machine.UART { ptr @"machine$alloc.335" } @"machine$alloc.335" = internal global %machine.RingBuffer zeroinitializer -declare void @"runtime/interrupt.callHandlers"(i32, i8*) local_unnamed_addr +declare void @"runtime/interrupt.callHandlers"(i32, ptr) local_unnamed_addr -declare void @"device/arm.EnableIRQ"(i32, i8* nocapture readnone) +declare void @"device/arm.EnableIRQ"(i32, ptr nocapture readnone) -declare void @"device/arm.SetPriority"(i32, i32, i8* nocapture readnone) +declare void @"device/arm.SetPriority"(i32, i32, ptr nocapture readnone) declare void @"runtime/interrupt.use"(%"runtime/interrupt.Interrupt") -define void @runtime.initAll(i8* nocapture readnone %0) unnamed_addr { +define void @runtime.initAll(ptr nocapture readnone %0) unnamed_addr { entry: - call void @"device/arm.SetPriority"(i32 2, i32 192, i8* undef) - call void @"device/arm.EnableIRQ"(i32 2, i8* undef) + call void @"device/arm.SetPriority"(i32 2, i32 192, ptr undef) + call void @"device/arm.EnableIRQ"(i32 2, ptr undef) ret void } define void @UARTE0_UART0_IRQHandler() { - call void @"(*machine.UART).handleInterrupt$bound"(i32 2, i8* bitcast (%machine.UART* @machine.UART0 to i8*)) + call void @"(*machine.UART).handleInterrupt$bound"(i32 2, ptr @machine.UART0) ret void } @@ -37,7 +37,7 @@ entry: ] switch.body2: ; preds = %entry - call void @"(*machine.UART).handleInterrupt$bound"(i32 2, i8* bitcast (%machine.UART* @machine.UART0 to i8*)) + call void @"(*machine.UART).handleInterrupt$bound"(i32 2, ptr @machine.UART0) ret void switch.body5: ; preds = %entry @@ -47,11 +47,10 @@ switch.done: ; preds = %entry ret void } -define internal void @"(*machine.UART).handleInterrupt$bound"(i32 %0, i8* nocapture %context) { +define internal void @"(*machine.UART).handleInterrupt$bound"(i32 %0, ptr nocapture %context) { entry: - %unpack.ptr = bitcast i8* %context to %machine.UART* - call void @"(*machine.UART).handleInterrupt"(%machine.UART* %unpack.ptr, i32 %0, i8* undef) + call void @"(*machine.UART).handleInterrupt"(ptr %context, i32 %0, ptr undef) ret void } -declare void @"(*machine.UART).handleInterrupt"(%machine.UART* nocapture, i32, i8* nocapture readnone) +declare void @"(*machine.UART).handleInterrupt"(ptr nocapture, i32, ptr nocapture readnone) diff --git a/transform/testdata/maps.ll b/transform/testdata/maps.ll index 0bf0042466..78f6819da7 100644 --- a/transform/testdata/maps.ll +++ b/transform/testdata/maps.ll @@ -1,28 +1,25 @@ target datalayout = "e-m:e-p:32:32-i64:64-v128:64:128-a:0:32-n32-S64" target triple = "armv7m-none-eabi" -%runtime.hashmap = type { %runtime.hashmap*, i8*, i32, i8, i8, i8 } - @answer = constant [6 x i8] c"answer" ; func(keySize, valueSize uint8, sizeHint uintptr) *runtime.hashmap -declare nonnull %runtime.hashmap* @runtime.hashmapMake(i8, i8, i32) +declare nonnull ptr @runtime.hashmapMake(i8, i8, i32) ; func(map[string]int, string, unsafe.Pointer) -declare void @runtime.hashmapStringSet(%runtime.hashmap* nocapture, i8*, i32, i8* nocapture readonly) +declare void @runtime.hashmapStringSet(ptr nocapture, ptr, i32, ptr nocapture readonly) ; func(map[string]int, string, unsafe.Pointer) -declare i1 @runtime.hashmapStringGet(%runtime.hashmap* nocapture, i8*, i32, i8* nocapture) +declare i1 @runtime.hashmapStringGet(ptr nocapture, ptr, i32, ptr nocapture) define void @testUnused() { ; create the map - %map = call %runtime.hashmap* @runtime.hashmapMake(i8 4, i8 4, i32 0) + %map = call ptr @runtime.hashmapMake(i8 4, i8 4, i32 0) ; create the value to be stored %hashmap.value = alloca i32 - store i32 42, i32* %hashmap.value + store i32 42, ptr %hashmap.value ; store the value - %hashmap.value.bitcast = bitcast i32* %hashmap.value to i8* - call void @runtime.hashmapStringSet(%runtime.hashmap* %map, i8* getelementptr inbounds ([6 x i8], [6 x i8]* @answer, i32 0, i32 0), i32 6, i8* %hashmap.value.bitcast) + call void @runtime.hashmapStringSet(ptr %map, ptr @answer, i32 6, ptr %hashmap.value) ret void } @@ -30,26 +27,24 @@ define void @testUnused() { ; return 42), but isn't at the moment. define i32 @testReadonly() { ; create the map - %map = call %runtime.hashmap* @runtime.hashmapMake(i8 4, i8 4, i32 0) + %map = call ptr @runtime.hashmapMake(i8 4, i8 4, i32 0) ; create the value to be stored %hashmap.value = alloca i32 - store i32 42, i32* %hashmap.value + store i32 42, ptr %hashmap.value ; store the value - %hashmap.value.bitcast = bitcast i32* %hashmap.value to i8* - call void @runtime.hashmapStringSet(%runtime.hashmap* %map, i8* getelementptr inbounds ([6 x i8], [6 x i8]* @answer, i32 0, i32 0), i32 6, i8* %hashmap.value.bitcast) + call void @runtime.hashmapStringSet(ptr %map, ptr @answer, i32 6, ptr %hashmap.value) ; load the value back %hashmap.value2 = alloca i32 - %hashmap.value2.bitcast = bitcast i32* %hashmap.value2 to i8* - %commaOk = call i1 @runtime.hashmapStringGet(%runtime.hashmap* %map, i8* getelementptr inbounds ([6 x i8], [6 x i8]* @answer, i32 0, i32 0), i32 6, i8* %hashmap.value2.bitcast) - %loadedValue = load i32, i32* %hashmap.value2 + %commaOk = call i1 @runtime.hashmapStringGet(ptr %map, ptr @answer, i32 6, ptr %hashmap.value2) + %loadedValue = load i32, ptr %hashmap.value2 ret i32 %loadedValue } -define %runtime.hashmap* @testUsed() { - %1 = call %runtime.hashmap* @runtime.hashmapMake(i8 4, i8 4, i32 0) - ret %runtime.hashmap* %1 +define ptr @testUsed() { + %1 = call ptr @runtime.hashmapMake(i8 4, i8 4, i32 0) + ret ptr %1 } diff --git a/transform/testdata/maps.out.ll b/transform/testdata/maps.out.ll index 81f9badab2..1bd0743953 100644 --- a/transform/testdata/maps.out.ll +++ b/transform/testdata/maps.out.ll @@ -1,34 +1,30 @@ target datalayout = "e-m:e-p:32:32-i64:64-v128:64:128-a:0:32-n32-S64" target triple = "armv7m-none-eabi" -%runtime.hashmap = type { %runtime.hashmap*, i8*, i32, i8, i8, i8 } - @answer = constant [6 x i8] c"answer" -declare nonnull %runtime.hashmap* @runtime.hashmapMake(i8, i8, i32) +declare nonnull ptr @runtime.hashmapMake(i8, i8, i32) -declare void @runtime.hashmapStringSet(%runtime.hashmap* nocapture, i8*, i32, i8* nocapture readonly) +declare void @runtime.hashmapStringSet(ptr nocapture, ptr, i32, ptr nocapture readonly) -declare i1 @runtime.hashmapStringGet(%runtime.hashmap* nocapture, i8*, i32, i8* nocapture) +declare i1 @runtime.hashmapStringGet(ptr nocapture, ptr, i32, ptr nocapture) define void @testUnused() { ret void } define i32 @testReadonly() { - %map = call %runtime.hashmap* @runtime.hashmapMake(i8 4, i8 4, i32 0) + %map = call ptr @runtime.hashmapMake(i8 4, i8 4, i32 0) %hashmap.value = alloca i32, align 4 - store i32 42, i32* %hashmap.value, align 4 - %hashmap.value.bitcast = bitcast i32* %hashmap.value to i8* - call void @runtime.hashmapStringSet(%runtime.hashmap* %map, i8* getelementptr inbounds ([6 x i8], [6 x i8]* @answer, i32 0, i32 0), i32 6, i8* %hashmap.value.bitcast) + store i32 42, ptr %hashmap.value, align 4 + call void @runtime.hashmapStringSet(ptr %map, ptr @answer, i32 6, ptr %hashmap.value) %hashmap.value2 = alloca i32, align 4 - %hashmap.value2.bitcast = bitcast i32* %hashmap.value2 to i8* - %commaOk = call i1 @runtime.hashmapStringGet(%runtime.hashmap* %map, i8* getelementptr inbounds ([6 x i8], [6 x i8]* @answer, i32 0, i32 0), i32 6, i8* %hashmap.value2.bitcast) - %loadedValue = load i32, i32* %hashmap.value2, align 4 + %commaOk = call i1 @runtime.hashmapStringGet(ptr %map, ptr @answer, i32 6, ptr %hashmap.value2) + %loadedValue = load i32, ptr %hashmap.value2, align 4 ret i32 %loadedValue } -define %runtime.hashmap* @testUsed() { - %1 = call %runtime.hashmap* @runtime.hashmapMake(i8 4, i8 4, i32 0) - ret %runtime.hashmap* %1 +define ptr @testUsed() { + %1 = call ptr @runtime.hashmapMake(i8 4, i8 4, i32 0) + ret ptr %1 } diff --git a/transform/testdata/panic.ll b/transform/testdata/panic.ll index 4f0f0a167b..660e30f2f5 100644 --- a/transform/testdata/panic.ll +++ b/transform/testdata/panic.ll @@ -3,12 +3,12 @@ target triple = "armv7m-none-eabi" @"runtime.lookupPanic$string" = constant [18 x i8] c"index out of range" -declare void @runtime.runtimePanic(i8*, i32) +declare void @runtime.runtimePanic(ptr, i32) -declare void @runtime._panic(i32, i8*) +declare void @runtime._panic(i32, ptr) define void @runtime.lookupPanic() { - call void @runtime.runtimePanic(i8* getelementptr inbounds ([18 x i8], [18 x i8]* @"runtime.lookupPanic$string", i64 0, i64 0), i32 18) + call void @runtime.runtimePanic(ptr @"runtime.lookupPanic$string", i32 18) ret void } @@ -16,7 +16,7 @@ define void @runtime.lookupPanic() { ; func someFunc(x interface{}) { ; panic(x) ; } -define void @someFunc(i32 %typecode, i8* %value) { - call void @runtime._panic(i32 %typecode, i8* %value) +define void @someFunc(i32 %typecode, ptr %value) { + call void @runtime._panic(i32 %typecode, ptr %value) unreachable } diff --git a/transform/testdata/panic.out.ll b/transform/testdata/panic.out.ll index 8612ae9a67..458e4c2477 100644 --- a/transform/testdata/panic.out.ll +++ b/transform/testdata/panic.out.ll @@ -3,19 +3,19 @@ target triple = "armv7m-none-eabi" @"runtime.lookupPanic$string" = constant [18 x i8] c"index out of range" -declare void @runtime.runtimePanic(i8*, i32) +declare void @runtime.runtimePanic(ptr, i32) -declare void @runtime._panic(i32, i8*) +declare void @runtime._panic(i32, ptr) define void @runtime.lookupPanic() { call void @llvm.trap() - call void @runtime.runtimePanic(i8* getelementptr inbounds ([18 x i8], [18 x i8]* @"runtime.lookupPanic$string", i64 0, i64 0), i32 18) + call void @runtime.runtimePanic(ptr @"runtime.lookupPanic$string", i32 18) ret void } -define void @someFunc(i32 %typecode, i8* %value) { +define void @someFunc(i32 %typecode, ptr %value) { call void @llvm.trap() - call void @runtime._panic(i32 %typecode, i8* %value) + call void @runtime._panic(i32 %typecode, ptr %value) unreachable } diff --git a/transform/testdata/reflect-implements.ll b/transform/testdata/reflect-implements.ll index ca6dcb8c5e..46536483b2 100644 --- a/transform/testdata/reflect-implements.ll +++ b/transform/testdata/reflect-implements.ll @@ -1,19 +1,14 @@ target datalayout = "e-m:e-p:32:32-p270:32:32-p271:32:32-p272:64:64-f64:32:64-f80:32-n8:16:32-S128" target triple = "i686--linux" -%runtime.typecodeID = type { %runtime.typecodeID*, i32, %runtime.interfaceMethodInfo*, %runtime.typecodeID*, i32 } -%runtime.interfaceMethodInfo = type { i8*, i32 } +%runtime._interface = type { ptr, ptr } -@"reflect/types.type:named:error" = linkonce_odr constant %runtime.typecodeID { %runtime.typecodeID* @"reflect/types.type:interface:{Error:func:{}{basic:string}}", i32 0, %runtime.interfaceMethodInfo* null, %runtime.typecodeID* null, i32 ptrtoint (i1 (i32)* @"error.$typeassert" to i32) } -@"reflect/types.type:interface:{Error:func:{}{basic:string}}" = linkonce_odr constant %runtime.typecodeID { %runtime.typecodeID* bitcast ([1 x i8*]* @"reflect/types.interface:interface{Error() string}$interface" to %runtime.typecodeID*), i32 0, %runtime.interfaceMethodInfo* null, %runtime.typecodeID* null, i32 ptrtoint (i1 (i32)* @"error.$typeassert" to i32) } -@"reflect/methods.Error() string" = linkonce_odr constant i8 0 -@"reflect/types.interface:interface{Error() string}$interface" = linkonce_odr constant [1 x i8*] [i8* @"reflect/methods.Error() string"] -@"reflect/methods.Align() int" = linkonce_odr constant i8 0 -@"reflect/methods.Implements(reflect.Type) bool" = linkonce_odr constant i8 0 -@"reflect.Type$interface" = linkonce_odr constant [2 x i8*] [i8* @"reflect/methods.Align() int", i8* @"reflect/methods.Implements(reflect.Type) bool"] -@"reflect/types.type:named:reflect.rawType" = linkonce_odr constant %runtime.typecodeID { %runtime.typecodeID* @"reflect/types.type:basic:uintptr", i32 0, %runtime.interfaceMethodInfo* getelementptr inbounds ([20 x %runtime.interfaceMethodInfo], [20 x %runtime.interfaceMethodInfo]* @"reflect.rawType$methodset", i32 0, i32 0), %runtime.typecodeID* null, i32 0 } -@"reflect.rawType$methodset" = linkonce_odr constant [20 x %runtime.interfaceMethodInfo] zeroinitializer -@"reflect/types.type:basic:uintptr" = linkonce_odr constant %runtime.typecodeID zeroinitializer +@"reflect/types.type:named:error" = internal constant { i8, i16, ptr, ptr } { i8 52, i16 0, ptr @"reflect/types.type:pointer:named:error", ptr @"reflect/types.type:interface:{Error:func:{}{basic:string}}" }, align 4 +@"reflect/types.type:interface:{Error:func:{}{basic:string}}" = internal constant { i8, ptr } { i8 20, ptr @"reflect/types.type:pointer:interface:{Error:func:{}{basic:string}}" }, align 4 +@"reflect/types.type:pointer:interface:{Error:func:{}{basic:string}}" = internal constant { i8, ptr } { i8 21, ptr @"reflect/types.type:interface:{Error:func:{}{basic:string}}" }, align 4 +@"reflect/types.type:pointer:named:error" = internal constant { i8, i16, ptr } { i8 21, i16 0, ptr @"reflect/types.type:named:error" }, align 4 +@"reflect/types.type:pointer:named:reflect.rawType" = internal constant { ptr, i8, i16, ptr } { ptr null, i8 21, i16 0, ptr null }, align 4 +@"reflect/methods.Implements(reflect.Type) bool" = internal constant i8 0, align 1 ; var errorType = reflect.TypeOf((*error)(nil)).Elem() ; func isError(typ reflect.Type) bool { @@ -22,9 +17,9 @@ target triple = "i686--linux" ; The type itself is stored in %typ.value, %typ.typecode just refers to the ; type of reflect.Type. This function can be optimized because errorType is ; known at compile time (after the interp pass has run). -define i1 @main.isError(i32 %typ.typecode, i8* %typ.value, i8* %context) { +define i1 @main.isError(ptr %typ.typecode, ptr %typ.value, ptr %context) { entry: - %result = call i1 @"reflect.Type.Implements$invoke"(i8* %typ.value, i32 ptrtoint (%runtime.typecodeID* @"reflect/types.type:named:reflect.rawType" to i32), i8* bitcast (%runtime.typecodeID* @"reflect/types.type:named:error" to i8*), i32 %typ.typecode, i8* undef) + %result = call i1 @"reflect.Type.Implements$invoke"(ptr %typ.value, ptr getelementptr inbounds ({ ptr, i8, ptr }, ptr @"reflect/types.type:pointer:named:reflect.rawType", i32 0, i32 1), ptr @"reflect/types.type:named:error", ptr %typ.typecode, ptr undef) ret i1 %result } @@ -33,14 +28,14 @@ entry: ; func isUnknown(typ, itf reflect.Type) bool { ; return typ.Implements(itf) ; } -define i1 @main.isUnknown(i32 %typ.typecode, i8* %typ.value, i32 %itf.typecode, i8* %itf.value, i8* %context) { +define i1 @main.isUnknown(ptr %typ.typecode, ptr %typ.value, ptr %itf.typecode, ptr %itf.value, ptr %context) { entry: - %result = call i1 @"reflect.Type.Implements$invoke"(i8* %typ.value, i32 %itf.typecode, i8* %itf.value, i32 %typ.typecode, i8* undef) + %result = call i1 @"reflect.Type.Implements$invoke"(ptr %typ.value, ptr %itf.typecode, ptr %itf.value, ptr %typ.typecode, ptr undef) ret i1 %result } -declare i1 @"reflect.Type.Implements$invoke"(i8*, i32, i8*, i32, i8*) #0 -declare i1 @"error.$typeassert"(i32) #1 +declare i1 @"reflect.Type.Implements$invoke"(ptr, ptr, ptr, ptr, ptr) #0 +declare i1 @"interface:{Error:func:{}{basic:string}}.$typeassert"(ptr %0) #1 attributes #0 = { "tinygo-invoke"="reflect/methods.Implements(reflect.Type) bool" "tinygo-methods"="reflect/methods.Align() int; reflect/methods.Implements(reflect.Type) bool" } attributes #1 = { "tinygo-methods"="reflect/methods.Error() string" } diff --git a/transform/testdata/reflect-implements.out.ll b/transform/testdata/reflect-implements.out.ll index 0093e2b0a1..b7b759c018 100644 --- a/transform/testdata/reflect-implements.out.ll +++ b/transform/testdata/reflect-implements.out.ll @@ -1,36 +1,28 @@ target datalayout = "e-m:e-p:32:32-p270:32:32-p271:32:32-p272:64:64-f64:32:64-f80:32-n8:16:32-S128" target triple = "i686--linux" -%runtime.typecodeID = type { %runtime.typecodeID*, i32, %runtime.interfaceMethodInfo*, %runtime.typecodeID*, i32 } -%runtime.interfaceMethodInfo = type { i8*, i32 } +@"reflect/types.type:named:error" = internal constant { i8, i16, ptr, ptr } { i8 52, i16 0, ptr @"reflect/types.type:pointer:named:error", ptr @"reflect/types.type:interface:{Error:func:{}{basic:string}}" }, align 4 +@"reflect/types.type:interface:{Error:func:{}{basic:string}}" = internal constant { i8, ptr } { i8 20, ptr @"reflect/types.type:pointer:interface:{Error:func:{}{basic:string}}" }, align 4 +@"reflect/types.type:pointer:interface:{Error:func:{}{basic:string}}" = internal constant { i8, ptr } { i8 21, ptr @"reflect/types.type:interface:{Error:func:{}{basic:string}}" }, align 4 +@"reflect/types.type:pointer:named:error" = internal constant { i8, i16, ptr } { i8 21, i16 0, ptr @"reflect/types.type:named:error" }, align 4 +@"reflect/types.type:pointer:named:reflect.rawType" = internal constant { ptr, i8, i16, ptr } { ptr null, i8 21, i16 0, ptr null }, align 4 +@"reflect/methods.Implements(reflect.Type) bool" = internal constant i8 0, align 1 -@"reflect/types.type:named:error" = linkonce_odr constant %runtime.typecodeID { %runtime.typecodeID* @"reflect/types.type:interface:{Error:func:{}{basic:string}}", i32 0, %runtime.interfaceMethodInfo* null, %runtime.typecodeID* null, i32 ptrtoint (i1 (i32)* @"error.$typeassert" to i32) } -@"reflect/types.type:interface:{Error:func:{}{basic:string}}" = linkonce_odr constant %runtime.typecodeID { %runtime.typecodeID* bitcast ([1 x i8*]* @"reflect/types.interface:interface{Error() string}$interface" to %runtime.typecodeID*), i32 0, %runtime.interfaceMethodInfo* null, %runtime.typecodeID* null, i32 ptrtoint (i1 (i32)* @"error.$typeassert" to i32) } -@"reflect/methods.Error() string" = linkonce_odr constant i8 0 -@"reflect/types.interface:interface{Error() string}$interface" = linkonce_odr constant [1 x i8*] [i8* @"reflect/methods.Error() string"] -@"reflect/methods.Align() int" = linkonce_odr constant i8 0 -@"reflect/methods.Implements(reflect.Type) bool" = linkonce_odr constant i8 0 -@"reflect.Type$interface" = linkonce_odr constant [2 x i8*] [i8* @"reflect/methods.Align() int", i8* @"reflect/methods.Implements(reflect.Type) bool"] -@"reflect/types.type:named:reflect.rawType" = linkonce_odr constant %runtime.typecodeID { %runtime.typecodeID* @"reflect/types.type:basic:uintptr", i32 0, %runtime.interfaceMethodInfo* getelementptr inbounds ([20 x %runtime.interfaceMethodInfo], [20 x %runtime.interfaceMethodInfo]* @"reflect.rawType$methodset", i32 0, i32 0), %runtime.typecodeID* null, i32 0 } -@"reflect.rawType$methodset" = linkonce_odr constant [20 x %runtime.interfaceMethodInfo] zeroinitializer -@"reflect/types.type:basic:uintptr" = linkonce_odr constant %runtime.typecodeID zeroinitializer - -define i1 @main.isError(i32 %typ.typecode, i8* %typ.value, i8* %context) { +define i1 @main.isError(ptr %typ.typecode, ptr %typ.value, ptr %context) { entry: - %0 = ptrtoint i8* %typ.value to i32 - %1 = call i1 @"error.$typeassert"(i32 %0) - ret i1 %1 + %0 = call i1 @"interface:{Error:func:{}{basic:string}}.$typeassert"(ptr %typ.value) + ret i1 %0 } -define i1 @main.isUnknown(i32 %typ.typecode, i8* %typ.value, i32 %itf.typecode, i8* %itf.value, i8* %context) { +define i1 @main.isUnknown(ptr %typ.typecode, ptr %typ.value, ptr %itf.typecode, ptr %itf.value, ptr %context) { entry: - %result = call i1 @"reflect.Type.Implements$invoke"(i8* %typ.value, i32 %itf.typecode, i8* %itf.value, i32 %typ.typecode, i8* undef) + %result = call i1 @"reflect.Type.Implements$invoke"(ptr %typ.value, ptr %itf.typecode, ptr %itf.value, ptr %typ.typecode, ptr undef) ret i1 %result } -declare i1 @"reflect.Type.Implements$invoke"(i8*, i32, i8*, i32, i8*) #0 +declare i1 @"reflect.Type.Implements$invoke"(ptr, ptr, ptr, ptr, ptr) #0 -declare i1 @"error.$typeassert"(i32) #1 +declare i1 @"interface:{Error:func:{}{basic:string}}.$typeassert"(ptr) #1 attributes #0 = { "tinygo-invoke"="reflect/methods.Implements(reflect.Type) bool" "tinygo-methods"="reflect/methods.Align() int; reflect/methods.Implements(reflect.Type) bool" } attributes #1 = { "tinygo-methods"="reflect/methods.Error() string" } diff --git a/transform/testdata/stacksize.ll b/transform/testdata/stacksize.ll index f80a712172..4df5874a67 100644 --- a/transform/testdata/stacksize.ll +++ b/transform/testdata/stacksize.ll @@ -1,15 +1,15 @@ target datalayout = "e-m:e-p:32:32-i64:64-v128:64:128-a:0:32-n32-S64" target triple = "armv7m-none-eabi" -declare i32 @"internal/task.getGoroutineStackSize"(i32, i8*, i8*) +declare i32 @"internal/task.getGoroutineStackSize"(i32, ptr, ptr) -declare void @"runtime.run$1$gowrapper"(i8*) +declare void @"runtime.run$1$gowrapper"(ptr) -declare void @"internal/task.start"(i32, i8*, i32) +declare void @"internal/task.start"(i32, ptr, i32) define void @Reset_Handler() { entry: - %stacksize = call i32 @"internal/task.getGoroutineStackSize"(i32 ptrtoint (void (i8*)* @"runtime.run$1$gowrapper" to i32), i8* undef, i8* undef) - call void @"internal/task.start"(i32 ptrtoint (void (i8*)* @"runtime.run$1$gowrapper" to i32), i8* undef, i32 %stacksize) + %stacksize = call i32 @"internal/task.getGoroutineStackSize"(i32 ptrtoint (ptr @"runtime.run$1$gowrapper" to i32), ptr undef, ptr undef) + call void @"internal/task.start"(i32 ptrtoint (ptr @"runtime.run$1$gowrapper" to i32), ptr undef, i32 %stacksize) ret void } diff --git a/transform/testdata/stacksize.out.ll b/transform/testdata/stacksize.out.ll index cea820ec1b..4efc4a22d6 100644 --- a/transform/testdata/stacksize.out.ll +++ b/transform/testdata/stacksize.out.ll @@ -1,18 +1,18 @@ target datalayout = "e-m:e-p:32:32-i64:64-v128:64:128-a:0:32-n32-S64" target triple = "armv7m-none-eabi" -@"internal/task.stackSizes" = global [1 x i32] [i32 1024], section ".tinygo_stacksizes" -@llvm.used = appending global [2 x i8*] [i8* bitcast ([1 x i32]* @"internal/task.stackSizes" to i8*), i8* bitcast (void (i8*)* @"runtime.run$1$gowrapper" to i8*)] +@"internal/task.stackSizes" = global [1 x i32] [i32 1024], section ".tinygo_stacksizes", align 4 +@llvm.used = appending global [2 x ptr] [ptr @"internal/task.stackSizes", ptr @"runtime.run$1$gowrapper"] -declare i32 @"internal/task.getGoroutineStackSize"(i32, i8*, i8*) +declare i32 @"internal/task.getGoroutineStackSize"(i32, ptr, ptr) -declare void @"runtime.run$1$gowrapper"(i8*) +declare void @"runtime.run$1$gowrapper"(ptr) -declare void @"internal/task.start"(i32, i8*, i32) +declare void @"internal/task.start"(i32, ptr, i32) define void @Reset_Handler() { entry: - %stacksize1 = load i32, i32* getelementptr inbounds ([1 x i32], [1 x i32]* @"internal/task.stackSizes", i32 0, i32 0), align 4 - call void @"internal/task.start"(i32 ptrtoint (void (i8*)* @"runtime.run$1$gowrapper" to i32), i8* undef, i32 %stacksize1) + %stacksize1 = load i32, ptr @"internal/task.stackSizes", align 4 + call void @"internal/task.start"(i32 ptrtoint (ptr @"runtime.run$1$gowrapper" to i32), ptr undef, i32 %stacksize1) ret void } diff --git a/transform/testdata/stringequal.ll b/transform/testdata/stringequal.ll index d355fc45e9..0d6ed7fb20 100644 --- a/transform/testdata/stringequal.ll +++ b/transform/testdata/stringequal.ll @@ -3,17 +3,17 @@ target triple = "armv7m-none-eabi" @zeroString = constant [0 x i8] zeroinitializer -declare i1 @runtime.stringEqual(i8*, i32, i8*, i32, i8*) +declare i1 @runtime.stringEqual(ptr, i32, ptr, i32, ptr) -define i1 @main.stringCompareEqualConstantZero(i8* %s1.data, i32 %s1.len, i8* %context) { +define i1 @main.stringCompareEqualConstantZero(ptr %s1.data, i32 %s1.len, ptr %context) { entry: - %0 = call i1 @runtime.stringEqual(i8* %s1.data, i32 %s1.len, i8* getelementptr inbounds ([0 x i8], [0 x i8]* @zeroString, i32 0, i32 0), i32 0, i8* undef) + %0 = call i1 @runtime.stringEqual(ptr %s1.data, i32 %s1.len, ptr @zeroString, i32 0, ptr undef) ret i1 %0 } -define i1 @main.stringCompareUnequalConstantZero(i8* %s1.data, i32 %s1.len, i8* %context) { +define i1 @main.stringCompareUnequalConstantZero(ptr %s1.data, i32 %s1.len, ptr %context) { entry: - %0 = call i1 @runtime.stringEqual(i8* %s1.data, i32 %s1.len, i8* getelementptr inbounds ([0 x i8], [0 x i8]* @zeroString, i32 0, i32 0), i32 0, i8* undef) + %0 = call i1 @runtime.stringEqual(ptr %s1.data, i32 %s1.len, ptr @zeroString, i32 0, ptr undef) %1 = xor i1 %0, true ret i1 %1 } diff --git a/transform/testdata/stringequal.out.ll b/transform/testdata/stringequal.out.ll index d148c84fde..f2aeb95aba 100644 --- a/transform/testdata/stringequal.out.ll +++ b/transform/testdata/stringequal.out.ll @@ -3,15 +3,15 @@ target triple = "armv7m-none-eabi" @zeroString = constant [0 x i8] zeroinitializer -declare i1 @runtime.stringEqual(i8*, i32, i8*, i32, i8*) +declare i1 @runtime.stringEqual(ptr, i32, ptr, i32, ptr) -define i1 @main.stringCompareEqualConstantZero(i8* %s1.data, i32 %s1.len, i8* %context) { +define i1 @main.stringCompareEqualConstantZero(ptr %s1.data, i32 %s1.len, ptr %context) { entry: %0 = icmp eq i32 %s1.len, 0 ret i1 %0 } -define i1 @main.stringCompareUnequalConstantZero(i8* %s1.data, i32 %s1.len, i8* %context) { +define i1 @main.stringCompareUnequalConstantZero(ptr %s1.data, i32 %s1.len, ptr %context) { entry: %0 = icmp eq i32 %s1.len, 0 %1 = xor i1 %0, true diff --git a/transform/testdata/stringtobytes.ll b/transform/testdata/stringtobytes.ll index f3cec82349..fa43f3d02f 100644 --- a/transform/testdata/stringtobytes.ll +++ b/transform/testdata/stringtobytes.ll @@ -3,30 +3,30 @@ target triple = "x86_64--linux" @str = constant [6 x i8] c"foobar" -declare { i8*, i64, i64 } @runtime.stringToBytes(i8*, i64) +declare { ptr, i64, i64 } @runtime.stringToBytes(ptr, i64) -declare void @printSlice(i8* nocapture readonly, i64, i64) +declare void @printSlice(ptr nocapture readonly, i64, i64) -declare void @writeToSlice(i8* nocapture, i64, i64) +declare void @writeToSlice(ptr nocapture, i64, i64) ; Test that runtime.stringToBytes can be fully optimized away. define void @testReadOnly() { entry: - %0 = call fastcc { i8*, i64, i64 } @runtime.stringToBytes(i8* getelementptr inbounds ([6 x i8], [6 x i8]* @str, i32 0, i32 0), i64 6) - %1 = extractvalue { i8*, i64, i64 } %0, 0 - %2 = extractvalue { i8*, i64, i64 } %0, 1 - %3 = extractvalue { i8*, i64, i64 } %0, 2 - call fastcc void @printSlice(i8* %1, i64 %2, i64 %3) + %0 = call fastcc { ptr, i64, i64 } @runtime.stringToBytes(ptr @str, i64 6) + %1 = extractvalue { ptr, i64, i64 } %0, 0 + %2 = extractvalue { ptr, i64, i64 } %0, 1 + %3 = extractvalue { ptr, i64, i64 } %0, 2 + call fastcc void @printSlice(ptr %1, i64 %2, i64 %3) ret void } ; Test that even though the slice is written to, some values can be propagated. define void @testReadWrite() { entry: - %0 = call fastcc { i8*, i64, i64 } @runtime.stringToBytes(i8* getelementptr inbounds ([6 x i8], [6 x i8]* @str, i32 0, i32 0), i64 6) - %1 = extractvalue { i8*, i64, i64 } %0, 0 - %2 = extractvalue { i8*, i64, i64 } %0, 1 - %3 = extractvalue { i8*, i64, i64 } %0, 2 - call fastcc void @writeToSlice(i8* %1, i64 %2, i64 %3) + %0 = call fastcc { ptr, i64, i64 } @runtime.stringToBytes(ptr @str, i64 6) + %1 = extractvalue { ptr, i64, i64 } %0, 0 + %2 = extractvalue { ptr, i64, i64 } %0, 1 + %3 = extractvalue { ptr, i64, i64 } %0, 2 + call fastcc void @writeToSlice(ptr %1, i64 %2, i64 %3) ret void } diff --git a/transform/testdata/stringtobytes.out.ll b/transform/testdata/stringtobytes.out.ll index 49b065818d..30aa520ad1 100644 --- a/transform/testdata/stringtobytes.out.ll +++ b/transform/testdata/stringtobytes.out.ll @@ -3,22 +3,22 @@ target triple = "x86_64--linux" @str = constant [6 x i8] c"foobar" -declare { i8*, i64, i64 } @runtime.stringToBytes(i8*, i64) +declare { ptr, i64, i64 } @runtime.stringToBytes(ptr, i64) -declare void @printSlice(i8* nocapture readonly, i64, i64) +declare void @printSlice(ptr nocapture readonly, i64, i64) -declare void @writeToSlice(i8* nocapture, i64, i64) +declare void @writeToSlice(ptr nocapture, i64, i64) define void @testReadOnly() { entry: - call fastcc void @printSlice(i8* getelementptr inbounds ([6 x i8], [6 x i8]* @str, i32 0, i32 0), i64 6, i64 6) + call fastcc void @printSlice(ptr @str, i64 6, i64 6) ret void } define void @testReadWrite() { entry: - %0 = call fastcc { i8*, i64, i64 } @runtime.stringToBytes(i8* getelementptr inbounds ([6 x i8], [6 x i8]* @str, i32 0, i32 0), i64 6) - %1 = extractvalue { i8*, i64, i64 } %0, 0 - call fastcc void @writeToSlice(i8* %1, i64 6, i64 6) + %0 = call fastcc { ptr, i64, i64 } @runtime.stringToBytes(ptr @str, i64 6) + %1 = extractvalue { ptr, i64, i64 } %0, 0 + call fastcc void @writeToSlice(ptr %1, i64 6, i64 6) ret void } diff --git a/transform/testdata/wasm-abi.ll b/transform/testdata/wasm-abi.ll index 79e253473b..ade4b5af56 100644 --- a/transform/testdata/wasm-abi.ll +++ b/transform/testdata/wasm-abi.ll @@ -1,19 +1,19 @@ target datalayout = "e-m:e-p:32:32-i64:64-n32:64-S128" target triple = "wasm32-unknown-unknown-wasm" -declare i64 @externalCall(i8*, i32, i64) +declare i64 @externalCall(ptr, i32, i64) -define internal i64 @testCall(i8* %ptr, i32 %len, i64 %foo) { - %val = call i64 @externalCall(i8* %ptr, i32 %len, i64 %foo) +define internal i64 @testCall(ptr %ptr, i32 %len, i64 %foo) { + %val = call i64 @externalCall(ptr %ptr, i32 %len, i64 %foo) ret i64 %val } -define internal i64 @testCallNonEntry(i8* %ptr, i32 %len) { +define internal i64 @testCallNonEntry(ptr %ptr, i32 %len) { entry: br label %bb1 bb1: - %val = call i64 @externalCall(i8* %ptr, i32 %len, i64 3) + %val = call i64 @externalCall(ptr %ptr, i32 %len, i64 3) ret i64 %val } diff --git a/transform/testdata/wasm-abi.out.ll b/transform/testdata/wasm-abi.out.ll index cf63c3d792..a1fc7d6a9b 100644 --- a/transform/testdata/wasm-abi.out.ll +++ b/transform/testdata/wasm-abi.out.ll @@ -1,27 +1,27 @@ target datalayout = "e-m:e-p:32:32-i64:64-n32:64-S128" target triple = "wasm32-unknown-unknown-wasm" -declare i64 @"externalCall$i64wrap"(i8*, i32, i64) +declare i64 @"externalCall$i64wrap"(ptr, i32, i64) -define internal i64 @testCall(i8* %ptr, i32 %len, i64 %foo) { +define internal i64 @testCall(ptr %ptr, i32 %len, i64 %foo) { %i64asptr = alloca i64, align 8 %i64asptr1 = alloca i64, align 8 - store i64 %foo, i64* %i64asptr1, align 8 - call void @externalCall(i64* %i64asptr, i8* %ptr, i32 %len, i64* %i64asptr1) - %retval = load i64, i64* %i64asptr, align 8 + store i64 %foo, ptr %i64asptr1, align 8 + call void @externalCall(ptr %i64asptr, ptr %ptr, i32 %len, ptr %i64asptr1) + %retval = load i64, ptr %i64asptr, align 8 ret i64 %retval } -define internal i64 @testCallNonEntry(i8* %ptr, i32 %len) { +define internal i64 @testCallNonEntry(ptr %ptr, i32 %len) { entry: %i64asptr = alloca i64, align 8 %i64asptr1 = alloca i64, align 8 br label %bb1 bb1: ; preds = %entry - store i64 3, i64* %i64asptr1, align 8 - call void @externalCall(i64* %i64asptr, i8* %ptr, i32 %len, i64* %i64asptr1) - %retval = load i64, i64* %i64asptr, align 8 + store i64 3, ptr %i64asptr1, align 8 + call void @externalCall(ptr %i64asptr, ptr %ptr, i32 %len, ptr %i64asptr1) + %retval = load i64, ptr %i64asptr, align 8 ret i64 %retval } @@ -35,11 +35,11 @@ define internal void @callExportedFunction(i64 %foo) { ret void } -declare void @externalCall(i64*, i8*, i32, i64*) +declare void @externalCall(ptr, ptr, i32, ptr) -define void @exportedFunction(i64* %0) { +define void @exportedFunction(ptr %0) { entry: - %i64 = load i64, i64* %0, align 8 + %i64 = load i64, ptr %0, align 8 call void @"exportedFunction$i64wrap"(i64 %i64) ret void }