#include "lrt.h"
#include "primitives.h"
#include "accel.h"
GridAccelerator::~GridAccelerator() {
	delete[] cells;
}
bool GridAccelerator::Intersect(const Ray &ray, Surf *surf, bool directFromEye) const {
	Float rayT = ray.maxt;
	if (bounds.Inside(ray(ray.mint)))
		rayT = ray.mint;
	else {
		Ray rb = ray;
		if (!bounds.IntersectP(rb))
			return false;
		rayT = rb.maxt;
	}
	Point gridIntersect = ray(rayT);
	int x = x2v(gridIntersect.x);
	if (x == XVoxels) x--;
	Assert( x >= 0 && x < XVoxels );
	Float NextXCrossing, DeltaX;
	int StepX, OutX;
	if (fabs(ray.D.x) < 1e-6) {
		NextXCrossing = INFINITY;
		DeltaX = 0;
		OutX = -1;
	}
	else if (ray.D.x > 0) {
		NextXCrossing = rayT + ( v2x( x+1 ) - gridIntersect.x )/ray.D.x;
		DeltaX = XWidth / ray.D.x;
		StepX = 1;
		OutX = XVoxels;
	}
	else {
		NextXCrossing = rayT + ( v2x( x ) - gridIntersect.x )/ray.D.x;
		DeltaX = - XWidth / ray.D.x;
		StepX = -1;
		OutX = -1;
	}
	int y = y2v( gridIntersect.y );
	if (y == YVoxels) y--;
	Float NextYCrossing, DeltaY;
	int StepY, OutY;
	
	Assert( y >= 0 && y < YVoxels );
	if (fabs(ray.D.y) < 1e-6) {
		NextYCrossing = INFINITY;
		DeltaY = 0;
		OutY = -1;
	}
	else if (ray.D.y < 0) {
		NextYCrossing = rayT + ( v2y( y ) - gridIntersect.y )/ray.D.y;
		DeltaY = - YWidth / ray.D.y;
		StepY = OutY = -1;
	}
	else {
		NextYCrossing = rayT + ( v2y( y+1 ) - gridIntersect.y )/ray.D.y;
		DeltaY = YWidth / ray.D.y;
		StepY = 1;
		OutY = YVoxels;
	}
	int z = z2v( gridIntersect.z );
	if (z == ZVoxels) z--;
	Float NextZCrossing, DeltaZ;
	int StepZ, OutZ;
	
	Assert( z >= 0 && z < ZVoxels );
	if (fabs(ray.D.z) < 1e-6) {
		NextZCrossing = INFINITY;
		DeltaZ = 0;
		OutZ = -1;
	}
	else if (ray.D.z < 0) {
		NextZCrossing = rayT + ( v2z( z ) - gridIntersect.z )/ray.D.z;
		DeltaZ = - ZWidth / ray.D.z;
		StepZ = OutZ = -1;
	}
	else {
		NextZCrossing = rayT + ( v2z( z+1 ) - gridIntersect.z )/ray.D.z;
		DeltaZ = ZWidth / ray.D.z;
		StepZ = 1;
		OutZ = ZVoxels;
	}
	bool hitSomething = false;
	int rayId = curMailboxId++;
	for (;;) {
		int offset = z*XVoxels*YVoxels + y*XVoxels + x;
		list<MailboxPrim *> *primitiveList = cells + offset;
		if (primitiveList->size() > 0) {
			list<MailboxPrim *>::iterator iter;
			for (iter = primitiveList->begin(); iter != primitiveList->end();
					++iter) {
				MailboxPrim *mp = *iter;
				if (mp->lastMailboxId == rayId)
					continue;
				mp->lastMailboxId = rayId;
				if (!mp->primitive->CanIntersect()) {
					vector<Primitive *> p;
					mp->primitive->Refine(p);
					Assert(p.size() > 0);
					if (p.size() == 1 && p[0]->CanIntersect())
						mp->primitive = p[0];
					else
						mp->primitive = new GridAccelerator(p);
				}
				if (mp->primitive->Intersect(ray, surf, directFromEye))
					hitSomething = true;
			}
		}
		if (NextXCrossing < NextYCrossing &&
			NextXCrossing < NextZCrossing) {
			if (ray.maxt < NextXCrossing)
				break;
			x += StepX;
			if (x == OutX)
				break;
			NextXCrossing += DeltaX;
		}
		else if (NextZCrossing < NextYCrossing) {
			if (ray.maxt < NextZCrossing)
				break;
			z += StepZ;
			if (z == OutZ)
				break;
			NextZCrossing += DeltaZ;
		}
		else {
			if (ray.maxt < NextYCrossing)
				break;
			y += StepY;
			if (y == OutY)
				break;
			NextYCrossing += DeltaY;
		}
	}
	return hitSomething;
}
bool GridAccelerator::IntersectP(const Ray &ray) const {
	Surf surf;
	return Intersect(ray, &surf, false);
}
HBVNode::HBVNode(const pair<Primitive *, BBox *> &prim)
		: bounds(*prim.second) {
	isLeaf = true;
	primitive = prim.first;
}
HBVNode::HBVNode(HBVNode *child1, HBVNode *child2)
		: bounds(Union(child1->bounds, child2->bounds)) {
	isLeaf = false;
	children[0] = child1;
	children[1] = child2;
}
HBVNode::~HBVNode() {
	if (!isLeaf) {
		delete children[0];
		delete children[1];
	}
}
HBVAccelerator::HBVAccelerator(const vector<Primitive *> &p)
	: PrimitiveSet(p) {
	root = BuildHBV(p);
}
HBVNode *HBVAccelerator::BuildHBV(const vector<Primitive *> &primitives) {
	BBox *bboxes = new BBox[primitives.size()];
	vector<pair<Primitive *, BBox *> > prims;
	prims.reserve(primitives.size());
	for (u_int i = 0; i < primitives.size(); ++i) {
		bboxes[i] = primitives[i]->BoundWorldSpace();
		prims.push_back(make_pair(primitives[i], &bboxes[i]));
	}
	for (u_int j = 0; j < prims.size(); ++j) {
		int s = int(random() % prims.size());
		swap(prims[j], prims[s]);
	}
	HBVNode *ret = RecursiveBuild(prims);
	delete[] bboxes;
	return ret;
}


HBVNode *HBVAccelerator::RecursiveBuild(const vector<pair<Primitive *, BBox *> > &primitives) {
	if (primitives.size() == 0)
		return NULL;
	else if (primitives.size() == 1)
		return new HBVNode(primitives[0]);
	else if (primitives.size() == 2)
		return new HBVNode(new HBVNode(primitives[0]), new HBVNode(primitives[1]));
	else {
		vector<pair<Primitive *, BBox *> > p1, p2;
		BBox bounds(*primitives[0].second);
		for (u_int i = 0; i < primitives.size(); ++i)
			bounds = Union(bounds, *primitives[i].second);
		int axis;
		Float dx = bounds.pMax.x - bounds.pMin.x;
		Float dy = bounds.pMax.y - bounds.pMin.y;
		Float dz = bounds.pMax.z - bounds.pMin.z;
		if (dx > dy) {
			if (dz > dz) axis = 0;
			else axis = 2;
		}
		else {
			if (dy > dz) axis = 1;
			else axis = 2;
		}
		#if 0
		vector<pair<Primitive *, Point> > primPartition;
		primPartition.reserve(primitives.size());
		for (u_int i = 0; i < primitives.size(); ++i) {
			Point Pcenter = 0.5 * primitives[i].second->pMin +
				0.5 * primitives[i].second->pMax;
			primPartition.push_back(primitives[i].first, Pcenter);
		}
		#endif
		Float pmid;
		switch (axis) {
			case 0: // X
				#if 0
				int mid = XPartition(primPartition, 0, primPartition.size());
				int i;
				for (i = 0; i < mid; ++i)
					p1.push_back(primitives[i]);
				for (; i < primPartition.size(); ++i)
					p2.push_back(primitives[i]);
				#else
				pmid = (bounds.pMax.x + bounds.pMin.x) * 0.5;
				for (u_int i = 0; i < primitives.size(); ++i) {
					if (primitives[i].second->pMax.x < pmid) {
						p1.push_back(primitives[i]);
					}
					else if (primitives[i].second->pMin.x > pmid) {
						p2.push_back(primitives[i]);
					}
					else {
						if (RandomFloat() > 0.5) p1.push_back(primitives[i]);
						else p2.push_back(primitives[i]);
					}
				}
				#endif
				break;
			case 1: // Y
				pmid = (bounds.pMax.y + bounds.pMin.y) * 0.5;
				for (u_int i = 0; i < primitives.size(); ++i) {
					if (primitives[i].second->pMax.y < pmid) {
						p1.push_back(primitives[i]);
					}
					else if (primitives[i].second->pMin.y > pmid) {
						p2.push_back(primitives[i]);
					}
					else {
						if (RandomFloat() > 0.5) p1.push_back(primitives[i]);
						else p2.push_back(primitives[i]);
					}
				}
				break;
			case 2: // Z
				pmid = (bounds.pMax.z + bounds.pMin.z) * 0.5;
				for (u_int i = 0; i < primitives.size(); ++i) {
					if (primitives[i].second->pMax.z < pmid) {
						p1.push_back(primitives[i]);
					}
					else if (primitives[i].second->pMin.z > pmid) {
						p2.push_back(primitives[i]);
					}
					else {
						if (RandomFloat() > 0.5) p1.push_back(primitives[i]);
						else p2.push_back(primitives[i]);
					}
				}
				break;
		}
		HBVNode *side1 = RecursiveBuild(p1);
		HBVNode *side2 = RecursiveBuild(p2);
		if (side1) {
			if (side2) return new HBVNode(side1, side2);
			else return side1;
		}
		else return side2;
	}
}
int HBVAccelerator::XPartition(vector<pair<Primitive *, Point> > &prims, int start, int end)
{
//	int ranitem = random() % (end-start);
	Assert(1 == 0);
	return 0;
}
HBVAccelerator::~HBVAccelerator() {
	delete root;
}
bool HBVAccelerator::Intersect(const Ray &ray, Surf *surf, bool directFromEye) const {
	return RecursiveIntersect((HBVNode *)root, ray, surf, directFromEye);
}
bool HBVAccelerator::RecursiveIntersect(HBVNode *&node, const Ray &ray, Surf *surf, bool directFromEye) {
	Assert(node != NULL);
	if (!node->bounds.Inside(ray(ray.mint)) && (!node->bounds.IntersectP(ray)))
		return false;

	if (node->isLeaf) {
		#if 0
		if (node->primitive->CanIntersect())
		#endif
			return node->primitive->Intersect(ray, surf, directFromEye);
		#if 0
		else {
			vector<Primitive *> refined;
			node->primitive->Refine(&refined);
			HBVNode *newNode = BuildHBV(refined);
			delete node;
			node = newNode;
			return RecursiveIntersect(node, ray, hit);
		}
		#endif
	}
	else {
		Assert(node->children[0] && node->children[1]);
		bool eitherHit = RecursiveIntersect(node->children[0], ray, surf, directFromEye);
		if (eitherHit && !surf)
			// quick out for shadow rays
			return true;
		if (RecursiveIntersect(node->children[1], ray, surf, directFromEye))
			eitherHit = true;
		return eitherHit;
	}
}
