diff --git a/src/journal/compress.c b/src/journal/compress.c index 626e079fef..a59c2b7a88 100644 --- a/src/journal/compress.c +++ b/src/journal/compress.c @@ -458,9 +458,6 @@ int decompress_startswith_zstd( const void *prefix, size_t prefix_len, uint8_t extra) { #if HAVE_ZSTD - _cleanup_(ZSTD_freeDCtxp) ZSTD_DCtx *dctx = NULL; - size_t k; - assert(src); assert(src_size > 0); assert(buffer); @@ -468,7 +465,14 @@ int decompress_startswith_zstd( assert(prefix); assert(*buffer_size == 0 || *buffer); - dctx = ZSTD_createDCtx(); + uint64_t size = ZSTD_getFrameContentSize(src, src_size); + if (IN_SET(size, ZSTD_CONTENTSIZE_ERROR, ZSTD_CONTENTSIZE_UNKNOWN)) + return -EBADMSG; + + if (size < prefix_len + 1) + return 0; /* Decompressed text too short to match the prefix and extra */ + + _cleanup_(ZSTD_freeDCtxp) ZSTD_DCtx *dctx = ZSTD_createDCtx(); if (!dctx) return -ENOMEM; @@ -483,30 +487,17 @@ int decompress_startswith_zstd( .dst = *buffer, .size = *buffer_size, }; + size_t k; - for (;;) { - k = ZSTD_decompressStream(dctx, &output, &input); - if (ZSTD_isError(k)) { - log_debug("ZSTD decoder failed: %s", ZSTD_getErrorName(k)); - return zstd_ret_to_errno(k); - } - - if (output.pos >= prefix_len + 1) - return memcmp(*buffer, prefix, prefix_len) == 0 && - ((const uint8_t*) *buffer)[prefix_len] == extra; - - if (input.pos >= input.size) - return 0; - - if (*buffer_size > SIZE_MAX/2) - return -ENOBUFS; - - if (!(greedy_realloc(buffer, buffer_size, *buffer_size * 2, 1))) - return -ENOMEM; - - output.dst = *buffer; - output.size = *buffer_size; + k = ZSTD_decompressStream(dctx, &output, &input); + if (ZSTD_isError(k)) { + log_debug("ZSTD decoder failed: %s", ZSTD_getErrorName(k)); + return zstd_ret_to_errno(k); } + assert(output.pos >= prefix_len + 1); + + return memcmp(*buffer, prefix, prefix_len) == 0 && + ((const uint8_t*) *buffer)[prefix_len] == extra; #else return -EPROTONOSUPPORT; #endif diff --git a/src/journal/test-compress.c b/src/journal/test-compress.c index 0990f7604d..f50fb0acea 100644 --- a/src/journal/test-compress.c +++ b/src/journal/test-compress.c @@ -232,6 +232,8 @@ static void test_lz4_decompress_partial(void) { int r; _cleanup_free_ char *huge = NULL; + log_debug("/* %s */", __func__); + assert_se(huge = malloc(HUGE_SIZE)); memcpy(huge, "HUGE=", STRLEN("HUGE=")); memset(&huge[STRLEN("HUGE=")], 'x', HUGE_SIZE - STRLEN("HUGE=") - 1);