#include <stdio.h>
#include <stdlib.h>
#include <math.h>
#include "photonmap.h"
#include <assert.h>

PhotonMap::PhotonMap( const int maxPhot ) {
    storedPhotons = 0;
    prevScale = 1;
    maxPhotons = maxPhot;

    photons = new Photon[maxPhotons + 1];
    if ( photons == NULL ) {
	fprintf( stderr, "Out of memory initializing photon map.\n" );
	exit( EXIT_FAILURE );
    }

    bboxMin[0] = bboxMin[1] = bboxMin[2] = 1e8f;
    bboxMax[0] = bboxMax[1] = bboxMax[2] = -1e8f;

    for ( int i = 0; i < 256; i++ ) {
	double angle = (double) i * (1.0/256.0) * M_PI;
	costheta[i] = cos( angle );
	sintheta[i] = sin( angle );
	cosphi[i] = cos( 2.0 * angle );
	sinphi[i] = sin( 2.0 * angle );
    }
}

PhotonMap::~PhotonMap() {
    delete [] photons;
}

void
PhotonMap::writeToDisk( const char * fileName ) {
    FILE * file = fopen( fileName, "wb" );

    fwrite( photons, sizeof( Photon ), storedPhotons, file );
    fclose( file );

    printf( "Wrote %d photons to file %s.\n", storedPhotons, fileName );
}

void
PhotonMap::readFromDisk( const char * fileName, int numPhotons ) {
    FILE * file = fopen( fileName, "rb" );

    if ( file == NULL ) {
        fprintf( stderr, "Failed to open photon map file %s.\n", fileName );
    }

    assert( fread( photons, sizeof( Photon ), numPhotons, file ) == numPhotons );
    fclose( file );

    storedPhotons = numPhotons;
    halfStoredPhotons = storedPhotons/2 - 1;

    printf( "Read %d photons from file %s.\n", numPhotons, fileName );
}

 
void
PhotonMap::photonDir( float * dir, const Photon * p ) const {
    dir[0] = sintheta[p->theta] * cosphi[p->phi];
    dir[1] = sintheta[p->theta] * sinphi[p->phi];
    dir[2] = costheta[p->theta];
}

void
PhotonMap::irradianceEstimate( float irrad[3], const float pos[3],
    const float normal[3], const float maxDist, const int nPhotons ) const {
   
    irrad[0] = irrad[1] = irrad[2] = 0.0;

    NearestPhotons np;
    np.dist2 = new float[nPhotons + 1];
    np.index = (const Photon **) new (Photon *)[nPhotons + 1];

    np.pos[0] = pos[0]; np.pos[1] = pos[1]; np.pos[2] = pos[2];
    np.max = nPhotons;
    np.found = 0;
    np.gotHeap = 0;
    np.dist2[0] = maxDist * maxDist;

    locatePhotons( &np, 1 );

    if ( np.found < 8 ) {
	delete [] np.dist2;
	delete [] np.index;
	return;
    }

    float pdir[3];

    for ( int i = 1; i <= np.found; i++ ) {
	const Photon * p = np.index[i];
	photonDir( pdir, p );
	if ( (pdir[0]*normal[0] + pdir[1]*normal[1] + pdir[2]*normal[2]) < 0.0f ) {
	    irrad[0] += p->power[0];
	    irrad[1] += p->power[1];
	    irrad[2] += p->power[2];
	}
    }

    const float tmp = (1.0/M_PI) / (np.dist2[0]); // estimate of density
    irrad[0] *= tmp;
    irrad[1] *= tmp;
    irrad[2] *= tmp;

    delete [] np.dist2;
    delete [] np.index;
}

void 
PhotonMap::locatePhotons( NearestPhotons * const np, const int index ) const {
    const Photon * p = &photons[index];
    float dist1;

    if ( index < halfStoredPhotons ) {
	dist1 = np->pos[p->plane] - p->pos[p->plane];

	if ( dist1 > 0.0 ) {
	    locatePhotons( np, 2*index+1 );
	    if ( dist1 * dist1 < np->dist2[0] )
		locatePhotons( np, 2*index );
	}
	else {
	    locatePhotons( np, 2*index );
	    if ( dist1 * dist1 < np->dist2[0] )
		locatePhotons( np, 2*index+1 );
	}
    }

    // Compute squared distance between current photon and np->pos

    dist1 = p->pos[0] - np->pos[0];
    float dist2 = dist1 * dist1;
    dist1 = p->pos[1] - np->pos[1];
    dist2 += dist1 * dist1;
    dist1 = p->pos[2] - np->pos[2];
    dist2 += dist1 * dist1;

    if ( dist2 < np->dist2[0] ) {
	// Found a photon.  Insert in the candidate list.
    
	if ( np->found < np->max ) {
	    // Heap is not full.  Use array.

	    np->found++;
	    np->dist2[np->found] = dist2;
	    np->index[np->found] = p;
	}
	else {
	    int j, parent;

	    if ( np->gotHeap == 0 ) {
		// Build heap
	
		float dst2;
		const Photon * phot;
		int halfFound = np->found >> 1;
		for ( int k = halfFound; k >= 1; k-- ) {
		    parent = k;
		    phot = np->index[k];
		    dst2 = np->dist2[k];
		    while ( parent <= halfFound ) {
			j = parent + parent;
			if ( j < np->found && np->dist2[j] < np->dist2[j+1] )
			    j++;
			if ( dst2 >= np->dist2[j] )
			    break;
			np->dist2[parent] = np->dist2[j];
			np->index[parent] = np->index[j];
			parent = j;
		    }
		    np->dist2[parent] = dst2;
		    np->index[parent] = phot;
		}
		np->gotHeap = 1;
	    }

	    // Insert new photon into max heap
	    // Delete largest element, insert new, and reorder the heap

	    parent = 1;
	    j = 2;
	    while ( j <= np->found ) {
		if ( j < np->found && np->dist2[j] < np->dist2[j+1] )
		    j++;
		if ( dist2 > np->dist2[j] )
		    break;
		np->dist2[parent] = np->dist2[j];
		np->index[parent] = np->index[j];
		parent = j;
		j += j;
	    }
	    np->index[parent] = p;
	    np->dist2[parent] = dist2;

	    np->dist2[0] = np->dist2[1];
	}
    }
}

void
PhotonMap::store( const float power[3], const float pos[3], const float dir[3] ) {
    if ( storedPhotons >= maxPhotons )
	return;

    storedPhotons++;
    Photon * const node = &photons[storedPhotons];

    for ( int i = 0; i < 3; i++ ) {
	node->pos[i] = pos[i];

	if ( node->pos[i] < bboxMin[i] )
	    bboxMin[i] = node->pos[i];
	if ( node->pos[i] > bboxMax[i] )
	    bboxMax[i] = node->pos[i];

	node->power[i] = power[i];
    }

    int theta = (int) ( acos( dir[2] ) * (256.0/M_PI) );
    if ( theta > 255 )
	node->theta = 255;
    else
	node->theta = (unsigned char) theta;

    int phi = (int) ( atan2( dir[1], dir[0] ) * (256.0/(2.0*M_PI)) );
    if ( phi > 255 )
	node->phi = 255;
    else if ( phi < 0 )
	node->phi = (unsigned char) (phi+256);
    else
	node->phi = (unsigned char) phi;
}

void
PhotonMap::scalePhotonPower( const float scale ) {
    for ( int i = prevScale; i <= storedPhotons; i++ ) {
	photons[i].power[0] *= scale;
	photons[i].power[1] *= scale;
	photons[i].power[2] *= scale;
    }
    prevScale = storedPhotons;
}

void
PhotonMap::balance( void ) {
    if ( storedPhotons > 1 ) {
	Photon ** pa1 = new (Photon *)[storedPhotons + 1];
	Photon ** pa2 = new (Photon *)[storedPhotons + 1];

	for ( int i = 0; i <= storedPhotons; i++ )
	    pa2[i] = &photons[i];

	balanceSegment( pa1, pa2, 1, 1, storedPhotons );
	delete [] pa2;

	int d, j=1, foo=1;
	Photon fooPhoton = photons[j];

	for ( int i = 1; i <= storedPhotons; i++ ) {
	    d = pa1[j] - photons;
	    pa1[j] = NULL;
	    if ( d != foo )
		photons[j] = photons[d];
	    else {
		photons[j] = fooPhoton;

		if ( i < storedPhotons) {
		    for ( ; foo <= storedPhotons; foo++ )
			if ( pa1[foo] != NULL )
			    break;
		    fooPhoton = photons[foo];
		    j = foo;
		}
		continue;
	    }
	    j = d;
	}
	delete [] pa1;
    }

    halfStoredPhotons = storedPhotons/2 - 1;
}

#define swap( ph, a, b ) { Photon * ph2 = ph[a]; ph[a] = ph[b]; ph[b] = ph2; }

void
PhotonMap::medianSplit( Photon ** p, const int start, const int end,
    const int median, const int axis ) {
    
    int left = start;
    int right = end;

    while ( right > left ) {
	const float v = p[right]->pos[axis];
	int i = left - 1;
	int j = right;
	for (;;) {
	    while ( p[++i]->pos[axis] < v )
		;
	    while ( p[--j]->pos[axis] > v && j > left )
		;
	    if ( i >= j )
		break;
	    swap( p, i, j );
	}

	swap( p, i, right );
	if ( i >= median )
	    right = i - 1;
	if ( i <= median )
	    left = i + 1;
    }
}

void
PhotonMap::balanceSegment( Photon ** pbal, Photon ** porg, const int index,
    const int start, const int end ) {
    int median = 1;

    while ( (4*median) <= (end - start + 1) )
	median += median;

    if ( (3*median) <= (end - start + 1) ) {
	median += median;
	median += start - 1;
    }
    else
	median = end - median + 1;

    int axis = 2;
    if ( (bboxMax[0] - bboxMin[0]) > (bboxMax[1] - bboxMin[1]) &&
	 (bboxMax[0] - bboxMin[0]) > (bboxMax[2] - bboxMin[2]) )
	axis = 0;
    else if ( (bboxMax[1] - bboxMin[1]) > (bboxMax[2] - bboxMin[2]) )
	axis = 1;

    medianSplit( porg, start, end, median, axis );

    pbal[index] = porg[median];
    pbal[index]->plane = axis;

    if ( median > start ) {
	if ( start < median - 1 ) {
	    const float tmp = bboxMax[axis];
	    bboxMax[axis] = pbal[index]->pos[axis];
	    balanceSegment( pbal, porg, 2*index, start, median-1 );
	    bboxMax[axis] = tmp;
	} 
	else {
	    pbal[2*index] = porg[start];
	}
    }

    if ( median < end ) {
	if ( median + 1 < end ) {
	    const float tmp = bboxMin[axis];
	    bboxMin[axis] = pbal[index]->pos[axis];
	    balanceSegment( pbal, porg, 2*index+1, median+1, end );
	    bboxMin[axis] = tmp;
	} 
	else {
	    pbal[2*index+1] = porg[end];
	}
    }
}
