Nelze vybrat více než 25 témat Téma musí začínat písmenem nebo číslem, může obsahovat pomlčky („-“) a může být dlouhé až 35 znaků.
 
 
 
 

642 řádky
15 KiB

  1. // Copyright 2019+ Klaus Post. All rights reserved.
  2. // License information can be found in the LICENSE file.
  3. // Based on work by Yann Collet, released under BSD License.
  4. package zstd
  5. import (
  6. "crypto/rand"
  7. "fmt"
  8. "io"
  9. rdebug "runtime/debug"
  10. "sync"
  11. "github.com/klauspost/compress/zstd/internal/xxhash"
  12. )
  13. // Encoder provides encoding to Zstandard.
  14. // An Encoder can be used for either compressing a stream via the
  15. // io.WriteCloser interface supported by the Encoder or as multiple independent
  16. // tasks via the EncodeAll function.
  17. // Smaller encodes are encouraged to use the EncodeAll function.
  18. // Use NewWriter to create a new instance.
  19. type Encoder struct {
  20. o encoderOptions
  21. encoders chan encoder
  22. state encoderState
  23. init sync.Once
  24. }
  25. type encoder interface {
  26. Encode(blk *blockEnc, src []byte)
  27. EncodeNoHist(blk *blockEnc, src []byte)
  28. Block() *blockEnc
  29. CRC() *xxhash.Digest
  30. AppendCRC([]byte) []byte
  31. WindowSize(size int64) int32
  32. UseBlock(*blockEnc)
  33. Reset(d *dict, singleBlock bool)
  34. }
  35. type encoderState struct {
  36. w io.Writer
  37. filling []byte
  38. current []byte
  39. previous []byte
  40. encoder encoder
  41. writing *blockEnc
  42. err error
  43. writeErr error
  44. nWritten int64
  45. nInput int64
  46. frameContentSize int64
  47. headerWritten bool
  48. eofWritten bool
  49. fullFrameWritten bool
  50. // This waitgroup indicates an encode is running.
  51. wg sync.WaitGroup
  52. // This waitgroup indicates we have a block encoding/writing.
  53. wWg sync.WaitGroup
  54. }
  55. // NewWriter will create a new Zstandard encoder.
  56. // If the encoder will be used for encoding blocks a nil writer can be used.
  57. func NewWriter(w io.Writer, opts ...EOption) (*Encoder, error) {
  58. initPredefined()
  59. var e Encoder
  60. e.o.setDefault()
  61. for _, o := range opts {
  62. err := o(&e.o)
  63. if err != nil {
  64. return nil, err
  65. }
  66. }
  67. if w != nil {
  68. e.Reset(w)
  69. }
  70. return &e, nil
  71. }
  72. func (e *Encoder) initialize() {
  73. if e.o.concurrent == 0 {
  74. e.o.setDefault()
  75. }
  76. e.encoders = make(chan encoder, e.o.concurrent)
  77. for i := 0; i < e.o.concurrent; i++ {
  78. enc := e.o.encoder()
  79. e.encoders <- enc
  80. }
  81. }
  82. // Reset will re-initialize the writer and new writes will encode to the supplied writer
  83. // as a new, independent stream.
  84. func (e *Encoder) Reset(w io.Writer) {
  85. s := &e.state
  86. s.wg.Wait()
  87. s.wWg.Wait()
  88. if cap(s.filling) == 0 {
  89. s.filling = make([]byte, 0, e.o.blockSize)
  90. }
  91. if e.o.concurrent > 1 {
  92. if cap(s.current) == 0 {
  93. s.current = make([]byte, 0, e.o.blockSize)
  94. }
  95. if cap(s.previous) == 0 {
  96. s.previous = make([]byte, 0, e.o.blockSize)
  97. }
  98. s.current = s.current[:0]
  99. s.previous = s.previous[:0]
  100. if s.writing == nil {
  101. s.writing = &blockEnc{lowMem: e.o.lowMem}
  102. s.writing.init()
  103. }
  104. s.writing.initNewEncode()
  105. }
  106. if s.encoder == nil {
  107. s.encoder = e.o.encoder()
  108. }
  109. s.filling = s.filling[:0]
  110. s.encoder.Reset(e.o.dict, false)
  111. s.headerWritten = false
  112. s.eofWritten = false
  113. s.fullFrameWritten = false
  114. s.w = w
  115. s.err = nil
  116. s.nWritten = 0
  117. s.nInput = 0
  118. s.writeErr = nil
  119. s.frameContentSize = 0
  120. }
  121. // ResetContentSize will reset and set a content size for the next stream.
  122. // If the bytes written does not match the size given an error will be returned
  123. // when calling Close().
  124. // This is removed when Reset is called.
  125. // Sizes <= 0 results in no content size set.
  126. func (e *Encoder) ResetContentSize(w io.Writer, size int64) {
  127. e.Reset(w)
  128. if size >= 0 {
  129. e.state.frameContentSize = size
  130. }
  131. }
  132. // Write data to the encoder.
  133. // Input data will be buffered and as the buffer fills up
  134. // content will be compressed and written to the output.
  135. // When done writing, use Close to flush the remaining output
  136. // and write CRC if requested.
  137. func (e *Encoder) Write(p []byte) (n int, err error) {
  138. s := &e.state
  139. for len(p) > 0 {
  140. if len(p)+len(s.filling) < e.o.blockSize {
  141. if e.o.crc {
  142. _, _ = s.encoder.CRC().Write(p)
  143. }
  144. s.filling = append(s.filling, p...)
  145. return n + len(p), nil
  146. }
  147. add := p
  148. if len(p)+len(s.filling) > e.o.blockSize {
  149. add = add[:e.o.blockSize-len(s.filling)]
  150. }
  151. if e.o.crc {
  152. _, _ = s.encoder.CRC().Write(add)
  153. }
  154. s.filling = append(s.filling, add...)
  155. p = p[len(add):]
  156. n += len(add)
  157. if len(s.filling) < e.o.blockSize {
  158. return n, nil
  159. }
  160. err := e.nextBlock(false)
  161. if err != nil {
  162. return n, err
  163. }
  164. if debugAsserts && len(s.filling) > 0 {
  165. panic(len(s.filling))
  166. }
  167. }
  168. return n, nil
  169. }
  170. // nextBlock will synchronize and start compressing input in e.state.filling.
  171. // If an error has occurred during encoding it will be returned.
  172. func (e *Encoder) nextBlock(final bool) error {
  173. s := &e.state
  174. // Wait for current block.
  175. s.wg.Wait()
  176. if s.err != nil {
  177. return s.err
  178. }
  179. if len(s.filling) > e.o.blockSize {
  180. return fmt.Errorf("block > maxStoreBlockSize")
  181. }
  182. if !s.headerWritten {
  183. // If we have a single block encode, do a sync compression.
  184. if final && len(s.filling) == 0 && !e.o.fullZero {
  185. s.headerWritten = true
  186. s.fullFrameWritten = true
  187. s.eofWritten = true
  188. return nil
  189. }
  190. if final && len(s.filling) > 0 {
  191. s.current = e.EncodeAll(s.filling, s.current[:0])
  192. var n2 int
  193. n2, s.err = s.w.Write(s.current)
  194. if s.err != nil {
  195. return s.err
  196. }
  197. s.nWritten += int64(n2)
  198. s.nInput += int64(len(s.filling))
  199. s.current = s.current[:0]
  200. s.filling = s.filling[:0]
  201. s.headerWritten = true
  202. s.fullFrameWritten = true
  203. s.eofWritten = true
  204. return nil
  205. }
  206. var tmp [maxHeaderSize]byte
  207. fh := frameHeader{
  208. ContentSize: uint64(s.frameContentSize),
  209. WindowSize: uint32(s.encoder.WindowSize(s.frameContentSize)),
  210. SingleSegment: false,
  211. Checksum: e.o.crc,
  212. DictID: e.o.dict.ID(),
  213. }
  214. dst, err := fh.appendTo(tmp[:0])
  215. if err != nil {
  216. return err
  217. }
  218. s.headerWritten = true
  219. s.wWg.Wait()
  220. var n2 int
  221. n2, s.err = s.w.Write(dst)
  222. if s.err != nil {
  223. return s.err
  224. }
  225. s.nWritten += int64(n2)
  226. }
  227. if s.eofWritten {
  228. // Ensure we only write it once.
  229. final = false
  230. }
  231. if len(s.filling) == 0 {
  232. // Final block, but no data.
  233. if final {
  234. enc := s.encoder
  235. blk := enc.Block()
  236. blk.reset(nil)
  237. blk.last = true
  238. blk.encodeRaw(nil)
  239. s.wWg.Wait()
  240. _, s.err = s.w.Write(blk.output)
  241. s.nWritten += int64(len(blk.output))
  242. s.eofWritten = true
  243. }
  244. return s.err
  245. }
  246. // SYNC:
  247. if e.o.concurrent == 1 {
  248. src := s.filling
  249. s.nInput += int64(len(s.filling))
  250. if debugEncoder {
  251. println("Adding sync block,", len(src), "bytes, final:", final)
  252. }
  253. enc := s.encoder
  254. blk := enc.Block()
  255. blk.reset(nil)
  256. enc.Encode(blk, src)
  257. blk.last = final
  258. if final {
  259. s.eofWritten = true
  260. }
  261. err := errIncompressible
  262. // If we got the exact same number of literals as input,
  263. // assume the literals cannot be compressed.
  264. if len(src) != len(blk.literals) || len(src) != e.o.blockSize {
  265. err = blk.encode(src, e.o.noEntropy, !e.o.allLitEntropy)
  266. }
  267. switch err {
  268. case errIncompressible:
  269. if debugEncoder {
  270. println("Storing incompressible block as raw")
  271. }
  272. blk.encodeRaw(src)
  273. // In fast mode, we do not transfer offsets, so we don't have to deal with changing the.
  274. case nil:
  275. default:
  276. s.err = err
  277. return err
  278. }
  279. _, s.err = s.w.Write(blk.output)
  280. s.nWritten += int64(len(blk.output))
  281. s.filling = s.filling[:0]
  282. return s.err
  283. }
  284. // Move blocks forward.
  285. s.filling, s.current, s.previous = s.previous[:0], s.filling, s.current
  286. s.nInput += int64(len(s.current))
  287. s.wg.Add(1)
  288. go func(src []byte) {
  289. if debugEncoder {
  290. println("Adding block,", len(src), "bytes, final:", final)
  291. }
  292. defer func() {
  293. if r := recover(); r != nil {
  294. s.err = fmt.Errorf("panic while encoding: %v", r)
  295. rdebug.PrintStack()
  296. }
  297. s.wg.Done()
  298. }()
  299. enc := s.encoder
  300. blk := enc.Block()
  301. enc.Encode(blk, src)
  302. blk.last = final
  303. if final {
  304. s.eofWritten = true
  305. }
  306. // Wait for pending writes.
  307. s.wWg.Wait()
  308. if s.writeErr != nil {
  309. s.err = s.writeErr
  310. return
  311. }
  312. // Transfer encoders from previous write block.
  313. blk.swapEncoders(s.writing)
  314. // Transfer recent offsets to next.
  315. enc.UseBlock(s.writing)
  316. s.writing = blk
  317. s.wWg.Add(1)
  318. go func() {
  319. defer func() {
  320. if r := recover(); r != nil {
  321. s.writeErr = fmt.Errorf("panic while encoding/writing: %v", r)
  322. rdebug.PrintStack()
  323. }
  324. s.wWg.Done()
  325. }()
  326. err := errIncompressible
  327. // If we got the exact same number of literals as input,
  328. // assume the literals cannot be compressed.
  329. if len(src) != len(blk.literals) || len(src) != e.o.blockSize {
  330. err = blk.encode(src, e.o.noEntropy, !e.o.allLitEntropy)
  331. }
  332. switch err {
  333. case errIncompressible:
  334. if debugEncoder {
  335. println("Storing incompressible block as raw")
  336. }
  337. blk.encodeRaw(src)
  338. // In fast mode, we do not transfer offsets, so we don't have to deal with changing the.
  339. case nil:
  340. default:
  341. s.writeErr = err
  342. return
  343. }
  344. _, s.writeErr = s.w.Write(blk.output)
  345. s.nWritten += int64(len(blk.output))
  346. }()
  347. }(s.current)
  348. return nil
  349. }
  350. // ReadFrom reads data from r until EOF or error.
  351. // The return value n is the number of bytes read.
  352. // Any error except io.EOF encountered during the read is also returned.
  353. //
  354. // The Copy function uses ReaderFrom if available.
  355. func (e *Encoder) ReadFrom(r io.Reader) (n int64, err error) {
  356. if debugEncoder {
  357. println("Using ReadFrom")
  358. }
  359. // Flush any current writes.
  360. if len(e.state.filling) > 0 {
  361. if err := e.nextBlock(false); err != nil {
  362. return 0, err
  363. }
  364. }
  365. e.state.filling = e.state.filling[:e.o.blockSize]
  366. src := e.state.filling
  367. for {
  368. n2, err := r.Read(src)
  369. if e.o.crc {
  370. _, _ = e.state.encoder.CRC().Write(src[:n2])
  371. }
  372. // src is now the unfilled part...
  373. src = src[n2:]
  374. n += int64(n2)
  375. switch err {
  376. case io.EOF:
  377. e.state.filling = e.state.filling[:len(e.state.filling)-len(src)]
  378. if debugEncoder {
  379. println("ReadFrom: got EOF final block:", len(e.state.filling))
  380. }
  381. return n, nil
  382. case nil:
  383. default:
  384. if debugEncoder {
  385. println("ReadFrom: got error:", err)
  386. }
  387. e.state.err = err
  388. return n, err
  389. }
  390. if len(src) > 0 {
  391. if debugEncoder {
  392. println("ReadFrom: got space left in source:", len(src))
  393. }
  394. continue
  395. }
  396. err = e.nextBlock(false)
  397. if err != nil {
  398. return n, err
  399. }
  400. e.state.filling = e.state.filling[:e.o.blockSize]
  401. src = e.state.filling
  402. }
  403. }
  404. // Flush will send the currently written data to output
  405. // and block until everything has been written.
  406. // This should only be used on rare occasions where pushing the currently queued data is critical.
  407. func (e *Encoder) Flush() error {
  408. s := &e.state
  409. if len(s.filling) > 0 {
  410. err := e.nextBlock(false)
  411. if err != nil {
  412. return err
  413. }
  414. }
  415. s.wg.Wait()
  416. s.wWg.Wait()
  417. if s.err != nil {
  418. return s.err
  419. }
  420. return s.writeErr
  421. }
  422. // Close will flush the final output and close the stream.
  423. // The function will block until everything has been written.
  424. // The Encoder can still be re-used after calling this.
  425. func (e *Encoder) Close() error {
  426. s := &e.state
  427. if s.encoder == nil {
  428. return nil
  429. }
  430. err := e.nextBlock(true)
  431. if err != nil {
  432. return err
  433. }
  434. if s.frameContentSize > 0 {
  435. if s.nInput != s.frameContentSize {
  436. return fmt.Errorf("frame content size %d given, but %d bytes was written", s.frameContentSize, s.nInput)
  437. }
  438. }
  439. if e.state.fullFrameWritten {
  440. return s.err
  441. }
  442. s.wg.Wait()
  443. s.wWg.Wait()
  444. if s.err != nil {
  445. return s.err
  446. }
  447. if s.writeErr != nil {
  448. return s.writeErr
  449. }
  450. // Write CRC
  451. if e.o.crc && s.err == nil {
  452. // heap alloc.
  453. var tmp [4]byte
  454. _, s.err = s.w.Write(s.encoder.AppendCRC(tmp[:0]))
  455. s.nWritten += 4
  456. }
  457. // Add padding with content from crypto/rand.Reader
  458. if s.err == nil && e.o.pad > 0 {
  459. add := calcSkippableFrame(s.nWritten, int64(e.o.pad))
  460. frame, err := skippableFrame(s.filling[:0], add, rand.Reader)
  461. if err != nil {
  462. return err
  463. }
  464. _, s.err = s.w.Write(frame)
  465. }
  466. return s.err
  467. }
  468. // EncodeAll will encode all input in src and append it to dst.
  469. // This function can be called concurrently, but each call will only run on a single goroutine.
  470. // If empty input is given, nothing is returned, unless WithZeroFrames is specified.
  471. // Encoded blocks can be concatenated and the result will be the combined input stream.
  472. // Data compressed with EncodeAll can be decoded with the Decoder,
  473. // using either a stream or DecodeAll.
  474. func (e *Encoder) EncodeAll(src, dst []byte) []byte {
  475. if len(src) == 0 {
  476. if e.o.fullZero {
  477. // Add frame header.
  478. fh := frameHeader{
  479. ContentSize: 0,
  480. WindowSize: MinWindowSize,
  481. SingleSegment: true,
  482. // Adding a checksum would be a waste of space.
  483. Checksum: false,
  484. DictID: 0,
  485. }
  486. dst, _ = fh.appendTo(dst)
  487. // Write raw block as last one only.
  488. var blk blockHeader
  489. blk.setSize(0)
  490. blk.setType(blockTypeRaw)
  491. blk.setLast(true)
  492. dst = blk.appendTo(dst)
  493. }
  494. return dst
  495. }
  496. e.init.Do(e.initialize)
  497. enc := <-e.encoders
  498. defer func() {
  499. // Release encoder reference to last block.
  500. // If a non-single block is needed the encoder will reset again.
  501. e.encoders <- enc
  502. }()
  503. // Use single segments when above minimum window and below window size.
  504. single := len(src) <= e.o.windowSize && len(src) > MinWindowSize
  505. if e.o.single != nil {
  506. single = *e.o.single
  507. }
  508. fh := frameHeader{
  509. ContentSize: uint64(len(src)),
  510. WindowSize: uint32(enc.WindowSize(int64(len(src)))),
  511. SingleSegment: single,
  512. Checksum: e.o.crc,
  513. DictID: e.o.dict.ID(),
  514. }
  515. // If less than 1MB, allocate a buffer up front.
  516. if len(dst) == 0 && cap(dst) == 0 && len(src) < 1<<20 && !e.o.lowMem {
  517. dst = make([]byte, 0, len(src))
  518. }
  519. dst, err := fh.appendTo(dst)
  520. if err != nil {
  521. panic(err)
  522. }
  523. // If we can do everything in one block, prefer that.
  524. if len(src) <= e.o.blockSize {
  525. enc.Reset(e.o.dict, true)
  526. // Slightly faster with no history and everything in one block.
  527. if e.o.crc {
  528. _, _ = enc.CRC().Write(src)
  529. }
  530. blk := enc.Block()
  531. blk.last = true
  532. if e.o.dict == nil {
  533. enc.EncodeNoHist(blk, src)
  534. } else {
  535. enc.Encode(blk, src)
  536. }
  537. // If we got the exact same number of literals as input,
  538. // assume the literals cannot be compressed.
  539. err := errIncompressible
  540. oldout := blk.output
  541. if len(blk.literals) != len(src) || len(src) != e.o.blockSize {
  542. // Output directly to dst
  543. blk.output = dst
  544. err = blk.encode(src, e.o.noEntropy, !e.o.allLitEntropy)
  545. }
  546. switch err {
  547. case errIncompressible:
  548. if debugEncoder {
  549. println("Storing incompressible block as raw")
  550. }
  551. dst = blk.encodeRawTo(dst, src)
  552. case nil:
  553. dst = blk.output
  554. default:
  555. panic(err)
  556. }
  557. blk.output = oldout
  558. } else {
  559. enc.Reset(e.o.dict, false)
  560. blk := enc.Block()
  561. for len(src) > 0 {
  562. todo := src
  563. if len(todo) > e.o.blockSize {
  564. todo = todo[:e.o.blockSize]
  565. }
  566. src = src[len(todo):]
  567. if e.o.crc {
  568. _, _ = enc.CRC().Write(todo)
  569. }
  570. blk.pushOffsets()
  571. enc.Encode(blk, todo)
  572. if len(src) == 0 {
  573. blk.last = true
  574. }
  575. err := errIncompressible
  576. // If we got the exact same number of literals as input,
  577. // assume the literals cannot be compressed.
  578. if len(blk.literals) != len(todo) || len(todo) != e.o.blockSize {
  579. err = blk.encode(todo, e.o.noEntropy, !e.o.allLitEntropy)
  580. }
  581. switch err {
  582. case errIncompressible:
  583. if debugEncoder {
  584. println("Storing incompressible block as raw")
  585. }
  586. dst = blk.encodeRawTo(dst, todo)
  587. blk.popOffsets()
  588. case nil:
  589. dst = append(dst, blk.output...)
  590. default:
  591. panic(err)
  592. }
  593. blk.reset(nil)
  594. }
  595. }
  596. if e.o.crc {
  597. dst = enc.AppendCRC(dst)
  598. }
  599. // Add padding with content from crypto/rand.Reader
  600. if e.o.pad > 0 {
  601. add := calcSkippableFrame(int64(len(dst)), int64(e.o.pad))
  602. dst, err = skippableFrame(dst, add, rand.Reader)
  603. if err != nil {
  604. panic(err)
  605. }
  606. }
  607. return dst
  608. }