#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <math.h>

#include "photonmap.h"

#define VARIANCE 0.05 //percent
#define INCLUDE_DIRECT_ILLUM false

/* This is the constructor for the photon map.
 * To create the photon map it is necessary to specify the
 * maximum number of photons that will be stored
*/
//************************************************
PhotonMap :: PhotonMap( const int max_phot, const int type )
//************************************************
{
  stored_photons = 0;
  prev_scale = 1;
  max_photons = max_phot;
  map_type = type;

  photons = (Photon*)malloc( sizeof( Photon ) * ( max_photons+1 ) );

  if (photons == NULL) {
    fprintf(stderr,"Out of memory initializing photon map\n");
    exit(-1);
  }

  bbox_min[0] = bbox_min[1] = bbox_min[2] = 1e8f;
  bbox_max[0] = bbox_max[1] = bbox_max[2] = -1e8f;
  
  //----------------------------------------
  // initialize direction conversion tables
  //----------------------------------------

  for (int i=0; i<256; i++) {
    double angle = double(i)*(1.0/256.0)*M_PI;
    costheta[i] = cos( angle );
    sintheta[i] = sin( angle );
    cosphi[i]   = cos( 2.0*angle );
    sinphi[i]   = sin( 2.0*angle );
  }
}


//*************************
PhotonMap :: ~PhotonMap()
//*************************
{
  free( photons );
}

/* This builds the photon map of the specified type */
//************************************************
void PhotonMap :: build_map( Scene *scene )
//************************************************
{
  Ray ray;
  Spectrum sun;
  Spectrum sky;
  float power[3];
  Vector vector;
  int emitted_photons = 0;

  cout << "Building Photon Map: " << endl;
  PointLight *light = (PointLight*)(scene->lights[0]);
  sun = (light->Power)*4*M_PI;
  sun.ConvertToRGB(power);
  sky = scene->envLight->skyBase*4*M_PI;

  cerr << "Sky: " << sky << endl;
  cerr << "Sun: " << sun << endl;
  cout << "0.00 percent\r";

  int max_photons_sun = max_photons - (int)((float)max_photons*
					   (float)sky.Intensity()/
					   (float)sun.Intensity());
  //sample sun, a point light
  while(stored_photons < max_photons_sun && 
	emitted_photons < 2*max_photons){
    
    //sky-light 
    // x: -9.2 -> 0.4 
    // y: 24.2 (constant)
    // z: 22.2 -> -21

    //side-glass
    // x: -9.2 -> 0.4
    // y: 0 -> 24.2
    // z: 22.2
    float selector = RandomFloat(0,1);
    float rand_x;
    float rand_y;
    float rand_z;
    float test = 1;
    //sky-light
    if(selector < 24.2/43.2){
	rand_x = RandomFloat(0,1)*(9.6)-9.2;
	rand_y = 24.2;
	rand_z = RandomFloat(0,1)*(43.2)-21.;
    }
    //glass
    else{
	rand_x = RandomFloat(0,1)*(9.6)-9.2;
	rand_y = RandomFloat(0,1)*24.2;
	rand_z = 22.2;
    }    

    Vector wo = Vector(rand_x - light->lightPos.x, 
		       rand_y - light->lightPos.y, 
		       rand_z - light->lightPos.z);
    ray = Ray(light->lightPos, wo);
    //emit photon (trace ray)
    trace_photon(power, scene, ray, INCLUDE_DIRECT_ILLUM);
    emitted_photons++;
    if(stored_photons%100==0){
      fprintf(stdout, "%.2f percent\r", (((float)stored_photons 
				   	/ (float)max_photons)*100.));
    }
  }
  //sample sky, env map
  while(stored_photons < max_photons){
    
    //sky-light 
    // x: -9.2 -> 0.4 
    // y: 24.2 (constant)
    // z: 22.2 -> -21  
  
    //side-glass
    // x: -9.2 -> 0.4
    // y: 0 -> 24.2
    // z: 22.2
    float selector = RandomFloat(0,1);
    float rand_x;
    float rand_y;
    float rand_z;
    //sky-light
    if(selector < 24.2/43.2){
      rand_x = RandomFloat(0,1)*(9.6)-9.2;
      rand_y = 24.2;
      rand_z = RandomFloat(0,1)*(43.2)-21.;
    }
    //glass
    else{
      rand_x = RandomFloat(0,1)*(9.6)-9.2;
      rand_y = RandomFloat(0,1)*24.2;
      rand_z = 22.2;
    }

    float test = 1;
    float theta;
    float phi;
    float R;
    
    float y;
    float cosPhi;
    float x;
    float z;
    while(test >= 0){
      float u1 = RandomFloat(0, 1);
      float r = sqrt(u1);
      phi = acos(r);
      float u2 = RandomFloat(0, 1);
      theta = 2.*M_PI*u2;
      //theta = RandomFloat(0,1)*2*M_PI;
      //phi = RandomFloat(0,1)*M_PI/2.;
      R = 1000.;
      
      y = R*sin(phi);
      cosPhi = R*cos(phi);
      x = sin(theta)*cosPhi;
      z = cos(theta)*cosPhi;
      
      if(selector < 24.2/43.2)
	test = rand_y - y;
      else
	test = x - rand_x;
    }
    
    Vector wo = Vector(rand_x - x, 
		       rand_y - y, 
		       rand_z - z);
    ray = Ray(Point(x,y,z), wo);
    //emit photon (trace ray)
    trace_photon(power, scene, ray, INCLUDE_DIRECT_ILLUM);
    emitted_photons++;
    if(stored_photons%100==0){
      fprintf(stdout, "%.2f percent\r", (((float)stored_photons 
				   	/ (float)max_photons)*100.));
    }
  }
 
  cerr << endl;
  cerr << "Stored photons: " << stored_photons << endl;
  cerr << "Emitted photons: " << emitted_photons << endl;
  
  //scale photon_map
  scale_photon_power( 1./emitted_photons );
  
  //balance KD tree
  balance(); 
}

/* This traces photons with Russian Roulette */
//************************************************
void PhotonMap :: trace_photon( float light_power[3], Scene *scene, Ray ray, bool storage )
//************************************************
{
  float P_D[3]; 
  float P_S[3];
  float Pd;
  float Ps;
  float pos[3];
  float dir[3];
  float power[3];
  float normal[3];
  Spectrum diffuse, specular;
  Vector wo;

  //Run intersect test
  Surf surf;
  if(scene->Intersect(ray, &surf)){
    if(Dot(ray.D, surf.dgGeom.N) < 0){
      return;
    }
    //Extract info from intersetion
    BSDF *bsdf = surf.getBSDF();

    //get photon storage info
    Vector norm_vector = -(ray.D.Hat());
    //dir[0] = -ray.D.x;
    //dir[1] = -ray.D.y;
    //dir[2] = -ray.D.z;
    dir[0] = norm_vector.x;
    dir[1] = norm_vector.y;
    dir[2] = norm_vector.z;
    pos[0] = surf.dgGeom.P.x;
    pos[1] = surf.dgGeom.P.y;
    pos[2] = surf.dgGeom.P.z;
    normal[0] = surf.dgGeom.N.x;
    normal[1] = surf.dgGeom.N.y;
    normal[2] = surf.dgGeom.N.z;

    // Normalize the Russian roulette probabilities    

    //REALLY HACKY, need to figure out how to correctly sample
    //multiple BRDFs
    diffuse = bsdf->get_diffuse();
    specular = bsdf->get_specular();

    diffuse.ConvertToRGB(P_D);
    specular.ConvertToRGB(P_S);   

    //Figure out normalization factor for color bleeding
    Pd = (P_D[0] + P_D[1] + P_D[2])/3.;
    Ps = (P_S[0] + P_S[1] + P_S[2])/3.;
 
    //Bounce the photon around using Russian roulette   
    Float epsilon = RandomFloat(0, 1);

    //light_power.ConvertToRGB(power);
    //Diffuse
    if (epsilon < Pd){

      if(storage)
	store(light_power, pos, dir, normal);

      //calculate for next hop
      bsdf->sample_fDiff(-ray.D, &wo);
      Ray ray = Ray(Point(pos[0], pos[1], pos[2]), wo);

      //Adjust light for color bleeding
      power[0] = light_power[0] * P_D[0] / Pd;
      power[1] = light_power[1] * P_D[1] / Pd;
      power[2] = light_power[2] * P_D[2] / Pd;

      //trace new photon
      trace_photon(power, scene, ray, true);
    }
    //Specular
    else if (epsilon < Pd+Ps){
      //reflect, but do not store
      
      //calculate for next hop
      bsdf->sample_fSpec(-ray.D, &wo);
      Ray ray = Ray(Point(pos[0], pos[1], pos[2]), wo);

      //Adjust light for color bleeding
      power[0] = light_power[0] * P_S[0] / Ps;
      power[1] = light_power[1] * P_S[1] / Ps;
      power[2] = light_power[2] * P_S[2] / Ps;

      //trace new photon
      trace_photon(power, scene, ray, true);
    }
    //Absorption - Photon stops
    else{
      if(storage)
	store(light_power, pos, dir, normal);
    }		
  }
}
  
/* photon_dir returns the direction of a photon
 */
//*****************************************************************
void PhotonMap :: photon_dir( float *dir, const Photon *p ) const
//*****************************************************************
{
  dir[0] = sintheta[p->theta]*cosphi[p->phi];
  dir[1] = sintheta[p->theta]*sinphi[p->phi];
  dir[2] = costheta[p->theta];
}


Vector PhotonMap :: sample_photon( const float pos[3], const float normal[3], 
				   const float max_dist, const int nphotons)
{
  static float test_pos[3] = {0., 0., 0.};
  static NearestPhotons np;

  if(pos[0] != test_pos[0] || 
     pos[1] != test_pos[1] || 
     pos[2] != test_pos[2]){
    test_pos[0] = pos[0];
    test_pos[1] = pos[1]; 
    test_pos[2] = pos[2];
    if(np.dist2 != NULL){
      free(np.dist2);
    }
    if(np.index != NULL){
      free(np.index);
    }
    
    np.dist2 = (float*)alloca( sizeof(float)*(nphotons+1) );
    np.index = (const Photon**)alloca( sizeof(Photon*)*(nphotons+1) );
    
    np.pos[0] = pos[0]; np.pos[1] = pos[1]; np.pos[2] = pos[2];
    np.max = nphotons;
    np.found = 0;
    np.got_heap = 0;
    np.dist2[0] = max_dist*max_dist;
    
    locate_photons( &np, 1 );
  }

  float rand = RandomFloat(0,1);
  const Photon *p = np.index[(int)(rand*np.found)];
  float theta = p->theta;
  float phi   = p->phi;    
  if(fabs(normal[0] - p->normal[0]) <= fabs(normal[0] * VARIANCE) &&
     fabs(normal[1] - p->normal[1]) <= fabs(normal[1] * VARIANCE) &&
     fabs(normal[2] - p->normal[2]) <= fabs(normal[2] * VARIANCE) ){
    return sample_photon(pos,normal,max_dist,nphotons);
  }
  else{
    /*float z = cos(theta*M_PI/256.);
    float r = sqrt(1.-pow(z,2));
    float x = r*cos(phi*2*M_PI/256.);
    float y = r*sin(phi*2*M_PI/256.);  */
    float x = sintheta[p->theta]*cosphi[p->phi];
    float y  = sintheta[p->theta]*sinphi[p->phi];
    float z  = costheta[p->theta];
    return Vector(x,y,z);
  }
  
}

/* irradiance_estimate computes an irradiance estimate
 * at a given surface position
*/
//**********************************************
void PhotonMap :: irradiance_estimate(
  float irrad[3],                // returned irradiance
  const float pos[3],            // surface position
  const float normal[3],         // surface normal at pos
  const float max_dist,          // max distance to look for photons
  const int nphotons ) const     // number of photons to use
//**********************************************
{
  int number_read = 0;
  irrad[0] = irrad[1] = irrad[2] = 0.0;

  NearestPhotons np;
  np.dist2 = (float*)alloca( sizeof(float)*(nphotons+1) );
  np.index = (const Photon**)alloca( sizeof(Photon*)*(nphotons+1) );

  np.pos[0] = pos[0]; np.pos[1] = pos[1]; np.pos[2] = pos[2];
  np.max = nphotons;
  np.found = 0;
  np.got_heap = 0;
  np.dist2[0] = max_dist*max_dist;
  //np.normal[0] = normal[0]; np.normal[1] = normal[1]; np.normal[1] = normal[1];

  // locate the nearest photons
  locate_photons( &np, 1 );

  // if less than 8 photons return
  if (np.found<8)
    return;

  //float pdir[3];

  // sum irradiance from all photons
  for (int i=1; i<=np.found; i++) {
    const Photon *p = np.index[i];
    // the photon_dir call and following if can be omitted (for speed)
    // if the scene does not have any thin surfaces
    //photon_dir( pdir, p );
    /*if ( (pdir[0]*normal[0]+pdir[1]*normal[1]+pdir[2]*normal[2]) < 0.0f ) {
      irrad[0] += p->power[0];
      irrad[1] += p->power[1];
      irrad[2] += p->power[2];
      }*/

    // try not to include drastically varying normals
    if(fabs(normal[0] - p->normal[0]) <= fabs(normal[0] * VARIANCE) &&
       fabs(normal[1] - p->normal[1]) <= fabs(normal[1] * VARIANCE) &&
       fabs(normal[2] - p->normal[2]) <= fabs(normal[2] * VARIANCE) ){
      irrad[0] += p->power[0];
      irrad[1] += p->power[1];
      irrad[2] += p->power[2];
      number_read++;
    }
  }
  // estimate of density
  const float tmp=(1.0f/M_PI)/(np.dist2[0]);

  irrad[0] *= tmp;
  irrad[1] *= tmp;
  irrad[2] *= tmp;
}


/* locate_photons finds the nearest photons in the
 * photon map given the parameters in np
*/
//******************************************
void PhotonMap :: locate_photons(
  NearestPhotons *const np,
  const int index ) const
//******************************************
{
  const Photon *p = &photons[index];
  float dist1;

  if (index<half_stored_photons) {
    dist1 = np->pos[ p->plane ] - p->pos[ p->plane ];

    if (dist1>0.0) { // if dist1 is positive search right plane
      locate_photons( np, 2*index+1 );
      if ( dist1*dist1 < np->dist2[0] )
        locate_photons( np, 2*index );
    } else {         // dist1 is negative search left first
      locate_photons( np, 2*index );
      if ( dist1*dist1 < np->dist2[0] )
        locate_photons( np, 2*index+1 );
    }
  }

  // compute squared distance between current photon and np->pos

  dist1 = p->pos[0] - np->pos[0];
  float dist2 = dist1*dist1;
  dist1 = p->pos[1] - np->pos[1];
  dist2 += dist1*dist1;
  dist1 = p->pos[2] - np->pos[2];
  dist2 += dist1*dist1;
  
  if ( dist2 < np->dist2[0] ) {
    // we found a photon :) Insert it in the candidate list

    if ( np->found < np->max ) {
      // heap is not full; use array
      np->found++;
      np->dist2[np->found] = dist2;
      np->index[np->found] = p;
    } 
    else {
      int j,parent;
      
      if (np->got_heap==0) { // Do we need to build the heap?
        // Build heap
        float dst2;
        const Photon *phot;
        int half_found = np->found>>1;
        for ( int k=half_found; k>=1; k--) {
          parent=k;
          phot = np->index[k];
          dst2 = np->dist2[k];
          while ( parent <= half_found ) {
            j = parent+parent;
            if (j<np->found && np->dist2[j]<np->dist2[j+1])
              j++;
            if (dst2>=np->dist2[j])
              break;
            np->dist2[parent] = np->dist2[j];
            np->index[parent] = np->index[j];
            parent=j;
          }
          np->dist2[parent] = dst2;
          np->index[parent] = phot;
        }
        np->got_heap = 1;
      }

      // insert new photon into max heap
      // delete largest element, insert new and reorder the heap

      parent=1;
      j = 2;
      while ( j <= np->found ) {
        if ( j < np->found && np->dist2[j] < np->dist2[j+1] )
          j++;
        if ( dist2 > np->dist2[j] )
          break;
        np->dist2[parent] = np->dist2[j];
        np->index[parent] = np->index[j];
        parent = j;
        j += j;
      }
      np->index[parent] = p;
      np->dist2[parent] = dist2;

      np->dist2[0] = np->dist2[1];
    }
  }
}


/* store puts a photon into the flat array that will form
 * the final kd-tree.
 *
 * Call this function to store a photon.
*/
//***************************
void PhotonMap :: store(
  const float power[3],
  const float pos[3],
  const float dir[3],
  const float normal[3] )
//***************************
{
  if (stored_photons>=max_photons)
    return;

  stored_photons++;
  Photon *const node = &photons[stored_photons];

  for (int i=0; i<3; i++) {
    node->pos[i] = pos[i];

    if (node->pos[i] < bbox_min[i])
      bbox_min[i] = node->pos[i];
    if (node->pos[i] > bbox_max[i])
      bbox_max[i] = node->pos[i];

    node->power[i] = power[i];
    node->normal[i] = normal[i];
  }

  int theta = int( acos(dir[2])*(256.0/M_PI) );
  if (theta>255)
    node->theta = 255;
  else
   node->theta = (unsigned char)theta;

  int phi = int( atan2(dir[1],dir[0])*(256.0/(2.0*M_PI)) );
  if (phi>255)
    node->phi = 255;
  else if (phi<0)
    node->phi = (unsigned char)(phi+256);
  else
    node->phi = (unsigned char)phi;
}


/* scale_photon_power is used to scale the power of all
 * photons once they have been emitted from the light
 * source. scale = 1/(#emitted photons).
 * Call this function after each light source is processed.
*/
//********************************************************
void PhotonMap :: scale_photon_power( const float scale )
//********************************************************
{
  for (int i=prev_scale; i<=stored_photons; i++) {
    photons[i].power[0] *= scale;
    photons[i].power[1] *= scale;
    photons[i].power[2] *= scale;
  }
  prev_scale = stored_photons;
}


/* balance creates a left balanced kd-tree from the flat photon array.
 * This function should be called before the photon map
 * is used for rendering.
 */
//******************************
void PhotonMap :: balance(void)
//******************************
{
  if (stored_photons>1) {
    // allocate two temporary arrays for the balancing procedure
    Photon **pa1 = (Photon**)malloc(sizeof(Photon*)*(stored_photons+1));
    Photon **pa2 = (Photon**)malloc(sizeof(Photon*)*(stored_photons+1));

    for (int i=0; i<=stored_photons; i++)
      pa2[i] = &photons[i];

    balance_segment( pa1, pa2, 1, 1, stored_photons );
    free(pa2);

    // reorganize balanced kd-tree (make a heap)
    int d, j=1, foo=1;
    Photon foo_photon = photons[j];

    for (int i=1; i<=stored_photons; i++) {
      d=pa1[j]-photons;
      pa1[j] = NULL;
      if (d != foo)
        photons[j] = photons[d];
      else {
        photons[j] = foo_photon;

        if (i<stored_photons) {
          for (;foo<=stored_photons; foo++)
            if (pa1[foo] != NULL)
              break;
          foo_photon = photons[foo];
          j = foo;
        }
        continue;
      }
      j = d;
    }
    free(pa1);
  }

  half_stored_photons = stored_photons/2-1;
}


#define swap(ph,a,b) { Photon *ph2=ph[a]; ph[a]=ph[b]; ph[b]=ph2; }

// median_split splits the photon array into two separate
// pieces around the median with all photons below the
// the median in the lower half and all photons above
// than the median in the upper half. The comparison
// criteria is the axis (indicated by the axis parameter)
// (inspired by routine in "Algorithms in C++" by Sedgewick)
//*****************************************************************
void PhotonMap :: median_split(
  Photon **p,
  const int start,               // start of photon block in array
  const int end,                 // end of photon block in array
  const int median,              // desired median number
  const int axis )               // axis to split along
//*****************************************************************
{
  int left = start;
  int right = end;

  while ( right > left ) {
    const float v = p[right]->pos[axis];
    int i=left-1;
    int j=right;
    for (;;) {
      while ( p[++i]->pos[axis] < v )
        ;
      while ( p[--j]->pos[axis] > v && j>left )
        ;
      if ( i >= j )
        break;
      swap(p,i,j);
    }

    swap(p,i,right);
    if ( i >= median )
      right=i-1;
    if ( i <= median )
      left=i+1;
  }
}

  
// See "Realistic image synthesis using Photon Mapping" chapter 6
// for an explanation of this function
//****************************
void PhotonMap :: balance_segment(
  Photon **pbal,
  Photon **porg,
  const int index,
  const int start,
  const int end )
//****************************
{
  //--------------------
  // compute new median
  //--------------------

  int median=1;
  while ((4*median) <= (end-start+1))
    median += median;

  if ((3*median) <= (end-start+1)) {
    median += median;
    median += start-1;
  } else	
    median = end-median+1;

  //--------------------------
  // find axis to split along
  //--------------------------

  int axis=2;
  if ((bbox_max[0]-bbox_min[0])>(bbox_max[1]-bbox_min[1]) &&
      (bbox_max[0]-bbox_min[0])>(bbox_max[2]-bbox_min[2]))
    axis=0;
  else if ((bbox_max[1]-bbox_min[1])>(bbox_max[2]-bbox_min[2]))
    axis=1;

  //------------------------------------------
  // partition photon block around the median
  //------------------------------------------

  median_split( porg, start, end, median, axis );

  pbal[ index ] = porg[ median ];
  pbal[ index ]->plane = axis;

  //----------------------------------------------
  // recursively balance the left and right block
  //----------------------------------------------

  if ( median > start ) {
    // balance left segment
    if ( start < median-1 ) {
      const float tmp=bbox_max[axis];
      bbox_max[axis] = pbal[index]->pos[axis];
      balance_segment( pbal, porg, 2*index, start, median-1 );
      bbox_max[axis] = tmp;
    } else {
      pbal[ 2*index ] = porg[start];
    }
  }

  if ( median < end ) {
    // balance right segment
    if ( median+1 < end ) {
      const float tmp = bbox_min[axis];		
      bbox_min[axis] = pbal[index]->pos[axis];
      balance_segment( pbal, porg, 2*index+1, median+1, end );
      bbox_min[axis] = tmp;
    } else {
      pbal[ 2*index+1 ] = porg[end];
    }
  }	
}

