diff --git a/cmd/ktor/main.go b/cmd/ktor/main.go index 2ff619a..8a35b83 100644 --- a/cmd/ktor/main.go +++ b/cmd/ktor/main.go @@ -1,6 +1,7 @@ package main import ( + "context" _ "embed" "fmt" "github.com/ktorio/ktor-cli/internal/app/cli" @@ -9,6 +10,7 @@ import ( "github.com/ktorio/ktor-cli/internal/app/utils" "io" "log" + "net" "net/http" "os" "path/filepath" @@ -58,7 +60,13 @@ func main() { cli.WriteUsage(os.Stdout) case cli.NewCommand: client := &http.Client{ - Timeout: 30 * time.Second, + Transport: &http.Transport{ + DialContext: func(_ context.Context, network, addr string) (net.Conn, error) { + return net.DialTimeout(network, addr, 5*time.Second) + }, + TLSHandshakeTimeout: 10 * time.Second, + ResponseHeaderTimeout: 10 * time.Second, + }, } projectName := utils.CleanProjectName(args.CommandArgs[0]) diff --git a/internal/app/cli/output.go b/internal/app/cli/output.go index 4a6a0e9..a9ea313 100644 --- a/internal/app/cli/output.go +++ b/internal/app/cli/output.go @@ -64,6 +64,8 @@ func HandleAppError(projectDir string, err error) (reportLog bool) { fmt.Fprintf(os.Stderr, "Unable to download JDK %s for %s %s\n", je.Descriptor.Version, je.Descriptor.Platform, je.Arch) case app.JdkServerError: fmt.Fprintf(os.Stderr, "Unexpected error occurred while connecting to a JDK server. Please try again later.\n") + case app.JdkServerDownloadError: + fmt.Fprintf(os.Stderr, "An error occurred while downloading from a JDK server. Please try again later.\n") case app.JdkVerificationFailed: fmt.Fprintln(os.Stderr, "Checksum verification for the downloaded JDK failed") case app.GradlewChmodError: diff --git a/internal/app/errors.go b/internal/app/errors.go index b505ceb..656098c 100644 --- a/internal/app/errors.go +++ b/internal/app/errors.go @@ -21,6 +21,7 @@ const ( ExtractRootDirExistError UnableLocateJdkError JdkServerError + JdkServerDownloadError JdkVerificationFailed ) diff --git a/internal/app/generate/project.go b/internal/app/generate/project.go index 9b81ab9..3f5047f 100644 --- a/internal/app/generate/project.go +++ b/internal/app/generate/project.go @@ -73,7 +73,7 @@ func Project(client *http.Client, logger *log.Logger, projectDir, project string len(zipBytes), logger.Writer() == io.Discard, ) - defer progressBar.Finish() + defer progressBar.Done() _, err = archive.ExtractZip(reader, int64(len(zipBytes)), projectDir, logger) diff --git a/internal/app/jdk/download.go b/internal/app/jdk/download.go index dd9d56d..af79737 100644 --- a/internal/app/jdk/download.go +++ b/internal/app/jdk/download.go @@ -40,8 +40,15 @@ func DownloadJdk(client *http.Client, d *Descriptor, logger *log.Logger) ([]byte } reader, progressBar := progress.NewReader(resp.Body, "Downloading JDK... ", utils.ContentLength(resp), true) - defer progressBar.Finish() - return io.ReadAll(reader) + b, err := io.ReadAll(reader) + + if err != nil { + progressBar.Stop() + return nil, &app.Error{Err: err, Kind: app.JdkServerDownloadError} + } + + progressBar.Done() + return b, nil } func hasJdkBuild(d *Descriptor) bool { diff --git a/internal/app/jdk/fetch.go b/internal/app/jdk/fetch.go index 1577823..a9a3deb 100644 --- a/internal/app/jdk/fetch.go +++ b/internal/app/jdk/fetch.go @@ -45,7 +45,7 @@ func fetch(client *http.Client, d *Descriptor, outDir string, logger *log.Logger len(jdkBytes), logger.Writer() == io.Discard, ) - defer progressBar.Finish() + defer progressBar.Done() extractedDirs, extractErr = archive.ExtractZip(reader, int64(len(jdkBytes)), outDir, logger) } else { @@ -55,7 +55,7 @@ func fetch(client *http.Client, d *Descriptor, outDir string, logger *log.Logger len(jdkBytes), logger.Writer() == io.Discard, ) - defer progressBar.Finish() + defer progressBar.Done() extractedDirs, extractErr = archive.ExtractTarGz(reader, outDir, logger) } diff --git a/internal/app/network/project.go b/internal/app/network/project.go index cb56ab5..9e1cd71 100644 --- a/internal/app/network/project.go +++ b/internal/app/network/project.go @@ -55,7 +55,7 @@ func NewProject(client *http.Client, payload ProjectPayload) ([]byte, error) { utils.ContentLength(resp), true, ) - defer progressBar.Finish() + defer progressBar.Done() bodyBytes, err := io.ReadAll(reader) if err != nil { diff --git a/internal/app/progress/percent.go b/internal/app/progress/percent.go index a7fb4c9..bbcc23c 100644 --- a/internal/app/progress/percent.go +++ b/internal/app/progress/percent.go @@ -64,7 +64,7 @@ func (b *Percent) tick(p []byte) (n int, err error) { return len(p), nil } -func (b *Percent) Finish() (err error) { +func (b *Percent) Done() (err error) { if !b.enabled { return nil } @@ -77,3 +77,12 @@ func (b *Percent) Finish() (err error) { fmt.Fprintf(b.Writer, "%s100%%\n", b.prefix) return } + +func (b *Percent) Stop() (err error) { + if !b.enabled { + return nil + } + + fmt.Fprintln(b.Writer) + return +} diff --git a/internal/app/progress/percent_test.go b/internal/app/progress/percent_test.go index d7f1e28..cb5305a 100644 --- a/internal/app/progress/percent_test.go +++ b/internal/app/progress/percent_test.go @@ -78,7 +78,7 @@ func checkAllWrites(t *testing.T, p *Percent, b *strings.Builder, cases []testCa for _, test := range cases { var err error if test.writeLen == -1 { - err = p.Finish() + err = p.Done() } else { _, err = p.Write(make([]byte, test.writeLen)) } @@ -93,7 +93,7 @@ func checkAllWrites(t *testing.T, p *Percent, b *strings.Builder, cases []testCa for _, test := range cases { var err error if test.writeLen == -1 { - err = p.Finish() + err = p.Done() } else { _, err = p.WriteAt(make([]byte, test.writeLen), test.offset) }