// GRAPE-5 compatible APIs
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <math.h>
#include <assert.h>
#include "sse_type.h"
#include "g5util.h"
#ifdef CUTOFF_FORCE
#include "pg5_table.h"
#endif
#ifdef _OPENMP
#include <omp.h>
#endif

#define NUM_PIPE 4
#define JMEMSIZE 65536

#ifndef MAXDEV
#define MAXDEV 4
#endif
static struct Ptcl_Mem{
    Ipdata iptcl;
    Fodata fout;
    Jpdata jptcl[JMEMSIZE];
    int Nbody, pad[15];
} ptcl_mem[MAXDEV] ALIGN64;

static double Eps[MAXDEV];
static double Eta[MAXDEV];

// thread id -> device id
static int Tid2devid[MAXDEV];

#ifdef CUTOFF_FORCE
static double Xscale;
static v4sf XMscale;
static v4sf Ascale;
static v4sf R2cut_xscale2 = {
    (1<<(1+(1<<EXP_BIT))) - 3,
    (1<<(1+(1<<EXP_BIT))) - 3,
    (1<<(1+(1<<EXP_BIT))) - 3,
    (1<<(1+(1<<EXP_BIT))) - 3
};

void pg5_set_xscale(double xscale){
    Xscale = xscale;
    XMscale = (v4sf){xscale, xscale, xscale, 1.0};
    double ascale = 1./xscale;
    Ascale = (v4sf){ascale, ascale, ascale, ascale};
}
#else
static float Acc_correct = 1.0;
static float Pot_correct = +1.0;
static v4sf Acc_correctV = {1.0, 1.0, 1.0, 1.0};
static v4sf Pot_correctV = {+1.0, +1.0, +1.0, +1.0};
#endif

/******** GRAPE-5 APIs ********/

static int    Ndevice = 0;
static int    Device[MAXDEV];
static int hib_ndevice(void); // number of GRAPE devices available.
static void init_envs(void);

int g5_get_number_of_cards(void)
{
    init_envs();
    return Ndevice;
}

int g5_get_number_of_pipelinesMC(int devid)
{
    return NUM_PIPE;
}

// returns total number of pipelines.
int g5_get_number_of_pipelines(void)
{
    int ic;
    static int n = 0;

    if (n != 0) return n;

    for (ic = 0; ic < hib_ndevice(); ic++) {
        if (Device[ic] == 0) continue;
        n += g5_get_number_of_pipelinesMC(ic);
    }
    return n;
}

int g5_get_jmemsizeMC(int devid)
{
    // if (Device[devid] == 0) return 0;
    return JMEMSIZE;
}

// always returns JMEMSIZE.
// JPs are shared among all pipelines.
int g5_get_jmemsize(void)
{
    return JMEMSIZE;
}

void g5_openMC(int devid)
{
#ifdef PERIODIC_BOUNDARY
    pg5_gen_s2_force_table(SFT_FOR_PP, SFT_FOR_PM);
#else
    static int init_call = 1;
    if(init_call){
        double rsqrt_bias();
        double bias = rsqrt_bias();
        float acc_corr = 1.0 - 3*bias;
        float pot_corr = +(1.0 - bias);
        Acc_correct = acc_corr;
        Pot_correct = pot_corr;
        Acc_correctV = (v4sf){acc_corr, acc_corr, acc_corr, acc_corr}; 
        Pot_correctV = (v4sf){pot_corr, pot_corr, pot_corr, pot_corr}; 
        init_call = 0;
    }
#endif
}

void g5_open(void)
{
    int ic;
    init_envs();
    for (ic = 0; ic < hib_ndevice(); ic++) {
	if (Device[ic] == 0) continue;
        g5_openMC(ic);
    }
}

void g5_closeMC(int devid)
{
    // do nothing.
}

void g5_close(void)
{
    int ic;

    for (ic = 0; ic < hib_ndevice(); ic++) {
	if (Device[ic] == 0) continue;
        g5_closeMC(ic);
    }
}

void g5_set_etaMC(int devid, double eta)
{
    Eta[devid] = eta;
}

void g5_set_eta(double eta)
{
    int ic;

    for (ic = 0; ic < hib_ndevice(); ic++) {
        if (Device[ic] == 0) continue;
        g5_set_etaMC(ic, eta);
    }
}

void g5_set_eps_to_allMC(int devid, double eps)
{
    Eps[devid] = eps;
}

void g5_set_eps_to_all(double eps)
{
    int ic;

    for (ic = 0; ic < hib_ndevice(); ic++) {
        if (Device[ic] == 0) continue;
        g5_set_eps_to_allMC(ic, eps);
    }
}

void g5_set_eps2_to_allMC(int devid, double eps2)
{
    Eps[devid] = sqrt(eps2);
}

void g5_set_eps2_to_all(double eps2)
{
    int ic;

    for (ic = 0; ic < hib_ndevice(); ic++) {
        if (Device[ic] == 0) continue;
        g5_set_eps2_to_allMC(ic, eps2);
    }
}


void g5_set_rangeMC(int devid, double xmin, double xmax, double mmin)
{
    // do nothing.
}

void g5_set_range(double xmin, double xmax, double mmin)
{
    int ic;

    for (ic = 0; ic < hib_ndevice(); ic++) {
        if (Device[ic] == 0) continue;
        g5_set_rangeMC(ic, xmin, xmax, mmin);
    }
}


#ifdef PERIODIC_BOUNDARY
void g5_set_cutoff_tableMC(int devid, double (*ffunc)(double), double fcut, double fcor,
                         double (*pfunc)(double), double pcut, double pcor)
{
    static int firstcall = 1;

    if (firstcall == 1) {
        firstcall = 0;

        // pg5_gen_plummer_force_table();
        //please implement the function calls pg5_gen_force_table().

        fprintf(stderr, "Warning: cut-off function is not implemented in this revision "
                "of G5 pipeline. g5_set_cutoff_tableMC() has no effect.\n");
    }
}

void g5_set_cutoff_table(double (*ffunc)(double), double fcut, double fcor,
                         double (*pfunc)(double), double pcut, double pcor)
{
    int ic;

    for (ic = 0; ic < hib_ndevice(); ic++) {
	if (g5_cards[ic] == 0) continue;
	g5_set_cutoff_tableMC(ic, ffunc, fcut, fcor, pfunc, pcut, pcor);
    }
}
#endif

void g5_set_nMC(int devid, int n)
{
    struct Ptcl_Mem *pm = ptcl_mem + devid;
    pm->Nbody = n;
}

void g5_set_n(int n)
{
    int ic = 0; // jp memory of device0 is shared among all devices.
    g5_set_nMC(ic, n);
}

void g5_set_xiMC(int devid, int ni, double (*xi)[3]){
    int i;
    struct Ptcl_Mem *pm = ptcl_mem + devid;

    assert(ni <= g5_get_number_of_pipelinesMC(devid));
    for(i=0;i<ni;i++){
#ifdef CUTOFF_FORCE
        pm->iptcl.x[i] = (float)xi[i][0] * Xscale;
        pm->iptcl.y[i] = (float)xi[i][1] * Xscale;
        pm->iptcl.z[i] = (float)xi[i][2] * Xscale;
#else
        float eps2 = Eps[devid]*Eps[devid];
        pm->iptcl.x[i] = (float)xi[i][0];
        pm->iptcl.y[i] = (float)xi[i][1];
        pm->iptcl.z[i] = (float)xi[i][2];
        pm->iptcl.eps2[i] = eps2;
#endif
    }
}

void g5_set_xi(int ni, double (*x)[3])
{
    int ic, i0, nii;

    i0 = 0;
    nii = (ni + Ndevice - 1) / Ndevice;

    for (ic = 0; ic < hib_ndevice(); ic++) {
	if (Device[ic] == 0) continue;
	if (ni < i0 + nii) {
	    nii = ni - i0;
	}
	g5_set_xiMC(ic, nii, (double (*)[3])(x[i0]));
	i0 += nii;
    }
}

void g5_set_xmjMC(int devid, int adr, int nj, double (*xj)[3], double *mj)
{
    return g5_set_jpMC(devid, adr, nj, mj, xj);
}

void g5_set_xmj(int adr, int nj, double (*xj)[3], double *mj){
    return g5_set_jp(adr, nj, mj, xj);
}

void g5_set_jpMC(int devid, int adr, int nj, double *mj, double (*xj)[3])
{
    int j;
    struct Ptcl_Mem *pm = ptcl_mem + devid;

    for(j=adr; j<adr+nj; j++){
#if __GNUC__ ==  4
        v2df pd0 = {xj[j][0], xj[j][2]};
        v2df pd1 = {xj[j][1], mj[j]   };
#else
        v2df pd0, pd1;
        V2DF_GATHER(pd0, xj[j],   xj[j]+2);
        V2DF_GATHER(pd1, xj[j]+1, mj+j);
#endif
        v4sf ps0, ps1;
        ps0 = __builtin_ia32_cvtpd2ps(pd0);
        ps1 = __builtin_ia32_cvtpd2ps(pd1);
        ps0 = __builtin_ia32_unpcklps(ps0, ps1);
#ifdef CUTOFF_FORCE
        *(v4sf *)(pm->jptcl+j) = ps0 * XMscale;
#else
        *(v4sf *)(pm->jptcl+j) = ps0;
#endif
    }
}

void g5_set_jp(int adr, int nj, double *m, double (*x)[3])
{
    int ic = 0; // jp memory of device0 is shared among all devices.
    g5_set_jpMC(ic, adr, nj, m, x);
}
 
void g5_runMC(int devid)
{
    struct Ptcl_Mem *pm = ptcl_mem + devid;
#ifdef CUTOFF_FORCE
    void gravity_kernel(pIpdata, pJpdata, pFodata, int, float (*)[2], v4sf, v4sf);
    gravity_kernel(&pm->iptcl, pm->jptcl, &pm->fout, pm->Nbody, 
                   Force_table, R2cut_xscale2, Ascale);
#else
    void GravityKernel(pIpdata, pFodata, pJpdata, int);
    GravityKernel(&pm->iptcl, &pm->fout, pm->jptcl, pm->Nbody);
#endif
}

void g5_run(void)
{
    int ic;

    for (ic = 0; ic < hib_ndevice(); ic++) {
        if (Device[ic] == 0) continue;
        struct Ptcl_Mem *pm = ptcl_mem + ic;
#ifdef CUTOFF_FORCE
        void gravity_kernel(pIpdata, pJpdata, pFodata, int, float (*)[2], v4sf, v4sf);
        gravity_kernel(&pm->iptcl, ptcl_mem->jptcl, &pm->fout, ptcl_mem->Nbody, 
                       Force_table, R2cut_xscale2, Ascale);
#else
        void GravityKernel(pIpdata, pFodata, pJpdata, int);
        GravityKernel(&pm->iptcl, &pm->fout, ptcl_mem->jptcl, ptcl_mem->Nbody); // jp memory of device0 is shared among all devices.
#endif
    }
}

void g5_get_forceMC(int devid, int ni, double (*a)[3], double *pot){
    struct Ptcl_Mem *pm = ptcl_mem + devid;
#ifdef CUTOFF_FORCE
    int i;
    for(i=0;i<ni;i++){
        a[i][0] = (double)pm->fout.ax[i];
        a[i][1] = (double)pm->fout.ay[i];
        a[i][2] = (double)pm->fout.az[i];
        pot[i] = 0.0;
    }
#else
#if __GNUC__ == 4
    v4sf ax = *(v4sf *)(pm->fout.ax) * Acc_correctV;
    v4sf ay = *(v4sf *)(pm->fout.ay) * Acc_correctV;
    v4sf az = *(v4sf *)(pm->fout.az) * Acc_correctV;
    v4sf phi = *(v4sf *)(pm->fout.phi) * Pot_correctV;
    v4sf f0, f1, f2, f3;
    v4sf_transpose(&f0, &f1, &f2, &f3, ax, ay, az, phi);
    if (ni==4){
        v4sf_store_dp(f0, &a[0][0], &a[0][1], &a[0][2], &pot[0]);
        v4sf_store_dp(f1, &a[1][0], &a[1][1], &a[1][2], &pot[1]);
        v4sf_store_dp(f2, &a[2][0], &a[2][1], &a[2][2], &pot[2]);
        v4sf_store_dp(f3, &a[3][0], &a[3][1], &a[3][2], &pot[3]);
    }else if (ni==3){
        v4sf_store_dp(f0, &a[0][0], &a[0][1], &a[0][2], &pot[0]);
        v4sf_store_dp(f1, &a[1][0], &a[1][1], &a[1][2], &pot[1]);
        v4sf_store_dp(f2, &a[2][0], &a[2][1], &a[2][2], &pot[2]);
    }else if (ni==2){
        v4sf_store_dp(f0, &a[0][0], &a[0][1], &a[0][2], &pot[0]);
        v4sf_store_dp(f1, &a[1][0], &a[1][1], &a[1][2], &pot[1]);
    }else if (ni==1){
        v4sf_store_dp(f0, &a[0][0], &a[0][1], &a[0][2], &pot[0]);
    }
#else
    int i;
    for(i=0;i<ni;i++){
        a[i][0] = (double)(pm->fout.ax[i] * Acc_correct);
        a[i][1] = (double)(pm->fout.ay[i] * Acc_correct);
        a[i][2] = (double)(pm->fout.az[i] * Acc_correct);
        pot[i] =  (double)(pm->fout.phi[i] * Pot_correct);
    }
#endif
#endif
}

void g5_get_force(int ni, double (*a)[3], double *pot)
{
    int ic, i0, nii;

    i0 = 0;
    nii = (ni + Ndevice - 1) / Ndevice;

    for (ic = 0; ic < hib_ndevice(); ic++) {
	if (Device[ic] == 0) continue;
	if (ni < i0 + nii) {
	    nii = ni - i0;
	}
	g5_get_forceMC(ic, nii, (double (*)[3])(a[i0]), pot + i0);
	i0 += nii;
    }
}

#ifdef _OPENMP

void g5_calculate_force_on_x(double (*x)[3], double (*a)[3], double *p, int nitot)
{
    int off;
    static int nthread = 0;
    const int np = g5_get_number_of_pipelinesMC(0); // number of pipelines per device.

#pragma omp parallel for
    for(off=0; off<nitot; off+=np) {
        int ic = Tid2devid[omp_get_thread_num()];
        int ni = np < nitot-off ? np : nitot-off;

        if (nthread == 0) {
            nthread = omp_get_num_threads();
            printf("nthread:%d\n", nthread);
        }
        g5_set_xiMC(ic, ni, x+off);

        /*
         * This part should not be replaced with g5_runMC(ic). 
         * By g5_set_jp(), JPs are stored to Ptcl_Mem of device[0].
         * On the other hand, g5_runMC(ic) reads JPs from Ptcl_Mem of
         * device[ic], which causes wrong calculation result.
         */
        {
            void GravityKernel(pIpdata, pFodata, pJpdata, int);
            pIpdata ip = &ptcl_mem[ic].iptcl;
            pFodata fo = &ptcl_mem[ic].fout;
            pJpdata jp = ptcl_mem[0].jptcl;
            int nbody  = ptcl_mem[0].Nbody;
            GravityKernel(ip, fo, jp, nbody);
        }

        g5_get_forceMC(ic, ni, a+off, p+off);
    }
}

#else

void g5_calculate_force_on_x(double (*x)[3], double (*a)[3], double *p, int nitot)
{
    int off;
    const int np = g5_get_number_of_pipelines(); // total number of pipelines of all devices in use.

    for (off = 0; off < nitot; off += np) {
        int ni = np < nitot-off ? np : nitot-off;
        g5_set_xi(ni, x + off);
        g5_run();
        g5_get_force(ni, a+off, p+off);
    }
}

#endif


/*
 * local functions.
 */

static int hib_ndevice(void)
{
    return MAXDEV;
}

/*
 * initialize variables used for "standard" functions.  this
 * initialization is not necessary for "primitive" functions, or even
 * harmful for use in multi-threaded application.
 */
static void init_envs(void)
{
    int ic;
    char *p;
    char *cardno;

    if (Ndevice != 0) return;

    /* cards are not allocated yet.
       try to allocate cards specified by environment variable "GDEVICE".
       try to allocate all cards, if GDEVICE is not set. */

    p = getenv("GDEVICE");
    if (!p) { // for backward compatibility.
        p = getenv("G5_CARDS");
    }
    if (p) { // parse G5_DEVICE
        for (ic = 0; ic < hib_ndevice(); ic++) {
            Device[ic] = 0;
        }
        cardno = strtok(p, " ");
        while (cardno) {
            ic = atoi(cardno);
            if (ic < 0 || ic >= hib_ndevice()) {
                fprintf(stderr, "GDEVICE (or G5_CARDS) have device_id out of range: %d\n", ic);
                exit(2);
            }
            Device[ic] = 1;
            Ndevice++;
            cardno = strtok(NULL, " ");
        }
            
    }
    else { // GDEVICE is not set. use all devices.
        Ndevice = hib_ndevice();
        for (ic = 0; ic < hib_ndevice(); ic++) {
            if (ic < Ndevice) {
                Device[ic] = 1;
            }
            else {
                Device[ic] = 0;
            }
        }
    }

    /*
     * create a table for thread id to device id conversion.
     */
    int tid;
    for (tid = 0; tid < MAXDEV; tid++) {
        Tid2devid[tid] = 0;
    }
    ic = 0;
    for (tid = 0; tid < Ndevice; tid++) {
        for ( ; ic < hib_ndevice(); ic++) {
            if (Device[ic] == 1) {
                Tid2devid[tid] = ic;
                ic++;
                break;
            }
        }
    }

    /*
     * diagnostics
     */
    fprintf(stderr, "OpenMP:");
#ifdef _OPENMP
    fprintf(stderr, "enabled");
    omp_set_num_threads(Ndevice);
#else
    fprintf(stderr, "disabled");
#endif
    fprintf(stderr, "  Ndevice:%d  MAXDEV:%d  ", Ndevice, MAXDEV);
    fprintf(stderr, "dev[ ");
    for (ic = 0; ic < MAXDEV; ic++) {
        fprintf(stderr, "%d ", Device[ic]);
    }
    fprintf(stderr, "]  ");

    fprintf(stderr, "Tid2devid[ ");
    for (tid = 0; tid < Ndevice; tid++) {
        fprintf(stderr, "%d ", Tid2devid[tid]);
    }
    fprintf(stderr, "]\n");
}
