#include "photonmap.h"

#define Multiplier -1664117991L

ANNpoint GetKeyFromPhoton(Photon *photon)
{
  return photon->annPosition;
}

unsigned int HashKey(ANNpoint key)
{
  unsigned int hashCode = (unsigned int) key[0];

  hashCode = hashCode * Multiplier + (unsigned int) key[1];
  hashCode = hashCode * Multiplier + (unsigned int) key[2];
  return hashCode;
}

PhotonMap::PhotonMap(Scene *scene, MapType type, int photonsPerLight,
		     vector < float > *primSpheres,
		     vector < Vector > *primSphereOffsets,
		     vector < float > *weights)
{
  // Hacked until I figure out how to generate a list of all lights in the
  // scene
  for (int i = 0; i == 0 /* i < (int) scene->primitives.size() */; i++)
    {
      list <Light *> &lights = scene->primitives[i]->attributes->Lights;
      list <Light *> ::const_iterator lightIterator = lights.begin();
      while (lightIterator != lights.end())
	{
	  allLights.push_back(*lightIterator);
	  lightIterator++;
	}
    }

  this->primSpheres = primSpheres;
  this->primSphereOffsets = primSphereOffsets;
  this->weights = weights;
  this->photonsPerLight = photonsPerLight;
  this->scene = scene;
  this->type = type;

  photonArray = annAllocPts(photonsPerLight * (int) allLights.size(), 3);
  photonTable = new HashTable <ANNpoint, Photon *> (GetKeyFromPhoton, HashKey);

  maxNphotons = 0;
  distances = NULL;
  nearestPoints = NULL;
  annReferencePoint = annAllocPt(3);
  nearestPhotons = NULL;

  fprintf(stderr, "Photon map initialized.\n");
}

PhotonMap::~PhotonMap()
{
  delete photonTable;
  delete photonTree;
  annDeallocPts(photonArray);

  delete distances;
  delete nearestPoints;
  annDeallocPt(annReferencePoint);
  delete nearestPhotons;
}

void PhotonMap::GeneratePhotonMap()
{
  int numPoints = 0;
  list <Light *> ::const_iterator lightIterator = allLights.begin();

  if (type == GlobalMap) {
    fprintf(stderr, "Emitting global photons...\n");
  } else if (type == CausticMap) {
    fprintf(stderr, "Emitting caustic photons...\n");
  } else if (type == IndirectMap) {
    fprintf(stderr, "Emitting indirect photons...\n");
  }

  // Iterate over all isotropic point light sources
  while (lightIterator != allLights.end())
    {
      const Light *light = *lightIterator;

      if (light->type == IsotropicPoint)
	for (int i = 0; i < photonsPerLight; )
	  {
	    HitInfo hitInfo;
	    int numTransmissions = 0;
	    IsotropicPointLight *ipLight = (IsotropicPointLight *) light;
	    Ray photonRay;

	    // Generate the initial photon ray
	    if (type != CausticMap)
	      {

		Float x = 1, y = 1, z = 1;
		while (x * x + y * y + z * z > 1)
		  {
		    x = RandomFloat(-1, 1);
		    y = RandomFloat(-1, 1);
		    z = RandomFloat(-1, 1);
		  }

		Vector photonDirection(x, y, z);
		photonDirection.Normalize();
		photonRay = Ray(ipLight->lightPos, photonDirection);
	      }
	    else
	      {
		HitInfo hitInfo;

		// choose a random caustic bounding sphere to sample
		// (need to weight this)

#if 1
		int index = 0;
		float u = RandomFloat();
		float p = 1.;
		for (int i = primSpheres->size() - 1; i >= 0; i--) {
		  p -= (*weights)[i];
		  if (u >= p) {
		    index = i;
		    break;
		  }
		}
#else
		int index = (int) (RandomFloat() * primSpheres->size());
#endif
#if 1
		Float x = 1, y = 1, z = 1;
		while (x * x + y * y + z * z > 1)
		  {
		    x = RandomFloat(-1, 1);
		    y = RandomFloat(-1, 1);
		    z = RandomFloat(-1, 1);
		  }

		Point P;
		P.x = x * (*primSpheres)[index];
		P.y = y * (*primSpheres)[index];
		P.z = z * (*primSpheres)[index];
#endif
		Point intersection = P + (*primSphereOffsets)[index];
		Vector photonDirection(intersection - ipLight->lightPos);
		photonDirection.Normalize();
		photonRay = Ray(ipLight->lightPos, photonDirection);
	      }

	    //fprintf(stderr, "Hit a bounding sphere\n");

	    // Generate the photon, making sure that it always first hits
	    // a surface capable of generating caustics if in caustic map
	    // mode
	    Float maxt = INFINITY;
	    if (scene->Intersect(photonRay, 1e-4, &maxt, &hitInfo) &&
		(type != CausticMap || 
		 (hitInfo.hitPrim->attributes->Surface->generatesCaustics)))
	      {
		bool addPhotonToMap = false;
		bool isIndirectPhoton = false;
		Material *surface = hitInfo.hitPrim->attributes->Surface;

		Spectrum powerMod(1.);

		// Bounce the photon around using Russian roulette
		while (true)
		  {
		    ShadeContext shadeContext(&hitInfo, -photonRay.D);
		    BRDF *brdf = hitInfo.hitPrim->attributes->Surface->Shade(shadeContext);
		    surface = hitInfo.hitPrim->attributes->Surface;

		    // Normalize the Russian roulette probabilities in
		    // preparation for generating the global map
		    Float total;
		    if (type != CausticMap)
		      {
			total = surface->pAbsorption + surface->pDiffuse + surface->pSpecular + surface->pTransmission;
			surface->pAbsorption /= total;
			surface->pDiffuse /= total;
			surface->pSpecular /= total;
			surface->pTransmission /= total;
		      }
		    maxt = INFINITY;

		    photonRay.D = -photonRay.D;

		    if (type != CausticMap)
		      {
			// For the global map, randomly select absorption,
			// diffuse reflection, specular reflection or
			// transmission for the photon, and handle each case
			// separately
			Float epsilon = RandomFloat(0, 1);
			if (epsilon < surface->pAbsorption)
			{
			    // absorption

			    if (type == IndirectMap && isIndirectPhoton == false) {
				addPhotonToMap = false;
				break;
			    }

			    if (type == IndirectMap && isIndirectPhoton) {
			      //fprintf(stderr, "... Adding Indirect photon ...\n");
			      addPhotonToMap = true;
			      break;
			    }

			    addPhotonToMap = true;
			    break;
			}
			// it can only be an Indirect photon by now...
			else if (epsilon < surface->pAbsorption + surface->pDiffuse) {
			  // diffuse reflection
			  Normal nm = hitInfo.hitPrim->attributes->ObjectToWorld(hitInfo.NgObj);
			  photonRay.D = -photonRay.D - 2.0 * Dot(nm, -photonRay.D) * Vector(nm);
			  powerMod *= brdf->fr(photonRay.D);
			}
			else if (epsilon < surface->pAbsorption + surface->pDiffuse + surface->pSpecular) {
			  // specular reflection
			  powerMod *= brdf->SampleSpecular(1, &photonRay.D);
			}
			else {
			  // transmission
			  powerMod *= brdf->SampleSpecular(0, &photonRay.D);
			}			

			//fprintf(stderr, "=== We have an indirect one! ===\n");
			isIndirectPhoton = true;
		      }
		    else
		      {
			// For the caustic map, always transmit the photon
			// if it hits a surface that is capable of
			// generating caustics (i.e., a transmittant
			// surface), and otherwise always absorb the
			// photon
			if (surface->generatesCaustics) {
			  powerMod *= brdf->SampleSpecular(0, &photonRay.D);
			}
			else
			  {
			    //fprintf(stderr, "... Adding Photon! (%d)...\n", numPoints);
			    addPhotonToMap = true;
			    break;
			  }

			// Reject photons that bounce back and forth
			// infinitely between transmissive surfaces
			if (numTransmissions++ >= 25)
			  break;
		      }

		    photonRay.O = hitInfo.hitPrim->attributes->ObjectToWorld(hitInfo.Pobj);
		    maxt = INFINITY;

		    if (!scene->Intersect(photonRay, 1e-4, &maxt, &hitInfo))
		      {
			//fprintf(stderr, "Hit something, but then missed scene\n");
			addPhotonToMap = false;
			break;
		      }
		  }

		// If the final photon hits an object, store it in an
		// array to be used for constructing the K-D tree
		if (addPhotonToMap)
		  {
		    Photon *photon = new Photon;
		      
		    photon->position = hitInfo.hitPrim->attributes->ObjectToWorld(hitInfo.Pobj);
		    photon->incidentDirection = photonRay.O - photon->position;
		    //photon->isIndirectPhoton = isIndirectPhoton;
		    photon->power = ipLight->I / photonsPerLight;
		    photon->power *= powerMod;

		    ANNpoint annPoint = photonArray[numPoints];
		    annPoint[0] = photon->position.x;
		    annPoint[1] = photon->position.y;
		    annPoint[2] = photon->position.z;
		    photon->annPosition = annPoint;
		      
		    // Also, because the K-D tree only stores photon
		    // locations, store the rest of the photon data in a
		    // hashtable keyed by photon location
		    photonTable->Insert(photon);
		    numPoints++;

		    if (numPoints % 10000 == 0) cerr << 'o';
		    else if (numPoints % 1000 == 0) cerr << '.';

		    if (type == CausticMap)
		      i++;
		  }
	      }

	    if (type != CausticMap)
	      i++;
	  } // Loop over all photons emitted from a light

      lightIterator++;
    } // Loop over all lights

  // Generate K-D tree of photons
  photonTree = new ANNkd_tree(photonArray, numPoints, 3);

  fprintf(stderr, "\n");
  if (type == GlobalMap) {
    fprintf(stderr, "Map of %d global photons generated successfully.\n", photonTable->NumItems());
  } else if (type == CausticMap) {
    fprintf(stderr, "Map of %d caustic photons generated successfully.\n", photonTable->NumItems());
  } else if (type == IndirectMap) {
    fprintf(stderr, "Map of %d indirect photons generated successfully.\n", photonTable->NumItems());
  }
}


float PhotonMap::GetClosestDistance(Point ref)
{
  ANNdistArray distances = new ANNdist[1];
  ANNidxArray nearestPoints = new ANNidx[1];
  ANNpoint annReferencePoint = annAllocPt(3);

  annReferencePoint[0] = ref.x;
  annReferencePoint[1] = ref.y;
  annReferencePoint[2] = ref.z;

  photonTree->annkSearch(annReferencePoint, 1, nearestPoints, distances);

  float dist = distances[0];

  delete distances;
  delete nearestPoints;

  return dist;
}


Photon **PhotonMap::GetNearestPhotons(Point referencePoint, int *numPhotons)
{
  if (*numPhotons > maxNphotons) *numPhotons = maxNphotons;
  annReferencePoint[0] = referencePoint.x;
  annReferencePoint[1] = referencePoint.y;
  annReferencePoint[2] = referencePoint.z;

  photonTree->annkSearch(annReferencePoint, *numPhotons, nearestPoints, distances);

  for (int i = 0; i < *numPhotons; i++)
  {
    ANNpoint annPoint = photonArray[nearestPoints[i]];
    if (photonTable->Find(annPoint, &nearestPhotons[i]))
      nearestPhotons[i]->distance = distances[i];
    else
      fprintf(stderr, "Error: Photon map was not generated correctly.\n");
  }

  //fprintf(stderr, "Set of %d nearest photons generated successfully.\n", *numPhotons);
  return nearestPhotons;
}

void PhotonMap::SetMaxN(int numPhotons)
{
	Assert(numPhotons > 0);
	if (numPhotons > photonTable->NumItems()) {
	  numPhotons = photonTable->NumItems();
	}
	maxNphotons = numPhotons;
	distances = new ANNdist[numPhotons];
	nearestPoints = new ANNidx[numPhotons];
	nearestPhotons = new Photon *[numPhotons];
}
