imdct15: replace the FFT with a faster PFA FFT algorithm
authorRostislav Pehlivanov <atomnuker@gmail.com>
Wed, 4 Jan 2017 09:23:24 +0000 (09:23 +0000)
committerRostislav Pehlivanov <atomnuker@gmail.com>
Thu, 5 Jan 2017 22:32:02 +0000 (22:32 +0000)
This commit replaces the current inefficient non-power-of-two FFT with a
much faster FFT based on the Prime Factor Algorithm.
Although it is already much faster than the old algorithm without SIMD,
the new algorithm makes use of the already very throughouly SIMD'd power
of two FFT, which improves performance even more across all platforms
which we have SIMD support for.

Most of the work was done by Peter Barfuss, who passed the code to me to
implement into the iMDCT and the current codebase. The code for a
5-point and 15-point FFT was derived from the previous implementation,
although it was optimized and simplified, which will make its future
SIMD easier. The 15-point FFT is currently using 6% of the current
overall decoder overhead.

The FFT can now easily be used as a forward transform by simply not
multiplying the 5-point FFT's imaginary component by -1 (which comes
from the fact that changing the complex exponential's angle by -1 also
changes the output by that) and by multiplying the "theta" angle of the
main exptab by -1. Hence the deliberately left multiplication by -1 at
the end.

FATE passes, and performance reports on other platforms/CPUs are
welcome.

Performance comparisons:

iMDCT, PFA:
101127 decicycles in speed,   32765 runs,      3 skips
iMDCT, Old:
211022 decicycles in speed,   32768 runs,      0 skips

Standalone FFT, 300000 transforms of size 960:
    PFA        Old FFT     kiss_fft    libfftw3f
    3.659695s, 15.726912s, 13.300789s, 1.182222s

Being only 3x slower than libfftw3f is a big achievement by itself.

There appears to be something capping the performance in the iMDCT side
of things, possibly during the pre-stage reindexing. However, it is
certainly fast enough for now.

Signed-off-by: Rostislav Pehlivanov <atomnuker@gmail.com>
libavcodec/imdct15.c
libavcodec/imdct15.h

index 7481c02..a6d4249 100644 (file)
@@ -1,5 +1,6 @@
 /*
  * Copyright (c) 2013-2014 Mozilla Corporation
+ * Copyright (c) 2017 Rostislav Pehlivanov <atomnuker@gmail.com>
  *
  * This file is part of FFmpeg.
  *
 #include "libavutil/attributes.h"
 #include "libavutil/common.h"
 
-#include "avfft.h"
 #include "imdct15.h"
-#include "opus.h"
-
-// minimal iMDCT size to make SIMD opts easier
-#define CELT_MIN_IMDCT_SIZE 120
 
 // complex c = a * b
 #define CMUL3(cre, cim, are, aim, bre, bim)          \
@@ -48,37 +44,18 @@ do {                                                 \
 
 #define CMUL(c, a, b) CMUL3((c).re, (c).im, (a).re, (a).im, (b).re, (b).im)
 
-// complex c = a * b
-//         d = a * conjugate(b)
-#define CMUL2(c, d, a, b)                            \
-do {                                                 \
-    float are = (a).re;                              \
-    float aim = (a).im;                              \
-    float bre = (b).re;                              \
-    float bim = (b).im;                              \
-    float rr  = are * bre;                           \
-    float ri  = are * bim;                           \
-    float ir  = aim * bre;                           \
-    float ii  = aim * bim;                           \
-    (c).re =  rr - ii;                               \
-    (c).im =  ri + ir;                               \
-    (d).re =  rr + ii;                               \
-    (d).im = -ri + ir;                               \
-} while (0)
-
 av_cold void ff_imdct15_uninit(IMDCT15Context **ps)
 {
     IMDCT15Context *s = *ps;
-    int i;
 
     if (!s)
         return;
 
-    for (i = 0; i < FF_ARRAY_ELEMS(s->exptab); i++)
-        av_freep(&s->exptab[i]);
+    ff_fft_end(&s->ptwo_fft);
 
+    av_freep(&s->pfa_prereindex);
+    av_freep(&s->pfa_postreindex);
     av_freep(&s->twiddle_exptab);
-
     av_freep(&s->tmp);
 
     av_freep(ps);
@@ -87,14 +64,46 @@ av_cold void ff_imdct15_uninit(IMDCT15Context **ps)
 static void imdct15_half(IMDCT15Context *s, float *dst, const float *src,
                          ptrdiff_t stride, float scale);
 
+static inline int init_pfa_reindex_tabs(IMDCT15Context *s)
+{
+    int i, j;
+    const int b_ptwo = s->ptwo_fft.nbits; /* Bits for the power of two FFTs */
+    const int l_ptwo = 1 << b_ptwo; /* Total length for the power of two FFTs */
+    const int inv_1 = l_ptwo << ((4 - b_ptwo) & 3); /* (2^b_ptwo)^-1 mod 15 */
+    const int inv_2 = 0xeeeeeeef & ((1U << b_ptwo) - 1); /* 15^-1 mod 2^b_ptwo */
+
+    s->pfa_prereindex = av_malloc(15 * l_ptwo * sizeof(*s->pfa_prereindex));
+    if (!s->pfa_prereindex)
+        return 1;
+
+    s->pfa_postreindex = av_malloc(15 * l_ptwo * sizeof(*s->pfa_postreindex));
+    if (!s->pfa_postreindex)
+        return 1;
+
+    /* Pre/Post-reindex */
+    for (i = 0; i < l_ptwo; i++) {
+        for (j = 0; j < 15; j++) {
+            const int q_pre = ((l_ptwo * j)/15 + i) >> b_ptwo;
+            const int q_post = (((j*inv_1)/15) + (i*inv_2)) >> b_ptwo;
+            const int k_pre = 15*i + (j - q_pre*15)*l_ptwo;
+            const int k_post = i*inv_2*15 + j*inv_1 - 15*q_post*l_ptwo;
+            s->pfa_prereindex[i*15 + j] = k_pre;
+            s->pfa_postreindex[k_post] = l_ptwo*j + i;
+        }
+    }
+
+    return 0;
+}
+
 av_cold int ff_imdct15_init(IMDCT15Context **ps, int N)
 {
     IMDCT15Context *s;
     int len2 = 15 * (1 << N);
     int len  = 2 * len2;
-    int i, j;
+    int i;
 
-    if (len2 > CELT_MAX_FRAME_SIZE || len2 < CELT_MIN_IMDCT_SIZE)
+    /* Tested and verified to work on everything in between */
+    if ((N < 2) || (N > 13))
         return AVERROR(EINVAL);
 
     s = av_mallocz(sizeof(*s));
@@ -104,6 +113,13 @@ av_cold int ff_imdct15_init(IMDCT15Context **ps, int N)
     s->fft_n = N - 1;
     s->len4 = len2 / 2;
     s->len2 = len2;
+    s->imdct_half = imdct15_half;
+
+    if (ff_fft_init(&s->ptwo_fft, N - 1, 1) < 0)
+        goto fail;
+
+    if (init_pfa_reindex_tabs(s))
+        goto fail;
 
     s->tmp  = av_malloc_array(len, 2 * sizeof(*s->tmp));
     if (!s->tmp)
@@ -114,27 +130,30 @@ av_cold int ff_imdct15_init(IMDCT15Context **ps, int N)
         goto fail;
 
     for (i = 0; i < s->len4; i++) {
-        s->twiddle_exptab[i].re = cos(2 * M_PI * (i + 0.125 + s->len4) / len);
-        s->twiddle_exptab[i].im = sin(2 * M_PI * (i + 0.125 + s->len4) / len);
+        s->twiddle_exptab[i].re = cos(2 * M_PI * (i + 0.125f + s->len4) / len);
+        s->twiddle_exptab[i].im = sin(2 * M_PI * (i + 0.125f + s->len4) / len);
     }
 
-    for (i = 0; i < FF_ARRAY_ELEMS(s->exptab); i++) {
-        int N = 15 * (1 << i);
-        s->exptab[i] = av_malloc(sizeof(*s->exptab[i]) * FFMAX(N, 19));
-        if (!s->exptab[i])
-            goto fail;
-
-        for (j = 0; j < N; j++) {
-            s->exptab[i][j].re = cos(2 * M_PI * j / N);
-            s->exptab[i][j].im = sin(2 * M_PI * j / N);
+    /* 15-point FFT exptab */
+    for (i = 0; i < 19; i++) {
+        if (i < 15) {
+            double theta = (2.0f * M_PI * i) / 15.0f;
+            s->exptab[i].re = cos(theta);
+            s->exptab[i].im = sin(theta);
+        } else { /* Wrap around to simplify fft15 */
+            s->exptab[i] = s->exptab[i - 15];
         }
     }
 
-    // wrap around to simplify fft15
-    for (j = 15; j < 19; j++)
-        s->exptab[0][j] = s->exptab[0][j - 15];
+    /* 5-point FFT exptab */
+    s->exptab[19].re = cos(2.0f * M_PI / 5.0f);
+    s->exptab[19].im = sin(2.0f * M_PI / 5.0f);
+    s->exptab[20].re = cos(1.0f * M_PI / 5.0f);
+    s->exptab[20].im = sin(1.0f * M_PI / 5.0f);
 
-    s->imdct_half = imdct15_half;
+    /* Invert the phase for an inverse transform, do nothing for a forward transform */
+    s->exptab[19].im *= -1;
+    s->exptab[20].im *= -1;
 
     *ps = s;
 
@@ -145,127 +164,116 @@ fail:
     return AVERROR(ENOMEM);
 }
 
-static void fft5(FFTComplex *out, const FFTComplex *in, ptrdiff_t stride)
+/* Stride is hardcoded to 3 */
+static inline void fft5(const FFTComplex exptab[2], FFTComplex *out,
+                        const FFTComplex *in)
 {
-    // [0] = exp(2 * i * pi / 5), [1] = exp(2 * i * pi * 2 / 5)
-    static const FFTComplex fact[] = { { 0.30901699437494745,  0.95105651629515353 },
-                                       { -0.80901699437494734, 0.58778525229247325 } };
-
-    FFTComplex z[4][4];
-
-    CMUL2(z[0][0], z[0][3], in[1 * stride], fact[0]);
-    CMUL2(z[0][1], z[0][2], in[1 * stride], fact[1]);
-    CMUL2(z[1][0], z[1][3], in[2 * stride], fact[0]);
-    CMUL2(z[1][1], z[1][2], in[2 * stride], fact[1]);
-    CMUL2(z[2][0], z[2][3], in[3 * stride], fact[0]);
-    CMUL2(z[2][1], z[2][2], in[3 * stride], fact[1]);
-    CMUL2(z[3][0], z[3][3], in[4 * stride], fact[0]);
-    CMUL2(z[3][1], z[3][2], in[4 * stride], fact[1]);
-
-    out[0].re = in[0].re + in[stride].re + in[2 * stride].re + in[3 * stride].re + in[4 * stride].re;
-    out[0].im = in[0].im + in[stride].im + in[2 * stride].im + in[3 * stride].im + in[4 * stride].im;
-
-    out[1].re = in[0].re + z[0][0].re + z[1][1].re + z[2][2].re + z[3][3].re;
-    out[1].im = in[0].im + z[0][0].im + z[1][1].im + z[2][2].im + z[3][3].im;
-
-    out[2].re = in[0].re + z[0][1].re + z[1][3].re + z[2][0].re + z[3][2].re;
-    out[2].im = in[0].im + z[0][1].im + z[1][3].im + z[2][0].im + z[3][2].im;
-
-    out[3].re = in[0].re + z[0][2].re + z[1][0].re + z[2][3].re + z[3][1].re;
-    out[3].im = in[0].im + z[0][2].im + z[1][0].im + z[2][3].im + z[3][1].im;
-
-    out[4].re = in[0].re + z[0][3].re + z[1][2].re + z[2][1].re + z[3][0].re;
-    out[4].im = in[0].im + z[0][3].im + z[1][2].im + z[2][1].im + z[3][0].im;
+    FFTComplex z0[4], t[6];
+
+    t[0].re = in[3].re + in[12].re;
+    t[0].im = in[3].im + in[12].im;
+    t[1].im = in[3].re - in[12].re;
+    t[1].re = in[3].im - in[12].im;
+    t[2].re = in[6].re + in[ 9].re;
+    t[2].im = in[6].im + in[ 9].im;
+    t[3].im = in[6].re - in[ 9].re;
+    t[3].re = in[6].im - in[ 9].im;
+
+    out[0].re = in[0].re + in[3].re + in[6].re + in[9].re + in[12].re;
+    out[0].im = in[0].im + in[3].im + in[6].im + in[9].im + in[12].im;
+
+    t[4].re = exptab[0].re * t[2].re - exptab[1].re * t[0].re;
+    t[4].im = exptab[0].re * t[2].im - exptab[1].re * t[0].im;
+    t[0].re = exptab[0].re * t[0].re - exptab[1].re * t[2].re;
+    t[0].im = exptab[0].re * t[0].im - exptab[1].re * t[2].im;
+    t[5].re = exptab[0].im * t[3].re - exptab[1].im * t[1].re;
+    t[5].im = exptab[0].im * t[3].im - exptab[1].im * t[1].im;
+    t[1].re = exptab[0].im * t[1].re + exptab[1].im * t[3].re;
+    t[1].im = exptab[0].im * t[1].im + exptab[1].im * t[3].im;
+
+    z0[0].re = t[0].re - t[1].re;
+    z0[0].im = t[0].im - t[1].im;
+    z0[1].re = t[4].re + t[5].re;
+    z0[1].im = t[4].im + t[5].im;
+
+    z0[2].re = t[4].re - t[5].re;
+    z0[2].im = t[4].im - t[5].im;
+    z0[3].re = t[0].re + t[1].re;
+    z0[3].im = t[0].im + t[1].im;
+
+    out[1].re = in[0].re + z0[3].re;
+    out[1].im = in[0].im + z0[0].im;
+    out[2].re = in[0].re + z0[2].re;
+    out[2].im = in[0].im + z0[1].im;
+    out[3].re = in[0].re + z0[1].re;
+    out[3].im = in[0].im + z0[2].im;
+    out[4].re = in[0].re + z0[0].re;
+    out[4].im = in[0].im + z0[3].im;
 }
 
-static void fft15(IMDCT15Context *s, FFTComplex *out, const FFTComplex *in,
-                  ptrdiff_t stride)
+static inline void fft15(const FFTComplex exptab[22], FFTComplex *out,
+                         const FFTComplex *in, size_t stride)
 {
-    const FFTComplex *exptab = s->exptab[0];
-    FFTComplex tmp[5];
-    FFTComplex tmp1[5];
-    FFTComplex tmp2[5];
     int k;
+    FFTComplex tmp1[5], tmp2[5], tmp3[5];
 
-    fft5(tmp,  in,              stride * 3);
-    fft5(tmp1, in +     stride, stride * 3);
-    fft5(tmp2, in + 2 * stride, stride * 3);
+    fft5(exptab + 19, tmp1, in + 0);
+    fft5(exptab + 19, tmp2, in + 1);
+    fft5(exptab + 19, tmp3, in + 2);
 
     for (k = 0; k < 5; k++) {
-        FFTComplex t1, t2;
-
-        CMUL(t1, tmp1[k], exptab[k]);
-        CMUL(t2, tmp2[k], exptab[2 * k]);
-        out[k].re = tmp[k].re + t1.re + t2.re;
-        out[k].im = tmp[k].im + t1.im + t2.im;
-
-        CMUL(t1, tmp1[k], exptab[k + 5]);
-        CMUL(t2, tmp2[k], exptab[2 * (k + 5)]);
-        out[k + 5].re = tmp[k].re + t1.re + t2.re;
-        out[k + 5].im = tmp[k].im + t1.im + t2.im;
-
-        CMUL(t1, tmp1[k], exptab[k + 10]);
-        CMUL(t2, tmp2[k], exptab[2 * k + 5]);
-        out[k + 10].re = tmp[k].re + t1.re + t2.re;
-        out[k + 10].im = tmp[k].im + t1.im + t2.im;
+        FFTComplex t[2];
+
+        CMUL(t[0], tmp2[k], exptab[k]);
+        CMUL(t[1], tmp3[k], exptab[2 * k]);
+        out[stride*k].re = tmp1[k].re + t[0].re + t[1].re;
+        out[stride*k].im = tmp1[k].im + t[0].im + t[1].im;
+
+        CMUL(t[0], tmp2[k], exptab[k + 5]);
+        CMUL(t[1], tmp3[k], exptab[2 * (k + 5)]);
+        out[stride*(k + 5)].re = tmp1[k].re + t[0].re + t[1].re;
+        out[stride*(k + 5)].im = tmp1[k].im + t[0].im + t[1].im;
+
+        CMUL(t[0], tmp2[k], exptab[k + 10]);
+        CMUL(t[1], tmp3[k], exptab[2 * k + 5]);
+        out[stride*(k + 10)].re = tmp1[k].re + t[0].re + t[1].re;
+        out[stride*(k + 10)].im = tmp1[k].im + t[0].im + t[1].im;
     }
 }
 
-/*
- * FFT of the length 15 * (2^N)
- */
-static void fft_calc(IMDCT15Context *s, FFTComplex *out, const FFTComplex *in,
-                     int N, ptrdiff_t stride)
-{
-    if (N) {
-        const FFTComplex *exptab = s->exptab[N];
-        const int len2 = 15 * (1 << (N - 1));
-        int k;
-
-        fft_calc(s, out,        in,          N - 1, stride * 2);
-        fft_calc(s, out + len2, in + stride, N - 1, stride * 2);
-
-        for (k = 0; k < len2; k++) {
-            FFTComplex t;
-
-            CMUL(t, out[len2 + k], exptab[k]);
-
-            out[len2 + k].re = out[k].re - t.re;
-            out[len2 + k].im = out[k].im - t.im;
-
-            out[k].re += t.re;
-            out[k].im += t.im;
-        }
-    } else
-        fft15(s, out, in, stride);
-}
-
 static void imdct15_half(IMDCT15Context *s, float *dst, const float *src,
                          ptrdiff_t stride, float scale)
 {
+    FFTComplex fft15in[15];
     FFTComplex *z = (FFTComplex *)dst;
-    const int len8 = s->len4 / 2;
-    const float *in1 = src;
-    const float *in2 = src + (s->len2 - 1) * stride;
-    int i;
-
-    for (i = 0; i < s->len4; i++) {
-        FFTComplex tmp = { *in2, *in1 };
-        CMUL(s->tmp[i], tmp, s->twiddle_exptab[i]);
-        in1 += 2 * stride;
-        in2 -= 2 * stride;
+    int i, j, len8 = s->len4 >> 1, l_ptwo = 1 << s->ptwo_fft.nbits;
+    const float *in1 = src, *in2 = src + (s->len2 - 1) * stride;
+
+    /* Reindex input, putting it into a buffer and doing an Nx15 FFT */
+    for (i = 0; i < l_ptwo; i++) {
+        for (j = 0; j < 15; j++) {
+            const int k = s->pfa_prereindex[i*15 + j];
+            FFTComplex tmp = { *(in2 - 2*k*stride), *(in1 + 2*k*stride) };
+            CMUL(fft15in[j], tmp, s->twiddle_exptab[k]);
+        }
+        fft15(s->exptab, s->tmp + s->ptwo_fft.revtab[i], fft15in, l_ptwo);
     }
 
-    fft_calc(s, z, s->tmp, s->fft_n, 1);
+    /* Then a 15xN FFT (where N is a power of two) */
+    for (i = 0; i < 15; i++)
+        s->ptwo_fft.fft_calc(&s->ptwo_fft, s->tmp + l_ptwo*i);
 
+    /* Reindex again, apply twiddles and output */
     for (i = 0; i < len8; i++) {
-        float r0, i0, r1, i1;
-
-        CMUL3(r0, i1, z[len8 - i - 1].im, z[len8 - i - 1].re,  s->twiddle_exptab[len8 - i - 1].im, s->twiddle_exptab[len8 - i - 1].re);
-        CMUL3(r1, i0, z[len8 + i].im,     z[len8 + i].re,      s->twiddle_exptab[len8 + i].im,     s->twiddle_exptab[len8 + i].re);
-        z[len8 - i - 1].re = scale * r0;
-        z[len8 - i - 1].im = scale * i0;
-        z[len8 + i].re     = scale * r1;
-        z[len8 + i].im     = scale * i1;
+        float re0, im0, re1, im1;
+        const int i0 = len8 + i, i1 = len8 - i - 1;
+        const int s0 = s->pfa_postreindex[i0], s1 = s->pfa_postreindex[i1];
+
+        CMUL3(re0, im1, s->tmp[s1].im, s->tmp[s1].re,  s->twiddle_exptab[i1].im, s->twiddle_exptab[i1].re);
+        CMUL3(re1, im0, s->tmp[s0].im, s->tmp[s0].re,  s->twiddle_exptab[i0].im, s->twiddle_exptab[i0].re);
+        z[i1].re = scale * re0;
+        z[i1].im = scale * im0;
+        z[i0].re = scale * re1;
+        z[i0].im = scale * im1;
     }
 }
index 7a58aac..a31f11e 100644 (file)
 
 #include <stddef.h>
 
-#include "avfft.h"
+#include "fft.h"
 
 typedef struct IMDCT15Context {
     int fft_n;
     int len2;
     int len4;
+    int *pfa_prereindex;
+    int *pfa_postreindex;
+
+    FFTContext ptwo_fft;
 
     FFTComplex *tmp;
 
     FFTComplex *twiddle_exptab;
 
-    FFTComplex *exptab[6];
+    /* 0 - 18: fft15 twiddles, 19 - 20: fft5 twiddles */
+    FFTComplex exptab[21];
 
     /**
      * Calculate the middle half of the iMDCT