#include "heightfield.h"

static inline float Abs (float x) {return x >= 0 ? x : -x;}
static inline float Sign (float x) {return x >= 0 ? 1 : -1;}


#define VERT(x,y) ((x)+(y)*nx)   // only valid in scopes where nx exists (i.e. all members defined below)
#define CELL(x,y) ((x)+(y)*(nx-1))   // only valid in scopes where nx exists (i.e. all members defined below)

Heightfield::Heightfield(const Transform &o2w, int x, int y, float *zs)
	: Shape(o2w)
{
	nx = x;
	ny = y;
	dx = 1./(nx-1);
	dy = 1./(ny-1);

	// copy over the heights
	z = new float[nx*ny];
	memcpy(z, zs, nx*ny*sizeof(float));

	BoundZ();
	ConstructVertexNormals();
	ConstructSmoothApproximation();
}


void Heightfield::BoundZ (void)
{
	minz = z[0];
	maxz = z[0];
	for (int i = 1; i < nx*ny; ++i) {
		if (z[i] < minz) minz = z[i];
		if (z[i] > maxz) maxz = z[i];
	}
}


void Heightfield::ConstructVertexNormals (void)
{
	n = new Normal[nx*ny];
	for (int j=0; j < ny-1; ++j) {
		for (int i=0; i < nx-1; ++i) {
			int pos = VERT(i,j);
			float z00 = z[pos], z10 = z[pos+1], z01 = z[pos+nx], z11 = z[pos+nx+1];
			// first triangle in the square
			Normal triNormal((z00-z10)*(nx-1), (z00-z01)*(ny-1), 1);
			triNormal.Normalize();
			n[pos] += triNormal;
			n[pos+1] += triNormal;
			n[pos+nx] += triNormal;
			n[pos+nx+1] += triNormal;
			// second triangle in the square
			triNormal = Normal((z01-z11)*(nx-1), (z10-z11)*(ny-1), 1);
			triNormal.Normalize();
			n[pos] += triNormal;
			n[pos+1] += triNormal;
			n[pos+nx] += triNormal;
			n[pos+nx+1] += triNormal;
		}
	}
	for (int i=0; i < nx*ny; ++i)
	   n[i].Normalize();
	// now double-check that the vertex normals are consistent with triangle normals
	bool consistent;
	int iterations = 0;
	do {
		consistent = true;
		++iterations;
		for (int j=0; j < ny-1; ++j) {
			for (int i=0; i < nx-1; ++i) {
				int pos = VERT(i,j);
				float z00 = z[pos], z10 = z[pos+1], z01 = z[pos+nx], z11 = z[pos+nx+1];
				// first triangle in the square
				Normal triNormal((z00-z10)*(nx-1), (z00-z01)*(ny-1), 1);
				triNormal.Normalize();
				if (Dot(n[pos],triNormal) <= 0) {
					consistent = false;
					n[pos] += .1*triNormal;
				}
				if (Dot(n[pos+1],triNormal) <= 0) {
					consistent = false;
					n[pos+1] += .1*triNormal;
				}
				if (Dot(n[pos+nx],triNormal) <= 0) {
					consistent = false;
					n[pos+nx] += .1*triNormal;
				}
				if (Dot(n[pos+nx+1],triNormal) <= 0) {
					consistent = false;
					n[pos+nx+1] += .1*triNormal;
				}
				// second triangle in the square
				triNormal = Normal((z01-z11)*(nx-1), (z10-z11)*(ny-1), 1);
				triNormal.Normalize();
				if (Dot(n[pos],triNormal) <= 0) {
					consistent = false;
					n[pos] += .1*triNormal;
				}
				if (Dot(n[pos+1],triNormal) <= 0) {
					consistent = false;
					n[pos+1] += .1*triNormal;
				}
				if (Dot(n[pos+nx],triNormal) <= 0) {
					consistent = false;
					n[pos+nx] += .1*triNormal;
				}
				if (Dot(n[pos+nx+1],triNormal) <= 0) {
					consistent = false;
					n[pos+nx+1] += .1*triNormal;
				}
			}
		}
		if (!consistent && iterations%5 == 0)
			Info("Heightfield: adjusting vertex normals to be consistent (iteration %d)\n", iterations);
	} while (!consistent);
	// and renormalize
	for (int i=0; i < nx*ny; ++i)
	   n[i].Normalize();
}


void Heightfield::ConstructSmoothApproximation (int iterations)
{
	smoothz = new float[(nx-1)*(ny-1)];
	for (int j=0; j < ny-1; ++j)
		for (int i=0; i < nx-1; ++i)
			smoothz[CELL(i,j)] = .25*(z[VERT(i,j)]+z[VERT(i+1,j)]+z[VERT(i,j+1)]+z[VERT(i+1,j+1)]);
	for (int its = 0; its < iterations; ++its) {
		for (int j=0; j < ny-1; ++j) {
			smoothz[CELL(0,j)] = .5*(smoothz[CELL(0,j)]+smoothz[CELL(1,j)]);
			for (int i=1; i < nx-2; ++i)
				smoothz[CELL(i,j)] = .333333333333333333333333*(smoothz[CELL(i-1,j)]+smoothz[CELL(i,j)]+smoothz[CELL(i+1,j)]);
			smoothz[CELL(nx-1,j)] = .5*(smoothz[CELL(nx-2,j)]+smoothz[CELL(nx-1,j)]);
		}
		for (int i=0; i < nx-1; ++i)
			smoothz[CELL(i,0)] = .5*(smoothz[CELL(i,0)]+smoothz[CELL(i,1)]);
		for (int j=1; j < ny-2; ++j)
			for (int i=0; i < nx-1; ++i)
				smoothz[CELL(i,j)] = .333333333333333333333333*(smoothz[CELL(i,j-1)]+smoothz[CELL(i,j)]+smoothz[CELL(i,j+1)]);
		for (int i=0; i < nx-1; ++i)
			smoothz[CELL(i,ny-2)] = .5*(smoothz[CELL(i,ny-3)]+smoothz[CELL(i,ny-2)]);
	}
	// now find bounds on the approximation
	maxSmoothError = 0;
	for (int j=0; j < ny-1; ++j) {
		for (int i=0; i < nx-1; ++i) {
			maxSmoothError = max (maxSmoothError, Abs(smoothz[CELL(i,j)]-z[VERT(i,j)]));
			maxSmoothError = max (maxSmoothError, Abs(smoothz[CELL(i,j)]-z[VERT(i+1,j)]));
			maxSmoothError = max (maxSmoothError, Abs(smoothz[CELL(i,j)]-z[VERT(i,j+1)]));
			maxSmoothError = max (maxSmoothError, Abs(smoothz[CELL(i,j)]-z[VERT(i+1,j+1)]));
		}
	}
	maxSmoothGradX = 0;
	for (int j=0; j < ny-1; ++j)
		for (int i=1; i < nx-1; ++i)
			maxSmoothGradX = max (maxSmoothGradX, Abs(smoothz[CELL(i,j)]-smoothz[CELL(i-1,j)]));
	maxSmoothGradY = 0;
	for (int j=1; j < ny-1; ++j)
		for (int i=0; i < nx-1; ++i)
			maxSmoothGradY = max (maxSmoothGradY, Abs(smoothz[CELL(i,j)]-smoothz[CELL(i,j-1)]));
}


Heightfield::~Heightfield()
{
	delete[] z;
	delete[] n;
	delete[] smoothz;
}


BBox Heightfield::Bound() const
{
	return BBox(Point(0,0,minz), Point(1,1,maxz));
}


bool Heightfield::CanIntersect() const
{
	return true;
}


bool Heightfield::PrepareLocalRay (const Ray &worldray, Ray &ray, Float &invDx, Float &invDy) const
{
	ray = WorldToObject(worldray);
	// scale so that object is on integer grid
	ray.O.x *= nx-1;
	ray.D.x *= nx-1;
	ray.O.y *= ny-1;
	ray.D.y *= ny-1;
	// advance to where we enter bounding box
	if (ray.O.x >= 0 && ray.O.x <= nx-1 && ray.O.y >= 0 && ray.O.y <= ny-1 && ray.O.z >= minz && ray.O.z <= maxz) {
		invDx = 1./ray.D.x;
		invDy = 1./ray.D.y;
	} else { // otherwise we clip to the bounding box...
		invDx = 1./ray.D.x; 						// check x
		Float tNear = -ray.O.x * invDx;
		Float tFar = (nx-1 - ray.O.x) * invDx;
		if (tNear > tFar) swap(tNear, tFar);
		ray.mint = max(tNear, ray.mint);
		ray.maxt = min(tFar, ray.maxt);
		if (ray.mint > ray.maxt) return false;
		invDy = 1./ray.D.y;							// check y
		tNear = -ray.O.y * invDy;
		tFar = (ny-1 - ray.O.y) * invDy;
		if (tNear > tFar) swap(tNear, tFar);
		ray.mint = max(tNear, ray.mint);
		ray.maxt = min(tFar, ray.maxt);
		if (ray.mint > ray.maxt) return false;
		Float invRayDir = 1. / ray.D.z; 			// check z
		tNear = (minz - ray.O.z) * invRayDir;
		tFar  = (maxz - ray.O.z) * invRayDir;
		if (tNear > tFar) swap(tNear, tFar);
		ray.mint = max(tNear, ray.mint);
		ray.maxt = min(tFar, ray.maxt);
		if (ray.mint > ray.maxt) return false;
	}
	return true;
}


bool Heightfield::StepUsingSmooth (int &i, int &j, Ray &ray) const
{
	for (;;) {
		Float distFromSmooth = ray.O.z + ray.mint*ray.D.z - smoothz[CELL(i,j)];
		if (distFromSmooth > 1.5*maxSmoothError || distFromSmooth < -1.5*maxSmoothError) {
			Float maxChange= Sign(ray.D.z)*(ray.D.z - Sign(distFromSmooth)*(maxSmoothGradX*Abs(ray.D.x) + maxSmoothGradY*Abs(ray.D.y)));
			if (maxChange <= 0) return false;
			ray.mint += (Abs(distFromSmooth) - maxSmoothError)/maxChange;
			i = int(ray.O.x + ray.mint*ray.D.x);
			if (i < 0 || i >= nx-1) return false;
			j = int(ray.O.y + ray.mint*ray.D.y);
			if (j < 0 || j >= ny-1) return false;
		} else
			return true;
	}
}


bool Heightfield::Intersect (const Ray &r, DifferentialGeometry *dg) const
{
	Ray ray;
	Float invDx, invDy;
	if (!PrepareLocalRay (r, ray, invDx, invDy)) return false;

	// now trace through the grid
	Float x = ray.O.x + ray.mint*ray.D.x;
	Float y = ray.O.y + ray.mint*ray.D.y;
	int i = int(x), j = int(y);
	if (ray.D.x >= 0) {
		if (ray.D.y >= 0) {
			while (i < nx-1 && j < ny-1) { // increasing i and increasing j //////////////////////////////
				if (!StepUsingSmooth (i, j, ray)) return false;
				if (LocalIntersect (i, j, ray, dg)) {
					r.maxt = ray.maxt;
					return true;
				}
				if ((i+1-x)*invDx <= (j+1-y)*invDy) {
					++i;
					ray.mint = (i-ray.O.x)*invDx;
					if (ray.mint > ray.maxt) return false;
					x = i;
					y = ray.O.y + ray.mint*ray.D.y;
				} else {
					++j;
					ray.mint = (j-ray.O.y)*invDy;
					if (ray.mint > ray.maxt) return false;
					y = j;
					x = ray.O.x + ray.mint*ray.D.x;
				}
			}
		} else {
			while (i < nx-1 && j >= 0) { // increasing i and decreasing j ////////////////////////////////
				if (!StepUsingSmooth (i, j, ray)) return false;
				if (LocalIntersect (i, j, ray, dg)) {
					r.maxt = ray.maxt;
					return true;
				}
				if ((i+1-x)*invDx <= (j-y)*invDy) {
					++i;
					ray.mint = (i-ray.O.x)*invDx;
					if (ray.mint > ray.maxt) return false;
					x = i;
					y = ray.O.y + ray.mint*ray.D.y;
				} else {
					ray.mint = (j-ray.O.y)*invDy;
					if (ray.mint > ray.maxt) return false;
					y = j;
					x = ray.O.x + ray.mint*ray.D.x;
					--j;
				}
			}
		}
	} else {
		if (ray.D.y >= 0) {
			while (i >= 0 && j < ny-1) { // decreasing i and increasing j ////////////////////////////////
				if (!StepUsingSmooth (i, j, ray)) return false;
				if (LocalIntersect (i, j, ray, dg)) {
					r.maxt = ray.maxt;
					return true;
				}
				if ((i-x)*invDx <= (j+1-y)*invDy) {
					ray.mint = (i-ray.O.x)*invDx;
					if (ray.mint > ray.maxt) return false;
					x = i;
					y = ray.O.y + ray.mint*ray.D.y;
					--i;
				} else {
					++j;
					ray.mint = (j-ray.O.y)*invDy;
					if (ray.mint > ray.maxt) return false;
					y = j;
					x = ray.O.x + ray.mint*ray.D.x;
				}
			}
		} else {
			while (i >= 0 && j >= 0) { // decreasing i and decreasing j /////////////////////////////////
				if (!StepUsingSmooth (i, j, ray)) return false;
				if (LocalIntersect (i, j, ray, dg)) {
					r.maxt = ray.maxt;
					return true;
				}
				if ((i-x)*invDx <= (j-y)*invDy) {
					ray.mint = (i-ray.O.x)*invDx;
					if (ray.mint > ray.maxt) return false;
					x = i;
					y = ray.O.y + ray.mint*ray.D.y;
					--i;
				} else {
					ray.mint = (j-ray.O.y)*invDy;
					if (ray.mint > ray.maxt) return false;
					y = j;
					x = ray.O.x + ray.mint*ray.D.x;
					--j;
				}
			}
		}
	}
	return false;
}


bool Heightfield::IntersectP(const Ray &r) const
{
	Ray ray;
	Float invDx, invDy;
	if (!PrepareLocalRay (r, ray, invDx, invDy)) return false;

	// now trace through the grid
	Float x = ray.O.x + ray.mint*ray.D.x;
	Float y = ray.O.y + ray.mint*ray.D.y;
	int i = int(x), j = int(y);
	if (ray.D.x >= 0) {
		if (ray.D.y >= 0) {
			while (i < nx-1 && j < ny-1) { // increasing i and increasing j //////////////////////////////
				if (!StepUsingSmooth (i, j, ray)) return false;
				if (LocalIntersectP (i, j, ray)) return true;
				if ((i+1-x)*invDx <= (j+1-y)*invDy) {
					++i;
					ray.mint = (i-ray.O.x)*invDx;
					if (ray.mint > ray.maxt) return false;
					x = i;
					y = ray.O.y + ray.mint*ray.D.y;
				} else {
					++j;
					ray.mint = (j-ray.O.y)*invDy;
					if (ray.mint > ray.maxt) return false;
					y = j;
					x = ray.O.x + ray.mint*ray.D.x;
				}
			}
		} else {
			while (i < nx-1 && j >= 0) { // increasing i and decreasing j ////////////////////////////////
				if (!StepUsingSmooth (i, j, ray)) return false;
				if (LocalIntersectP (i, j, ray)) return true;
				if ((i+1-x)*invDx <= (j-y)*invDy) {
					++i;
					ray.mint = (i-ray.O.x)*invDx;
					if (ray.mint > ray.maxt) return false;
					x = i;
					y = ray.O.y + ray.mint*ray.D.y;
				} else {
					ray.mint = (j-ray.O.y)*invDy;
					if (ray.mint > ray.maxt) return false;
					y = j;
					x = ray.O.x + ray.mint*ray.D.x;
					--j;
				}
			}
		}
	} else {
		if (ray.D.y >= 0) {
			while (i >= 0 && j < ny-1) { // decreasing i and increasing j ////////////////////////////////
				if (!StepUsingSmooth (i, j, ray)) return false;
				if (LocalIntersectP (i, j, ray)) return true;
				if ((i-x)*invDx <= (j+1-y)*invDy) {
					ray.mint = (i-ray.O.x)*invDx;
					if (ray.mint > ray.maxt) return false;
					x = i;
					y = ray.O.y + ray.mint*ray.D.y;
					--i;
				} else {
					++j;
					ray.mint = (j-ray.O.y)*invDy;
					if (ray.mint > ray.maxt) return false;
					y = j;
					x = ray.O.x + ray.mint*ray.D.x;
				}
			}
		} else {
			while (i >= 0 && j >= 0) { // decreasing i and decreasing j /////////////////////////////////
				if (!StepUsingSmooth (i, j, ray)) return false;
				if (LocalIntersectP (i, j, ray)) return true;
				if ((i-x)*invDx <= (j-y)*invDy) {
					ray.mint = (i-ray.O.x)*invDx;
					if (ray.mint > ray.maxt) return false;
					x = i;
					y = ray.O.y + ray.mint*ray.D.y;
					--i;
				} else {
					ray.mint = (j-ray.O.y)*invDy;
					if (ray.mint > ray.maxt) return false;
					y = j;
					x = ray.O.x + ray.mint*ray.D.x;
					--j;
				}
			}
		}
	}
	return false;
}


bool Heightfield::LocalIntersect(int i, int j, const Ray &ray, DifferentialGeometry *dg) const
{
	Assert(0<=i && i<nx-1 && 0<=j && j<ny-1);
	const Float Ox=ray.O.x-i, Oy=ray.O.y-j;
	const Float Dx=ray.D.x, Dy=ray.D.y;
	const int pos=VERT(i,j);
	const Float z00 = z[pos], z10 = z[pos+1], z01 = z[pos+nx];
	const Float zd0 = z10-z00, z0d = z01-z00;
	Float numer, denom, t, x, y;
	bool intersect=false;
	//Normal trinorm;

	// intersect with first triangle
	numer = ray.O.z - z00 - Ox*zd0 - Oy*z0d;
	denom = -ray.D.z + Dx*zd0 + Dy*z0d;
	if ((denom > 0 && numer > ray.mint*denom && numer <= ray.maxt*denom)
		|| (denom < 0 && numer < ray.mint*denom && numer >= ray.maxt*denom)) {
		t = numer/denom;
		x = Ox + t*Dx;
		y = Oy + t*Dy;
		if (x > -1e-6 && y > -1e-6 && x+y < 1+1e-6) {
			intersect = true;
			ray.maxt = t;
			//trinorm = Normal(-zd0*(nx-1),-z0d*(ny-1),1).Hat();
		}
	}

	// intersect with second triangle
	const Float z11 = z[pos+nx+1];
	const Float zd1 = z11-z01, z1d = z11-z10;
	numer = ray.O.z + z1d - z01 - Ox*zd1 - Oy*z1d;
	denom = -ray.D.z + Dx*zd1 + Dy*z1d;
	if ((denom > 0 && numer > ray.mint*denom && numer <= ray.maxt*denom)
		|| (denom < 0 && numer < ray.mint*denom && numer >= ray.maxt*denom)) {
		Float newt = numer/denom;
		Float newx = Ox + newt*Dx;
		Float newy = Oy + newt*Dy;
		if (newx < 1+1e-6 && newy < 1+1e-6 && newx+newy > 1-1e-6) {
			t = newt;
			x = newx;
			y = newy;
			intersect = true;
			ray.maxt = newt;
			//trinorm = Normal(-zd0*(nx-1),-z0d*(ny-1),1).Hat();
		}
	}

	// now handle the intersection if there was one
	if (intersect) {
		Float realx = dx*(x+i), realy = dy*(y+j);
		Float z = ray.O.z + t*ray.D.z;
		dg->P = ObjectToWorld (Point (realx, realy, z));
		Normal objectN = ((1-x)*(1-y))*n[pos]+(x*(1-y))*n[pos+1]+((1-x)*y)*n[pos+nx]+(x*y)*n[pos+nx+1];
		// + 40*x*(1-x)*y*(1-y)*Abs(1-x-y)*trinorm;
		//if (Dot(objectN,trinorm) <= 0) {
		//	printf("normal problem %g:   (%f, %f, %f)  versus triangle  (%f, %f, %f)\n",Dot(objectN,trinorm),
		//		objectN.x,objectN.y,objectN.z,trinorm.x,trinorm.y,trinorm.z);
		//	objectN = trinorm;
		//}
		dg->N = ObjectToWorld(objectN).Hat();
		dg->S = ObjectToWorld (Vector (objectN.z, 0, -objectN.x)).Hat();
		dg->T = Cross (dg->N, dg->S);
		dg->u = realx;
		dg->v = realy;
		return true;
	} else
		return false;
}


bool Heightfield::LocalIntersectP(int i, int j, const Ray &ray) const
{
	Assert(0<=i && i<nx-1 && 0<=j && j<ny-1);
	const Float Ox=ray.O.x-i, Oy=ray.O.y-j;
	const Float Dx=ray.D.x, Dy=ray.D.y;
	const int pos=VERT(i,j);
	const Float z00 = z[pos], z10 = z[pos+1], z01 = z[pos+nx];
	const Float zd0 = z10-z00, z0d = z01-z00;
	Float numer, denom;

	// intersect with first triangle
	numer = ray.O.z - z00 - Ox*zd0 - Oy*z0d;
	denom = -ray.D.z + Dx*zd0 + Dy*z0d;
	if ((denom > 0 && numer > ray.mint*denom && numer <= ray.maxt*denom)
		|| (denom < 0 && numer < ray.mint*denom && numer >= ray.maxt*denom)) {
		const Float t = numer/denom;
		const Float x = Ox+t*Dx;
		const Float y = Oy+t*Dy;
		if (x > -1e-6 && y > -1e-6 && x+y < 1+1e-6)
			return true;
	}

	// intersect with second triangle
	const Float z11 = z[pos+nx+1];
	const Float zd1 = z11-z01, z1d = z11-z10;
	numer = ray.O.z + z1d - z01 - Ox*zd1 - Oy*z1d;
	denom = -ray.D.z + Dx*zd1 + Dy*z1d;
	if ((denom > 0 && numer > ray.mint*denom && numer <= ray.maxt*denom)
		|| (denom < 0 && numer < ray.mint*denom && numer >= ray.maxt*denom)) {
		const Float t = numer/denom;
		const Float x = Ox+t*Dx;
		const Float y = Oy+t*Dy;
		if (x < 1+1e-6 && y < 1+1e-6 && x+y > 1-1e-6)
			return true;
	}

	return false; // intersected neither triangle
}

