#include <math.h>
#include <typeinfo>

#include "neighbour_joining.h"
#include "sequences.h"
#include "nj_internal_node.h"
#include "exceptions.h"
#include "hi_res_timer.h"


INJTreeNode*  parseFile(const char* filename);

using Data::Sequences;
using namespace Exceptions;
using namespace std;

NeighbourJoining::NeighbourJoining(const char* filename, DistanceMatrix* dm, Sequences* seq)
{
    this->filename  = filename;
    this->dm        = dm;
    this->seq       = seq;
}


void NeighbourJoining::findLeafNumbers(INJTreeNode* node, INJTreeNode* parent) {
    node->parent = parent;
    if (node->isLeaf) {
        NJLeafNode* lnode = (NJLeafNode*) node;
        leavesPhyl[nLeavesPhyl++] = lnode;
        for (int i = 0; i < seq->getSequenceNumber(); i++)
            if (!strcmp(seq->getSeqName(i), lnode->sequenceName)) {
                lnode->seqNo = i;
                lnode->length = this->seq->getLengths()[i];
                //cout << i << endl;
                return;
            }
        throw new IndexOutOfRangeException((string("Sequence ") + lnode->sequenceName + " hasn't been found in the sequences set").c_str());
    }
    NJInternalNode* inode = (NJInternalNode*) node;
    findLeafNumbers(inode->childNodeLeft, node);
    findLeafNumbers(inode->childNodeRight, node);
}

void NeighbourJoining::constructTreeFromPhylipFile()
{
    leavesPhyl = new NJLeafNode*[dm->sequenceNumber];
    nLeavesPhyl = 0;
    this->root = parseFile(this->filename);
    findLeafNumbers(this->root, NULL);
    if (nLeavesPhyl != dm->sequenceNumber)
        throw new IndexOutOfRangeException("Not all sequences used in constructTreeFromPhylipFile()");
    seqsWeights = new double[dm->sequenceNumber];
    computeWeights(leavesPhyl);
    delete[] leavesPhyl;
//    INJTreeNode* node = root;
//    while (strcmp(typeid(*node).name() + strlen(typeid(*node).name()) - strlen("NJLeafNode"), "NJLeafNode")) {
//        node = ((NJInternalNode*)node)->childNodeLeft;
//    }
//    pause();
}


void NeighbourJoining::run()
{
    HiResTimer timer;
    timer.start();

    orphans = new INJTreeNode*[dm->sequenceNumber];
    NJLeafNode* leaves = new NJLeafNode[dm->sequenceNumber]; // needed to compute weights of sequences
    seqsWeights = new double[dm->sequenceNumber];

    for (int i = 0; i < dm->sequenceNumber; i++)
    {
        orphans[i] = &leaves[i];
        orphans[i]->isLeaf = true;
        ((NJLeafNode*)orphans[i])->sequenceName = seq->getSeqName(i);
        ((NJLeafNode*)orphans[i])->seqNo = i;
        ((NJLeafNode*)orphans[i])->numOfLeaves = 1;
        ((NJLeafNode*)orphans[i])->length = seq->getLengths()[i];

        seqsWeights[i] = 0.0;
    }

    int currWeight = 1;
    DistanceMatrix* Q = new DistanceMatrix(seq);

    int xIdx;
    int yIdx;
    double minVal = INT_MAX;
    int r = dm->sequenceNumber;

    // CALCULATING Q MATRIX
    double* sumOfRowsCache = new double[dm->sequenceNumber];
    for (int i = 0; i < dm->sequenceNumber; i++)
        sumOfRowsCache[i] = 0;

    for (int i = 0; i < dm->sequenceNumber; i++)
        for (int j = 0; j < dm->sequenceNumber; j++)
            sumOfRowsCache[i] -= dm->getElement(j, i);

    for (int i = 0; i < dm->sequenceNumber; i++)
    {
        for (int j = 0; j < i; j++)
        {
            double qValue = (r - 2)*dm->getElement(j, i);
            qValue += sumOfRowsCache[i];
            qValue += sumOfRowsCache[j];
            
            Q->setElement(j, i, qValue);
        }
    }

    // IN THIS LOOP WE JOIN NODES TO CREATE TREE
    for(r = dm->sequenceNumber; r>2; r--)
    {
        
        // CHOOSING MIN VALUE FROM Q MATRIX
        minVal = INT_MAX;
        for (int i = 0; i < dm->sequenceNumber; i++)
        {
            if (orphans[i] != NULL)
            {
                for (int j = 0; j < i; j++)
                {
                    if (orphans[j] != NULL)
                    {
                        if(Q->getElement(j,i)<minVal)
                        {
                            xIdx = j;
                            yIdx = i;
                            minVal = Q->getElement(j,i);
                        }
                    }
                }
            }
        }

        // UPDATING Q MATRIX
        // for each cell we add value of deleted elements
        double updatedValue;
        for (int i = 0; i < dm->sequenceNumber; i++)
        {
            if (orphans[i] != NULL)
            {
                for (int j = 0; j < i; j++)
                {
                    if (orphans[j] != NULL)
                    {
                        updatedValue  = Q->getElement(j, i);
                        updatedValue += dm->getElement(xIdx,i);
                        updatedValue += dm->getElement(yIdx,i);
                        updatedValue += dm->getElement(xIdx,j);
                        updatedValue += dm->getElement(yIdx,j);
                        updatedValue -= (r - 2)*dm->getElement(j, i);
                        Q->setElement(j, i,updatedValue);
                    }
                }
            }
        }
        

        // CREATING NEW NODE
        NJInternalNode* newNode = new NJInternalNode();
        newNode->isLeaf = false;

        newNode->value = dm->getElement(xIdx, yIdx);
        newNode->childNodeLeft  = orphans[xIdx];
        newNode->childNodeRight = orphans[yIdx];
        newNode->numOfLeaves = orphans[xIdx]->numOfLeaves + orphans[yIdx]->numOfLeaves;
        orphans[xIdx]->parent = newNode;
        orphans[yIdx]->parent = newNode;


        // CALCULATING VALUES FOR CHILDREN
        double leftDistance  = 0;
        double rightDistance = 0;
        for (int k = 0; k < dm->sequenceNumber; k++)
        {
            if (orphans[k] != NULL)
            {
                leftDistance  += dm->getElement(xIdx, k);
                rightDistance += dm->getElement(yIdx, k);
            }
        }
        
//        if(r>2)
//        {
            newNode->childNodeLeft->value  = 0.5 * dm->getElement(xIdx, yIdx) +
                                             ( 1.0 / (2.0*(r-2.0)) ) *
                                             (leftDistance - rightDistance);
            newNode->childNodeRight->value = 0.5 * dm->getElement(xIdx, yIdx) +
                                             ( 1.0 / (2.0*(r-2.0)) ) *
                                             (rightDistance - leftDistance);
//        }
//        else // last two nodes (because r can't be 2)
//        {
//            newNode->childNodeLeft->value  = 0.5 * dm->getElement(xIdx, yIdx) +
//                                             (leftDistance - rightDistance);
//            newNode->childNodeRight->value = 0.5 * dm->getElement(xIdx, yIdx) +
//                                             (rightDistance - leftDistance);
//        }

            
//        if(orphans[xIdx]->isLeaf)
//        {
//            seqsWeights[xIdx] = currWeight + orphans[xIdx].value;
//        }
//        if(orphans[yIdx]->isLeaf)
//        {
//            seqsWeights[yIdx] = currWeight + orphans[yIdx].value;
//        }
//        currWeight += max(orphans[xIdx].value, orphans[yIdx].value);
        

        // UPDATING K (Distance) MATRIX
        for (int k = 0; k < dm->sequenceNumber; k++)
        {
            if (orphans[k] != NULL)
            {
                dm->setElement(xIdx, k,
                               0.5 * (dm->getElement(xIdx, k) -
                                      newNode->childNodeLeft->value) +
                               0.5 * (dm->getElement(yIdx, k) -
                                      newNode->childNodeRight->value)
                               );
            }
        }

        orphans[xIdx] = newNode;
        orphans[yIdx] = NULL;

        // CONTINUATION OF Q MATRIX UPDATE
        for (int i = 0; i < dm->sequenceNumber; i++)
        {
            if (orphans[i] != NULL)
            {
                for (int j = 0; j < i; j++)
                {
                    if (orphans[j] != NULL)
                    {
                        updatedValue  = Q->getElement(j, i);
                        updatedValue -= dm->getElement(xIdx,i);
                        updatedValue -= dm->getElement(xIdx,j);
                        updatedValue += (r - 3)*dm->getElement(j, i);
                        Q->setElement(j, i,updatedValue);
                    }
                }
            }
        }

        // RECALCULATING Q MATRIX xIdx column (row)
        for (int i = 0; i < dm->sequenceNumber; i++)
        {
            if (orphans[i] != NULL)
            {
                double qValue = (r - 3)*dm->getElement(xIdx, i);
                for (int k = 0; k < dm->sequenceNumber; k++)
                {
                    if (orphans[k] != NULL)
                    {
                        qValue -= dm->getElement(xIdx, k);
                        qValue -= dm->getElement(i, k);
                    }
                }
                Q->setElement(xIdx, i, qValue);
            }
        }


    }// FOR



    // HERE TREE IS READY
    // THE ROOT IS HERE: orphans[xIdx]

//      // Fully binary tree
//    int offset = 0;
//    int totalElementsCount = 30000;
//    char* memForString = (char*)malloc(totalElementsCount);
//    memForString[offset++] = '(';
//    memForString[offset++] = '\n';
//    orphans[xIdx]->ToString(memForString, offset, totalElementsCount);
//    memForString[offset++] = ')';
//    memForString[offset++] = ';';
//    memForString[offset++] = 0;


    // Binary tree with three elements joined to the root node
    int indexes[2];
    int actualIndex = 0;
    for (int i = 0; i < dm->sequenceNumber; i++)
    {
        if (orphans[i] != NULL)
        {
            indexes[actualIndex++] = i;
        }
    }
    if (actualIndex != 2)
    {
        throw new IncorrectInputData("Number of sequences is not big enough to generate tree (must be > 2).\nTree hasn't been generated.");
    }

    NJInternalNode* internalNode;
    INJTreeNode* otherNode;
    if (!orphans[indexes[0]]->isLeaf)
    {
        internalNode = (NJInternalNode*)orphans[indexes[0]];
        otherNode = orphans[indexes[1]];
    }
    else if (!orphans[indexes[1]]->isLeaf)
    {
        internalNode = (NJInternalNode*)orphans[indexes[1]];
        otherNode = orphans[indexes[0]];
    }
    else
    {
        throw new IncorrectInputData("Number of sequences is not big enough to generate tree (must be > 2).\nTree hasn't been generated.");
    }
    otherNode->value = dm->getElement(indexes[0], indexes[1]);

    root = new NJInternalNode();
    NJInternalNode* newNode = new NJInternalNode();
    root->isLeaf = false;
    root->value = 0; // irrelevant here
    ((NJInternalNode*)root)->childNodeLeft  = internalNode;
    ((NJInternalNode*)root)->childNodeRight = otherNode;
    internalNode->parent = root;
    otherNode->parent = root;
    root->parent = NULL;
    root->numOfLeaves = internalNode->numOfLeaves + otherNode->numOfLeaves;



    NJLeafNode* pleaves[dm->sequenceNumber];
    for (int i = 0; i < dm->sequenceNumber; i++)
        pleaves[i] = &leaves[i];
    computeWeights(pleaves);


    timer.stop();
    printf("Computing NJ tree: %dms\n", (int)timer.getElapsedTime());



    // generating string and saving to a file
    if(filename != NULL)
    {
        timer.start();
        
        int offset = 0;
        int totalElementsCount = 30000;
        char* memForString = (char*)malloc(totalElementsCount);
        memForString[offset++] = '(';
        memForString[offset++] = '\n';
        internalNode->childNodeLeft->ToString(memForString, offset, totalElementsCount);
        memForString[offset++] = ',';
        memForString[offset++] = '\n';
        internalNode->childNodeRight->ToString(memForString, offset, totalElementsCount);
        memForString[offset++] = ',';
        memForString[offset++] = '\n';
        otherNode->ToString(memForString, offset, totalElementsCount);
        memForString[offset++] = ')';
        memForString[offset++] = ';';
        memForString[offset++] = 0;

        //printf("%s\n", memForString);


        FILE* file = fopen(filename, "w");
        fprintf(file, "%s", memForString);
        fclose(file);

        free(memForString);

        timer.stop();
        printf("Saving NJ tree to file %s:  %dms\n", filename, (int)timer.getElapsedTime());
    }
}

void NeighbourJoining::computeWeights(NJLeafNode** leaves)
{
    // Computing weights of the sequences
    double sum = 0;
    double stdDev = 0;
    double variance = 0;
    for (int i = 0; i < dm->sequenceNumber; i++)
    {
        INJTreeNode* node = leaves[i];
        while(node)
        {
            seqsWeights[i] += node->value / node->numOfLeaves;
            node = node->parent;
        }
        sum += seqsWeights[i];
    }
    // Normalise the weights, such that the sum of the weights = 1 (in clustalw2 INT_SCALE_FACTOR)
    for (int i = 0; i < dm->sequenceNumber; i++)
    {
        if (sum != 0)
        {
            seqsWeights[i] /= sum;
            seqsWeights[i] *= dm->sequenceNumber; // add this line to make avg weight equal to 1
        }
        else
        {
            seqsWeights[i] = 1.0;
        }
        //printf("%f\n", seqsWeights[i]);
    }

    double maxDeviation = 0.0;

    for (int i = 0; i < dm->sequenceNumber; i++)
    {
        variance += (seqsWeights[i] - 1.0) * (seqsWeights[i] - 1.0);
        maxDeviation = max(maxDeviation, (((seqsWeights[i] - 1.0) > 0) ? (seqsWeights[i] - 1.0): (-(seqsWeights[i] - 1.0))) );
    }

    variance = variance / dm->sequenceNumber;
    stdDev = sqrt(variance);

    if (stdDev != 0)
    {
        maxDeviation /= stdDev;
        for (int i = 0; i < dm->sequenceNumber; i++)
        {
            seqsWeights[i] = (seqsWeights[i] - 1.0) / maxDeviation + 1.0;
            //printf("%f\n", seqsWeights[i]);
        }
    }

    //printf("stddev=%lf\n", stdDev);
}

