#include "heightfield.h"

Heightfield::Heightfield(const Transform &o2w, int x, int y,
		float *zs)
	: Shape(o2w) {
	nx = x;
	ny = y;
	z = new float[nx*ny];
	ns = new Vector[nx*ny];
	for (int i=0; i<nx*ny; i++)
	  {
	    ns[i].x = 0;
	    ns[i].y = 0;
	    ns[i].z = 0;
	  }
	//	printf("got here\n");
	memcpy(z, zs, (nx*ny*sizeof(float)));
	BBox myBound = Bound();
	bounds = new BBox(myBound.pMin, myBound.pMax);
	//printf("got here\n");
}
Heightfield::~Heightfield() {
	delete[] z;
}
BBox Heightfield::Bound() const {
	float minz = z[0], maxz = z[0];
	for (int i = 1; i < nx*ny; ++i) {
		if (z[i] < minz) minz = z[i];
		if (z[i] > maxz) maxz = z[i];
	}
	return BBox(Point(0,0,minz), Point(1,1,maxz));
}
bool Heightfield::CanIntersect() const {
	return true;
}

Vector Heightfield::GetPoint(int x, int y) const
{
  Vector vec((Float)x/(nx-1), (Float)y/(ny-1), z[x+(y*nx)]);
  //  printf("Got vector\n");
  return vec;
}


Vector Heightfield::GrabNormal(int x, int y) const
{
  Vector result(0,0,0);
  if (fabs(ns[x+(y*nx)].Length())>0.001) 
    {
      // printf("Cache hit!\n");
      return ns[x+(y*nx)];
    }
  // printf("Cache miss :(\n");
  result+=(Cross(GetPoint(x-1, y)-GetPoint(x,y), 
		GetPoint(x-1, y-1)-GetPoint(x-1,y))).Hat();
  result+=(Cross(GetPoint(x-1, y-1)-GetPoint(x,y-1), 
		GetPoint(x, y-1)-GetPoint(x,y))).Hat();
  result+=(Cross(GetPoint(x, y-1)-GetPoint(x,y),
		GetPoint(x+1, y)-GetPoint(x,y))).Hat();
  result+=(Cross(GetPoint(x+1, y)-GetPoint(x,y),
		GetPoint(x+1,y+1)-GetPoint(x+1,y))).Hat();
  result+=(Cross(GetPoint(x+1,y+1)-GetPoint(x,y+1),
		GetPoint(x,y+1)-GetPoint(x,y))).Hat();
  result+=(Cross(GetPoint(x,y+1)-GetPoint(x,y),
		GetPoint(x-1,y)-GetPoint(x,y))).Hat();

  result/=6.0;

  result.Normalize();
  ns[x+(y*nx)] = result;

  return result;
}

bool Heightfield::VisitedCell(const Ray &ray, DifferentialGeometry *dg, int x,
int y) const
{
  Point verts[4];
  int vi[6];
  Float height1, height2, height3, height4;
  bool result1 = false, result2= false;
  Vector normal1, normal2, normal3;
  Float u, v;

  //  if (x<20 || x>25) return false;

  //printf("checking cell %d %d\n", x, y);

  height1 = z[x+(y*nx)];
  height2 = z[(x+1)+(y*nx)];
  height3 = z[x+((y+1)*nx)];
  height4 = z[(x+1)+((y+1)*nx)];

  verts[0] = Point((Float)x/(nx-1), (Float)y/(ny-1), height1);
  verts[1] = Point((Float)(x+1)/(nx-1), (Float)y/(ny-1), height2);
  verts[2] = Point((Float)x/(nx-1), (Float)(y+1)/(ny-1), height3);
  verts[3] = Point((Float)(x+1)/(nx-1), (Float)(y+1)/(ny-1), height4);

  BBox gridBox1(verts[0], verts[3]);
  BBox gridBox2(verts[1], verts[2]);
  
  Ray rt = WorldToObject(ray);

  if ((!gridBox1.IntersectP(rt)) && 
      (!gridBox2.IntersectP(rt))) return false;

  //  printf("check for real\n");

  /* printf("Vertex set:\n");
  printf("%g %g %g\n", verts[0].x, verts[0].y, verts[0].z);
  printf("%g %g %g\n", verts[1].x, verts[1].y, verts[1].z);
  printf("%g %g %g\n", verts[2].x, verts[2].y, verts[2].z);*/
  vi[0] = 0;
  vi[1] = 1;
  vi[2] = 3;
  vi[3] = 0;
  vi[4] = 3;
  vi[5] = 2;
  TriangleMesh tmesh (ObjectToWorld, 2, 4, vi, verts);
  Triangle tri1(ObjectToWorld, &tmesh, 0);
  Triangle tri2(ObjectToWorld, &tmesh, 1);
  if (tri1.Intersect(ray, dg))
    {    
      result1 = true;
    }
  if (tri2.Intersect(ray, dg))
    {
      result1 = false;
      result2 = true;
    }

  if (result1)
    {
      //      printf("interpolating\n");
      *dg = WorldToObject(*dg);
      normal1 = GrabNormal(x,y);
      normal2 = GrabNormal(x+1,y);
      normal3 = GrabNormal(x+1,y+1);
      
      u = (GetPoint(x+1, y).x - dg->P.x)/(1.0/(nx-1));
      v = (dg->P.y - GetPoint(x+1, y).y)/(1.0/(ny-1));
      //    printf("barycentric: u: %g, v: %g\n", u, v);
      
      dg->N = Normal((1-v)*((1-u)*normal2 + (u*normal1)) + (v*normal3));
      dg->u = dg->P.x;
      dg->v = dg->P.y;

      *dg = ObjectToWorld(*dg);

      dg->N.Normalize();  
    }
  else if (result2)
    {
      //printf("interpolating\n");
      *dg = WorldToObject(*dg);
      normal1 = GrabNormal(x,y);
      normal2 = GrabNormal(x+1, y+1);
      normal3 = GrabNormal(x, y+1);

      u = (dg->P.x - GetPoint(x, y+1).x)/(1.0/(nx-1));
      v = (GetPoint(x, y+1).y - dg->P.y)/(1.0/(ny-1));      
      
      dg->u = dg->P.x;
      dg->v = dg->P.y;

      dg->N = Normal((1-v)*((1-u)*normal3 + (u*normal2)) + (v*normal1));
      *dg = ObjectToWorld(*dg);
      
      dg->N.Normalize();
    }

  return (result1||result2);

}

bool Heightfield::TraverseHeightfield(Ray &ray, const Ray &origRay, DifferentialGeometry *dg)
const
{
  Float s1 = (ray.D.y*(ny-1))/(ray.D.x*(nx-1));
  Float e;
  int x = (int)(ray(ray.mint).x * (nx-1));
  int y = (int)(ray(ray.mint).y * (ny-1));

  //printf("start x and start y: %d %d\n", x, y);

  Float x0 = ray(ray.mint).x * (nx-1);
  Float y0 = ray(ray.mint).y * (ny-1);

  Float tstep;
  Point rayPt;

   if (fabs(s1)<=1)
    {
      //shallow slope
      
      e = y0 + (s1*(x-x0));

      do
	{
	  if (VisitedCell(origRay, dg, x, y)) return true;
	  e+=s1;
	  if (ray.D.y>0)
	    {
	      if (e>=y+1.0) 
		{
		  y++;
		  if (e>y) 
		    {
		      if (VisitedCell(origRay, dg, x, y)) return true;
		    }
		}
	    }
	  else
	    {
	      if (e<=y)
		{
		  y--;
		  if (e<(y+1.0)) 
		    {
		      if (VisitedCell(origRay, dg, x, y)) return true;
		    }
		}
	    }
	  if (ray.D.x>0) x++;
	  else x--;
	  tstep = (((Float)x/(nx-1)) - ray.O.x)/ray.D.x;
	  rayPt = ray(tstep);
	}
      while (bounds->Inside(rayPt));
      //	while(x>=0 && x<nx && y>=0 && y<ny);
    }
  else
    {
       //steep slope
      s1 = 1.0/s1;
      e = x0 + (s1*(y-y0));
      
      do
	{
	  if (VisitedCell(origRay, dg, x, y)) return true;
	  e+=s1;
	  if (ray.D.x>0)
	    {
	      if (e>=x+1.0) 
		{
		  //printf("Comparing e to x. E=%g, x=%d\n", e, x);
		  x++;
		  if (e>x) 
		    {
		      if (VisitedCell(origRay, dg, x, y)) return true;
		    }
		}
	    }
	  else
	    {
	      //printf("s1 = %g, e = %g, x = %d, y = %d\n", s1, e, x, y);
	      if (e<=x) 
		{
		  x--;
		  if (e<(x+1.0)) 
		    {
		      if (VisitedCell(origRay, dg, x, y)) return true;
		    }
		}
	    }
	  if (ray.D.y>0) y++;
	  else y--;
	  tstep = (((Float)y/(ny-1)) - ray.O.y)/ray.D.y;
	  rayPt = ray(tstep);
	  //	  printf("tstep = %g, rayPt = %g %g %g\n", tstep, rayPt.x,
	  //	 rayPt.y, rayPt.z);
	  }
      while (bounds->Inside(rayPt));
      //while (x>=0 && x<nx && y>=0 && y<ny);
      //    printf("This ray visited %d cells\n", counter);
    }
  return false;
}

bool Heightfield::Intersect(const Ray &ray, DifferentialGeometry *dg)
const
{
  // printf("got here\n");
  //printf("ray origin before: %g %g %g\n", ray.O.x, ray.O.y, ray.O.z);
  Ray rt = WorldToObject(ray);
  //printf("ray origin after: %g %g %g\n", rt.O.x, rt.O.y, rt.O.z); 
  Float rayT = rt.maxt;

  if (bounds->Inside(rt(rt.mint)))
    {
      //printf("starting inside");
      rayT = rt.mint;
    }
  else {
    Ray rb = rt;
    //  printf("outside boxy");
    //   printf("rb maxt before intersect = %g\n", rb.maxt);
    if (!bounds->IntersectP(rb))
      return false;
    rayT = rb.maxt;
    //printf("rb maxt before intersect = %g\n", rt.maxt);  
    //printf("rb maxt after intersect = %g\n", rb.maxt);
  }
  //  printf("got here");

  rt.mint = rayT;

  //printf("mint = %g %g %g\n", rt(rt.mint).x, rt(rt.mint).y, rt(rt.mint).z);
  return (TraverseHeightfield(rt, ray, dg));
 }

bool Heightfield::IntersectP(const Ray &ray) const
{
  DifferentialGeometry dg;
  return (Intersect(ray, &dg));
}

void Heightfield::Refine(vector<Shape *> &refined) const {
	int ntris = 2*(nx-1)*(ny-1);
	int *verts = new int[3*ntris];
	Point *P = new Point[nx*ny];
	int x, y;
	P = new Point[nx*ny];
	int pos = 0;
	for (y = 0; y < ny; ++y) {
		for (x = 0; x < nx; ++x) {
			P[pos].x = (float)x / (float)(nx-1);
			P[pos].y = (float)y / (float)(ny-1);
			P[pos].z = z[pos];
			++pos;
		}
	}
	int *vp = verts;
	for (y = 0; y < ny-1; ++y) {
		for (x = 0; x < nx-1; ++x) {
	#define VERT(x,y) ((x)+(y)*nx)
			*vp++ = VERT(x, y);
			*vp++ = VERT(x+1, y);
			*vp++ = VERT(x+1, y+1);
	
			*vp++ = VERT(x, y);
			*vp++ = VERT(x+1, y+1);
			*vp++ = VERT(x, y+1);
		}
	#undef VERT
	}
	refined.push_back(new TriangleMesh(ObjectToWorld, ntris,
		nx*ny, verts, P));
	delete[] P;
	delete[] verts;
}
