diff --git a/demo/file-header_test.go b/demo/filetool/file-header_test.go similarity index 100% rename from demo/file-header_test.go rename to demo/filetool/file-header_test.go diff --git a/demo/httpreq/curl_test.go b/demo/httpreq/curl_test.go new file mode 100644 index 0000000..af34e4a --- /dev/null +++ b/demo/httpreq/curl_test.go @@ -0,0 +1,24 @@ +package httpreq + +import ( + "strings" + "testing" + + "github.com/ahuigo/gohttptool/httpreq" +) + +func TestCurl(t *testing.T) { + curl, err := httpreq.R(). + SetParams(map[string]string{"p": "1"}). + AddCookieKV("count", "1"). + AddFileHeader("file", "test.txt", []byte("hello world")). + ToCurl() + if err != nil { + t.Fatal(err) + } + if !strings.HasPrefix(curl, "curl ") { + t.Fatal("bad curl: ", curl) + }else{ + t.Log("curl: ", curl) + } +} diff --git a/go.mod b/go.mod index 83bf0f6..517ba10 100644 --- a/go.mod +++ b/go.mod @@ -1,3 +1,5 @@ module github.com/ahuigo/gohttptool go 1.22.1 + +require github.com/pkg/errors v0.9.1 diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..7c401c3 --- /dev/null +++ b/go.sum @@ -0,0 +1,2 @@ +github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= +github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= diff --git a/httpreq/curl.go b/httpreq/curl.go new file mode 100644 index 0000000..119d3e1 --- /dev/null +++ b/httpreq/curl.go @@ -0,0 +1,88 @@ +package httpreq + +import ( + "bytes" + "io" + "net/http" + "net/http/cookiejar" + + "net/url" + "strings" + + "github.com/ahuigo/gohttptool/shell" +) + +func (r *request) FromCurl(curl string) { + +} +func (r *request) ToCurl() (curl string, err error) { + if httpreq, err := r.ToRequest(); err != nil { + return "", err + } else { + curl := buildCurlRequest(httpreq, nil) + return curl, nil + } +} + +func buildCurlRequest(req *http.Request, httpCookiejar http.CookieJar) (curl string) { + // 1. Generate curl raw headers + curl = "curl -X " + req.Method + " " + // req.Host + req.URL.Path + "?" + req.URL.RawQuery + " " + req.Proto + " " + headers := dumpCurlHeaders(req) + for _, kv := range *headers { + curl += `-H ` + shell.Quote(kv[0]+": "+kv[1]) + ` ` + } + + // 2. Generate curl cookies + if cookieJar, ok := httpCookiejar.(*cookiejar.Jar); ok { + cookies := cookieJar.Cookies(req.URL) + if len(cookies) > 0 { + curl += ` -H ` + shell.Quote(dumpCurlCookies(cookies)) + " " + } + } + + // 3. Generate curl body + if req.Body != nil { + buf, _ := io.ReadAll(req.Body) + req.Body = io.NopCloser(bytes.NewBuffer(buf)) // important!! + curl += `-d ` + shell.Quote(string(buf)) + } + + urlString := shell.Quote(req.URL.String()) + if urlString == "''" { + urlString = "'http://unexecuted-request'" + } + curl += " " + urlString + return curl +} + +// dumpCurlCookies dumps cookies to curl format +func dumpCurlCookies(cookies []*http.Cookie) string { + sb := strings.Builder{} + sb.WriteString("Cookie: ") + for _, cookie := range cookies { + sb.WriteString(cookie.Name + "=" + url.QueryEscape(cookie.Value) + "&") + } + return strings.TrimRight(sb.String(), "&") +} + +// dumpCurlHeaders dumps headers to curl format +func dumpCurlHeaders(req *http.Request) *[][2]string { + headers := [][2]string{} + for k, vs := range req.Header { + for _, v := range vs { + headers = append(headers, [2]string{k, v}) + } + } + n := len(headers) + for i := 0; i < n; i++ { + for j := n - 1; j > i; j-- { + jj := j - 1 + h1, h2 := headers[j], headers[jj] + if h1[0] < h2[0] { + headers[jj], headers[j] = headers[j], headers[jj] + } + } + } + return &headers +} diff --git a/httpreq/req-builder.go b/httpreq/req-builder.go new file mode 100644 index 0000000..c1d81a3 --- /dev/null +++ b/httpreq/req-builder.go @@ -0,0 +1,143 @@ +package httpreq + +import ( + "bytes" + "fmt" + "io" + "mime/multipart" + "net/http" + "net/url" + "os" + "strings" + + "github.com/pkg/errors" +) + +func (session *request) ToRequest() (*http.Request, error) { + var dataType = ContentType(session.rawreq.Header.Get("Content-Type")) + var origurl = session.url + if len(session.files) > 0 || len(session.fileHeaders) > 0 { + dataType = ContentTypeFormData + } + + URL, err := session.buildURLParams(origurl) + if err != nil { + return nil, err + } + if URL.Scheme == "" || URL.Host == "" { + err = &url.Error{Op: "parse", URL: origurl, Err: fmt.Errorf("failed")} + return nil, err + } + + switch dataType { + case ContentTypeFormEncode: + if len(session.datas) > 0 { + formEncodeValues := session.buildFormEncode(session.datas) + session.setBodyFormEncode(formEncodeValues) + } + case ContentTypeFormData: + // multipart/form-data + session.buildFilesAndForms() + } + + if session.rawreq.Body == nil && session.rawreq.Method != "GET" { + session.rawreq.Body = http.NoBody + } + + session.rawreq.URL = URL + + return session.rawreq, nil +} + +// build post Form encode +func (session *request) buildFormEncode(datas map[string]string) (Forms url.Values) { + Forms = url.Values{} + for key, value := range datas { + Forms.Add(key, value) + } + return Forms +} + +// set form urlencode +func (session *request) setBodyFormEncode(Forms url.Values) { + data := Forms.Encode() + session.rawreq.Body = io.NopCloser(strings.NewReader(data)) + session.rawreq.ContentLength = int64(len(data)) +} + +func (r *request) buildURLParams(userURL string) (*url.URL, error) { + params := r.params + paramsArray := r.paramsList + if strings.HasPrefix(userURL, "/") { + userURL = "http://localhost" + userURL + }else if userURL == ""{ + userURL = "http://unknown" + } + parsedURL, err := url.Parse(userURL) + + if err != nil { + return nil, err + } + + values := parsedURL.Query() + + for key, value := range params { + values.Set(key, value) + } + for key, vals := range paramsArray { + for _, v := range vals { + values.Add(key, v) + } + } + parsedURL.RawQuery = values.Encode() + return parsedURL, nil +} + +func (r *request) buildFilesAndForms() error { + files := r.files + datas := r.datas + filesHeaders := r.fileHeaders + //handle file multipart + var b bytes.Buffer + w := multipart.NewWriter(&b) + + for k, v := range datas { + w.WriteField(k, v) + } + + for field, path := range files { + part, err := w.CreateFormFile(field, path) + if err != nil { + fmt.Printf("Upload %s failed!", path) + panic(err) + } + file, err := os.Open(path) + if err != nil { + err = errors.WithMessagef(err, "Open %s", path) + return err + } + _, err = io.Copy(part, file) + if err != nil { + return err + } + } + for field, fileheader := range filesHeaders { + part, err := w.CreateFormFile(field, fileheader.Filename) + if err != nil { + fmt.Printf("Upload %s failed!", field) + panic(err) + } + _, err = io.Copy(part, bytes.NewReader([]byte(fileheader.content))) + if err != nil { + return err + } + } + + w.Close() + // set file header example: + // "Content-Type": "multipart/form-data; boundary=------------------------7d87eceb5520850c", + r.rawreq.Body = io.NopCloser(bytes.NewReader(b.Bytes())) + r.rawreq.ContentLength = int64(b.Len()) + r.rawreq.Header.Set("Content-Type", w.FormDataContentType()) + return nil +} diff --git a/httpreq/req.go b/httpreq/req.go new file mode 100644 index 0000000..23425a3 --- /dev/null +++ b/httpreq/req.go @@ -0,0 +1,119 @@ +package httpreq + +import ( + "context" + "net/http" +) + +type ContentType string + +const ( + ContentTypeNone ContentType = "" + ContentTypeFormEncode ContentType = "application/x-www-form-urlencoded" + ContentTypeFormData ContentType = "multipart/form-data" + ContentTypeJson ContentType = "application/json" + ContentTypePlain ContentType = "text/plain" +) + +type fileHeader struct { + Filename string + // Header textproto.MIMEHeader + Size int64 + content []byte + // tmpfile string + // tmpoff int64 + // tmpshared bool +} + +type request struct { + rawreq *http.Request + url string + files map[string]string // field -> path + fileHeaders map[string]fileHeader // field -> contents + datas map[string]string // key -> value + params map[string]string // key -> value + paramsList map[string][]string // key -> value list +} + +func R() *request { + return &request{ + rawreq: &http.Request{ + Method: "GET", + Header: make(http.Header), + Proto: "HTTP/1.1", + ProtoMajor: 1, + ProtoMinor: 1, + }, + files: make(map[string]string), + fileHeaders: make(map[string]fileHeader), + datas: make(map[string]string), + params: make(map[string]string), + paramsList: make(map[string][]string), + } +} + +func (r *request) SetAuth(key, value string) { + r.rawreq.SetBasicAuth(key, value) +} + +func (r *request) SetHeader(key, value string) { + r.rawreq.Header.Set(key, value) +} +func (r *request) AddFile(fieldname, path string) *request { + r.files[fieldname] = path + return r +} + +func (r *request) AddFileHeader(fieldname, filename string, content []byte) *request { + r.fileHeaders[fieldname] = fileHeader{ + Filename: filename, + content: content, + Size: int64(len(content)), + } + return r +} + +func (r *request) AddCookies(cookies []*http.Cookie) *request { + for _, cookie := range cookies { + r.rawreq.AddCookie(cookie) + } + return r +} +func (r *request) AddCookieKV(name, value string) *request { + cookie := &http.Cookie{ + Name: name, + Value: value, + } + r.rawreq.AddCookie(cookie) + return r +} + +func (r *request) SetUrl(url string) *request { + r.url = url + return r +} + +func (r *request) SetMethod(method string) *request { + r.rawreq.Method = method + return r +} + +func (r *request) SetParams(params map[string]string) *request { + r.params = params + return r +} + +func (r *request) GetRawreq() *http.Request { + return r.rawreq +} + +func (r *request) SetCtx(ctx context.Context) *request { + r.rawreq = r.rawreq.WithContext(ctx) + return r +} + +func (r *request) EnableTrace(ctx context.Context) *request { + trace := clientTraceNew(r.rawreq.Context()) + r.rawreq = r.rawreq.WithContext(trace.ctx) + return r +} diff --git a/httpreq/trace.go b/httpreq/trace.go new file mode 100644 index 0000000..24593ae --- /dev/null +++ b/httpreq/trace.go @@ -0,0 +1,181 @@ +// Copyright (c) 2015-2021 Jeevanandam M (jeeva@myjeeva.com), All rights reserved. + +package httpreq + +import ( + "context" + "crypto/tls" + "net" + "net/http/httptrace" + "time" +) + +//‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾ +// TraceInfo struct +//_______________________________________________________________________ + +// TraceInfo struct is used provide request trace info such as DNS lookup +// duration, Connection obtain duration, Server processing duration, etc. +// +// Since v2.0.0 +type TraceInfo struct { + // DNSLookup is a duration that transport took to perform + // DNS lookup. + DNSLookup time.Duration + + // ConnTime is a duration that took to obtain a successful connection. + ConnTime time.Duration + + // TCPConnTime is a duration that took to obtain the TCP connection. + TCPConnTime time.Duration + + // TLSHandshake is a duration that TLS handshake took place. + TLSHandshake time.Duration + + // ServerTime is a duration that server took to respond first byte. + ServerTime time.Duration + + // ResponseTime is a duration since first response byte from server to + // request completion. + ResponseTime time.Duration + + // TotalTime is a duration that total request took end-to-end. + TotalTime time.Duration + + // IsConnReused is whether this connection has been previously + // used for another HTTP request. + IsConnReused bool + + // IsConnWasIdle is whether this connection was obtained from an + // idle pool. + IsConnWasIdle bool + + // ConnIdleTime is a duration how long the connection was previously + // idle, if IsConnWasIdle is true. + ConnIdleTime time.Duration + + // RequestAttempt is to represent the request attempt made during a Resty + // request execution flow, including retry count. + RequestAttempt int + + // RemoteAddr returns the remote network address. + RemoteAddr net.Addr +} + +//‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾ +// ClientTrace struct and its methods +//_______________________________________________________________________ + +// tracer struct maps the `httptrace.ClientTrace` hooks into Fields +// with same naming for easy understanding. Plus additional insights +// Request. +type clientTrace struct { + getConn time.Time + dnsStart time.Time + dnsDone time.Time + connectDone time.Time + tlsHandshakeStart time.Time + tlsHandshakeDone time.Time + gotConn time.Time + gotFirstResponseByte time.Time + endTime time.Time + gotConnInfo httptrace.GotConnInfo + ctx context.Context +} + +//‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾ +// Trace unexported methods +//_______________________________________________________________________ + +func clientTraceNew(ctx context.Context) *clientTrace { + trace := &clientTrace{} + trace.ctx = trace.createContext(ctx) + return trace +} + +func (t *clientTrace) createContext(ctx context.Context) context.Context { + return httptrace.WithClientTrace( + ctx, + &httptrace.ClientTrace{ + DNSStart: func(_ httptrace.DNSStartInfo) { + t.dnsStart = time.Now() + }, + DNSDone: func(_ httptrace.DNSDoneInfo) { + t.dnsDone = time.Now() + }, + ConnectStart: func(_, _ string) { + if t.dnsDone.IsZero() { + t.dnsDone = time.Now() + } + if t.dnsStart.IsZero() { + t.dnsStart = t.dnsDone + } + }, + ConnectDone: func(net, addr string, err error) { + t.connectDone = time.Now() + }, + GetConn: func(_ string) { + t.getConn = time.Now() + }, + GotConn: func(ci httptrace.GotConnInfo) { + t.gotConn = time.Now() + t.gotConnInfo = ci + }, + GotFirstResponseByte: func() { + t.gotFirstResponseByte = time.Now() + }, + TLSHandshakeStart: func() { + t.tlsHandshakeStart = time.Now() + }, + TLSHandshakeDone: func(_ tls.ConnectionState, _ error) { + t.tlsHandshakeDone = time.Now() + }, + }, + ) +} + +func (ct *clientTrace) TraceInfo() TraceInfo { + if ct == nil { + return TraceInfo{} + } + + ti := TraceInfo{ + DNSLookup: ct.dnsDone.Sub(ct.dnsStart), + TLSHandshake: ct.tlsHandshakeDone.Sub(ct.tlsHandshakeStart), + ServerTime: ct.gotFirstResponseByte.Sub(ct.gotConn), + IsConnReused: ct.gotConnInfo.Reused, + IsConnWasIdle: ct.gotConnInfo.WasIdle, + ConnIdleTime: ct.gotConnInfo.IdleTime, + // RequestAttempt: r.Attempt, + } + + // Calculate the total time accordingly, + // when connection is reused + if ct.gotConnInfo.Reused { + ti.TotalTime = ct.endTime.Sub(ct.getConn) + } else { + ti.TotalTime = ct.endTime.Sub(ct.dnsStart) + } + + // Only calculate on successful connections + if !ct.connectDone.IsZero() { + ti.TCPConnTime = ct.connectDone.Sub(ct.dnsDone) + } + + // Only calculate on successful connections + if !ct.gotConn.IsZero() { + ti.ConnTime = ct.gotConn.Sub(ct.getConn) + } + + // Only calculate on successful connections + if !ct.gotFirstResponseByte.IsZero() { + ti.ResponseTime = ct.endTime.Sub(ct.gotFirstResponseByte) + } + + // Capture remote address info when connection is non-nil + if ct.gotConnInfo.Conn != nil { + ti.RemoteAddr = ct.gotConnInfo.Conn.RemoteAddr() + } + + return ti +} diff --git a/shell/shellescape.go b/shell/shellescape.go new file mode 100644 index 0000000..9ada7ad --- /dev/null +++ b/shell/shellescape.go @@ -0,0 +1,34 @@ +/* +Package shellescape provides the shellescape.Quote to escape arbitrary +strings for a safe use as command line arguments in the most common +POSIX shells. + +The original Python package which this work was inspired by can be found +at https://pypi.python.org/pypi/shellescape. +*/ +package shell + +import ( + "regexp" + "strings" +) + +var pattern *regexp.Regexp + +func init() { + pattern = regexp.MustCompile(`[^\w@%+=:,./-]`) +} + +// Quote returns a shell-escaped version of the string s. The returned value +// is a string that can safely be used as one token in a shell command line. +func Quote(s string) string { + if len(s) == 0 { + return "''" + } + + if pattern.MatchString(s) { + return "'" + strings.ReplaceAll(s, "'", "'\"'\"'") + "'" + } + + return s +}