diracdec: fix unchecked byte length
[ffmpeg.git] / libavcodec / diracdec.c
index c473e87..6cb098b 100644 (file)
@@ -32,6 +32,7 @@
 #include "internal.h"
 #include "golomb.h"
 #include "dirac_arith.h"
+#include "dirac_vlc.h"
 #include "mpeg12data.h"
 #include "libavcodec/mpegvideo.h"
 #include "mpegvideoencdsp.h"
@@ -120,11 +121,20 @@ typedef struct Plane {
     SubBand band[MAX_DWT_LEVELS][4];
 } Plane;
 
+/* Used by Low Delay and High Quality profiles */
+typedef struct DiracSlice {
+    GetBitContext gb;
+    int slice_x;
+    int slice_y;
+    int bytes;
+} DiracSlice;
+
 typedef struct DiracContext {
     AVCodecContext *avctx;
     MpegvideoEncDSPContext mpvencdsp;
     VideoDSPContext vdsp;
     DiracDSPContext diracdsp;
+    DiracGolombLUT *reader_ctx;
     DiracVersionInfo version;
     GetBitContext gb;
     AVDiracSeqHeader seq;
@@ -161,6 +171,13 @@ typedef struct DiracContext {
     unsigned num_x;              /* number of horizontal slices               */
     unsigned num_y;              /* number of vertical slices                 */
 
+    uint8_t *thread_buf;         /* Per-thread buffer for coefficient storage */
+    int threads_num_buf;         /* Current # of buffers allocated            */
+    int thread_buf_size;         /* Each thread has a buffer this size        */
+
+    DiracSlice *slice_params_buf;
+    int slice_params_num_buf;
+
     struct {
         unsigned width;
         unsigned height;
@@ -370,6 +387,11 @@ static av_cold int dirac_decode_init(AVCodecContext *avctx)
     s->avctx = avctx;
     s->frame_number = -1;
 
+    s->thread_buf = NULL;
+    s->threads_num_buf = -1;
+    s->thread_buf_size = -1;
+
+    ff_dirac_golomb_reader_init(&s->reader_ctx);
     ff_diracdsp_init(&s->diracdsp);
     ff_mpegvideoencdsp_init(&s->mpvencdsp, avctx);
     ff_videodsp_init(&s->vdsp, 8);
@@ -399,65 +421,29 @@ static av_cold int dirac_decode_end(AVCodecContext *avctx)
     DiracContext *s = avctx->priv_data;
     int i;
 
+    ff_dirac_golomb_reader_end(&s->reader_ctx);
+
     dirac_decode_flush(avctx);
     for (i = 0; i < MAX_FRAMES; i++)
         av_frame_free(&s->all_frames[i].avframe);
 
+    av_freep(&s->thread_buf);
+    av_freep(&s->slice_params_buf);
+
     return 0;
 }
 
-#define SIGN_CTX(x) (CTX_SIGN_ZERO + ((x) > 0) - ((x) < 0))
-
 static inline int coeff_unpack_golomb(GetBitContext *gb, int qfactor, int qoffset)
 {
-    int sign, coeff;
-    uint32_t buf;
-
-    OPEN_READER(re, gb);
-    UPDATE_CACHE(re, gb);
-    buf = GET_CACHE(re, gb);
-
-    if (buf & 0x80000000) {
-        LAST_SKIP_BITS(re,gb,1);
-        CLOSE_READER(re, gb);
-        return 0;
-    }
-
-    if (buf & 0xAA800000) {
-        buf >>= 32 - 8;
-        SKIP_BITS(re, gb, ff_interleaved_golomb_vlc_len[buf]);
-
-        coeff = ff_interleaved_ue_golomb_vlc_code[buf];
-    } else {
-        unsigned ret = 1;
-
-        do {
-            buf >>= 32 - 8;
-            SKIP_BITS(re, gb,
-                           FFMIN(ff_interleaved_golomb_vlc_len[buf], 8));
-
-            if (ff_interleaved_golomb_vlc_len[buf] != 9) {
-                ret <<= (ff_interleaved_golomb_vlc_len[buf] - 1) >> 1;
-                ret  |= ff_interleaved_dirac_golomb_vlc_code[buf];
-                break;
-            }
-            ret = (ret << 4) | ff_interleaved_dirac_golomb_vlc_code[buf];
-            UPDATE_CACHE(re, gb);
-            buf = GET_CACHE(re, gb);
-        } while (ret<0x8000000U && BITS_AVAILABLE(re, gb));
-
-        coeff = ret - 1;
-    }
-
-    coeff = (coeff * qfactor + qoffset) >> 2;
-    sign  = SHOW_SBITS(re, gb, 1);
-    LAST_SKIP_BITS(re, gb, 1);
-    coeff = (coeff ^ sign) - sign;
-
-    CLOSE_READER(re, gb);
+    int coeff = dirac_get_se_golomb(gb);
+    const int sign = FFSIGN(coeff);
+    if (coeff)
+        coeff = sign*((sign * coeff * qfactor + qoffset) >> 2);
     return coeff;
 }
 
+#define SIGN_CTX(x) (CTX_SIGN_ZERO + ((x) > 0) - ((x) < 0))
+
 #define UNPACK_ARITH(n, type) \
     static inline void coeff_unpack_arith_##n(DiracArith *c, int qfactor, int qoffset, \
                                               SubBand *b, type *buf, int x, int y) \
@@ -527,7 +513,7 @@ static inline void codeblock(DiracContext *s, SubBand *b,
         b->quant = quant;
     }
 
-    if (b->quant > 115) {
+    if (b->quant > (DIRAC_MAX_QUANT_INDEX - 1)) {
         av_log(s->avctx, AV_LOG_ERROR, "Unsupported quant %d\n", b->quant);
         b->quant = 0;
         return;
@@ -717,12 +703,12 @@ static void decode_subband(DiracContext *s, GetBitContext *gb, int quant,
     uint8_t *buf2 = b2 ? b2->ibuf + top * b2->stride: NULL;
     int x, y;
 
-    if (quant > 115) {
+    if (quant > (DIRAC_MAX_QUANT_INDEX - 1)) {
         av_log(s->avctx, AV_LOG_ERROR, "Unsupported quant %d\n", quant);
         return;
     }
-    qfactor = ff_dirac_qscale_tab[quant & 0x7f];
-    qoffset = ff_dirac_qoffset_intra_tab[quant & 0x7f] + 2;
+    qfactor = ff_dirac_qscale_tab[quant];
+    qoffset = ff_dirac_qoffset_intra_tab[quant] + 2;
     /* we have to constantly check for overread since the spec explicitly
        requires this, with the meaning that all remaining coeffs are set to 0 */
     if (get_bits_count(gb) >= bits_end)
@@ -750,15 +736,6 @@ static void decode_subband(DiracContext *s, GetBitContext *gb, int quant,
     }
 }
 
-/* Used by Low Delay and High Quality profiles */
-typedef struct DiracSlice {
-    GetBitContext gb;
-    int slice_x;
-    int slice_y;
-    int bytes;
-} DiracSlice;
-
-
 /**
  * Dirac Specification ->
  * 13.5.2 Slices. slice(sx,sy)
@@ -801,52 +778,120 @@ static int decode_lowdelay_slice(AVCodecContext *avctx, void *arg)
     return 0;
 }
 
+typedef struct SliceCoeffs {
+    int left;
+    int top;
+    int tot_h;
+    int tot_v;
+    int tot;
+} SliceCoeffs;
+
+static int subband_coeffs(DiracContext *s, int x, int y, int p,
+                          SliceCoeffs c[MAX_DWT_LEVELS])
+{
+    int level, coef = 0;
+    for (level = 0; level < s->wavelet_depth; level++) {
+        SliceCoeffs *o = &c[level];
+        SubBand *b = &s->plane[p].band[level][3]; /* orientation doens't matter */
+        o->top   = b->height * y / s->num_y;
+        o->left  = b->width  * x / s->num_x;
+        o->tot_h = ((b->width  * (x + 1)) / s->num_x) - o->left;
+        o->tot_v = ((b->height * (y + 1)) / s->num_y) - o->top;
+        o->tot   = o->tot_h*o->tot_v;
+        coef    += o->tot * (4 - !!level);
+    }
+    return coef;
+}
+
 /**
  * VC-2 Specification ->
  * 13.5.3 hq_slice(sx,sy)
  */
-static int decode_hq_slice(AVCodecContext *avctx, void *arg)
+static int decode_hq_slice(DiracContext *s, DiracSlice *slice, uint8_t *tmp_buf)
 {
-    int i, quant, level, orientation, quant_idx;
-    uint8_t quants[MAX_DWT_LEVELS][4];
-    DiracContext *s = avctx->priv_data;
-    DiracSlice *slice = arg;
+    int i, level, orientation, quant_idx;
+    int qfactor[MAX_DWT_LEVELS][4], qoffset[MAX_DWT_LEVELS][4];
     GetBitContext *gb = &slice->gb;
+    SliceCoeffs coeffs_num[MAX_DWT_LEVELS];
 
     skip_bits_long(gb, 8*s->highquality.prefix_bytes);
     quant_idx = get_bits(gb, 8);
 
+    if (quant_idx > DIRAC_MAX_QUANT_INDEX) {
+        av_log(s->avctx, AV_LOG_ERROR, "Invalid quantization index - %i\n", quant_idx);
+        return AVERROR_INVALIDDATA;
+    }
+
     /* Slice quantization (slice_quantizers() in the specs) */
     for (level = 0; level < s->wavelet_depth; level++) {
         for (orientation = !!level; orientation < 4; orientation++) {
-            quant = FFMAX(quant_idx - s->lowdelay.quant[level][orientation], 0);
-            quants[level][orientation] = quant;
+            const int quant = FFMAX(quant_idx - s->lowdelay.quant[level][orientation], 0);
+            qfactor[level][orientation] = ff_dirac_qscale_tab[quant];
+            qoffset[level][orientation] = ff_dirac_qoffset_intra_tab[quant] + 2;
         }
     }
 
     /* Luma + 2 Chroma planes */
     for (i = 0; i < 3; i++) {
-        int64_t length = s->highquality.size_scaler * get_bits(gb, 8);
-        int64_t bits_left = 8 * length;
-        int64_t bits_end = get_bits_count(gb) + bits_left;
+        int coef_num, coef_par, off = 0;
+        int64_t length = s->highquality.size_scaler*get_bits(gb, 8);
+        int64_t bits_end = get_bits_count(gb) + 8*length;
+        const uint8_t *addr = align_get_bits(gb);
 
-        if (bits_end >= INT_MAX) {
+        if (length*8 > get_bits_left(gb)) {
             av_log(s->avctx, AV_LOG_ERROR, "end too far away\n");
             return AVERROR_INVALIDDATA;
         }
 
+        coef_num = subband_coeffs(s, slice->slice_x, slice->slice_y, i, coeffs_num);
+
+        if (s->pshift)
+            coef_par = ff_dirac_golomb_read_32bit(s->reader_ctx, addr,
+                                                  length, tmp_buf, coef_num);
+        else
+            coef_par = ff_dirac_golomb_read_16bit(s->reader_ctx, addr,
+                                                  length, tmp_buf, coef_num);
+
+        if (coef_num > coef_par) {
+            const int start_b = coef_par * (1 << (s->pshift + 1));
+            const int end_b   = coef_num * (1 << (s->pshift + 1));
+            memset(&tmp_buf[start_b], 0, end_b - start_b);
+        }
+
         for (level = 0; level < s->wavelet_depth; level++) {
+            const SliceCoeffs *c = &coeffs_num[level];
             for (orientation = !!level; orientation < 4; orientation++) {
-                decode_subband(s, gb, quants[level][orientation], slice->slice_x, slice->slice_y, bits_end,
-                               &s->plane[i].band[level][orientation], NULL);
+                const SubBand *b1 = &s->plane[i].band[level][orientation];
+                uint8_t *buf = b1->ibuf + c->top * b1->stride + (c->left << (s->pshift + 1));
+
+                /* Change to c->tot_h <= 4 for AVX2 dequantization */
+                const int qfunc = s->pshift + 2*(c->tot_h <= 2);
+                s->diracdsp.dequant_subband[qfunc](&tmp_buf[off], buf, b1->stride,
+                                                   qfactor[level][orientation],
+                                                   qoffset[level][orientation],
+                                                   c->tot_v, c->tot_h);
+
+                off += c->tot << (s->pshift + 1);
             }
         }
+
         skip_bits_long(gb, bits_end - get_bits_count(gb));
     }
 
     return 0;
 }
 
+static int decode_hq_slice_row(AVCodecContext *avctx, void *arg, int jobnr, int threadnr)
+{
+    int i;
+    DiracContext *s = avctx->priv_data;
+    DiracSlice *slices = ((DiracSlice *)arg) + s->num_x*jobnr;
+    uint8_t *thread_buf = &s->thread_buf[s->thread_buf_size*threadnr];
+    for (i = 0; i < s->num_x; i++)
+        decode_hq_slice(s, &slices[i], thread_buf);
+    return 0;
+}
+
 /**
  * Dirac Specification ->
  * 13.5.1 low_delay_transform_data()
@@ -855,14 +900,37 @@ static int decode_lowdelay(DiracContext *s)
 {
     AVCodecContext *avctx = s->avctx;
     int slice_x, slice_y, bufsize;
-    int64_t bytes = 0;
+    int64_t coef_buf_size, bytes = 0;
     const uint8_t *buf;
     DiracSlice *slices;
+    SliceCoeffs tmp[MAX_DWT_LEVELS];
     int slice_num = 0;
 
-    slices = av_mallocz_array(s->num_x, s->num_y * sizeof(DiracSlice));
-    if (!slices)
-        return AVERROR(ENOMEM);
+    if (s->slice_params_num_buf != (s->num_x * s->num_y)) {
+        s->slice_params_buf = av_realloc_f(s->thread_buf, s->num_x * s->num_y, sizeof(DiracSlice));
+        if (!s->slice_params_buf) {
+            av_log(s->avctx, AV_LOG_ERROR, "slice params buffer allocation failure\n");
+            return AVERROR(ENOMEM);
+        }
+        s->slice_params_num_buf = s->num_x * s->num_y;
+    }
+    slices = s->slice_params_buf;
+
+    /* 8 becacuse that's how much the golomb reader could overread junk data
+     * from another plane/slice at most, and 512 because SIMD */
+    coef_buf_size = subband_coeffs(s, s->num_x - 1, s->num_y - 1, 0, tmp) + 8;
+    coef_buf_size = (coef_buf_size << (1 + s->pshift)) + 512;
+
+    if (s->threads_num_buf != avctx->thread_count ||
+        s->thread_buf_size != coef_buf_size) {
+        s->threads_num_buf  = avctx->thread_count;
+        s->thread_buf_size  = coef_buf_size;
+        s->thread_buf       = av_realloc_f(s->thread_buf, avctx->thread_count, s->thread_buf_size);
+        if (!s->thread_buf) {
+            av_log(s->avctx, AV_LOG_ERROR, "thread buffer allocation failure\n");
+            return AVERROR(ENOMEM);
+        }
+    }
 
     align_get_bits(&s->gb);
     /*[DIRAC_STD] 13.5.2 Slices. slice(sx,sy) */
@@ -879,9 +947,8 @@ static int decode_lowdelay(DiracContext *s)
                     if (bytes <= bufsize/8)
                         bytes += buf[bytes] * s->highquality.size_scaler + 1;
                 }
-                if (bytes >= INT_MAX) {
+                if (bytes >= INT_MAX || bytes*8 > bufsize) {
                     av_log(s->avctx, AV_LOG_ERROR, "too many bytes\n");
-                    av_free(slices);
                     return AVERROR_INVALIDDATA;
                 }
 
@@ -898,8 +965,13 @@ static int decode_lowdelay(DiracContext *s)
                     bufsize = 0;
             }
         }
-        avctx->execute(avctx, decode_hq_slice, slices, NULL, slice_num,
-                       sizeof(DiracSlice));
+
+        if (s->num_x*s->num_y != slice_num) {
+            av_log(s->avctx, AV_LOG_ERROR, "too few slices\n");
+            return AVERROR_INVALIDDATA;
+        }
+
+        avctx->execute2(avctx, decode_hq_slice_row, slices, NULL, s->num_y);
     } else {
         for (slice_y = 0; bufsize > 0 && slice_y < s->num_y; slice_y++) {
             for (slice_x = 0; bufsize > 0 && slice_x < s->num_x; slice_x++) {
@@ -933,7 +1005,7 @@ static int decode_lowdelay(DiracContext *s)
             intra_dc_prediction_8(&s->plane[2].band[0][0]);
         }
     }
-    av_free(slices);
+
     return 0;
 }
 
@@ -1743,9 +1815,11 @@ static int dirac_decode_frame_internal(DiracContext *s)
 
     if (s->low_delay) {
         /* [DIRAC_STD] 13.5.1 low_delay_transform_data() */
-        for (comp = 0; comp < 3; comp++) {
-            Plane *p = &s->plane[comp];
-            memset(p->idwt.buf, 0, p->idwt.stride * p->idwt.height);
+        if (!s->hq_picture) {
+            for (comp = 0; comp < 3; comp++) {
+                Plane *p = &s->plane[comp];
+                memset(p->idwt.buf, 0, p->idwt.stride * p->idwt.height);
+            }
         }
         if (!s->zero_res) {
             if ((ret = decode_lowdelay(s)) < 0)