您最多选择25个主题 主题必须以字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符

369 行
11 KiB

  1. //go:build amd64 && !appengine && !noasm && gc
  2. // +build amd64,!appengine,!noasm,gc
  3. package zstd
  4. import (
  5. "fmt"
  6. "github.com/klauspost/compress/internal/cpuinfo"
  7. )
  8. type decodeSyncAsmContext struct {
  9. llTable []decSymbol
  10. mlTable []decSymbol
  11. ofTable []decSymbol
  12. llState uint64
  13. mlState uint64
  14. ofState uint64
  15. iteration int
  16. litRemain int
  17. out []byte
  18. outPosition int
  19. literals []byte
  20. litPosition int
  21. history []byte
  22. windowSize int
  23. ll int // set on error (not for all errors, please refer to _generate/gen.go)
  24. ml int // set on error (not for all errors, please refer to _generate/gen.go)
  25. mo int // set on error (not for all errors, please refer to _generate/gen.go)
  26. }
  27. // sequenceDecs_decodeSync_amd64 implements the main loop of sequenceDecs.decodeSync in x86 asm.
  28. //
  29. // Please refer to seqdec_generic.go for the reference implementation.
  30. //go:noescape
  31. func sequenceDecs_decodeSync_amd64(s *sequenceDecs, br *bitReader, ctx *decodeSyncAsmContext) int
  32. // sequenceDecs_decodeSync_bmi2 implements the main loop of sequenceDecs.decodeSync in x86 asm with BMI2 extensions.
  33. //go:noescape
  34. func sequenceDecs_decodeSync_bmi2(s *sequenceDecs, br *bitReader, ctx *decodeSyncAsmContext) int
  35. // sequenceDecs_decodeSync_safe_amd64 does the same as above, but does not write more than output buffer.
  36. //go:noescape
  37. func sequenceDecs_decodeSync_safe_amd64(s *sequenceDecs, br *bitReader, ctx *decodeSyncAsmContext) int
  38. // sequenceDecs_decodeSync_safe_bmi2 does the same as above, but does not write more than output buffer.
  39. //go:noescape
  40. func sequenceDecs_decodeSync_safe_bmi2(s *sequenceDecs, br *bitReader, ctx *decodeSyncAsmContext) int
  41. // decode sequences from the stream with the provided history but without a dictionary.
  42. func (s *sequenceDecs) decodeSyncSimple(hist []byte) (bool, error) {
  43. if len(s.dict) > 0 {
  44. return false, nil
  45. }
  46. if s.maxSyncLen == 0 && cap(s.out)-len(s.out) < maxCompressedBlockSize {
  47. return false, nil
  48. }
  49. // FIXME: Using unsafe memory copies leads to rare, random crashes
  50. // with fuzz testing. It is therefore disabled for now.
  51. const useSafe = true
  52. /*
  53. useSafe := false
  54. if s.maxSyncLen == 0 && cap(s.out)-len(s.out) < maxCompressedBlockSizeAlloc {
  55. useSafe = true
  56. }
  57. if s.maxSyncLen > 0 && cap(s.out)-len(s.out)-compressedBlockOverAlloc < int(s.maxSyncLen) {
  58. useSafe = true
  59. }
  60. if cap(s.literals) < len(s.literals)+compressedBlockOverAlloc {
  61. useSafe = true
  62. }
  63. */
  64. br := s.br
  65. maxBlockSize := maxCompressedBlockSize
  66. if s.windowSize < maxBlockSize {
  67. maxBlockSize = s.windowSize
  68. }
  69. ctx := decodeSyncAsmContext{
  70. llTable: s.litLengths.fse.dt[:maxTablesize],
  71. mlTable: s.matchLengths.fse.dt[:maxTablesize],
  72. ofTable: s.offsets.fse.dt[:maxTablesize],
  73. llState: uint64(s.litLengths.state.state),
  74. mlState: uint64(s.matchLengths.state.state),
  75. ofState: uint64(s.offsets.state.state),
  76. iteration: s.nSeqs - 1,
  77. litRemain: len(s.literals),
  78. out: s.out,
  79. outPosition: len(s.out),
  80. literals: s.literals,
  81. windowSize: s.windowSize,
  82. history: hist,
  83. }
  84. s.seqSize = 0
  85. startSize := len(s.out)
  86. var errCode int
  87. if cpuinfo.HasBMI2() {
  88. if useSafe {
  89. errCode = sequenceDecs_decodeSync_safe_bmi2(s, br, &ctx)
  90. } else {
  91. errCode = sequenceDecs_decodeSync_bmi2(s, br, &ctx)
  92. }
  93. } else {
  94. if useSafe {
  95. errCode = sequenceDecs_decodeSync_safe_amd64(s, br, &ctx)
  96. } else {
  97. errCode = sequenceDecs_decodeSync_amd64(s, br, &ctx)
  98. }
  99. }
  100. switch errCode {
  101. case noError:
  102. break
  103. case errorMatchLenOfsMismatch:
  104. return true, fmt.Errorf("zero matchoff and matchlen (%d) > 0", ctx.ml)
  105. case errorMatchLenTooBig:
  106. return true, fmt.Errorf("match len (%d) bigger than max allowed length", ctx.ml)
  107. case errorMatchOffTooBig:
  108. return true, fmt.Errorf("match offset (%d) bigger than current history (%d)",
  109. ctx.mo, ctx.outPosition+len(hist)-startSize)
  110. case errorNotEnoughLiterals:
  111. return true, fmt.Errorf("unexpected literal count, want %d bytes, but only %d is available",
  112. ctx.ll, ctx.litRemain+ctx.ll)
  113. case errorNotEnoughSpace:
  114. size := ctx.outPosition + ctx.ll + ctx.ml
  115. if debugDecoder {
  116. println("msl:", s.maxSyncLen, "cap", cap(s.out), "bef:", startSize, "sz:", size-startSize, "mbs:", maxBlockSize, "outsz:", cap(s.out)-startSize)
  117. }
  118. return true, fmt.Errorf("output (%d) bigger than max block size (%d)", size-startSize, maxBlockSize)
  119. default:
  120. return true, fmt.Errorf("sequenceDecs_decode returned erronous code %d", errCode)
  121. }
  122. s.seqSize += ctx.litRemain
  123. if s.seqSize > maxBlockSize {
  124. return true, fmt.Errorf("output (%d) bigger than max block size (%d)", s.seqSize, maxBlockSize)
  125. }
  126. err := br.close()
  127. if err != nil {
  128. printf("Closing sequences: %v, %+v\n", err, *br)
  129. return true, err
  130. }
  131. s.literals = s.literals[ctx.litPosition:]
  132. t := ctx.outPosition
  133. s.out = s.out[:t]
  134. // Add final literals
  135. s.out = append(s.out, s.literals...)
  136. if debugDecoder {
  137. t += len(s.literals)
  138. if t != len(s.out) {
  139. panic(fmt.Errorf("length mismatch, want %d, got %d", len(s.out), t))
  140. }
  141. }
  142. return true, nil
  143. }
  144. // --------------------------------------------------------------------------------
  145. type decodeAsmContext struct {
  146. llTable []decSymbol
  147. mlTable []decSymbol
  148. ofTable []decSymbol
  149. llState uint64
  150. mlState uint64
  151. ofState uint64
  152. iteration int
  153. seqs []seqVals
  154. litRemain int
  155. }
  156. const noError = 0
  157. // error reported when mo == 0 && ml > 0
  158. const errorMatchLenOfsMismatch = 1
  159. // error reported when ml > maxMatchLen
  160. const errorMatchLenTooBig = 2
  161. // error reported when mo > available history or mo > s.windowSize
  162. const errorMatchOffTooBig = 3
  163. // error reported when the sum of literal lengths exeeceds the literal buffer size
  164. const errorNotEnoughLiterals = 4
  165. // error reported when capacity of `out` is too small
  166. const errorNotEnoughSpace = 5
  167. // sequenceDecs_decode implements the main loop of sequenceDecs in x86 asm.
  168. //
  169. // Please refer to seqdec_generic.go for the reference implementation.
  170. //go:noescape
  171. func sequenceDecs_decode_amd64(s *sequenceDecs, br *bitReader, ctx *decodeAsmContext) int
  172. // sequenceDecs_decode implements the main loop of sequenceDecs in x86 asm.
  173. //
  174. // Please refer to seqdec_generic.go for the reference implementation.
  175. //go:noescape
  176. func sequenceDecs_decode_56_amd64(s *sequenceDecs, br *bitReader, ctx *decodeAsmContext) int
  177. // sequenceDecs_decode implements the main loop of sequenceDecs in x86 asm with BMI2 extensions.
  178. //go:noescape
  179. func sequenceDecs_decode_bmi2(s *sequenceDecs, br *bitReader, ctx *decodeAsmContext) int
  180. // sequenceDecs_decode implements the main loop of sequenceDecs in x86 asm with BMI2 extensions.
  181. //go:noescape
  182. func sequenceDecs_decode_56_bmi2(s *sequenceDecs, br *bitReader, ctx *decodeAsmContext) int
  183. // decode sequences from the stream without the provided history.
  184. func (s *sequenceDecs) decode(seqs []seqVals) error {
  185. br := s.br
  186. maxBlockSize := maxCompressedBlockSize
  187. if s.windowSize < maxBlockSize {
  188. maxBlockSize = s.windowSize
  189. }
  190. ctx := decodeAsmContext{
  191. llTable: s.litLengths.fse.dt[:maxTablesize],
  192. mlTable: s.matchLengths.fse.dt[:maxTablesize],
  193. ofTable: s.offsets.fse.dt[:maxTablesize],
  194. llState: uint64(s.litLengths.state.state),
  195. mlState: uint64(s.matchLengths.state.state),
  196. ofState: uint64(s.offsets.state.state),
  197. seqs: seqs,
  198. iteration: len(seqs) - 1,
  199. litRemain: len(s.literals),
  200. }
  201. s.seqSize = 0
  202. lte56bits := s.maxBits+s.offsets.fse.actualTableLog+s.matchLengths.fse.actualTableLog+s.litLengths.fse.actualTableLog <= 56
  203. var errCode int
  204. if cpuinfo.HasBMI2() {
  205. if lte56bits {
  206. errCode = sequenceDecs_decode_56_bmi2(s, br, &ctx)
  207. } else {
  208. errCode = sequenceDecs_decode_bmi2(s, br, &ctx)
  209. }
  210. } else {
  211. if lte56bits {
  212. errCode = sequenceDecs_decode_56_amd64(s, br, &ctx)
  213. } else {
  214. errCode = sequenceDecs_decode_amd64(s, br, &ctx)
  215. }
  216. }
  217. if errCode != 0 {
  218. i := len(seqs) - ctx.iteration - 1
  219. switch errCode {
  220. case errorMatchLenOfsMismatch:
  221. ml := ctx.seqs[i].ml
  222. return fmt.Errorf("zero matchoff and matchlen (%d) > 0", ml)
  223. case errorMatchLenTooBig:
  224. ml := ctx.seqs[i].ml
  225. return fmt.Errorf("match len (%d) bigger than max allowed length", ml)
  226. case errorNotEnoughLiterals:
  227. ll := ctx.seqs[i].ll
  228. return fmt.Errorf("unexpected literal count, want %d bytes, but only %d is available", ll, ctx.litRemain+ll)
  229. }
  230. return fmt.Errorf("sequenceDecs_decode_amd64 returned erronous code %d", errCode)
  231. }
  232. if ctx.litRemain < 0 {
  233. return fmt.Errorf("literal count is too big: total available %d, total requested %d",
  234. len(s.literals), len(s.literals)-ctx.litRemain)
  235. }
  236. s.seqSize += ctx.litRemain
  237. if s.seqSize > maxBlockSize {
  238. return fmt.Errorf("output (%d) bigger than max block size (%d)", s.seqSize, maxBlockSize)
  239. }
  240. err := br.close()
  241. if err != nil {
  242. printf("Closing sequences: %v, %+v\n", err, *br)
  243. }
  244. return err
  245. }
  246. // --------------------------------------------------------------------------------
  247. type executeAsmContext struct {
  248. seqs []seqVals
  249. seqIndex int
  250. out []byte
  251. history []byte
  252. literals []byte
  253. outPosition int
  254. litPosition int
  255. windowSize int
  256. }
  257. // sequenceDecs_executeSimple_amd64 implements the main loop of sequenceDecs.executeSimple in x86 asm.
  258. //
  259. // Returns false if a match offset is too big.
  260. //
  261. // Please refer to seqdec_generic.go for the reference implementation.
  262. //go:noescape
  263. func sequenceDecs_executeSimple_amd64(ctx *executeAsmContext) bool
  264. // Same as above, but with safe memcopies
  265. //go:noescape
  266. func sequenceDecs_executeSimple_safe_amd64(ctx *executeAsmContext) bool
  267. // executeSimple handles cases when dictionary is not used.
  268. func (s *sequenceDecs) executeSimple(seqs []seqVals, hist []byte) error {
  269. // Ensure we have enough output size...
  270. if len(s.out)+s.seqSize+compressedBlockOverAlloc > cap(s.out) {
  271. addBytes := s.seqSize + len(s.out) + compressedBlockOverAlloc
  272. s.out = append(s.out, make([]byte, addBytes)...)
  273. s.out = s.out[:len(s.out)-addBytes]
  274. }
  275. if debugDecoder {
  276. printf("Execute %d seqs with literals: %d into %d bytes\n", len(seqs), len(s.literals), s.seqSize)
  277. }
  278. var t = len(s.out)
  279. out := s.out[:t+s.seqSize]
  280. ctx := executeAsmContext{
  281. seqs: seqs,
  282. seqIndex: 0,
  283. out: out,
  284. history: hist,
  285. outPosition: t,
  286. litPosition: 0,
  287. literals: s.literals,
  288. windowSize: s.windowSize,
  289. }
  290. var ok bool
  291. if cap(s.literals) < len(s.literals)+compressedBlockOverAlloc {
  292. ok = sequenceDecs_executeSimple_safe_amd64(&ctx)
  293. } else {
  294. ok = sequenceDecs_executeSimple_amd64(&ctx)
  295. }
  296. if !ok {
  297. return fmt.Errorf("match offset (%d) bigger than current history (%d)",
  298. seqs[ctx.seqIndex].mo, ctx.outPosition+len(hist))
  299. }
  300. s.literals = s.literals[ctx.litPosition:]
  301. t = ctx.outPosition
  302. // Add final literals
  303. copy(out[t:], s.literals)
  304. if debugDecoder {
  305. t += len(s.literals)
  306. if t != len(out) {
  307. panic(fmt.Errorf("length mismatch, want %d, got %d, ss: %d", len(out), t, s.seqSize))
  308. }
  309. }
  310. s.out = out
  311. return nil
  312. }