diff --git a/cmd/proxy/actions/app_proxy.go b/cmd/proxy/actions/app_proxy.go index 47ad60ac8..86c2a2478 100644 --- a/cmd/proxy/actions/app_proxy.go +++ b/cmd/proxy/actions/app_proxy.go @@ -134,7 +134,7 @@ type athensLoggerForRedis struct { } func (l *athensLoggerForRedis) Printf(ctx context.Context, format string, v ...any) { - l.logger.WithContext(ctx).Infof(format, v...) + l.logger.WithContext(ctx).Printf(format, v...) } func getSingleFlight(l *log.Logger, c *config.Config, s storage.Backend, checker storage.Checker) (stash.Wrapper, error) { diff --git a/cmd/proxy/actions/basicauth_test.go b/cmd/proxy/actions/basicauth_test.go index 57515c8c5..1f02f84fc 100644 --- a/cmd/proxy/actions/basicauth_test.go +++ b/cmd/proxy/actions/basicauth_test.go @@ -3,13 +3,12 @@ package actions import ( "bytes" "context" + "log/slog" "net/http" "net/http/httptest" "strings" "testing" - "log/slog" - "github.com/gomods/athens/pkg/log" ) diff --git a/cmd/proxy/main.go b/cmd/proxy/main.go index a7d8d1e2e..e56547c4d 100644 --- a/cmd/proxy/main.go +++ b/cmd/proxy/main.go @@ -37,7 +37,7 @@ func main() { stdlog.Fatalf("Could not load config file: %v", err) } - logLvl := slog.Level(0) + var logLvl slog.Level err = logLvl.UnmarshalText([]byte(conf.LogLevel)) if err != nil { stdlog.Fatalf("Could not parse log level %q: %v", conf.LogLevel, err) @@ -45,9 +45,19 @@ func main() { logger := athenslog.New(conf.CloudRuntime, logLvl, conf.LogFormat) + // Turn standard logger output into slog Errors. + logrusErrorWriter := logger.WriterLevel(slog.LevelError) + defer func() { + if err := logrusErrorWriter.Close(); err != nil { + logger.WithError(err).Warn("Could not close logrus writer pipe") + } + }() + stdlog.SetOutput(logrusErrorWriter) + stdlog.SetFlags(stdlog.Flags() &^ (stdlog.Ldate | stdlog.Ltime)) + handler, err := actions.App(logger, conf) if err != nil { - logger.With(err.Error()).Error("Could not create App") + logger.WithError(err).Fatal("Could not create App") } srv := &http.Server{ @@ -66,7 +76,7 @@ func main() { ctx, cancel := context.WithTimeout(context.Background(), time.Second*time.Duration(conf.ShutdownTimeout)) defer cancel() if err := srv.Shutdown(ctx); err != nil { - logger.With(err).Error("Could not shut down server") + logger.WithError(err).Fatal("Could not shut down server") } close(idleConnsClosed) }() @@ -77,7 +87,7 @@ func main() { // not to expose profiling data and avoid DoS attacks (profiling slows down the service) // https://www.farsightsecurity.com/txt-record/2016/10/28/cmikk-go-remote-profiling/ logger.WithField("port", conf.PprofPort).Infof("starting pprof") - logger.Error(http.ListenAndServe(conf.PprofPort, nil).Error()) //nolint:gosec // This should not be exposed to the world. + logger.Fatal(http.ListenAndServe(conf.PprofPort, nil)) //nolint:gosec // This should not be exposed to the world. }() } @@ -86,19 +96,19 @@ func main() { if conf.UnixSocket != "" { logger := logger.WithField("unixSocket", conf.UnixSocket) - logger.Infof("Starting application") + logger.Info("Starting application") ln, err = net.Listen("unix", conf.UnixSocket) if err != nil { - logger.WithError(err).Fatalf("Could not listen on Unix domain socket") + logger.WithError(err).Fatal("Could not listen on Unix domain socket") } } else { logger := logger.WithField("tcpPort", conf.Port) - logger.Infof("Starting application") + logger.Info("Starting application") ln, err = net.Listen("tcp", conf.Port) if err != nil { - logger.WithError(err).Fatalf("Could not listen on TCP port") + logger.WithError(err).Fatal("Could not listen on TCP port") } } @@ -109,8 +119,9 @@ func main() { } if !errors.Is(err, http.ErrServerClosed) { - logger.WithError(err).Fatalf("Could not start server") + logger.WithError(err).Fatal("Could not start server") } <-idleConnsClosed + } diff --git a/pkg/download/protocol.go b/pkg/download/protocol.go index c4cc76911..a6942a9c7 100644 --- a/pkg/download/protocol.go +++ b/pkg/download/protocol.go @@ -299,9 +299,6 @@ func union(list1, list2 []string) []string { func copyContextWithCustomTimeout(ctx context.Context, timeout time.Duration) (context.Context, context.CancelFunc) { ctxCopy, cancel := context.WithTimeout(context.Background(), timeout) ctxCopy = requestid.SetInContext(ctxCopy, requestid.FromContext(ctx)) - - if entry := log.EntryFromContext(ctx); entry != nil { - ctxCopy = log.SetEntryInContext(ctxCopy, &entry) - } + ctxCopy = log.SetEntryInContext(ctxCopy, log.EntryFromContext(ctx)) return ctxCopy, cancel } diff --git a/pkg/download/protocol_test.go b/pkg/download/protocol_test.go index 5037eae22..e2d46ad2c 100644 --- a/pkg/download/protocol_test.go +++ b/pkg/download/protocol_test.go @@ -5,6 +5,7 @@ import ( "context" "encoding/json" "io" + "log/slog" "os" "path/filepath" "regexp" @@ -504,15 +505,27 @@ var _ log.Entry = &testEntry{} func (e *testEntry) Debugf(format string, args ...any) { e.msg = format } -func (*testEntry) Infof(format string, args ...any) {} -func (*testEntry) Warnf(format string, args ...any) {} -func (*testEntry) Errorf(format string, args ...any) {} -func (*testEntry) Fatalf(format string, args ...any) {} -func (*testEntry) WithFields(fields map[string]any) log.Entry { return nil } -func (*testEntry) SystemErr(err error) {} -func (*testEntry) WithContext(ctx context.Context) log.Entry { return nil } -func (*testEntry) WithError(err error) log.Entry { return nil } -func (*testEntry) WithField(key string, value any) log.Entry { return nil } +func (*testEntry) Infof(format string, args ...any) {} +func (*testEntry) Warnf(format string, args ...any) {} +func (*testEntry) Errorf(format string, args ...any) {} +func (*testEntry) Fatalf(format string, args ...any) {} +func (*testEntry) Panicf(format string, args ...any) {} +func (*testEntry) Printf(format string, args ...any) {} + +func (*testEntry) Debug(args ...any) {} +func (*testEntry) Info(args ...any) {} +func (*testEntry) Warn(args ...any) {} +func (*testEntry) Error(args ...any) {} +func (*testEntry) Fatal(args ...any) {} +func (*testEntry) Panic(args ...any) {} +func (*testEntry) Print(args ...any) {} + +func (*testEntry) WithFields(fields map[string]any) log.Entry { return nil } +func (*testEntry) SystemErr(err error) {} +func (*testEntry) WithField(key string, value any) log.Entry { return nil } +func (*testEntry) WithError(err error) log.Entry { return nil } +func (*testEntry) WithContext(ctx context.Context) log.Entry { return nil } +func (*testEntry) WriterLevel(level slog.Level) *io.PipeWriter { return nil } func Test_copyContextWithCustomTimeout(t *testing.T) { testEntry := &testEntry{} diff --git a/pkg/errors/errors.go b/pkg/errors/errors.go index 189ce9622..67799ea7b 100644 --- a/pkg/errors/errors.go +++ b/pkg/errors/errors.go @@ -6,6 +6,8 @@ import ( "log/slog" "net/http" "runtime" + + "github.com/sirupsen/logrus" ) // Kind enums. @@ -145,13 +147,13 @@ func Severity(err error) slog.Level { // Expect is a helper that returns an Info level // if the error has the expected kind, otherwise // it returns an Error level. -func Expect(err error, kinds ...int) slog.Level { +func Expect(err error, kinds ...int) logrus.Level { for _, kind := range kinds { if Kind(err) == kind { - return slog.LevelInfo + return logrus.InfoLevel } } - return slog.LevelError + return logrus.ErrorLevel } // Kind recursively searches for the diff --git a/pkg/log/entry.go b/pkg/log/entry.go index 221ba43d6..97cefb392 100644 --- a/pkg/log/entry.go +++ b/pkg/log/entry.go @@ -1,46 +1,64 @@ package log import ( + "bufio" "context" + "fmt" + "io" "log/slog" + "os" "github.com/gomods/athens/pkg/errors" ) -// Entry is an abstraction to the -// Logger and the logrus.Entry -// so that *Logger always creates -// an Entry copy which ensures no -// Fields are being overwritten. type Entry interface { - // Basic Logging Operation - Debugf(format string, args ...any) - Infof(format string, args ...any) - Warnf(format string, args ...any) - Errorf(format string, args ...any) - Fatalf(format string, args ...any) - - // Attach contextual information to the logging entry - WithFields(fields map[string]any) Entry + // Keep the existing interface methods unchanged + Debugf(string, ...interface{}) + Infof(string, ...interface{}) + Warnf(string, ...interface{}) + Errorf(string, ...interface{}) + Fatalf(string, ...interface{}) + Panicf(string, ...interface{}) + Printf(string, ...interface{}) - WithField(key string, value any) Entry + Debug(...interface{}) + Info(...interface{}) + Warn(...interface{}) + Error(...interface{}) + Fatal(...interface{}) + Panic(...interface{}) + Print(...interface{}) + WithFields(fields map[string]any) Entry + WithField(key string, value any) Entry WithError(err error) Entry - WithContext(ctx context.Context) Entry - - // SystemErr is a method that disects the error - // and logs the appropriate level and fields for it. SystemErr(err error) + WriterLevel(level slog.Level) *io.PipeWriter } type entry struct { - *slog.Logger + logger *slog.Logger } func (e *entry) WithFields(fields map[string]any) Entry { - ent := e.WithFields(fields) - return ent + attrs := make([]any, 0, len(fields)*2) + for k, v := range fields { + attrs = append(attrs, slog.Any(k, v)) + } + return &entry{logger: e.logger.With(attrs...)} +} + +func (e *entry) WithField(key string, value any) Entry { + return &entry{logger: e.logger.With(key, value)} +} + +func (e *entry) WithError(err error) Entry { + return &entry{logger: e.logger.With("error", err)} +} + +func (e *entry) WithContext(ctx context.Context) Entry { + return &entry{logger: e.logger.With("context", ctx)} } func (e *entry) SystemErr(err error) { @@ -62,13 +80,102 @@ func (e *entry) SystemErr(err error) { ent.Errorf("%v", err) } } -func errFields(err errors.Error) map[string]any { - f := map[string]any{} - f["operation"] = err.Op - f["kind"] = errors.KindText(err) - f["module"] = err.Module - f["version"] = err.Version - f["ops"] = errors.Ops(err) +func (e *entry) Debug(args ...interface{}) { + e.logger.Debug(fmt.Sprint(args...)) +} + +func (e *entry) Info(args ...interface{}) { + e.logger.Info(fmt.Sprint(args...)) +} + +func (e *entry) Warn(args ...interface{}) { + e.logger.Warn(fmt.Sprint(args...)) +} + +func (e *entry) Error(args ...interface{}) { + e.logger.Error(fmt.Sprint(args...)) +} + +func (e *entry) Fatal(args ...interface{}) { + e.logger.Error(fmt.Sprint(args...)) // slog doesn't have Fatal, using Error +} + +func (e *entry) Panic(args ...interface{}) { + e.logger.Error(fmt.Sprint(args...)) // slog doesn't have Panic, using Error +} + +func (e *entry) Print(args ...interface{}) { + e.logger.Info(fmt.Sprint(args...)) +} + +func (e *entry) Debugf(format string, args ...interface{}) { + e.logger.Debug(fmt.Sprintf(format, args...)) +} + +func (e *entry) Infof(format string, args ...interface{}) { + e.logger.Info(fmt.Sprintf(format, args...)) +} + +func (e *entry) Warnf(format string, args ...interface{}) { + e.logger.Warn(fmt.Sprintf(format, args...)) +} + +func (e *entry) Errorf(format string, args ...interface{}) { + e.logger.Error(fmt.Sprintf(format, args...)) +} + +func (e *entry) Fatalf(format string, args ...interface{}) { + e.logger.Error(fmt.Sprintf(format, args...)) + os.Exit(1) +} + +func (e *entry) Panicf(format string, args ...interface{}) { + e.logger.Error(fmt.Sprintf(format, args...)) // slog doesn't have Panic +} + +func (e *entry) Printf(format string, args ...interface{}) { + e.logger.Info(fmt.Sprintf(format, args...)) +} + +func (e *entry) WriterLevel(level slog.Level) *io.PipeWriter { + reader, writer := io.Pipe() + + var logFunc func(args ...interface{}) + + // Determine which log function to use based on the specified log level + switch level { + case slog.LevelDebug: + logFunc = e.Debug + case slog.LevelInfo: + logFunc = e.Print + case slog.LevelWarn: + logFunc = e.Warn + case slog.LevelError: + logFunc = e.Error + default: + logFunc = e.Print + } + + // Start a new goroutine to scan and write to logger + go func(r *io.PipeReader, logFn func(...interface{})) { + scanner := bufio.NewScanner(r) + scanner.Buffer(make([]byte, 65536), 65536) + for scanner.Scan() { + logFn(scanner.Text()) + } + r.Close() + }(reader, logFunc) + + return writer +} + +func errFields(err errors.Error) map[string]any { + f := map[string]any{ + "kind": errors.KindText(err), + "module": err.Module, + "version": err.Version, + "ops": errors.Ops(err), + } return f } diff --git a/pkg/log/log.go b/pkg/log/log.go index 421b665d2..de35ea089 100644 --- a/pkg/log/log.go +++ b/pkg/log/log.go @@ -1,8 +1,10 @@ package log import ( + "bufio" "bytes" "context" + "io" "log/slog" "os" ) @@ -33,13 +35,17 @@ func New(cloudProvider string, level slog.Level, format string) *Logger { // SystemErr Entry implementation. func (l *Logger) SystemErr(err error) { - e := &entry{Logger: l.Logger} + e := &entry{l.Logger} e.SystemErr(err) } // WithFields Entry implementation. func (l *Logger) WithFields(fields map[string]any) Entry { - return l.WithFields(fields) + attrs := make([]any, 0, len(fields)) + for k, v := range fields { + attrs = append(attrs, slog.Any(k, v)) + } + return &entry{logger: l.Logger.With(attrs...)} } func (l *Logger) WithField(key string, value any) Entry { @@ -63,8 +69,22 @@ func (l *Logger) WithContext(ctx context.Context) Entry { return l.WithFields(keys) } +// Define WriterLevel +func (l *Logger) WriterLevel(level slog.Level) *io.PipeWriter { + pipeReader, pipeWriter := io.Pipe() + go func() { + scanner := bufio.NewScanner(pipeReader) + for scanner. + Scan() { + l.Info(scanner.Text()) + } + }() + return pipeWriter +} + // NoOpLogger provides a Logger that does nothing. func NoOpLogger() *Logger { - l := slog.New(slog.NewTextHandler(os.Stdout, nil)) - return &Logger{Logger: l} + return &Logger{ + Logger: &slog.Logger{}, + } } diff --git a/pkg/log/log_context.go b/pkg/log/log_context.go index 176ad4d22..54a65ddf3 100644 --- a/pkg/log/log_context.go +++ b/pkg/log/log_context.go @@ -9,7 +9,7 @@ type ctxKey string const logEntryKey ctxKey = "log-entry-context-key" // SetEntryInContext stores an Entry in the request context. -func SetEntryInContext(ctx context.Context, e *Entry) context.Context { +func SetEntryInContext(ctx context.Context, e Entry) context.Context { return context.WithValue(ctx, logEntryKey, e) } @@ -19,8 +19,7 @@ func SetEntryInContext(ctx context.Context, e *Entry) context.Context { func EntryFromContext(ctx context.Context) Entry { e, ok := ctx.Value(logEntryKey).(Entry) if !ok || e == nil { - //return a new entry - return &Entry{} + return &entry{NoOpLogger().Logger} } return e } diff --git a/pkg/middleware/log_entry.go b/pkg/middleware/log_entry.go index 00bc2e6e3..8d6016cf7 100644 --- a/pkg/middleware/log_entry.go +++ b/pkg/middleware/log_entry.go @@ -6,6 +6,7 @@ import ( "github.com/gomods/athens/pkg/log" "github.com/gomods/athens/pkg/requestid" "github.com/gorilla/mux" + "github.com/sirupsen/logrus" ) // LogEntryMiddleware builds a log.Entry, setting the request fields @@ -14,11 +15,11 @@ func LogEntryMiddleware(lggr *log.Logger) mux.MiddlewareFunc { return func(h http.Handler) http.Handler { f := func(w http.ResponseWriter, r *http.Request) { ctx := r.Context() - ent := lggr.With( - "http-method", r.Method, - "http-path", r.URL.Path, - "request-id", requestid.FromContext(ctx), - ) + ent := lggr.WithFields(logrus.Fields{ + "http-method": r.Method, + "http-path": r.URL.Path, + "request-id": requestid.FromContext(ctx), + }) ctx = log.SetEntryInContext(ctx, ent) r = r.WithContext(ctx) h.ServeHTTP(w, r) diff --git a/pkg/middleware/log_entry_test.go b/pkg/middleware/log_entry_test.go index 835b7e528..7ce3e87b2 100644 --- a/pkg/middleware/log_entry_test.go +++ b/pkg/middleware/log_entry_test.go @@ -2,10 +2,11 @@ package middleware import ( "bytes" - "encoding/json" + "fmt" "log/slog" "net/http" "net/http/httptest" + "strings" "testing" "github.com/gomods/athens/pkg/log" @@ -22,8 +23,11 @@ func TestLogContext(t *testing.T) { r := mux.NewRouter() r.HandleFunc("/test", h) - buf := new(bytes.Buffer) - lggr := log.New("", slog.LevelInfo, "") + var buf bytes.Buffer + lggr := log.New("", slog.LevelDebug, "") + opts := slog.HandlerOptions{Level: slog.LevelDebug} + handler := slog.NewJSONHandler(&buf, &opts) + lggr.Logger = slog.New(handler) r.Use(LogEntryMiddleware(lggr)) @@ -31,19 +35,6 @@ func TestLogContext(t *testing.T) { req, _ := http.NewRequest("GET", "/test", nil) r.ServeHTTP(w, req) - var logEntry map[string]interface{} - err := json.Unmarshal(buf.Bytes(), &logEntry) - assert.NoError(t, err) - - expectedFields := map[string]interface{}{ - "level": "INFO", - "msg": "test", - "http-method": "GET", - "http-path": "/test", - "request-id": "", - } - - for k, v := range expectedFields { - assert.Equal(t, v, logEntry[k], "Log entry should contain %s with value %v", k, v) - } + expected := `{"http-method":"GET","http-path":"/test","level":"info","msg":"test","request-id":""}` + assert.True(t, strings.Contains(buf.String(), expected), fmt.Sprintf("%s should contain: %s", buf.String(), expected)) }