#ifndef __MEMORY_KERNELS_GPU_CUH__
#define __MEMORY_KERNELS_GPU_CUH__

#include <vector_types.h>
#include <cuda_runtime.h>

#include "cuda_declarations.h"
#include "data_management.cuh"
#include "thread_manager.h"
#include "sequences.h"
#include "main_cu.h"
#include "exceptions.h"
#include "hi_res_timer.h"

#include "defines_gpu.cuh"

/*******************************************************************************
 * Function reorderPrePPLsInMemory should be used to change order of global    *
 * memory allocated for prePPL in the way that the consecutive letters of      *
 * the sequence to be in successive global memory cells.                       *
 *******************************************************************************/

#define BLOCK_X_SIZE  ALIGNMENT_MATCH_BLOCK_X_SIZE
#define MEM_OFFSET    memOffset

__global__ void reorderPrePPLsInMemory(unsigned int* input, unsigned int* output, int maxAlignmentLength)
{

    // blocks must be launched in a grid with shape: dim3(n,1,1)
    //  ____   ____   ____   ____
    // |____| |____| |____| |____| ...
    //
    // each block transcripts 16 sequences using shared memory

    //the number of sequence that is transcripted by this thread
    int seqNoRead = blockIdx.x * BLOCK_X_SIZE + threadIdx.x;
    int seqNoWrite = blockIdx.x * BLOCK_X_SIZE + threadIdx.y;

    //BLOCK_X_SIZE + 1 -> to avoid bank conflicts
    __shared__ unsigned int shmMatches[BLOCK_X_SIZE][BLOCK_X_SIZE + 1];

    unsigned int fetch;

    //14 -> 0, 15 -> 0, 16 -> 16, 17 -> 16 ...
    int end = (maxAlignmentLength / BLOCK_X_SIZE) * BLOCK_X_SIZE;

    //main loop
    for (int i = 0; i < end; i += BLOCK_X_SIZE)
    {
        fetch = input[seqNoRead + (i + threadIdx.y) * MEM_OFFSET];


        shmMatches[threadIdx.y][threadIdx.x] = fetch;

        __syncthreads();

        output[seqNoWrite * maxAlignmentLength + i + threadIdx.x] = shmMatches[threadIdx.x][threadIdx.y];

        __syncthreads();

    }

    //transcripting the end of sequecne (if maxMatchLength % BLOCK_X_SIZE != 0)
    if (end + threadIdx.y < maxAlignmentLength)
    {
        fetch = input[seqNoRead + (end + threadIdx.y) * MEM_OFFSET];

        shmMatches[threadIdx.y][threadIdx.x] = fetch;
    }

    __syncthreads();

    if (end + threadIdx.x < maxAlignmentLength)
    {
        output[seqNoWrite * maxAlignmentLength + end + threadIdx.x] = shmMatches[threadIdx.x][threadIdx.y];
    }
}

#undef BLOCK_X_SIZE
#undef MEM_OFFSET

#define PROPER_SYNC_THREAD 1

#define BLOCK_X_SIZE  ALIGNMENT_MATCH_BLOCK_X_SIZE
#define BLOCK_Y_SIZE  ALIGNMENT_MATCH_BLOCK_X_SIZE
#define MEM_OFFSET    memOffset
#define seqXNo       (blockIdx.x * blockDim.x + threadIdx.y)
#define seqYNo       (blockIdx.y)
#define MAX_K        11
#define WIN_SIZE     winSize

__global__ void sortAndMergePrePPL(unsigned int* input, unsigned int* output, int maxAlignmentLength, short K, unsigned int windowX, unsigned int windowY, unsigned int* lengths, unsigned int sequenceCount)
{

    // block size should be:  num_of_threads_in_warp x num_of_aligment_processed
    // here is:               BLOCK_X_SIZE           x BLOCK_Y_SIZE
    // on Fermi could be e.g. 32                     x 16

    
    // blocks must be launched in a grid with shape: dim3(WIN_SIZE/BLOCK_Y_SIZE,WIN_SIZE,1)
    //  _____   _____   _____   _____
    // |_____| |_____| |_____| |_____|
    //  _____   _____   _____   _____
    // |_____| |_____| |_____| |_____|...
    //
    // each block merge and sort BLOCK_Y_SIZE sequences using shared memory


    bool overDiagonal = ((windowX == windowY) && (seqXNo > seqYNo)) || (windowX > windowY);

//    if (seqYNo >= sequenceCount)
//        return;
//    if (seqXNo >= sequenceCount)
//        return;
    
    // shmInputAlignments[no_of_aligment_processed][input_alignment_no][index_within_alignment]
    //
    // input_alignment_no: 0 = NW, 1 = first local alignment, 2 = second local...
    // threadIdx.y: indicates (within a block) which alignment is processed
    //
    // e.g. shmInputAlignments[threadIdx.y][0..K][threadIdx.x]
    __shared__ unsigned int shmInputAlignments[BLOCK_Y_SIZE][MAX_K+1][BLOCK_X_SIZE + 1];

    __shared__ unsigned int shmOutputAlignment[BLOCK_Y_SIZE][BLOCK_X_SIZE];

    __shared__ unsigned int shmTmp[BLOCK_Y_SIZE][BLOCK_X_SIZE];

    // - how many elements left in each from K+1 read buffers
    // - reading position (carret) in global memory for input alignments
    __shared__ ushort2 shmInLeftCount_globalCarret[BLOCK_Y_SIZE][MAX_K+1];

    // - writing position (carret) in shared memory for output alignments
    // - writing position (carret) in global memory for output alignments
    __shared__ ushort2 shmOutLocalCarret_globalCarret[BLOCK_Y_SIZE];

    #ifdef PROPER_SYNC_THREAD
    //determines if computing of a given alignment has been finished
    __shared__ bool running[BLOCK_Y_SIZE];
    __shared__ bool mayBeFinished;
    #endif


    short2 lengthXY;

    unsigned int* myInputAlign;
    unsigned int* myOutputAlign = &output[seqYNo*WIN_SIZE*maxAlignmentLength*(K+1) + seqXNo*maxAlignmentLength*(K+1)];

    if (!overDiagonal)
    {
        myInputAlign = &input[seqYNo*WIN_SIZE*maxAlignmentLength + seqXNo*maxAlignmentLength];

        lengthXY.x = tex1Starts[seqXNo + 1] - tex1Starts[seqXNo];
        lengthXY.y = tex2Starts[seqYNo + 1] - tex2Starts[seqYNo];
    }
    else
    {
        myInputAlign = &input[seqXNo*WIN_SIZE*maxAlignmentLength + seqYNo*maxAlignmentLength];

        lengthXY.x = tex1Starts[seqYNo + 1] - tex1Starts[seqYNo];
        lengthXY.y = tex2Starts[seqXNo + 1] - tex2Starts[seqXNo];
    }

    #ifndef PROPER_SYNC_THREAD
    if((lengthXY.x <= 0) || (lengthXY.y <= 0) || (seqXNo >= sequenceCount) || (seqYNo >= sequenceCount))//if there is nothing to do -> quit
    {
        if(threadIdx.x == 0)
        {
            lengths[seqYNo*WIN_SIZE + seqXNo] = 0; //just to not to have trashes in memory
        }
        return;
    }
    #endif


    for(short k=0; k<=K; k++)
    {
        shmInputAlignments[threadIdx.y][k][threadIdx.x] = myInputAlign[threadIdx.x + k*WIN_SIZE*WIN_SIZE*maxAlignmentLength];
    }
    if(threadIdx.x <= K)
    {
        shmInLeftCount_globalCarret[threadIdx.y][threadIdx.x].x = BLOCK_X_SIZE;
        shmInLeftCount_globalCarret[threadIdx.y][threadIdx.x].y = BLOCK_X_SIZE;
    }
    if(threadIdx.x == 0)
    {
        shmOutLocalCarret_globalCarret[threadIdx.y].x = 0;
        shmOutLocalCarret_globalCarret[threadIdx.y].y = 0;

        #ifdef PROPER_SYNC_THREAD
        if((lengthXY.x <= 0) || (lengthXY.y <= 0) || (seqXNo >= sequenceCount) || (seqYNo >= sequenceCount))//if there is nothing to do -> quit
        {
            running[threadIdx.y] = false;
            mayBeFinished = true;
            lengthXY.x = 0;
            lengthXY.y = 0;
        }
        else
        {
            running[threadIdx.y] = true;
            mayBeFinished = false;
        }
        #endif


    }

    unsigned int tmp;
    unsigned int shiftTmp;

/*
    while(true)
    {
        counter++;//todo ??
        shmTmp[threadIdx.y][threadIdx.x] = 0xFFFFFFFF; //because we search for min value
                                                       //and there are more threads (threadIdx.x)
                                                       //than K+1
        if(threadIdx.x <= K)
        {
            //reading K+1 first elements to merge
            //shmTmp[threadIdx.y][threadIdx.x] = shmInputAlignments[threadIdx.y][threadIdx.x][0];

            shmTmp[threadIdx.y][threadIdx.x] =
            myInputAlign[shmInLeftCount_globalCarret[threadIdx.y][threadIdx.x].y  + threadIdx.x*WIN_SIZE*WIN_SIZE*maxAlignmentLength];
        }



        shmTmp[threadIdx.y][threadIdx.x] &= 0xFFFFFF00;
        shmTmp[threadIdx.y][threadIdx.x] |= threadIdx.x; //storing input sequence no

        // We search for the minimal coordinate index (x and y) from first
        // elements of each input alignment (they are in shmTmp).
        // Binary search:
        for(short i=BLOCK_X_SIZE/2; i>0; i/=2)
        {
            if(threadIdx.x < i)
                shmTmp[threadIdx.y][threadIdx.x] = min(shmTmp[threadIdx.y][threadIdx.x], shmTmp[threadIdx.y][threadIdx.x + i]);
        }
        // At the end in shmTmp[threadIdx.y][0] is the minimal value
        // for threadIdx.y'th alignment.

        tmp = shmTmp[threadIdx.y][0] & 0x000000FF; //retrieving input sequence no (one from K+1)

        if( (shmTmp[threadIdx.y][0] & 0xFFFFFF00) == 0xFFFFFF00)
        {
            // end of input sources
            break;
        }


        // if the element is valid (no 0xFF as % identity)
        // then we append/add it to output
        if(threadIdx.x == 0)
        {

//            myOutputAlign[shmOutLocalCarret_globalCarret[threadIdx.y].y + threadIdx.x] = 5;
//                    //shmTmp[threadIdx.y][0];
//            shmInLeftCount_globalCarret[threadIdx.y][tmp].y++;
//            shmOutLocalCarret_globalCarret[threadIdx.y].y++;

            shmOutputAlignment[threadIdx.y][shmOutLocalCarret_globalCarret[threadIdx.y].x] =
                    shmTmp[threadIdx.y][0];
            shmInLeftCount_globalCarret[threadIdx.y][tmp].y++;
            shmOutLocalCarret_globalCarret[threadIdx.y].x++;

        }

        __syncthreads();
        myOutputAlign[shmOutLocalCarret_globalCarret[threadIdx.y].y + threadIdx.x] = shmOutLocalCarret_globalCarret[threadIdx.y].x;

        //__syncthreads();

        if(shmOutLocalCarret_globalCarret[threadIdx.y].x == BLOCK_X_SIZE)
        {

                //myOutputAlign[shmOutLocalCarret_globalCarret[threadIdx.y].y + threadIdx.x] =
                //    shmOutputAlignment[threadIdx.y][threadIdx.x];

            if(threadIdx.x == 0)
            {
                shmOutLocalCarret_globalCarret[threadIdx.y].y += BLOCK_X_SIZE;
                shmOutLocalCarret_globalCarret[threadIdx.y].x = 0;
            }
        }

    }

//    if(threadIdx.x == 0)
//        myOutputAlign[shmOutLocalCarret_globalCarret[threadIdx.y].y + threadIdx.x] = 0xFFFFFFFF;

    //flush output cache
    if(shmOutLocalCarret_globalCarret[threadIdx.y].x < BLOCK_X_SIZE)
    {
        if(threadIdx.x == 0)
        {
            //0xFFFFFFFF - end of alignment
            shmOutputAlignment[threadIdx.y][shmOutLocalCarret_globalCarret[threadIdx.y].x] = 0xFFFFFFFF;
        }
    }

    myOutputAlign[shmOutLocalCarret_globalCarret[threadIdx.y].y + threadIdx.x] =
        shmOutputAlignment[threadIdx.y][threadIdx.x];

    if(shmOutLocalCarret_globalCarret[threadIdx.y].x == BLOCK_X_SIZE)
    {
        if(threadIdx.x == 0)
        {
            //0xFFFFFFFF - end of alignment
            myOutputAlign[shmOutLocalCarret_globalCarret[threadIdx.y].y + BLOCK_X_SIZE] = 0xFFFFFFFF;
        }
    }
*/

    // The reason why we use __syncthreads() so often is that
    // shared memory writen by one thread is visible to other
    // threads after this barrier.
    // And because we use __syncthreads(), we must ensure all
    // threads within the block finishes its execution at the
    // same time (see 'running' array). Othrerwise the kernel
    // may behaves indefinitely.

    while(true)
    {        
        shmTmp[threadIdx.y][threadIdx.x] = 0xFFFFFFFF; //because we search for min value
                                                       //and there are more threads (threadIdx.x)
                                                       //than K+1
        __syncthreads();

        if(threadIdx.x <= K)
        {
            //reading K+1 first elements to merge
            shmTmp[threadIdx.y][threadIdx.x] = shmInputAlignments[threadIdx.y][threadIdx.x][0];
        }


//        if(!overDiagonal) <-- this "if is commented to ALWAYS sort by second index
//        {
//
//            // if we are over diagonal of PL array
//            // then we exchange x coordinate with y
//            // to sort by right index (x or y)
            tmp = (shmTmp[threadIdx.y][threadIdx.x] & 0xFFF00000) >> 12;
            shmTmp[threadIdx.y][threadIdx.x] <<= 12;
            shmTmp[threadIdx.y][threadIdx.x] &= 0xFFF00000;
            shmTmp[threadIdx.y][threadIdx.x] |= tmp;
//
//        }

        shmTmp[threadIdx.y][threadIdx.x] &= 0xFFFFFF00;
        shmTmp[threadIdx.y][threadIdx.x] |= threadIdx.x; //storing input sequence no

        __syncthreads();

        // We search for the minimal coordinate index (x and y) from first
        // elements of each input alignment (they are in shmTmp).
        // Binary search:
        for(short i=BLOCK_X_SIZE/2; i>0; i/=2)
        {
            if(threadIdx.x < i)
                shmTmp[threadIdx.y][threadIdx.x] = min(shmTmp[threadIdx.y][threadIdx.x], shmTmp[threadIdx.y][threadIdx.x + i]);

            __syncthreads();
        }
        // At the end in shmTmp[threadIdx.y][0] is the minimal value
        // for threadIdx.y'th alignment.


        #ifdef PROPER_SYNC_THREAD
        if(running[threadIdx.y])
        {
        #endif
            tmp = shmTmp[threadIdx.y][0] & 0x000000FF; //retrieving input sequence no (one from K+1)

            if(shmInputAlignments[threadIdx.y][tmp][0] == 0xFFFFFFFF)
            {
                // end of input sources
                #ifdef PROPER_SYNC_THREAD
                if(threadIdx.x == 0)
                {
                    running[threadIdx.y] = false;
                    mayBeFinished = true;
                }
                #else
                break;
                #endif
            }

            
            // if the element is valid (no 0xFF as % identity)
            // then we append/add it to output
            if( (threadIdx.x == 0) && ((shmInputAlignments[threadIdx.y][tmp][0] & 0x000000FF) != 0x000000FF) )
            {
                // if x and y indecies in consecutive elements are the same
                // we have to merge them (add % of identity)

                if( (shmOutLocalCarret_globalCarret[threadIdx.y].x > 0) &&
                    ((shmInputAlignments[threadIdx.y][tmp][0] & 0xFFFFFF00) ==
                     (shmOutputAlignment[threadIdx.y][shmOutLocalCarret_globalCarret[threadIdx.y].x - 1] & 0xFFFFFF00)) )
                {
                    // In the article they say that during merging process
                    // one has to add the weights of merged items.
                    // But the T-Coffee seams to take the max value instead.

                    // add version
                    //shmOutputAlignment[threadIdx.y][shmOutLocalCarret_globalCarret[threadIdx.y].x - 1] +=
                    //        shmInputAlignments[threadIdx.y][tmp][0] & 0x000000FF;

                    // max version
                    shmOutputAlignment[threadIdx.y][shmOutLocalCarret_globalCarret[threadIdx.y].x - 1] =
                            max(shmInputAlignments[threadIdx.y][tmp][0],
                                shmOutputAlignment[threadIdx.y][shmOutLocalCarret_globalCarret[threadIdx.y].x - 1]);
                }
                else
                {
                    if(shmOutLocalCarret_globalCarret[threadIdx.y].x < BLOCK_X_SIZE)
                    {
                        shmOutputAlignment[threadIdx.y][shmOutLocalCarret_globalCarret[threadIdx.y].x] =
                            shmInputAlignments[threadIdx.y][tmp][0];
                    }
                    shmOutLocalCarret_globalCarret[threadIdx.y].x++;
                }
            }
        #ifdef PROPER_SYNC_THREAD
        }
        #endif
        
        __syncthreads();

        #ifdef PROPER_SYNC_THREAD
        if(running[threadIdx.y])
        {
        #endif
            //if output cache is full we have to write it to the global memory
            if(shmOutLocalCarret_globalCarret[threadIdx.y].x == (BLOCK_X_SIZE+1))
            {
                myOutputAlign[shmOutLocalCarret_globalCarret[threadIdx.y].y + threadIdx.x] =
                        shmOutputAlignment[threadIdx.y][threadIdx.x];

                if(threadIdx.x == 0)
                {
                    shmOutLocalCarret_globalCarret[threadIdx.y].y += BLOCK_X_SIZE;

                    shmOutputAlignment[threadIdx.y][0] = shmInputAlignments[threadIdx.y][tmp][0];
                    shmOutLocalCarret_globalCarret[threadIdx.y].x = 1;
                }
            }

            if(threadIdx.x == 0)
                shmInLeftCount_globalCarret[threadIdx.y][tmp].x--;
        #ifdef PROPER_SYNC_THREAD
        }
        #endif

        __syncthreads();

        #ifdef PROPER_SYNC_THREAD
        if(running[threadIdx.y])
        {
        #endif
            if(shmInLeftCount_globalCarret[threadIdx.y][tmp].x == 0)
            {
                //fetching input data from the global memory
                shmInputAlignments[threadIdx.y][tmp][threadIdx.x] = myInputAlign[threadIdx.x + shmInLeftCount_globalCarret[threadIdx.y][tmp].y  + tmp*WIN_SIZE*WIN_SIZE*maxAlignmentLength];

                if(threadIdx.x == 0)
                {
                    shmInLeftCount_globalCarret[threadIdx.y][tmp].x =  BLOCK_X_SIZE;
                    shmInLeftCount_globalCarret[threadIdx.y][tmp].y += BLOCK_X_SIZE;
                }
            }
            else
            {
                //shifting tmp'th input buffer
                if(threadIdx.x > 0)
                {
                    shiftTmp = shmInputAlignments[threadIdx.y][tmp][threadIdx.x];
                    shmInputAlignments[threadIdx.y][tmp][threadIdx.x - 1] = shiftTmp;
                }
            }
        #ifdef PROPER_SYNC_THREAD
        }


        //checking if all threads have already finished the job
        if(mayBeFinished)
        {
            bool stillRunning = running[0];
            for(short i=1; i<BLOCK_Y_SIZE; i++)
                stillRunning |= running[i];

            __syncthreads();

            //the entire block of threads leaves at the same time...
            if(!stillRunning)
                break;

            if(threadIdx.x == 0)
                mayBeFinished = false;
            //__syncthreads();
        }
        #endif

    }//while(true)



    
    //flush output cache
    if(shmOutLocalCarret_globalCarret[threadIdx.y].x < BLOCK_X_SIZE)
    {
        if(threadIdx.x == 0)
        {
            //0xFFFFFFFF - end of alignment
            shmOutputAlignment[threadIdx.y][shmOutLocalCarret_globalCarret[threadIdx.y].x] = 0xFFFFFFFF;
        }
    }

    __syncthreads();
    myOutputAlign[shmOutLocalCarret_globalCarret[threadIdx.y].y + threadIdx.x] =
        shmOutputAlignment[threadIdx.y][threadIdx.x];

    if(shmOutLocalCarret_globalCarret[threadIdx.y].x == BLOCK_X_SIZE)
    {
        if(threadIdx.x == 0)
        {
            //0xFFFFFFFF - end of alignment
            myOutputAlign[shmOutLocalCarret_globalCarret[threadIdx.y].y + BLOCK_X_SIZE] = 0xFFFFFFFF;
        }
    }

    //writing length of the alignment
    if(threadIdx.x == 0)
    {
        int lengthAlignedTo16 = shmOutLocalCarret_globalCarret[threadIdx.y].x +
                                            shmOutLocalCarret_globalCarret[threadIdx.y].y + 1;
        
        if((lengthXY.x == 0) || (lengthXY.y == 0)) //because of PROPER_SYNC_THREAD
        {
            lengthAlignedTo16 = 0;
        }
        else
        {
            lengthAlignedTo16 = ((lengthAlignedTo16 - 1) / 16) * 16 + 16;
        }
        //we write length of alignment to the global memory
        lengths[seqYNo*WIN_SIZE + seqXNo] = lengthAlignedTo16;
    }

//==============================================================================




    
//    //wypisanie 16 pierwszych elementów dopasowania lokalnego
//    if ((threadIdx.y == 0) && (threadIdx.x <= K))
//        myOutputAlign[threadIdx.x] = shmInputAlignments[threadIdx.y][threadIdx.x][0];



//    shmInputAlignments[threadIdx.y][1][threadIdx.x] = 0;
//    shmOutputAlignment[threadIdx.y][threadIdx.x] = 2;
//    shmTmp[threadIdx.y][threadIdx.x] = 2;
//    shmInLeftCount_globalCarret[threadIdx.y][threadIdx.x].x = 2;
//    shmOutLocalCarret_globalCarret[threadIdx.y].x = 2;

}

#ifdef PROPER_SYNC_THREAD
#undef PROPER_SYNC_THREAD
#endif

#undef MEM_OFFSET
#undef MAX_K



/*******************************************************************************
 * PrePPLtoPPL function converts prePPL array to PPL array.                    *
 * PrePPL contains sorted and merged alignments but there is a lot of space    *
 * between different alignments.                                               *
 * PPL array don't have this free space (only round up to 16 elements).        *
 *******************************************************************************/

__global__ void PrePPLtoPPL(unsigned int* input, unsigned int* output, unsigned int* starts, int maxAlignmentLength, short K)
{

    // block size should be:  num_of_threads_in_warp x num_of_aligment_processed
    // here is:               BLOCK_X_SIZE           x BLOCK_Y_SIZE
    // on Fermi could be e.g. 32                     x 16


    // blocks must be launched in a grid with shape: dim3(WIN_SIZE/BLOCK_Y_SIZE,WIN_SIZE,1)
    //  _____   _____   _____   _____
    // |_____| |_____| |_____| |_____|
    //  _____   _____   _____   _____
    // |_____| |_____| |_____| |_____|...
    //
    // each block merge and sort BLOCK_Y_SIZE sequences using shared memory


    unsigned int* myInput = &input[seqYNo*WIN_SIZE*maxAlignmentLength*(K+1) + seqXNo*maxAlignmentLength*(K+1)];

    //to keep starts cached
    __shared__ unsigned int shmStarts[BLOCK_Y_SIZE + 1];

    //BLOCK_X_SIZE consecutive threads transcript (process) one alignment
    int seqToProcess = seqYNo * WIN_SIZE + seqXNo;

    if(threadIdx.y == 0)
    {
        shmStarts[threadIdx.x] = starts[seqToProcess + threadIdx.x];
        if(threadIdx.x == 0)
            shmStarts[BLOCK_Y_SIZE] = starts[seqToProcess + BLOCK_Y_SIZE]; //+1 to compute length of the last alignment
    }

    __syncthreads();
    
    //"starts" array tells us where to write in output array

    int readCarret = 0;
    unsigned int end = shmStarts[threadIdx.y + 1];
    
    for(unsigned int offset=shmStarts[threadIdx.y]; offset<end; offset+=BLOCK_X_SIZE)
    {
        output[offset + threadIdx.x] = myInput[readCarret + threadIdx.x];

        readCarret += BLOCK_X_SIZE;
    }
    
    
}

#undef BLOCK_X_SIZE
#undef BLOCK_Y_SIZE
#undef seqXNo
#undef seqYNo
#undef WIN_SIZE


#endif
