您最多选择25个主题 主题必须以字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符
 
 
 
 

939 行
22 KiB

  1. package redis
  2. import (
  3. "context"
  4. "crypto/tls"
  5. "errors"
  6. "fmt"
  7. "net"
  8. "strconv"
  9. "sync"
  10. "sync/atomic"
  11. "time"
  12. "github.com/cespare/xxhash/v2"
  13. "github.com/dgryski/go-rendezvous" //nolint
  14. "github.com/redis/go-redis/v9/auth"
  15. "github.com/redis/go-redis/v9/internal"
  16. "github.com/redis/go-redis/v9/internal/hashtag"
  17. "github.com/redis/go-redis/v9/internal/pool"
  18. "github.com/redis/go-redis/v9/internal/proto"
  19. "github.com/redis/go-redis/v9/internal/rand"
  20. )
  21. var errRingShardsDown = errors.New("redis: all ring shards are down")
  22. // defaultHeartbeatFn is the default function used to check the shard liveness
  23. var defaultHeartbeatFn = func(ctx context.Context, client *Client) bool {
  24. err := client.Ping(ctx).Err()
  25. return err == nil || err == pool.ErrPoolTimeout
  26. }
  27. //------------------------------------------------------------------------------
  28. type ConsistentHash interface {
  29. Get(string) string
  30. }
  31. type rendezvousWrapper struct {
  32. *rendezvous.Rendezvous
  33. }
  34. func (w rendezvousWrapper) Get(key string) string {
  35. return w.Lookup(key)
  36. }
  37. func newRendezvous(shards []string) ConsistentHash {
  38. return rendezvousWrapper{rendezvous.New(shards, xxhash.Sum64String)}
  39. }
  40. //------------------------------------------------------------------------------
  41. // RingOptions are used to configure a ring client and should be
  42. // passed to NewRing.
  43. type RingOptions struct {
  44. // Map of name => host:port addresses of ring shards.
  45. Addrs map[string]string
  46. // NewClient creates a shard client with provided options.
  47. NewClient func(opt *Options) *Client
  48. // ClientName will execute the `CLIENT SETNAME ClientName` command for each conn.
  49. ClientName string
  50. // Frequency of executing HeartbeatFn to check shards availability.
  51. // Shard is considered down after 3 subsequent failed checks.
  52. HeartbeatFrequency time.Duration
  53. // A function used to check the shard liveness
  54. // if not set, defaults to defaultHeartbeatFn
  55. HeartbeatFn func(ctx context.Context, client *Client) bool
  56. // NewConsistentHash returns a consistent hash that is used
  57. // to distribute keys across the shards.
  58. //
  59. // See https://medium.com/@dgryski/consistent-hashing-algorithmic-tradeoffs-ef6b8e2fcae8
  60. // for consistent hashing algorithmic tradeoffs.
  61. NewConsistentHash func(shards []string) ConsistentHash
  62. // Following options are copied from Options struct.
  63. Dialer func(ctx context.Context, network, addr string) (net.Conn, error)
  64. OnConnect func(ctx context.Context, cn *Conn) error
  65. Protocol int
  66. Username string
  67. Password string
  68. // CredentialsProvider allows the username and password to be updated
  69. // before reconnecting. It should return the current username and password.
  70. CredentialsProvider func() (username string, password string)
  71. // CredentialsProviderContext is an enhanced parameter of CredentialsProvider,
  72. // done to maintain API compatibility. In the future,
  73. // there might be a merge between CredentialsProviderContext and CredentialsProvider.
  74. // There will be a conflict between them; if CredentialsProviderContext exists, we will ignore CredentialsProvider.
  75. CredentialsProviderContext func(ctx context.Context) (username string, password string, err error)
  76. // StreamingCredentialsProvider is used to retrieve the credentials
  77. // for the connection from an external source. Those credentials may change
  78. // during the connection lifetime. This is useful for managed identity
  79. // scenarios where the credentials are retrieved from an external source.
  80. //
  81. // Currently, this is a placeholder for the future implementation.
  82. StreamingCredentialsProvider auth.StreamingCredentialsProvider
  83. DB int
  84. MaxRetries int
  85. MinRetryBackoff time.Duration
  86. MaxRetryBackoff time.Duration
  87. DialTimeout time.Duration
  88. ReadTimeout time.Duration
  89. WriteTimeout time.Duration
  90. ContextTimeoutEnabled bool
  91. // PoolFIFO uses FIFO mode for each node connection pool GET/PUT (default LIFO).
  92. PoolFIFO bool
  93. PoolSize int
  94. PoolTimeout time.Duration
  95. MinIdleConns int
  96. MaxIdleConns int
  97. MaxActiveConns int
  98. ConnMaxIdleTime time.Duration
  99. ConnMaxLifetime time.Duration
  100. // ReadBufferSize is the size of the bufio.Reader buffer for each connection.
  101. // Larger buffers can improve performance for commands that return large responses.
  102. // Smaller buffers can improve memory usage for larger pools.
  103. //
  104. // default: 32KiB (32768 bytes)
  105. ReadBufferSize int
  106. // WriteBufferSize is the size of the bufio.Writer buffer for each connection.
  107. // Larger buffers can improve performance for large pipelines and commands with many arguments.
  108. // Smaller buffers can improve memory usage for larger pools.
  109. //
  110. // default: 32KiB (32768 bytes)
  111. WriteBufferSize int
  112. TLSConfig *tls.Config
  113. Limiter Limiter
  114. // DisableIndentity - Disable set-lib on connect.
  115. //
  116. // default: false
  117. //
  118. // Deprecated: Use DisableIdentity instead.
  119. DisableIndentity bool
  120. // DisableIdentity is used to disable CLIENT SETINFO command on connect.
  121. //
  122. // default: false
  123. DisableIdentity bool
  124. IdentitySuffix string
  125. UnstableResp3 bool
  126. }
  127. func (opt *RingOptions) init() {
  128. if opt.NewClient == nil {
  129. opt.NewClient = func(opt *Options) *Client {
  130. return NewClient(opt)
  131. }
  132. }
  133. if opt.HeartbeatFrequency == 0 {
  134. opt.HeartbeatFrequency = 500 * time.Millisecond
  135. }
  136. if opt.HeartbeatFn == nil {
  137. opt.HeartbeatFn = defaultHeartbeatFn
  138. }
  139. if opt.NewConsistentHash == nil {
  140. opt.NewConsistentHash = newRendezvous
  141. }
  142. switch opt.MaxRetries {
  143. case -1:
  144. opt.MaxRetries = 0
  145. case 0:
  146. opt.MaxRetries = 3
  147. }
  148. switch opt.MinRetryBackoff {
  149. case -1:
  150. opt.MinRetryBackoff = 0
  151. case 0:
  152. opt.MinRetryBackoff = 8 * time.Millisecond
  153. }
  154. switch opt.MaxRetryBackoff {
  155. case -1:
  156. opt.MaxRetryBackoff = 0
  157. case 0:
  158. opt.MaxRetryBackoff = 512 * time.Millisecond
  159. }
  160. if opt.ReadBufferSize == 0 {
  161. opt.ReadBufferSize = proto.DefaultBufferSize
  162. }
  163. if opt.WriteBufferSize == 0 {
  164. opt.WriteBufferSize = proto.DefaultBufferSize
  165. }
  166. }
  167. func (opt *RingOptions) clientOptions() *Options {
  168. return &Options{
  169. ClientName: opt.ClientName,
  170. Dialer: opt.Dialer,
  171. OnConnect: opt.OnConnect,
  172. Protocol: opt.Protocol,
  173. Username: opt.Username,
  174. Password: opt.Password,
  175. CredentialsProvider: opt.CredentialsProvider,
  176. CredentialsProviderContext: opt.CredentialsProviderContext,
  177. StreamingCredentialsProvider: opt.StreamingCredentialsProvider,
  178. DB: opt.DB,
  179. MaxRetries: -1,
  180. DialTimeout: opt.DialTimeout,
  181. ReadTimeout: opt.ReadTimeout,
  182. WriteTimeout: opt.WriteTimeout,
  183. ContextTimeoutEnabled: opt.ContextTimeoutEnabled,
  184. PoolFIFO: opt.PoolFIFO,
  185. PoolSize: opt.PoolSize,
  186. PoolTimeout: opt.PoolTimeout,
  187. MinIdleConns: opt.MinIdleConns,
  188. MaxIdleConns: opt.MaxIdleConns,
  189. MaxActiveConns: opt.MaxActiveConns,
  190. ConnMaxIdleTime: opt.ConnMaxIdleTime,
  191. ConnMaxLifetime: opt.ConnMaxLifetime,
  192. ReadBufferSize: opt.ReadBufferSize,
  193. WriteBufferSize: opt.WriteBufferSize,
  194. TLSConfig: opt.TLSConfig,
  195. Limiter: opt.Limiter,
  196. DisableIdentity: opt.DisableIdentity,
  197. DisableIndentity: opt.DisableIndentity,
  198. IdentitySuffix: opt.IdentitySuffix,
  199. UnstableResp3: opt.UnstableResp3,
  200. }
  201. }
  202. //------------------------------------------------------------------------------
  203. type ringShard struct {
  204. Client *Client
  205. down int32
  206. addr string
  207. }
  208. func newRingShard(opt *RingOptions, addr string) *ringShard {
  209. clopt := opt.clientOptions()
  210. clopt.Addr = addr
  211. return &ringShard{
  212. Client: opt.NewClient(clopt),
  213. addr: addr,
  214. }
  215. }
  216. func (shard *ringShard) String() string {
  217. var state string
  218. if shard.IsUp() {
  219. state = "up"
  220. } else {
  221. state = "down"
  222. }
  223. return fmt.Sprintf("%s is %s", shard.Client, state)
  224. }
  225. func (shard *ringShard) IsDown() bool {
  226. const threshold = 3
  227. return atomic.LoadInt32(&shard.down) >= threshold
  228. }
  229. func (shard *ringShard) IsUp() bool {
  230. return !shard.IsDown()
  231. }
  232. // Vote votes to set shard state and returns true if state was changed.
  233. func (shard *ringShard) Vote(up bool) bool {
  234. if up {
  235. changed := shard.IsDown()
  236. atomic.StoreInt32(&shard.down, 0)
  237. return changed
  238. }
  239. if shard.IsDown() {
  240. return false
  241. }
  242. atomic.AddInt32(&shard.down, 1)
  243. return shard.IsDown()
  244. }
  245. //------------------------------------------------------------------------------
  246. type ringSharding struct {
  247. opt *RingOptions
  248. mu sync.RWMutex
  249. shards *ringShards
  250. closed bool
  251. hash ConsistentHash
  252. numShard int
  253. onNewNode []func(rdb *Client)
  254. // ensures exclusive access to SetAddrs so there is no need
  255. // to hold mu for the duration of potentially long shard creation
  256. setAddrsMu sync.Mutex
  257. }
  258. type ringShards struct {
  259. m map[string]*ringShard
  260. list []*ringShard
  261. }
  262. func newRingSharding(opt *RingOptions) *ringSharding {
  263. c := &ringSharding{
  264. opt: opt,
  265. }
  266. c.SetAddrs(opt.Addrs)
  267. return c
  268. }
  269. func (c *ringSharding) OnNewNode(fn func(rdb *Client)) {
  270. c.mu.Lock()
  271. c.onNewNode = append(c.onNewNode, fn)
  272. c.mu.Unlock()
  273. }
  274. // SetAddrs replaces the shards in use, such that you can increase and
  275. // decrease number of shards, that you use. It will reuse shards that
  276. // existed before and close the ones that will not be used anymore.
  277. func (c *ringSharding) SetAddrs(addrs map[string]string) {
  278. c.setAddrsMu.Lock()
  279. defer c.setAddrsMu.Unlock()
  280. cleanup := func(shards map[string]*ringShard) {
  281. for addr, shard := range shards {
  282. if err := shard.Client.Close(); err != nil {
  283. internal.Logger.Printf(context.Background(), "shard.Close %s failed: %s", addr, err)
  284. }
  285. }
  286. }
  287. c.mu.RLock()
  288. if c.closed {
  289. c.mu.RUnlock()
  290. return
  291. }
  292. existing := c.shards
  293. c.mu.RUnlock()
  294. shards, created, unused := c.newRingShards(addrs, existing)
  295. c.mu.Lock()
  296. if c.closed {
  297. cleanup(created)
  298. c.mu.Unlock()
  299. return
  300. }
  301. c.shards = shards
  302. c.rebalanceLocked()
  303. c.mu.Unlock()
  304. cleanup(unused)
  305. }
  306. func (c *ringSharding) newRingShards(
  307. addrs map[string]string, existing *ringShards,
  308. ) (shards *ringShards, created, unused map[string]*ringShard) {
  309. shards = &ringShards{m: make(map[string]*ringShard, len(addrs))}
  310. created = make(map[string]*ringShard) // indexed by addr
  311. unused = make(map[string]*ringShard) // indexed by addr
  312. if existing != nil {
  313. for _, shard := range existing.list {
  314. unused[shard.addr] = shard
  315. }
  316. }
  317. for name, addr := range addrs {
  318. if shard, ok := unused[addr]; ok {
  319. shards.m[name] = shard
  320. delete(unused, addr)
  321. } else {
  322. shard := newRingShard(c.opt, addr)
  323. shards.m[name] = shard
  324. created[addr] = shard
  325. for _, fn := range c.onNewNode {
  326. fn(shard.Client)
  327. }
  328. }
  329. }
  330. for _, shard := range shards.m {
  331. shards.list = append(shards.list, shard)
  332. }
  333. return
  334. }
  335. // Warning: External exposure of `c.shards.list` may cause data races.
  336. // So keep internal or implement deep copy if exposed.
  337. func (c *ringSharding) List() []*ringShard {
  338. c.mu.RLock()
  339. defer c.mu.RUnlock()
  340. if c.closed {
  341. return nil
  342. }
  343. return c.shards.list
  344. }
  345. func (c *ringSharding) Hash(key string) string {
  346. key = hashtag.Key(key)
  347. var hash string
  348. c.mu.RLock()
  349. defer c.mu.RUnlock()
  350. if c.numShard > 0 {
  351. hash = c.hash.Get(key)
  352. }
  353. return hash
  354. }
  355. func (c *ringSharding) GetByKey(key string) (*ringShard, error) {
  356. key = hashtag.Key(key)
  357. c.mu.RLock()
  358. defer c.mu.RUnlock()
  359. if c.closed {
  360. return nil, pool.ErrClosed
  361. }
  362. if c.numShard == 0 {
  363. return nil, errRingShardsDown
  364. }
  365. shardName := c.hash.Get(key)
  366. if shardName == "" {
  367. return nil, errRingShardsDown
  368. }
  369. return c.shards.m[shardName], nil
  370. }
  371. func (c *ringSharding) GetByName(shardName string) (*ringShard, error) {
  372. if shardName == "" {
  373. return c.Random()
  374. }
  375. c.mu.RLock()
  376. defer c.mu.RUnlock()
  377. shard, ok := c.shards.m[shardName]
  378. if !ok {
  379. return nil, errors.New("redis: the shard is not in the ring")
  380. }
  381. return shard, nil
  382. }
  383. func (c *ringSharding) Random() (*ringShard, error) {
  384. return c.GetByKey(strconv.Itoa(rand.Int()))
  385. }
  386. // Heartbeat monitors state of each shard in the ring.
  387. func (c *ringSharding) Heartbeat(ctx context.Context, frequency time.Duration) {
  388. ticker := time.NewTicker(frequency)
  389. defer ticker.Stop()
  390. for {
  391. select {
  392. case <-ticker.C:
  393. var rebalance bool
  394. // note: `c.List()` return a shadow copy of `[]*ringShard`.
  395. for _, shard := range c.List() {
  396. isUp := c.opt.HeartbeatFn(ctx, shard.Client)
  397. if shard.Vote(isUp) {
  398. internal.Logger.Printf(ctx, "ring shard state changed: %s", shard)
  399. rebalance = true
  400. }
  401. }
  402. if rebalance {
  403. c.mu.Lock()
  404. c.rebalanceLocked()
  405. c.mu.Unlock()
  406. }
  407. case <-ctx.Done():
  408. return
  409. }
  410. }
  411. }
  412. // rebalanceLocked removes dead shards from the Ring.
  413. // Requires c.mu locked.
  414. func (c *ringSharding) rebalanceLocked() {
  415. if c.closed {
  416. return
  417. }
  418. if c.shards == nil {
  419. return
  420. }
  421. liveShards := make([]string, 0, len(c.shards.m))
  422. for name, shard := range c.shards.m {
  423. if shard.IsUp() {
  424. liveShards = append(liveShards, name)
  425. }
  426. }
  427. c.hash = c.opt.NewConsistentHash(liveShards)
  428. c.numShard = len(liveShards)
  429. }
  430. func (c *ringSharding) Len() int {
  431. c.mu.RLock()
  432. defer c.mu.RUnlock()
  433. return c.numShard
  434. }
  435. func (c *ringSharding) Close() error {
  436. c.mu.Lock()
  437. defer c.mu.Unlock()
  438. if c.closed {
  439. return nil
  440. }
  441. c.closed = true
  442. var firstErr error
  443. for _, shard := range c.shards.list {
  444. if err := shard.Client.Close(); err != nil && firstErr == nil {
  445. firstErr = err
  446. }
  447. }
  448. c.hash = nil
  449. c.shards = nil
  450. c.numShard = 0
  451. return firstErr
  452. }
  453. //------------------------------------------------------------------------------
  454. // Ring is a Redis client that uses consistent hashing to distribute
  455. // keys across multiple Redis servers (shards). It's safe for
  456. // concurrent use by multiple goroutines.
  457. //
  458. // Ring monitors the state of each shard and removes dead shards from
  459. // the ring. When a shard comes online it is added back to the ring. This
  460. // gives you maximum availability and partition tolerance, but no
  461. // consistency between different shards or even clients. Each client
  462. // uses shards that are available to the client and does not do any
  463. // coordination when shard state is changed.
  464. //
  465. // Ring should be used when you need multiple Redis servers for caching
  466. // and can tolerate losing data when one of the servers dies.
  467. // Otherwise you should use Redis Cluster.
  468. type Ring struct {
  469. cmdable
  470. hooksMixin
  471. opt *RingOptions
  472. sharding *ringSharding
  473. cmdsInfoCache *cmdsInfoCache
  474. heartbeatCancelFn context.CancelFunc
  475. }
  476. func NewRing(opt *RingOptions) *Ring {
  477. if opt == nil {
  478. panic("redis: NewRing nil options")
  479. }
  480. opt.init()
  481. hbCtx, hbCancel := context.WithCancel(context.Background())
  482. ring := Ring{
  483. opt: opt,
  484. sharding: newRingSharding(opt),
  485. heartbeatCancelFn: hbCancel,
  486. }
  487. ring.cmdsInfoCache = newCmdsInfoCache(ring.cmdsInfo)
  488. ring.cmdable = ring.Process
  489. ring.initHooks(hooks{
  490. process: ring.process,
  491. pipeline: func(ctx context.Context, cmds []Cmder) error {
  492. return ring.generalProcessPipeline(ctx, cmds, false)
  493. },
  494. txPipeline: func(ctx context.Context, cmds []Cmder) error {
  495. return ring.generalProcessPipeline(ctx, cmds, true)
  496. },
  497. })
  498. go ring.sharding.Heartbeat(hbCtx, opt.HeartbeatFrequency)
  499. return &ring
  500. }
  501. func (c *Ring) SetAddrs(addrs map[string]string) {
  502. c.sharding.SetAddrs(addrs)
  503. }
  504. func (c *Ring) Process(ctx context.Context, cmd Cmder) error {
  505. err := c.processHook(ctx, cmd)
  506. cmd.SetErr(err)
  507. return err
  508. }
  509. // Options returns read-only Options that were used to create the client.
  510. func (c *Ring) Options() *RingOptions {
  511. return c.opt
  512. }
  513. func (c *Ring) retryBackoff(attempt int) time.Duration {
  514. return internal.RetryBackoff(attempt, c.opt.MinRetryBackoff, c.opt.MaxRetryBackoff)
  515. }
  516. // PoolStats returns accumulated connection pool stats.
  517. func (c *Ring) PoolStats() *PoolStats {
  518. // note: `c.List()` return a shadow copy of `[]*ringShard`.
  519. shards := c.sharding.List()
  520. var acc PoolStats
  521. for _, shard := range shards {
  522. s := shard.Client.connPool.Stats()
  523. acc.Hits += s.Hits
  524. acc.Misses += s.Misses
  525. acc.Timeouts += s.Timeouts
  526. acc.TotalConns += s.TotalConns
  527. acc.IdleConns += s.IdleConns
  528. }
  529. return &acc
  530. }
  531. // Len returns the current number of shards in the ring.
  532. func (c *Ring) Len() int {
  533. return c.sharding.Len()
  534. }
  535. // Subscribe subscribes the client to the specified channels.
  536. func (c *Ring) Subscribe(ctx context.Context, channels ...string) *PubSub {
  537. if len(channels) == 0 {
  538. panic("at least one channel is required")
  539. }
  540. shard, err := c.sharding.GetByKey(channels[0])
  541. if err != nil {
  542. // TODO: return PubSub with sticky error
  543. panic(err)
  544. }
  545. return shard.Client.Subscribe(ctx, channels...)
  546. }
  547. // PSubscribe subscribes the client to the given patterns.
  548. func (c *Ring) PSubscribe(ctx context.Context, channels ...string) *PubSub {
  549. if len(channels) == 0 {
  550. panic("at least one channel is required")
  551. }
  552. shard, err := c.sharding.GetByKey(channels[0])
  553. if err != nil {
  554. // TODO: return PubSub with sticky error
  555. panic(err)
  556. }
  557. return shard.Client.PSubscribe(ctx, channels...)
  558. }
  559. // SSubscribe Subscribes the client to the specified shard channels.
  560. func (c *Ring) SSubscribe(ctx context.Context, channels ...string) *PubSub {
  561. if len(channels) == 0 {
  562. panic("at least one channel is required")
  563. }
  564. shard, err := c.sharding.GetByKey(channels[0])
  565. if err != nil {
  566. // TODO: return PubSub with sticky error
  567. panic(err)
  568. }
  569. return shard.Client.SSubscribe(ctx, channels...)
  570. }
  571. func (c *Ring) OnNewNode(fn func(rdb *Client)) {
  572. c.sharding.OnNewNode(fn)
  573. }
  574. // ForEachShard concurrently calls the fn on each live shard in the ring.
  575. // It returns the first error if any.
  576. func (c *Ring) ForEachShard(
  577. ctx context.Context,
  578. fn func(ctx context.Context, client *Client) error,
  579. ) error {
  580. // note: `c.List()` return a shadow copy of `[]*ringShard`.
  581. shards := c.sharding.List()
  582. var wg sync.WaitGroup
  583. errCh := make(chan error, 1)
  584. for _, shard := range shards {
  585. if shard.IsDown() {
  586. continue
  587. }
  588. wg.Add(1)
  589. go func(shard *ringShard) {
  590. defer wg.Done()
  591. err := fn(ctx, shard.Client)
  592. if err != nil {
  593. select {
  594. case errCh <- err:
  595. default:
  596. }
  597. }
  598. }(shard)
  599. }
  600. wg.Wait()
  601. select {
  602. case err := <-errCh:
  603. return err
  604. default:
  605. return nil
  606. }
  607. }
  608. func (c *Ring) cmdsInfo(ctx context.Context) (map[string]*CommandInfo, error) {
  609. // note: `c.List()` return a shadow copy of `[]*ringShard`.
  610. shards := c.sharding.List()
  611. var firstErr error
  612. for _, shard := range shards {
  613. cmdsInfo, err := shard.Client.Command(ctx).Result()
  614. if err == nil {
  615. return cmdsInfo, nil
  616. }
  617. if firstErr == nil {
  618. firstErr = err
  619. }
  620. }
  621. if firstErr == nil {
  622. return nil, errRingShardsDown
  623. }
  624. return nil, firstErr
  625. }
  626. func (c *Ring) cmdShard(cmd Cmder) (*ringShard, error) {
  627. pos := cmdFirstKeyPos(cmd)
  628. if pos == 0 {
  629. return c.sharding.Random()
  630. }
  631. firstKey := cmd.stringArg(pos)
  632. return c.sharding.GetByKey(firstKey)
  633. }
  634. func (c *Ring) process(ctx context.Context, cmd Cmder) error {
  635. var lastErr error
  636. for attempt := 0; attempt <= c.opt.MaxRetries; attempt++ {
  637. if attempt > 0 {
  638. if err := internal.Sleep(ctx, c.retryBackoff(attempt)); err != nil {
  639. return err
  640. }
  641. }
  642. shard, err := c.cmdShard(cmd)
  643. if err != nil {
  644. return err
  645. }
  646. lastErr = shard.Client.Process(ctx, cmd)
  647. if lastErr == nil || !shouldRetry(lastErr, cmd.readTimeout() == nil) {
  648. return lastErr
  649. }
  650. }
  651. return lastErr
  652. }
  653. func (c *Ring) Pipelined(ctx context.Context, fn func(Pipeliner) error) ([]Cmder, error) {
  654. return c.Pipeline().Pipelined(ctx, fn)
  655. }
  656. func (c *Ring) Pipeline() Pipeliner {
  657. pipe := Pipeline{
  658. exec: pipelineExecer(c.processPipelineHook),
  659. }
  660. pipe.init()
  661. return &pipe
  662. }
  663. func (c *Ring) TxPipelined(ctx context.Context, fn func(Pipeliner) error) ([]Cmder, error) {
  664. return c.TxPipeline().Pipelined(ctx, fn)
  665. }
  666. func (c *Ring) TxPipeline() Pipeliner {
  667. pipe := Pipeline{
  668. exec: func(ctx context.Context, cmds []Cmder) error {
  669. cmds = wrapMultiExec(ctx, cmds)
  670. return c.processTxPipelineHook(ctx, cmds)
  671. },
  672. }
  673. pipe.init()
  674. return &pipe
  675. }
  676. func (c *Ring) generalProcessPipeline(
  677. ctx context.Context, cmds []Cmder, tx bool,
  678. ) error {
  679. if tx {
  680. // Trim multi .. exec.
  681. cmds = cmds[1 : len(cmds)-1]
  682. }
  683. cmdsMap := make(map[string][]Cmder)
  684. for _, cmd := range cmds {
  685. hash := cmd.stringArg(cmdFirstKeyPos(cmd))
  686. if hash != "" {
  687. hash = c.sharding.Hash(hash)
  688. }
  689. cmdsMap[hash] = append(cmdsMap[hash], cmd)
  690. }
  691. var wg sync.WaitGroup
  692. errs := make(chan error, len(cmdsMap))
  693. for hash, cmds := range cmdsMap {
  694. wg.Add(1)
  695. go func(hash string, cmds []Cmder) {
  696. defer wg.Done()
  697. // TODO: retry?
  698. shard, err := c.sharding.GetByName(hash)
  699. if err != nil {
  700. setCmdsErr(cmds, err)
  701. return
  702. }
  703. hook := shard.Client.processPipelineHook
  704. if tx {
  705. cmds = wrapMultiExec(ctx, cmds)
  706. hook = shard.Client.processTxPipelineHook
  707. }
  708. if err = hook(ctx, cmds); err != nil {
  709. errs <- err
  710. }
  711. }(hash, cmds)
  712. }
  713. wg.Wait()
  714. close(errs)
  715. if err := <-errs; err != nil {
  716. return err
  717. }
  718. return cmdsFirstErr(cmds)
  719. }
  720. func (c *Ring) Watch(ctx context.Context, fn func(*Tx) error, keys ...string) error {
  721. if len(keys) == 0 {
  722. return fmt.Errorf("redis: Watch requires at least one key")
  723. }
  724. var shards []*ringShard
  725. for _, key := range keys {
  726. if key != "" {
  727. shard, err := c.sharding.GetByKey(key)
  728. if err != nil {
  729. return err
  730. }
  731. shards = append(shards, shard)
  732. }
  733. }
  734. if len(shards) == 0 {
  735. return fmt.Errorf("redis: Watch requires at least one shard")
  736. }
  737. if len(shards) > 1 {
  738. for _, shard := range shards[1:] {
  739. if shard.Client != shards[0].Client {
  740. err := fmt.Errorf("redis: Watch requires all keys to be in the same shard")
  741. return err
  742. }
  743. }
  744. }
  745. return shards[0].Client.Watch(ctx, fn, keys...)
  746. }
  747. // Close closes the ring client, releasing any open resources.
  748. //
  749. // It is rare to Close a Ring, as the Ring is meant to be long-lived
  750. // and shared between many goroutines.
  751. func (c *Ring) Close() error {
  752. c.heartbeatCancelFn()
  753. return c.sharding.Close()
  754. }
  755. // GetShardClients returns a list of all shard clients in the ring.
  756. // This can be used to create dedicated connections (e.g., PubSub) for each shard.
  757. func (c *Ring) GetShardClients() []*Client {
  758. shards := c.sharding.List()
  759. clients := make([]*Client, 0, len(shards))
  760. for _, shard := range shards {
  761. if shard.IsUp() {
  762. clients = append(clients, shard.Client)
  763. }
  764. }
  765. return clients
  766. }
  767. // GetShardClientForKey returns the shard client that would handle the given key.
  768. // This can be used to determine which shard a particular key/channel would be routed to.
  769. func (c *Ring) GetShardClientForKey(key string) (*Client, error) {
  770. shard, err := c.sharding.GetByKey(key)
  771. if err != nil {
  772. return nil, err
  773. }
  774. return shard.Client, nil
  775. }