diff --git a/mq_impl/diskq/consumer.go b/mq_impl/diskq/consumer.go index bdbdd29..b4255e4 100644 --- a/mq_impl/diskq/consumer.go +++ b/mq_impl/diskq/consumer.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "sync" + "sync/atomic" "time" "github.com/Ccheers/kratos-mq/mq" @@ -24,6 +25,8 @@ type ConsumerImpl struct { sf singleflight.Group consumerChan map[string]chan mq.Message pool routinepool.Pool + + status uint32 } func NewConsumer(c *config.Config, logger log.Logger) (mq.Consumer, error) { @@ -33,6 +36,7 @@ func NewConsumer(c *config.Config, logger log.Logger) (mq.Consumer, error) { logger: logger, consumerChan: make(map[string]chan mq.Message), pool: routinepool.NewPool("[diskq][Consumer]", 4, routinepool.NewConfig()), + status: statusRunning, }, nil } @@ -52,6 +56,9 @@ func (x *ConsumerImpl) Subscribe(ctx context.Context, topic string, channel stri ch := make(chan mq.Message, 1) x.pool.Go(func(ctx context.Context) { for { + if atomic.LoadUint32(&x.status) == statusClosed { + return + } select { case body := <-queue.ReadChan(): msg, err := mq.NewMessageFromByte(body) @@ -70,9 +77,12 @@ func (x *ConsumerImpl) Subscribe(ctx context.Context, topic string, channel stri } func (x *ConsumerImpl) Close(ctx context.Context) error { + if !atomic.CompareAndSwapUint32(&x.status, statusRunning, statusClosed) { + return nil + } x.mu.Lock() defer x.mu.Unlock() - for uniKey, _ := range x.consumerChan { + for uniKey := range x.consumerChan { close(x.consumerChan[uniKey]) } err := gDiskQueueManager.Close(ctx) diff --git a/mq_impl/diskq/status.go b/mq_impl/diskq/status.go new file mode 100644 index 0000000..304ad3f --- /dev/null +++ b/mq_impl/diskq/status.go @@ -0,0 +1,6 @@ +package diskq + +const ( + statusRunning uint32 = 0 + statusClosed uint32 = 1 +) diff --git a/mq_impl/mqtt/consumer.go b/mq_impl/mqtt/consumer.go index 3d36580..31e0d19 100644 --- a/mq_impl/mqtt/consumer.go +++ b/mq_impl/mqtt/consumer.go @@ -7,7 +7,6 @@ import ( "sync/atomic" "github.com/Ccheers/kratos-mq/mq" - "github.com/eclipse/paho.mqtt.golang" "github.com/go-kratos/kratos/v2/log" ) @@ -51,6 +50,9 @@ func (x *ConsumerImpl) Subscribe(ctx context.Context, topic string, channel stri ch := make(chan mq.Message, 1) token := x.client.Subscribe(topic, x.cfg.WillQos, func(client mqtt.Client, message mqtt.Message) { + if atomic.LoadUint32(&x.status) == statusClosed { + return + } msg, err := mq.NewMessageFromByte(message.Payload()) if err != nil { x.logger.Errorw("topic", topic, "channel", channel, "payload", string(message.Payload()), "err", err) diff --git a/mq_impl/nsq/consumer.go b/mq_impl/nsq/consumer.go index 5548272..5820ebe 100644 --- a/mq_impl/nsq/consumer.go +++ b/mq_impl/nsq/consumer.go @@ -4,11 +4,11 @@ import ( "context" "fmt" "sync" + "sync/atomic" "github.com/Ccheers/kratos-mq/mq" "github.com/Ccheers/kratos-mq/mq_impl/nsq/config" "github.com/go-kratos/kratos/v2/log" - "github.com/nsqio/go-nsq" ) var _ mq.Consumer = (*ConsumerImpl)(nil) @@ -22,6 +22,8 @@ type ConsumerImpl struct { mu sync.Mutex consumerMap map[string]*nsq.Consumer consumerChan map[string]chan mq.Message + + status uint32 } func NewConsumer(c *config.Config, logger log.Logger) (mq.Consumer, error) { @@ -53,6 +55,9 @@ func (x *ConsumerImpl) Subscribe(ctx context.Context, topic string, channel stri ch := make(chan mq.Message, 1) consumer.AddHandler(nsq.HandlerFunc(func(message *nsq.Message) error { + if atomic.LoadUint32(&x.status) == statusClosed { + return nil + } msg, err := mq.NewMessageFromByte(message.Body) if err != nil { return err @@ -73,6 +78,9 @@ func (x *ConsumerImpl) Subscribe(ctx context.Context, topic string, channel stri } func (x *ConsumerImpl) Close(ctx context.Context) error { + if !atomic.CompareAndSwapUint32(&x.status, statusRunning, statusClosed) { + return nil + } x.mu.Lock() defer x.mu.Unlock() for uniKey, consumer := range x.consumerMap { diff --git a/mq_impl/nsq/status.go b/mq_impl/nsq/status.go new file mode 100644 index 0000000..2f87350 --- /dev/null +++ b/mq_impl/nsq/status.go @@ -0,0 +1,6 @@ +package nsq + +const ( + statusRunning uint32 = 0 + statusClosed uint32 = 1 +)