diff --git a/src/journal/compress.c b/src/journal/compress.c index 1b0c01a8fb..626e079fef 100644 --- a/src/journal/compress.c +++ b/src/journal/compress.c @@ -259,10 +259,10 @@ int decompress_blob_lz4(const void *src, uint64_t src_size, int decompress_blob_zstd( const void *src, uint64_t src_size, - void **dst, size_t *dst_alloc_size, size_t* dst_size, size_t dst_max) { + void **dst, size_t *dst_alloc_size, size_t *dst_size, size_t dst_max) { #if HAVE_ZSTD - size_t space; + uint64_t size; assert(src); assert(src_size > 0); @@ -271,38 +271,40 @@ int decompress_blob_zstd( assert(dst_size); assert(*dst_alloc_size == 0 || *dst); - if (src_size > SIZE_MAX/2) /* Overflow? */ - return -ENOBUFS; - space = src_size * 2; - if (dst_max > 0 && space > dst_max) - space = dst_max; + size = ZSTD_getFrameContentSize(src, src_size); + if (IN_SET(size, ZSTD_CONTENTSIZE_ERROR, ZSTD_CONTENTSIZE_UNKNOWN)) + return -EBADMSG; - if (!greedy_realloc(dst, dst_alloc_size, space, 1)) + if (dst_max > 0 && size > dst_max) + size = dst_max; + if (size > SIZE_MAX) + return -E2BIG; + + if (!(greedy_realloc(dst, dst_alloc_size, MAX(ZSTD_DStreamOutSize(), size), 1))) return -ENOMEM; - for (;;) { - size_t k; + _cleanup_(ZSTD_freeDCtxp) ZSTD_DCtx *dctx = ZSTD_createDCtx(); + if (!dctx) + return -ENOMEM; - k = ZSTD_decompress(*dst, *dst_alloc_size, src, src_size); - if (!ZSTD_isError(k)) { - *dst_size = k; - return 0; - } - if (ZSTD_getErrorCode(k) != ZSTD_error_dstSize_tooSmall) - return zstd_ret_to_errno(k); + ZSTD_inBuffer input = { + .src = src, + .size = src_size, + }; + ZSTD_outBuffer output = { + .dst = *dst, + .size = *dst_alloc_size, + }; - if (dst_max > 0 && space >= dst_max) /* Already at max? */ - return -ENOBUFS; - if (space > SIZE_MAX / 2) /* Overflow? */ - return -ENOBUFS; - - space *= 2; - if (dst_max > 0 && space > dst_max) - space = dst_max; - - if (!greedy_realloc(dst, dst_alloc_size, space, 1)) - return -ENOMEM; + size_t k = ZSTD_decompressStream(dctx, &output, &input); + if (ZSTD_isError(k)) { + log_debug("ZSTD decoder failed: %s", ZSTD_getErrorName(k)); + return zstd_ret_to_errno(k); } + assert(output.pos >= size); + + *dst_size = size; + return 0; #else return -EPROTONOSUPPORT; #endif