2012年2月18日 星期六

TV-L1 based Optical Flow using GPU

Mr. Chris Mcclanahan has implemented a TV-L1 based Optical Flow based on LibJacket (see here). However, is is not compatible the newer version of LibJacket...well...there is no more LibJacket. Now AccelerEyes announced ArrayFire, which has a huge gap to LibJacket. Thus, I managed modified his work as the following:

//
// Chris McClanahan - 2011
// Modified by Chao-Hui Huang 2012
//
// Adapted from: http://gpu4vision.icg.tugraz.at/index.php?content=downloads.php
//   "An Improved Algorithm for TV-L1 Optical Flow"
//
// More info: http://mcclanahoochie.com/blog/portfolio/gpu-tv-l1-optical-flow-with-libjacket/
//

#include <iostream>
#include <fstream>
#include <stdio.h>
#include <math.h>
#include <string.h>

#include <opencv/cv.h>
#include <opencv/cxcore.h>
#include <opencv/highgui.h>

#include <arrayfire.h>

using namespace std;
using namespace cv;
using namespace af;

// control
const float pfactor = 0.7;    // scale each pyr level by this amount
const int max_plevels = 9;    // number of pyramid levels
const int max_iters = 6;      // u v w update loop
const float lambda = 40;      // smoothness constraint
const int max_warps = 3;      // warping u v warping
const int min_img_sz = 20;    // min mxn img in pyramid
#define TIMING 0              // warmup, then average multiple runs

// functions
int  grab_frame(Mat& img, char* filename);
void create_pyramids(array& im1, array& im2, array& pyr1, array& pyr2);
void process_pyramids(array& pyr1, array& pyr2, array& u, array& v);
void tv_l1_dual(array& u, array& v, array& p, array& w, array& I1, array& I2, int level);
void optical_flow_tvl1(Mat& img1, Mat& img2, Mat& u, Mat& v);
void display_flow(array& I2, array& u, array& v);
void MatToFloat(const Mat& thing, float* thing2);
void FloatToMat(float const* thing, Mat& thing2);

// misc
int plevels = max_plevels;
const int n_dual_vars = 6;
static int cam_init = 0;
static int pyr_init = 0;
VideoCapture  capture;
int pyr_M[max_plevels + 1];
int pyr_N[max_plevels + 1];
array pyr1, pyr2;

// macros
#define MSG(msg,...) do {                                   \
                fprintf(stdout,__FILE__":%d(%s) " msg "\n",     \
                        __LINE__, __FUNCTION__, ##__VA_ARGS__); \
                fflush(stdout);                                 \
            } while (0)

#define M_PI 3.14159265358979323846

// ===== main =====
void optical_flow_tvl1(Mat& img1, Mat& img2, Mat& mu, Mat& mv) {

    // extract cv image 1
    Mat mi1(img1.rows, img1.cols, CV_8UC1);
    cvtColor(img1.t(), mi1, CV_BGR2GRAY);
    mi1.convertTo(mi1, CV_32FC1);
    float* fi1 = (float*)mi1.data;
    array I1 = array(img1.rows, img1.cols, fi1) / 255.0f;

    // extract cv image 2
    Mat mi2(img2.rows, img2.cols, CV_8UC1);
    cvtColor(img2.t(), mi2, CV_BGR2GRAY);
    mi2.convertTo(mi2, CV_32FC1);
    float* fi2 = (float*)mi2.data;
    array I2 = array(img2.rows, img2.cols, fi2) / 255.0f;

#if TIMING
    // runs
    int nruns = 4;
    // warmup
    create_pyramids(I1, I2, pyr1, pyr2);
    f32 ou, ov;
    process_pyramids(pyr1, pyr2, ou, ov);
    // timing
    timer::tic();
    for (int i = 0; i < nruns; ++i) {
        create_pyramids(I1, I2, pyr1, pyr2);
        process_pyramids(pyr1, pyr2, ou, ov);
    }
    MSG("fps: %f", 1.0f / (timer::toc() / (float)nruns));
#else
    // timing
    timer::tic();
    // pyramids
    create_pyramids(I1, I2, pyr1, pyr2);
    // flow
    array ou, ov;
    process_pyramids(pyr1, pyr2, ou, ov);
    // timing
    MSG("fps: %f", 1.0f / (timer::toc()));
#endif

    // output
#if 1
    // to opencv
    FloatToMat((float*)ou.T().host(), mu);
    FloatToMat((float*)ov.T().host(), mv);
#else
    // to libjacket
    display_flow(I2, ou, ov);
#endif
}


void MatToFloat(const Mat& thing, float* thing2) {
    int tmp = 0;
    for (int i = 0; i < thing.rows; i++) {
        const float* fptr = thing.ptr<float>(i);
        for (int j = 0; j < thing.cols; j++)
        { thing2[tmp++] = fptr[j]; }
    }
}


void FloatToMat(float const* thing, Mat& thing2) {
    int tmp = 0;
    for (int i = 0; i < thing2.rows; ++i) {
        float* fptr = thing2.ptr<float>(i);
        for (int j = 0; j < thing2.cols; ++j)
        { fptr[j] = thing[tmp++]; }
    }
}


void display_flow(array& I2, array& u, array& v) {
#if 1
    // show in libjacket
    palette("bone");
    subfigure(2, 2, 1); imgplot(I2);                  title("input");
    subfigure(2, 2, 2); imgplot(u);                   title("u");
    subfigure(2, 2, 3); imgplot(v);                   title("v");
    subfigure(2, 2, 4); imgplot((abs(v) + abs(u)));   title("u+v");
    // int M = I2.dims()[0];
    // int N = I2.dims()[1];
    // f32 idx, idy; meshgrid(idx, idy, f32(seq(0,N-1,3)), f32(seq(0,M-1,3)));
    // quiver(idx,idy,u,v);
    draw();
#else
    // show in opencv
    int M = I2.dims()[0];
    int N = I2.dims()[1];
    Mat mu(M, N, CV_32FC1);
    Mat mv(M, N, CV_32FC1);
    FloatToMat(u.T().host(), mu);
    FloatToMat(v.T().host(), mv);
    imshow("u", mu);
    imshow("v", mv);
#endif
}


void display_flow(const Mat& u, const Mat& v) {
#if 0
    cv::Mat magnitude, angle, bgr;
    cv::cartToPolar(u, v, magnitude, angle, true);
    double mag_max, mag_min;
    cv::minMaxLoc(magnitude, &mag_min, &mag_max);
    magnitude.convertTo(magnitude, -1, 1.0 / mag_max);
    cv::Mat _hsv[3], hsv_image;
    _hsv[0] = angle;
    _hsv[1] = Mat::ones(angle.size(), CV_32F);
    _hsv[2] = magnitude;
    cv::merge(_hsv, 3, hsv_image);
#else
    cv::Mat magnitude, angle, bgr;
    Mat hsv_image(u.rows, u.cols, CV_8UC3);
    for (int i = 0; i < u.rows; ++i) {
        const float* x_ptr = u.ptr<float>(i);
        const float* y_ptr = v.ptr<float>(i);
        uchar* hsv_ptr = hsv_image.ptr<uchar>(i);
        for (int j = 0; j < u.cols; ++j, hsv_ptr += 3, ++x_ptr, ++y_ptr) {
            hsv_ptr[0] = (uchar)((atan2f(*y_ptr, *x_ptr) / M_PI + 1) * 90);
            hsv_ptr[1] = hsv_ptr[2] = (uchar) std::min<float>(
                                          sqrtf(*y_ptr * *y_ptr + *x_ptr * *x_ptr) * 20, 255.0);
        }
    }
#endif
    cv::cvtColor(hsv_image, bgr, CV_HSV2BGR);
    cv::imshow("optical flow", bgr);
}


int grab_frame(Mat& img, char* filename) {

    // camera/image setup
    if (!cam_init) {
        if (filename != NULL) {
            capture.open(filename);
        } else {
            float rescale = 0.615;
            int w = 640 * rescale;
            int h = 480 * rescale;
            capture.open(0); //try to open
            capture.set(CV_CAP_PROP_FRAME_WIDTH, w);  capture.set(CV_CAP_PROP_FRAME_HEIGHT, h);
        }
        if (!capture.isOpened()) { cerr << "open video device fail\n" << endl; return 0; }
        capture >> img; capture >> img;
        if (img.empty()) { cout << "load image fail " << endl; return 0; }
        namedWindow("cam", CV_WINDOW_KEEPRATIO);
        printf(" img = %d x %d \n", img.cols, img.rows);
        cam_init = 1;
    }

    // get frames
    capture.grab();
    capture.retrieve(img);
    imshow("cam", img);

    if (waitKey(10) >= 0) { return 0; }
    else { return 1; }
}


void gen_pyramid_sizes(array& im1) {
    dim4 mnk = im1.dims();
    float sM = mnk[0];
    float sN = mnk[1];
    // store resizing
    for (int level = 0; level <= plevels; ++level) {
        if (level == 0) {
        } else {
            sM *= pfactor;
            sN *= pfactor;
        }
        pyr_M[level] = (int)(sM + 0.5f);
        pyr_N[level] = (int)(sN + 0.5f);
        MSG(" pyr %d: %d x %d ", level, (int)sM, (int)sN);
        if (sM < min_img_sz || sN < min_img_sz) { plevels = level; break; }
    }
}

void create_pyramids(array& im1, array& im2, array& pyr1, array& pyr2) {

    if (!pyr_init) {
        // list of h,w
        gen_pyramid_sizes(im1);

        // init
        pyr1 = zeros(pyr_M[0], pyr_N[0], plevels);
        pyr2 = zeros(pyr_M[0], pyr_N[0], plevels);
        pyr_init = 1;
    }

    // create
    for (int level = 0; level < plevels; level++) {
        if (level == 0) {
            pyr1(span, span, level) = im1;
            pyr2(span, span, level) = im2;
        } else {
            seq spyi = seq(pyr_M[level - 1]);
            seq spxi = seq(pyr_N[level - 1]);
            array small1 = resize(pyr1(spyi, spxi, level - 1), pyr_M[level], pyr_N[level], AF_RSZ_Bilinear);
            array small2 = resize(pyr2(spyi, spxi, level - 1), pyr_M[level], pyr_N[level], AF_RSZ_Bilinear);
            seq spyo = seq(pyr_M[level]);
            seq spxo = seq(pyr_N[level]);
            pyr1(spyo, spxo, level) = small1;
            pyr2(spyo, spxo, level) = small2;
        }
    }
}


void process_pyramids(array& pyr1, array& pyr2, array& ou, array& ov) {
    array p, u, v, w;

    // pyramid loop
    for (int level = plevels - 1; level >= 0; level--) {
        if (level == plevels - 1) {
            u  = zeros(pyr_M[level], pyr_N[level]);
            v  = zeros(pyr_M[level], pyr_N[level]);
            w  = zeros(pyr_M[level], pyr_N[level]);
            p  = zeros(pyr_M[level], pyr_N[level], n_dual_vars);
        } else {
            float rescale_u =  pyr_N[level + 1] / (float)pyr_N[level];
            float rescale_v =  pyr_M[level + 1] / (float)pyr_M[level];
            // propagate
            array u_ =  resize(u, pyr_M[level], pyr_N[level], AF_RSZ_Bilinear) * rescale_u;
            array v_ =  resize(v, pyr_M[level], pyr_N[level], AF_RSZ_Bilinear) * rescale_v;
            array w_ =  resize(w, pyr_M[level], pyr_N[level], AF_RSZ_Bilinear);
            array p_ = zeros(pyr_M[level], pyr_N[level], n_dual_vars);
            gfor(array ndv, n_dual_vars) {
                p_(span, span, ndv) = resize(p(span, span, ndv), pyr_M[level], pyr_N[level], AF_RSZ_Bilinear);
            }
            u = u_;  v = v_;  p = p_;  w = w_;
        }

        // extract
        seq spy = seq(pyr_M[level]);
        seq spx = seq(pyr_N[level]);
        array I1 = pyr1(spy, spx, level);
        array I2 = pyr2(spy, spx, level);

        // ===== core ====== //
        tv_l1_dual(u, v, p, w, I1, I2, level);
        // ===== ==== ====== //
    }

    // output
    ou = u;
    ov = v;
}


void warping(array& Ix, array& Iy, array& It, array& I1, array& I2, array& u, array& v) {

    dim4 mnk = I2.dims();
    int M = mnk[0];
    int N = mnk[1];
    array idx = tile(array(seq(N)).T(), M, 1) + 1;
    array idy = tile(array(seq(M)), 1, N) + 1;
    /* ^ BUG: idx idy should ideally be [0-N); ^ */

    array idxx0 = idx + u;
    array idyy0 = idy + v;
    array idxx = max(1, min(N - 1, idxx0));
    array idyy = max(1, min(M - 1, idyy0));

    // interp2 based warp ()
    It = interp(idy, idx, I2, idyy, idxx) - I1;

    // interp2 based warp ()
    array idxm = max(1, min(N - 1, idxx - 1.f));
    array idxp = max(1, min(N - 1, idxx + 1.f));
    array idym = max(1, min(M - 1, idyy - 1.f));
    array idyp = max(1, min(M - 1, idyy + 1.f));
    Ix = interp(idy, idx, I2, idy, idxp) - interp(idy, idx, I2, idy, idxm);
    Iy = interp(idy, idx, I2, idyp, idx) - interp(idy, idx, I2, idym, idx);
    /* ^ BUG: interp2 should be cubic; that may fix things; ^ */
}


void dxym(array& Id, array I0x, array I0y) {
    // divergence
    dim4 mnk = I0x.dims();
    int M = mnk[0];
    int N = mnk[1];

    array x0 = zeros(M, N);
    array x1 = zeros(M, N);
    x0(seq(M - 1), seq(N)) = I0x(seq(M - 1), seq(N));
    x1(seq(1,M-1), seq(N)) = I0x(seq(1,M-1), seq(N));

    array y0 = zeros(M, N);
    array y1 = zeros(M, N);
    y0(seq(M), seq(N - 1)) = I0y(seq(M), seq(N - 1));
    y1(seq(M), seq(1,N-1)) = I0y(seq(M), seq(1,N-1));

    Id = (x0 - x1) + (y0 - y1);
}


void dxyp(array& Ix, array& Iy, array& I0) {
    // shifts
    dim4 mnk = I0.dims();
    int M = mnk[0];
    int N = mnk[1];

    array y0 = I0;
    array y1 = I0;
    y0(seq(0, M - 2), span) = I0(seq(1, M - 1), span);

    array x0 = I0;
    array x1 = I0;
    x0(span, seq(0, N - 2)) = I0(span, seq(1, N - 1));

    Ix = (x0 - x1);  Iy = (y0 - y1);
}


void tv_l1_dual(array& u, array& v, array& p, array& w, array& I1, array& I2, int level) {
    try {
    float L = sqrtf(8.0f);
    float tau   = 1 / L;
    float sigma = 1 / L;

    float eps_u = 0.01f;
    float eps_w = 0.01f;
    float gamma = 0.02f;

    array u_ = u;
    array v_ = v;
    array w_ = w;

    for (int j = 0; j < max_warps; j++) {

        array u0 = u;
        array v0 = v;

        // warping
        array Ix, Iy, It;   warping(Ix, Iy, It, I1, I2, u0, v0);

        // gradients
        array I_grad_sqr = max(float(1e-6), array(pow(Ix, 2) + pow(Iy, 2) + gamma * gamma));

        // inner loop
        for (int k = 0; k < max_iters; ++k) {

            // dual =====

            // shifts
            array u_x, u_y;    dxyp(u_x, u_y, u_);
            array v_x, v_y;    dxyp(v_x, v_y, v_);
            array w_x, w_y;    dxyp(w_x, w_y, w_);

            // update dual
            p(span, span, 0) = (p(span, span, 0) + sigma * u_x) / (1 + sigma * eps_u);
            p(span, span, 1) = (p(span, span, 1) + sigma * u_y) / (1 + sigma * eps_u);
            p(span, span, 2) = (p(span, span, 2) + sigma * v_x) / (1 + sigma * eps_u);
            p(span, span, 3) = (p(span, span, 3) + sigma * v_y) / (1 + sigma * eps_u);

            p(span, span, 4) = (p(span, span, 4) + sigma * w_x) / (1 + sigma * eps_w);
            p(span, span, 5) = (p(span, span, 5) + sigma * w_y) / (1 + sigma * eps_w);

            // normalize
            array reprojection = max(1, sqrt(pow(p(span, span, 0), 2) + pow(p(span, span, 1), 2) +
                                           pow(p(span, span, 2), 2) + pow(p(span, span, 3), 2)));

            p(span, span, 0) = p(span, span, 0) / reprojection;
            p(span, span, 1) = p(span, span, 1) / reprojection;
            p(span, span, 2) = p(span, span, 2) / reprojection;
            p(span, span, 3) = p(span, span, 3) / reprojection;

            reprojection = max(1, sqrt(pow(p(span, span, 4), 2) + pow(p(span, span, 5), 2)));

            p(span, span, 4) = p(span, span, 4) / reprojection;
            p(span, span, 5) = p(span, span, 5) / reprojection;

            // primal =====

            // divergence
            array div_u;   dxym(div_u, p(span, span, 0), p(span, span, 1));
            array div_v;   dxym(div_v, p(span, span, 2), p(span, span, 3));
            array div_w;   dxym(div_w, p(span, span, 4), p(span, span, 5));

            // old
            u_ = u;
            v_ = v;
            w_ = w;

            // update
            u = u + tau * div_u;
            v = v + tau * div_v;
            w = w + tau * div_w;

            // indexing
            array rho  = It + mul((u - u0), Ix) + mul((v - v0), Iy) + gamma * w;
            array idx1 = rho      <  -tau * lambda * I_grad_sqr;
            array idx2 = rho      >   tau * lambda * I_grad_sqr;
            array idx3 = abs(rho) <=  tau * lambda * I_grad_sqr;

            u = u + tau * lambda * (mul(Ix, idx1)) ;
            v = v + tau * lambda * (mul(Iy, idx1)) ;
            w = w + tau * lambda * gamma * idx1;

            u = u - tau * lambda * (mul(Ix, idx2)) ;
            v = v - tau * lambda * (mul(Iy, idx2)) ;
            w = w - tau * lambda * gamma * idx2;

            u = u - mul(rho, mul(idx3, Ix / I_grad_sqr));
            v = v - mul(rho, mul(idx3, Iy / I_grad_sqr));
            w = w - mul(rho, mul(idx3, gamma / I_grad_sqr));

            // propagate
            u_ = 2 * u - u_;
            v_ = 2 * v - v_;
            w_ = 2 * w - w_;

        }

        // output
        // const unsigned hw[] = {3, 3};
        u = medfilt(u, 3, 3);
        v = medfilt(v, 3, 3);

    } /* j < warps */


    } catch (af::exception& e) {
        cout << e.what() << endl;
        throw;
    }


}


// =======================================

int main(int argc, char* argv[]) {

    // video file or usb camera
    Mat cam_img, prev_img, disp_u, disp_v;
    int is_images = 0;
    if (argc == 2) { grab_frame(prev_img, argv[1]); } // video
    else if (argc == 3) {
        prev_img = imread(argv[1]); cam_img = imread(argv[2]);
        is_images = 1;
    } else { grab_frame(prev_img, NULL); } // usb camera

    // results
    int mm = prev_img.rows;  int nn = prev_img.cols;
    disp_u = Mat::zeros(mm, nn, CV_32FC1);
    disp_v = Mat::zeros(mm, nn, CV_32FC1);
    printf("img %d x %d \n", mm, nn);

    // process main
    if (is_images) {
        // show
        imshow("i", cam_img);
        // process files
        optical_flow_tvl1(prev_img, cam_img, disp_u, disp_v);
        // show
        // imshow("u", disp_u);
        // imshow("v", disp_v);
        display_flow(disp_u, disp_v);
        waitKey(0);
        // // write
        // writeFlo(disp_u, disp_v);
    } else {
        // process loop
        while (grab_frame(cam_img, NULL)) {
            try {
                // process
                optical_flow_tvl1(prev_img, cam_img, disp_u, disp_v);
                // frames
                prev_img = cam_img.clone();
                // show
                // imshow("u", disp_u);
                // imshow("v", disp_v);
                display_flow(disp_u, disp_v);
            } catch (af::exception& e) {
                cout << e.what() << endl;
                throw;
            }
        }
    }

    return 0;
}