#include "heightfield.h"

extern float heightMinZ;
extern Transform* transformHeightFromWorld;


Heightfield::Heightfield(const Transform &o2w, int x, int y,
		float *zs)
	: Shape(o2w) {
	nx = x;
	ny = y;
	z = new float[nx*ny];
	memcpy(z, zs, nx*ny*sizeof(float));

#ifdef HW2

	int i, j, k;

	InvNumX = 1./(nx - 1);
	InvNumY = 1./(ny - 1);

	heightField = new (vertexInfo*)[nx];
	for(i=0; i<nx; i++) { 
	  heightField[i] = new vertexInfo[ny];
	}

	minz = z[0];
	maxz = z[0];
	for(i=0; i<nx; i++) {
	  for(j=0; j<ny; j++) {
	    k = nx*j + i;
	    heightField[i][j].P = Point(InvNumX*i, InvNumY*j, z[k]);
	    heightField[i][j].N = Normal(0,0,0);
	    if (z[k] < minz) minz = z[k];
	    if (z[k] > maxz) maxz = z[k];
	  }
	}

	/*Normal Computation at vertices*/	
	for(i=0; i<nx-1; i++) {
	  for(j=0; j<ny-1; j++) {
	    Vector V1,V2;
	    Normal N;
	    V1 = heightField[i+1][j].P - heightField[i][j].P;
	    V2 = heightField[i+1][j+1].P - heightField[i][j].P;
	    N = Normal(Cross(V2,V1));
	    N /= 6;
	    heightField[i][j].N += N;
	    heightField[i+1][j].N += N;
	    heightField[i+1][j+1].N += N;
	    
	    V1 = V2;
	    V2 = heightField[i][j+1].P - heightField[i][j].P;
	    N = Normal(Cross(V2,V1));
	    N/=6;
	    heightField[i][j].N += N;
	    heightField[i][j+1].N += N;
	    heightField[i+1][j+1].N += N;
	  }
	}
#endif

#ifdef PROJECT
	transformHeightFromWorld= new Transform(WorldToObject);
	heightMinZ= (Bound()).pMin.z;
#endif

}

Heightfield::~Heightfield() {
	delete[] z;

#ifdef HW2
	for (int i=0; i<nx; i++)
	  delete[] heightField[i];
	delete[] heightField;
#endif
}

BBox Heightfield::Bound() const {
#ifdef HW2
	return BBox(Point(0,0,minz), Point(1,1,maxz));
#else
	float minz = z[0], maxz = z[0];
	for (int i = 1; i < nx*ny; ++i) {
		if (z[i] < minz) minz = z[i];
		if (z[i] > maxz) maxz = z[i];
	}
	return BBox(Point(0,0,minz), Point(1,1,maxz));
#endif
}

#ifdef HW2
bool Heightfield::Intersect(const Ray &r, DifferentialGeometry *dg) const {
  Ray ray = WorldToObject(r);

  /*Checks for intersection against the overall grid bound*/
  Float rayT = ray.maxt;
  if(Bound().Inside(ray(ray.mint)))
    rayT = ray.mint;
  else {
    Ray rb = ray;
    if (!Bound().IntersectP(rb))
      return false;
    rayT = rb.maxt;
  }
  Point gridIntersect = ray(rayT);

  /*Setting up for stepping through the grid*/
  Float invRayDirX = 1./ray.D.x;
  int x = int(gridIntersect.x*(nx-1));
  if (x == (nx-1)) x--;
  Assert(x >= 0 && x < (nx-1));
  
  Float NextXCrossing, DeltaX;
  int StepX, OutX;
  if (fabs(ray.D.x) < 1e-6) {
    NextXCrossing = INFINITY;
    DeltaX = 0;
    OutX = -1;
  }
  else if (ray.D.x > 0) {
    NextXCrossing = rayT + ((x+1)*InvNumX - gridIntersect.x)*invRayDirX;
    DeltaX = InvNumX*invRayDirX;
    StepX = 1;
    OutX = nx - 1;
  }
  else {
    NextXCrossing = rayT + (x*InvNumX - gridIntersect.x)*invRayDirX;
    DeltaX = - InvNumX*invRayDirX;
    StepX = -1;
    OutX = -1;
  }

  Float invRayDirY = 1./ray.D.y;
  int y = int(gridIntersect.y*(ny-1));
  if (y == (ny-1)) y--;
  Assert(y >= 0 && y < (ny-1));
  
  Float NextYCrossing, DeltaY;
  int StepY, OutY;
  if (fabs(ray.D.y) < 1e-6) {
    NextYCrossing = INFINITY;
    DeltaY = 0;
    OutY = -1;
  }
  else if (ray.D.y > 0) {
    NextYCrossing = rayT + ((y+1)*InvNumY - gridIntersect.y)*invRayDirY;
    DeltaY = InvNumY*invRayDirY;
    StepY = 1;
    OutY = ny - 1;
  }
  else {
    NextYCrossing = rayT + (y*InvNumY - gridIntersect.y)*invRayDirY;
    DeltaY = - InvNumY*invRayDirY;
    StepY = -1;
    OutY = -1;
  }

  /*Checking for intersections inside the grid, and moving along the grid
    as well*/
  bool hitSomething = false;
  for (;;) {
    vertexInfo* Triangle[3];

    Assert(x >= 0 && x < (nx-1));
    Assert(y >= 0 && y < (ny-1));

    for (int i=0; i<2; i++) {
      if (i==0) {
	Triangle[0] = &(heightField[x][y]);
	Triangle[1] = &(heightField[x+1][y]);
	Triangle[2] = &(heightField[x+1][y+1]);
      }
      else {
	Triangle[0] = &(heightField[x][y]);
	Triangle[1] = &(heightField[x+1][y+1]);
	Triangle[2] = &(heightField[x][y+1]);
      }
      
      Vector E1 = Triangle[1]->P - Triangle[0]->P;
      Vector E2 = Triangle[2]->P - Triangle[0]->P;
      Vector S_1 = Cross(ray.D, E2);
      Float divisor = Dot(S_1, E1);
      if (divisor != 0.) {
	Float invDivisor = 1./divisor;
	Vector T = ray.O - Triangle[0]->P;
	Float u = Dot(T, S_1) * invDivisor;
	if (u>= 0. && u<=1.0) {
	  Vector S_2 = Cross(T,E1);
	  Float v = Dot(ray.D, S_2) * invDivisor;
	  if (v>=0. && (u+v) <= 1.0) {
	    Float t = Dot(E2, S_2) * invDivisor;
	    if (t >= ray.mint && t <= ray.maxt) {
#ifdef PHONG_INTERPOLATE
	      Normal N = (1-u-v)*Triangle[0]->N + u*Triangle[1]->N + v*Triangle[2]->N;
#else
	      Normal N = Normal(Cross(E2, E1));
#endif

	      Vector S, T;
//	      S = Cross(Vector(N), Vector(0, 0, 1));
//	      T = Cross(Vector(N), S);

	      S = Vector(1, 0, 0);
	      T = Vector(0, 1, 0);

	      *dg = DifferentialGeometry(ray(t), N.Hat(), S.Hat(), T.Hat(), ray(t).x, ray(t).y);

	      *dg = ObjectToWorld(*dg);
	      ray.maxt = t;
	      r.maxt = t;
	      hitSomething = true;
	    }
	  }
	}
      }
    }

    /*Updating for next intersection in the grid*/
    if (hitSomething) break;
    if (NextXCrossing < NextYCrossing) {
      if (ray.maxt < NextXCrossing)
	break;
      x += StepX;
      if (x == OutX)
	break;
      NextXCrossing += DeltaX;
    }
    else {
      if (ray.maxt < NextYCrossing)
	break;
      y += StepY;
      if (y == OutY)
	break;
      NextYCrossing += DeltaY;
    }
  }

  return hitSomething;
}

#else

bool Heightfield::CanIntersect() const {
	return false;
}

void Heightfield::Refine(vector<Shape *> &refined) const {
	int ntris = 2*(nx-1)*(ny-1);
	int *verts = new int[3*ntris];
	Point *P = new Point[nx*ny];
	int x, y;
	P = new Point[nx*ny];
	int pos = 0;
	for (y = 0; y < ny; ++y) {
		for (x = 0; x < nx; ++x) {
			P[pos].x = (float)x / (float)(nx-1);
			P[pos].y = (float)y / (float)(ny-1);
			P[pos].z = z[pos];
			++pos;
		}
	}
	int *vp = verts;
	for (y = 0; y < ny-1; ++y) {
		for (x = 0; x < nx-1; ++x) {
	#define VERT(x,y) ((x)+(y)*nx)
			*vp++ = VERT(x, y);
			*vp++ = VERT(x+1, y);
			*vp++ = VERT(x+1, y+1);
	
			*vp++ = VERT(x, y);
			*vp++ = VERT(x+1, y+1);
			*vp++ = VERT(x, y+1);
		}
	#undef VERT
	}
	refined.push_back(new TriangleMesh(ObjectToWorld, ntris,
		nx*ny, verts, P));
	delete[] P;
	delete[] verts;
}
#endif
