|
- package maintnotifications
-
- import (
- "context"
- "errors"
- "fmt"
- "net"
- "sync"
- "sync/atomic"
- "time"
-
- "github.com/redis/go-redis/v9/internal"
- "github.com/redis/go-redis/v9/internal/interfaces"
- "github.com/redis/go-redis/v9/internal/maintnotifications/logs"
- "github.com/redis/go-redis/v9/internal/pool"
- "github.com/redis/go-redis/v9/push"
- )
-
- // Push notification type constants for maintenance
- const (
- NotificationMoving = "MOVING"
- NotificationMigrating = "MIGRATING"
- NotificationMigrated = "MIGRATED"
- NotificationFailingOver = "FAILING_OVER"
- NotificationFailedOver = "FAILED_OVER"
- )
-
- // maintenanceNotificationTypes contains all notification types that maintenance handles
- var maintenanceNotificationTypes = []string{
- NotificationMoving,
- NotificationMigrating,
- NotificationMigrated,
- NotificationFailingOver,
- NotificationFailedOver,
- }
-
- // NotificationHook is called before and after notification processing
- // PreHook can modify the notification and return false to skip processing
- // PostHook is called after successful processing
- type NotificationHook interface {
- PreHook(ctx context.Context, notificationCtx push.NotificationHandlerContext, notificationType string, notification []interface{}) ([]interface{}, bool)
- PostHook(ctx context.Context, notificationCtx push.NotificationHandlerContext, notificationType string, notification []interface{}, result error)
- }
-
- // MovingOperationKey provides a unique key for tracking MOVING operations
- // that combines sequence ID with connection identifier to handle duplicate
- // sequence IDs across multiple connections to the same node.
- type MovingOperationKey struct {
- SeqID int64 // Sequence ID from MOVING notification
- ConnID uint64 // Unique connection identifier
- }
-
- // String returns a string representation of the key for debugging
- func (k MovingOperationKey) String() string {
- return fmt.Sprintf("seq:%d-conn:%d", k.SeqID, k.ConnID)
- }
-
- // Manager provides a simplified upgrade functionality with hooks and atomic state.
- type Manager struct {
- client interfaces.ClientInterface
- config *Config
- options interfaces.OptionsInterface
- pool pool.Pooler
-
- // MOVING operation tracking - using sync.Map for better concurrent performance
- activeMovingOps sync.Map // map[MovingOperationKey]*MovingOperation
-
- // Atomic state tracking - no locks needed for state queries
- activeOperationCount atomic.Int64 // Number of active operations
- closed atomic.Bool // Manager closed state
-
- // Notification hooks for extensibility
- hooks []NotificationHook
- hooksMu sync.RWMutex // Protects hooks slice
- poolHooksRef *PoolHook
- }
-
- // MovingOperation tracks an active MOVING operation.
- type MovingOperation struct {
- SeqID int64
- NewEndpoint string
- StartTime time.Time
- Deadline time.Time
- }
-
- // NewManager creates a new simplified manager.
- func NewManager(client interfaces.ClientInterface, pool pool.Pooler, config *Config) (*Manager, error) {
- if client == nil {
- return nil, ErrInvalidClient
- }
-
- hm := &Manager{
- client: client,
- pool: pool,
- options: client.GetOptions(),
- config: config.Clone(),
- hooks: make([]NotificationHook, 0),
- }
-
- // Set up push notification handling
- if err := hm.setupPushNotifications(); err != nil {
- return nil, err
- }
-
- return hm, nil
- }
-
- // GetPoolHook creates a pool hook with a custom dialer.
- func (hm *Manager) InitPoolHook(baseDialer func(context.Context, string, string) (net.Conn, error)) {
- poolHook := hm.createPoolHook(baseDialer)
- hm.pool.AddPoolHook(poolHook)
- }
-
- // setupPushNotifications sets up push notification handling by registering with the client's processor.
- func (hm *Manager) setupPushNotifications() error {
- processor := hm.client.GetPushProcessor()
- if processor == nil {
- return ErrInvalidClient // Client doesn't support push notifications
- }
-
- // Create our notification handler
- handler := &NotificationHandler{manager: hm, operationsManager: hm}
-
- // Register handlers for all upgrade notifications with the client's processor
- for _, notificationType := range maintenanceNotificationTypes {
- if err := processor.RegisterHandler(notificationType, handler, true); err != nil {
- return errors.New(logs.FailedToRegisterHandler(notificationType, err))
- }
- }
-
- return nil
- }
-
- // TrackMovingOperationWithConnID starts a new MOVING operation with a specific connection ID.
- func (hm *Manager) TrackMovingOperationWithConnID(ctx context.Context, newEndpoint string, deadline time.Time, seqID int64, connID uint64) error {
- // Create composite key
- key := MovingOperationKey{
- SeqID: seqID,
- ConnID: connID,
- }
-
- // Create MOVING operation record
- movingOp := &MovingOperation{
- SeqID: seqID,
- NewEndpoint: newEndpoint,
- StartTime: time.Now(),
- Deadline: deadline,
- }
-
- // Use LoadOrStore for atomic check-and-set operation
- if _, loaded := hm.activeMovingOps.LoadOrStore(key, movingOp); loaded {
- // Duplicate MOVING notification, ignore
- if internal.LogLevel.DebugOrAbove() { // Debug level
- internal.Logger.Printf(context.Background(), logs.DuplicateMovingOperation(connID, newEndpoint, seqID))
- }
- return nil
- }
- if internal.LogLevel.DebugOrAbove() { // Debug level
- internal.Logger.Printf(context.Background(), logs.TrackingMovingOperation(connID, newEndpoint, seqID))
- }
-
- // Increment active operation count atomically
- hm.activeOperationCount.Add(1)
-
- return nil
- }
-
- // UntrackOperationWithConnID completes a MOVING operation with a specific connection ID.
- func (hm *Manager) UntrackOperationWithConnID(seqID int64, connID uint64) {
- // Create composite key
- key := MovingOperationKey{
- SeqID: seqID,
- ConnID: connID,
- }
-
- // Remove from active operations atomically
- if _, loaded := hm.activeMovingOps.LoadAndDelete(key); loaded {
- if internal.LogLevel.DebugOrAbove() { // Debug level
- internal.Logger.Printf(context.Background(), logs.UntrackingMovingOperation(connID, seqID))
- }
- // Decrement active operation count only if operation existed
- hm.activeOperationCount.Add(-1)
- } else {
- if internal.LogLevel.DebugOrAbove() { // Debug level
- internal.Logger.Printf(context.Background(), logs.OperationNotTracked(connID, seqID))
- }
- }
- }
-
- // GetActiveMovingOperations returns active operations with composite keys.
- // WARNING: This method creates a new map and copies all operations on every call.
- // Use sparingly, especially in hot paths or high-frequency logging.
- func (hm *Manager) GetActiveMovingOperations() map[MovingOperationKey]*MovingOperation {
- result := make(map[MovingOperationKey]*MovingOperation)
-
- // Iterate over sync.Map to build result
- hm.activeMovingOps.Range(func(key, value interface{}) bool {
- k := key.(MovingOperationKey)
- op := value.(*MovingOperation)
-
- // Create a copy to avoid sharing references
- result[k] = &MovingOperation{
- SeqID: op.SeqID,
- NewEndpoint: op.NewEndpoint,
- StartTime: op.StartTime,
- Deadline: op.Deadline,
- }
- return true // Continue iteration
- })
-
- return result
- }
-
- // IsHandoffInProgress returns true if any handoff is in progress.
- // Uses atomic counter for lock-free operation.
- func (hm *Manager) IsHandoffInProgress() bool {
- return hm.activeOperationCount.Load() > 0
- }
-
- // GetActiveOperationCount returns the number of active operations.
- // Uses atomic counter for lock-free operation.
- func (hm *Manager) GetActiveOperationCount() int64 {
- return hm.activeOperationCount.Load()
- }
-
- // Close closes the manager.
- func (hm *Manager) Close() error {
- // Use atomic operation for thread-safe close check
- if !hm.closed.CompareAndSwap(false, true) {
- return nil // Already closed
- }
-
- // Shutdown the pool hook if it exists
- if hm.poolHooksRef != nil {
- // Use a timeout to prevent hanging indefinitely
- shutdownCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
- defer cancel()
-
- err := hm.poolHooksRef.Shutdown(shutdownCtx)
- if err != nil {
- // was not able to close pool hook, keep closed state false
- hm.closed.Store(false)
- return err
- }
- // Remove the pool hook from the pool
- if hm.pool != nil {
- hm.pool.RemovePoolHook(hm.poolHooksRef)
- }
- }
-
- // Clear all active operations
- hm.activeMovingOps.Range(func(key, value interface{}) bool {
- hm.activeMovingOps.Delete(key)
- return true
- })
-
- // Reset counter
- hm.activeOperationCount.Store(0)
-
- return nil
- }
-
- // GetState returns current state using atomic counter for lock-free operation.
- func (hm *Manager) GetState() State {
- if hm.activeOperationCount.Load() > 0 {
- return StateMoving
- }
- return StateIdle
- }
-
- // processPreHooks calls all pre-hooks and returns the modified notification and whether to continue processing.
- func (hm *Manager) processPreHooks(ctx context.Context, notificationCtx push.NotificationHandlerContext, notificationType string, notification []interface{}) ([]interface{}, bool) {
- hm.hooksMu.RLock()
- defer hm.hooksMu.RUnlock()
-
- currentNotification := notification
-
- for _, hook := range hm.hooks {
- modifiedNotification, shouldContinue := hook.PreHook(ctx, notificationCtx, notificationType, currentNotification)
- if !shouldContinue {
- return modifiedNotification, false
- }
- currentNotification = modifiedNotification
- }
-
- return currentNotification, true
- }
-
- // processPostHooks calls all post-hooks with the processing result.
- func (hm *Manager) processPostHooks(ctx context.Context, notificationCtx push.NotificationHandlerContext, notificationType string, notification []interface{}, result error) {
- hm.hooksMu.RLock()
- defer hm.hooksMu.RUnlock()
-
- for _, hook := range hm.hooks {
- hook.PostHook(ctx, notificationCtx, notificationType, notification, result)
- }
- }
-
- // createPoolHook creates a pool hook with this manager already set.
- func (hm *Manager) createPoolHook(baseDialer func(context.Context, string, string) (net.Conn, error)) *PoolHook {
- if hm.poolHooksRef != nil {
- return hm.poolHooksRef
- }
- // Get pool size from client options for better worker defaults
- poolSize := 0
- if hm.options != nil {
- poolSize = hm.options.GetPoolSize()
- }
-
- hm.poolHooksRef = NewPoolHookWithPoolSize(baseDialer, hm.options.GetNetwork(), hm.config, hm, poolSize)
- hm.poolHooksRef.SetPool(hm.pool)
-
- return hm.poolHooksRef
- }
-
- func (hm *Manager) AddNotificationHook(notificationHook NotificationHook) {
- hm.hooksMu.Lock()
- defer hm.hooksMu.Unlock()
- hm.hooks = append(hm.hooks, notificationHook)
- }
|