#include "nws_single_gpu_runnable.h"
#include "nws_algorithm_option.h"
#include "nws_single_gpu_runnable_init.h"
#include "dev_pointers.h"
#include "dev_params.h"
#include "constants.h"


using align::NWSSingleGPURunnable;
using align::NWSSingleGPURunnableInit;
using align::NWSAlgorithmOption;
using align::DevPointers;
using align::DevParams;



//tu N ba być od 1 do 16
template <int Y_STEPS,
          int MASK,
          int BITS,
          int RESIDUES_COUNT,
          int N>
__global__ void NeedlemanWunschSemiglobalScoreKernel(DevPointers devPtr, DevParams devParams)
{        
    int idx = (blockIdx.x * BLOCK_SIZE) + threadIdx.x;
    
    __shared__ int sm[RESIDUES_COUNT*RESIDUES_COUNT];
    
    if(threadIdx.x < RESIDUES_COUNT*RESIDUES_COUNT)
        sm[threadIdx.x] = devPtr.sm[threadIdx.x];

    __syncthreads();
    
    if(idx >= devParams.pairsToCompute)
        return;

    int pairX = devPtr.pair1[idx];
    int pairY = devPtr.pair2[idx];

    unsigned int startX  = devPtr.starts[pairX];
    unsigned int startY  = devPtr.starts[pairY];

    int2 lengthXY;
    lengthXY.x = devPtr.lengths[pairX];
    lengthXY.y = devPtr.lengths[pairY];
    

    if((lengthXY.x == 0) || (lengthXY.y == 0))//if there is nothing to do -> quit
        return;


    __shared__ int H_shared[Y_STEPS][BLOCK_SIZE];
    //elements of Y sequence go to sharedYSeq
    __shared__ char sharedYSeq[Y_STEPS][BLOCK_SIZE];


    int H_current;
    H_current = 0;

    int2 score_overlap;
    score_overlap.x = 0;
    score_overlap.y = lengthXY.y;

    
    int lenY = lengthXY.y - N;
    
    // |
    // |
    // |
    // V
    for (int y = 0; y < lenY; y += Y_STEPS)
    {        
        int2 H_init_upleft;
        H_init_upleft.x = 0;

        //initialialization of the -1 column in H matrix
        #pragma unroll
        for (int i = 0; i < Y_STEPS; i++)
        {
            H_shared[i][threadIdx.x] = 0;
        }        
        

        //we read elements of the Y sequence
        unsigned int seqY = devPtr.packedSeqs[startY++];
        #pragma unroll
        for (int i = 0; i < Y_STEPS; i++)
        {
            sharedYSeq[i][threadIdx.x] = seqY & MASK;
            seqY >>= BITS;
        }
               
        
        //------>
        for (int x = 0; x < lengthXY.x; x++)
        {
            H_init_upleft.y = H_init_upleft.x;
            
            unsigned int XSeq16;
            
            if(x % Y_STEPS == 0)
                XSeq16 = devPtr.packedSeqs[startX + x/Y_STEPS];
            
            int XSeq = XSeq16 & MASK;
            XSeq16 >>= BITS;
            
            
            //read from global memory
            H_init_upleft.x = devPtr.H[idx + x * devParams.memOffset];

            //H_init_upleft.x -> up element read in previous iteration from global memory (up-left)
            H_current = H_init_upleft.x;
            
            //  |  /|  /|
            //  | / | / |
            //  |/  |/  V
            //  |  /|  /|
            //  | / | / |
            //  |/  |/  V
            #pragma unroll
            for(int i = 0; i < Y_STEPS; i++)
            {
                
                H_current = max(H_current - devParams.gapPenalty, H_init_upleft.y + sm[sharedYSeq[i][threadIdx.x] * RESIDUES_COUNT + XSeq]);
                H_current = max(H_current, H_shared[i][threadIdx.x] - devParams.gapPenalty);
                
                //initialize variables for next iterations
                H_init_upleft.y = H_shared[i][threadIdx.x];
                H_shared[i][threadIdx.x] = H_current;

            }

            //write variables to global memory for next loop
            devPtr.H[idx + x * devParams.memOffset] = H_current;// - H_init_upleft.y;
                

        }

        //search for max value in last column of H
        #pragma unroll
        for(int i = 0; i < Y_STEPS; i++)
        {
            score_overlap.x = max(score_overlap.x, H_shared[i][threadIdx.x]);
            if(score_overlap.x == H_shared[i][threadIdx.x])
                score_overlap.y = lengthXY.y - 1 - y - i;
        }
        
    }
    
    // LICZENIE DOLNEJ KOŃCÓWKI MACIERZY
    
    //int ymin = lengthXY.y % Y_STEPS;
    
    {      
        int y = lenY;
        int2 H_init_upleft;
        H_init_upleft.x = 0;

        //initialialization of the -1 column in H matrix
        #pragma unroll
        for (int i = 0; i < N; i++)
        {
            H_shared[i][threadIdx.x] = 0;
        }        
        

        //we read elements of the Y sequence
        unsigned int seqY = devPtr.packedSeqs[startY++];
        #pragma unroll
        for (int i = 0; i < N; i++)
        {
            sharedYSeq[i][threadIdx.x] = seqY & MASK;
            seqY >>= BITS;
        }
               
        
        //------>
        for (int x = 0; x < lengthXY.x; x++)
        {
            H_init_upleft.y = H_init_upleft.x;
            
            unsigned int XSeq16;
            
            if(x % Y_STEPS == 0)
                XSeq16 = devPtr.packedSeqs[startX + x/Y_STEPS];
            
            int XSeq = XSeq16 & MASK;
            XSeq16 >>= BITS;
            
            
            //read from global memory
            H_init_upleft.x = devPtr.H[idx + x * devParams.memOffset];

            //H_init_upleft.x -> up element read in previous iteration from global memory (up-left)
            H_current = H_init_upleft.x;
            
            //  |  /|  /|
            //  | / | / |
            //  |/  |/  V
            //  |  /|  /|
            //  | / | / |
            //  |/  |/  V
			#pragma unroll
            for(int i = 0; i < N; i++)
            {
                
                H_current = max(H_current - devParams.gapPenalty, H_init_upleft.y + sm[sharedYSeq[i][threadIdx.x] * RESIDUES_COUNT + XSeq]);
                H_current = max(H_current, H_shared[i][threadIdx.x] - devParams.gapPenalty);
                
                //initialize variables for next iterations
                H_init_upleft.y = H_shared[i][threadIdx.x];
                H_shared[i][threadIdx.x] = H_current;

            }

                
            //search for max value in last row of H
            score_overlap.x = max(score_overlap.x, H_current);

            if(score_overlap.x == H_current)
                score_overlap.y = -(lengthXY.x - 1 - x);                

        }

        //search for max value in last column of H
		#pragma unroll
        for(int i = 0; i < N; i++)
        {
            score_overlap.x = max(score_overlap.x, H_shared[i][threadIdx.x]);
            if(score_overlap.x == H_shared[i][threadIdx.x])
                score_overlap.y = lengthXY.y - 1 - y - i;
        }
        
    }

    
    devPtr.scores_overlaps[idx] = score_overlap;
}

template <int Y_STEPS,
          int MASK,
          int BITS,
          int RESIDUES_COUNT,
          int N>
void RunNeedlemanWunschSemiglobalScoreKernel(int blocks, int block_size, DevPointers devPtr, DevParams devParams, NWSAlgorithmOption *nws) {
    NeedlemanWunschSemiglobalScoreKernel<Y_STEPS, MASK, BITS, RESIDUES_COUNT, N> <<< blocks, block_size >>> (devPtr, devParams);
    nws->fasterKernelInvocationCount++;
}

template <int Y_STEPS,
          int MASK,
          int BITS,
          int RESIDUES_COUNT>
__global__ void NeedlemanWunschSemiglobalScoreKernelVariableLength(DevPointers devPtr, DevParams devParams)
{        
    int idx = (blockIdx.x * BLOCK_SIZE) + threadIdx.x;
    
    __shared__ int sm[RESIDUES_COUNT*RESIDUES_COUNT];
    
    if(threadIdx.x < RESIDUES_COUNT*RESIDUES_COUNT)
        sm[threadIdx.x] = devPtr.sm[threadIdx.x];

    __syncthreads();
    
    if(idx >= devParams.pairsToCompute)
        return;

    int pairX = devPtr.pair1[idx];
    int pairY = devPtr.pair2[idx];

    unsigned int startX  = devPtr.starts[pairX];
    unsigned int startY  = devPtr.starts[pairY];

    int2 lengthXY;
    lengthXY.x = devPtr.lengths[pairX];
    lengthXY.y = devPtr.lengths[pairY];
    

    if((lengthXY.x == 0) || (lengthXY.y == 0))//if there is nothing to do -> quit
        return;


    __shared__ int H_shared[Y_STEPS][BLOCK_SIZE];
    //elements of Y sequence go to sharedYSeq
    __shared__ char sharedYSeq[Y_STEPS][BLOCK_SIZE];


    int H_current;
    H_current = 0;

    int2 score_overlap;
    score_overlap.x = 0;
    score_overlap.y = lengthXY.y;

    
    int lenY = lengthXY.y - lengthXY.y % (Y_STEPS);
    
    // |
    // |
    // |
    // V
    for (int y = 0; y < lenY; y += Y_STEPS)
    {        
        int2 H_init_upleft;
        H_init_upleft.x = 0;

        //initialialization of the -1 column in H matrix
        #pragma unroll
        for (int i = 0; i < Y_STEPS; i++)
        {
            H_shared[i][threadIdx.x] = 0;
        }        
        

        //we read elements of the Y sequence
        unsigned int seqY = devPtr.packedSeqs[startY++];
        #pragma unroll
        for (int i = 0; i < Y_STEPS; i++)
        {
            sharedYSeq[i][threadIdx.x] = seqY & MASK;
            seqY >>= BITS;
        }
               
        
        //------>
        for (int x = 0; x < lengthXY.x; x++)
        {
            H_init_upleft.y = H_init_upleft.x;
            
            unsigned int XSeq16;
            
            if(x % Y_STEPS == 0)
                XSeq16 = devPtr.packedSeqs[startX + x/Y_STEPS];
            
            int XSeq = XSeq16 & MASK;
            XSeq16 >>= BITS;
            
            
            //read from global memory
            H_init_upleft.x = devPtr.H[idx + x * devParams.memOffset];

            //H_init_upleft.x -> up element read in previous iteration from global memory (up-left)
            H_current = H_init_upleft.x;
            
            //  |  /|  /|
            //  | / | / |
            //  |/  |/  V
            //  |  /|  /|
            //  | / | / |
            //  |/  |/  V
            #pragma unroll
            for(int i = 0; i < Y_STEPS; i++)
            {
                
                H_current = max(H_current - devParams.gapPenalty, H_init_upleft.y + sm[sharedYSeq[i][threadIdx.x] * RESIDUES_COUNT + XSeq]);
                H_current = max(H_current, H_shared[i][threadIdx.x] - devParams.gapPenalty);
                
                //initialize variables for next iterations
                H_init_upleft.y = H_shared[i][threadIdx.x];
                H_shared[i][threadIdx.x] = H_current;

            }

            //write variables to global memory for next loop
            devPtr.H[idx + x * devParams.memOffset] = H_current;// - H_init_upleft.y;
                

        }

        //search for max value in last column of H
        #pragma unroll
        for(int i = 0; i < Y_STEPS; i++)
        {
            score_overlap.x = max(score_overlap.x, H_shared[i][threadIdx.x]);
            if(score_overlap.x == H_shared[i][threadIdx.x])
                score_overlap.y = lengthXY.y - 1 - y - i;
        }
        
    }
    
    // LICZENIE DOLNEJ KOŃCÓWKI MACIERZY
    
    int ymin = lengthXY.y % Y_STEPS;
    
    {      
        int y = lenY;
        int2 H_init_upleft;
        H_init_upleft.x = 0;

        //initialialization of the -1 column in H matrix
        #pragma unroll
        for (int i = 0; i < Y_STEPS; i++)
        {
            H_shared[i][threadIdx.x] = 0;
        }        
        

        //we read elements of the Y sequence
        unsigned int seqY = devPtr.packedSeqs[startY++];
        #pragma unroll
        for (int i = 0; i < Y_STEPS; i++)
        {
            sharedYSeq[i][threadIdx.x] = seqY & MASK;
            seqY >>= BITS;
        }
               
        
        //------>
        for (int x = 0; x < lengthXY.x; x++)
        {
            H_init_upleft.y = H_init_upleft.x;
            
            unsigned int XSeq16;
            
            if(x % Y_STEPS == 0)
                XSeq16 = devPtr.packedSeqs[startX + x/Y_STEPS];
            
            int XSeq = XSeq16 & MASK;
            XSeq16 >>= BITS;
            
            
            //read from global memory
            H_init_upleft.x = devPtr.H[idx + x * devParams.memOffset];

            //H_init_upleft.x -> up element read in previous iteration from global memory (up-left)
            H_current = H_init_upleft.x;
            
            //  |  /|  /|
            //  | / | / |
            //  |/  |/  V
            //  |  /|  /|
            //  | / | / |
            //  |/  |/  V
            for(int i = 0; i < ymin; i++)
            {
                
                H_current = max(H_current - devParams.gapPenalty, H_init_upleft.y + sm[sharedYSeq[i][threadIdx.x] * RESIDUES_COUNT + XSeq]);
                H_current = max(H_current, H_shared[i][threadIdx.x] - devParams.gapPenalty);
                
                //initialize variables for next iterations
                H_init_upleft.y = H_shared[i][threadIdx.x];
                H_shared[i][threadIdx.x] = H_current;

            }

                
            //search for max value in last row of H
            score_overlap.x = max(score_overlap.x, H_current);

            if(score_overlap.x == H_current)
                score_overlap.y = -(lengthXY.x - 1 - x);                

        }

        //search for max value in last column of H
        for(int i = 0; i < ymin; i++)
        {
            score_overlap.x = max(score_overlap.x, H_shared[i][threadIdx.x]);
            if(score_overlap.x == H_shared[i][threadIdx.x])
                score_overlap.y = lengthXY.y - 1 - y - i;
        }
        
    }

    
    devPtr.scores_overlaps[idx] = score_overlap;
}

template <int Y_STEPS,
          int MASK,
          int BITS,
          int RESIDUES_COUNT>
void RunNeedlemanWunschSemiglobalScoreKernelVariableLength(int blocks, int block_size, DevPointers devPtr, DevParams devParams, NWSAlgorithmOption *nws) {
    NeedlemanWunschSemiglobalScoreKernelVariableLength<Y_STEPS, MASK, BITS, RESIDUES_COUNT> <<< blocks, block_size >>> (devPtr, devParams);
    nws->slowerKernelInvocationCount++;
}




typedef void (*KernelInvoker)(int blocks, int blocksize, DevPointers devPtr, DevParams devParams,  NWSAlgorithmOption *nws);
KernelInvoker invokers[6][16] = { {}, {}, {}, {}, {
             RunNeedlemanWunschSemiglobalScoreKernel<16,3,2,4,16>,  
             RunNeedlemanWunschSemiglobalScoreKernel<16,3,2,4, 1>,  
             RunNeedlemanWunschSemiglobalScoreKernel<16,3,2,4, 2>,  
             RunNeedlemanWunschSemiglobalScoreKernel<16,3,2,4, 3>,  
             RunNeedlemanWunschSemiglobalScoreKernel<16,3,2,4, 4>,  
             RunNeedlemanWunschSemiglobalScoreKernel<16,3,2,4, 5>,  
             RunNeedlemanWunschSemiglobalScoreKernel<16,3,2,4, 6>,  
             RunNeedlemanWunschSemiglobalScoreKernel<16,3,2,4, 7>,  
             RunNeedlemanWunschSemiglobalScoreKernel<16,3,2,4, 8>,  
             RunNeedlemanWunschSemiglobalScoreKernel<16,3,2,4, 9>,  
             RunNeedlemanWunschSemiglobalScoreKernel<16,3,2,4,10>,  
             RunNeedlemanWunschSemiglobalScoreKernel<16,3,2,4,11>,  
             RunNeedlemanWunschSemiglobalScoreKernel<16,3,2,4,12>,  
             RunNeedlemanWunschSemiglobalScoreKernel<16,3,2,4,13>,  
             RunNeedlemanWunschSemiglobalScoreKernel<16,3,2,4,14>,  
             RunNeedlemanWunschSemiglobalScoreKernel<16,3,2,4,15>
        }, {
             RunNeedlemanWunschSemiglobalScoreKernel<10,7,3,5,10>,
             RunNeedlemanWunschSemiglobalScoreKernel<10,7,3,5, 1>,
             RunNeedlemanWunschSemiglobalScoreKernel<10,7,3,5, 2>,
             RunNeedlemanWunschSemiglobalScoreKernel<10,7,3,5, 3>,
             RunNeedlemanWunschSemiglobalScoreKernel<10,7,3,5, 4>,
             RunNeedlemanWunschSemiglobalScoreKernel<10,7,3,5, 5>,
             RunNeedlemanWunschSemiglobalScoreKernel<10,7,3,5, 6>,
             RunNeedlemanWunschSemiglobalScoreKernel<10,7,3,5, 7>,
             RunNeedlemanWunschSemiglobalScoreKernel<10,7,3,5, 8>,
             RunNeedlemanWunschSemiglobalScoreKernel<10,7,3,5, 9>
        }
    };

KernelInvoker invokersVariableLength[6] = {
        NULL,
        NULL,
        NULL,
        NULL,
        RunNeedlemanWunschSemiglobalScoreKernelVariableLength<16,3,2,4>,
        RunNeedlemanWunschSemiglobalScoreKernelVariableLength<10,7,3,5>
    };

int y_steps[6] = { 32, 32, 32, 16, 16, 10 };

void NWSSingleGPURunnable::runOnOneGPU() {
    int pairsToCompute = MIN(this->nwa->pairsPerGPU, this->nwa->p->pairsCount - this->pairIdToStartWith);
    int blocks = (int)ceil(pairsToCompute*1.0/BLOCK_SIZE);
//    devParams[0].memOffset = ((blocks * BLOCK_SIZE - 1)/128 + 1) * 128;
    DevPointers devptr = NWSSingleGPURunnableInit::devPtr[this->tm->getThreadsInfo().gpuNo];
    DevParams devparam;
    devparam.gapPenalty = this->nwa->gapPenalty;
    devparam.memOffset = this->memoryOffset;
    devparam.pairsToCompute = pairsToCompute;


    int longestSeqY = this->nwa->seqs->lengths[ this->nwa->p->pair2[pairIdToStartWith] ];
    int shortestSeqY = this->nwa->seqs->lengths[ this->nwa->p->pair2[pairIdToStartWith + pairsToCompute - 1] ];


	// if all Y sequences are of equal length, then we may launch optimised kernel
    if (longestSeqY == shortestSeqY)
    {
		int tail = longestSeqY % y_steps[this->nwa->sm->residuesCount];

               invokers[this->nwa->sm->residuesCount][tail](blocks, BLOCK_SIZE, devptr, devparam, nwa);
    }
    else
    {

          invokersVariableLength[this->nwa->sm->residuesCount](blocks, BLOCK_SIZE, devptr, devparam, nwa);
          /*
		if(this->nwa->sm->residuesCount == 4)
		{
		    NeedlemanWunschSemiglobalScoreKernelVariableLength<16,3,2,4><<<blocks, BLOCK_SIZE>>>(devptr, devparam);
		}
		else if(this->nwa->sm->residuesCount == 5)
		{
			NeedlemanWunschSemiglobalScoreKernelVariableLength<10,7,3,5><<<blocks, BLOCK_SIZE>>>(devptr, devparam);
		}
          */
    }
    
}

