
#include "lrt.h"
#include "primitives.h"
#include "accel.h"

bool Scene::Intersect(const Ray & ray, Float mint, Float * maxt,
					  HitInfo * hit) const
{
	if (!accelerator)
		return false;
	return accelerator->IntersectClosest(ray, mint, maxt, hit);
}

Scene::Scene(const Options &, const vector < Primitive * >&prim)
:	primitives(prim)
{
	if (primitives.size() > 0)
		accelerator = new GridAccelerator(primitives);
	else
		accelerator = NULL;
}

Scene::~Scene()
{
	delete accelerator;
//  for (u_int i = 0; i < primitives.size(); ++i)
//      delete primitives[i];
}

bool Scene::Unoccluded(const Point & p1, const Point & p2) const
{
	if (!accelerator)
		return true;

	static int shadowRayChecks = 0, shadowRayOccluded = 0;
	if (shadowRayChecks == 0)
		StatsRegisterRatio(STATS_DETAILED, "Integration",
						   "Finite Shadow Ray Checks",
						   &shadowRayOccluded, &shadowRayChecks);
	++shadowRayChecks;

	Float tmin = 1e-6;
	Float tmax = 1. - tmin;
	if (accelerator->IntersectClosest(Ray(p1, p2 - p1), tmin, &tmax, NULL)) {
		++shadowRayOccluded;
		return false;
	}

	return true;
}

bool Scene::Unoccluded(const Ray & r) const
{
	if (!accelerator)
		return true;

	static int shadowRayChecks = 0, shadowRayOccluded = 0;
	if (shadowRayChecks == 0)
		StatsRegisterRatio(STATS_DETAILED, "Integration",
						   "Infinite Shadow Ray Checks",
						   &shadowRayOccluded, &shadowRayChecks);
	++shadowRayChecks;

	Float tmin = 1e-6;
	Float tmax = INFINITY;
	Ray ray = r;
	ray.D.Normalize();
	if (accelerator->IntersectClosest(ray, tmin, &tmax, NULL)) {
		++shadowRayOccluded;
		return false;
	}

	return true;
}

Accelerator::~Accelerator()
{
}

ListAccelerator::ListAccelerator(const vector < Primitive * >&prims)
{
}

bool ListAccelerator::IntersectClosest(const Ray & ray, Float mint,
									   Float * maxt, HitInfo * hit)
{
	return false;
}

GridAccelerator::GridAccelerator(const vector < Primitive * >&primitives)
{

	Assert(primitives.size() > 0);
	bbox = primitives[0]->BoundWorldSpace();
	for (u_int i = 1; i < primitives.size(); ++i)
		bbox = Union(bbox, primitives[i]->BoundWorldSpace());

	vector < Primitive * >expandedPrimitives;
	vector < Primitive * >primitivesToProcess = primitives;
	while (primitivesToProcess.size()) {
		int lastPrim = primitivesToProcess.size() - 1;
		Primitive *prim = primitivesToProcess[lastPrim];
		primitivesToProcess.pop_back();

		if (prim->CanIntersect())
			expandedPrimitives.push_back(prim);
		else {
			vector < Primitive * >refined;
			prim->Refine(&refined);
			for (u_int j = 0; j < refined.size(); ++j)
				primitivesToProcess.push_back(refined[j]);
		}
	}

	int cubeRoot = (int) pow(expandedPrimitives.size(), .333333333);
	Float dx = 1.00001 * (bbox.pMax.x - bbox.pMin.x);
	Float dy = 1.00001 * (bbox.pMax.y - bbox.pMin.y);
	Float dz = 1.00001 * (bbox.pMax.z - bbox.pMin.z);
	Float invmaxWidth = 1.0 / max(dx, max(dy, dz));
	Assert(invmaxWidth > 0.);
	XVoxels = min( 100, max(1, 3 * Round(cubeRoot * dx * invmaxWidth)));
	YVoxels = min( 100, max(1, 3 * Round(cubeRoot * dy * invmaxWidth)));
	ZVoxels = min( 100, max(1, 3 * Round(cubeRoot * dz * invmaxWidth)));
	//XVoxels = max(1, 3 * Round(cubeRoot * dx * invmaxWidth));
	//YVoxels = max(1, 3 * Round(cubeRoot * dy * invmaxWidth));
	//ZVoxels = max(1, 3 * Round(cubeRoot * dz * invmaxWidth));

	XWidth = dx / XVoxels;
	YWidth = dy / YVoxels;
	ZWidth = dz / ZVoxels;
	InvXWidth = (XWidth == 0.) ? 0. : 1. / XWidth;
	InvYWidth = (YWidth == 0.) ? 0. : 1. / YWidth;
	InvZWidth = (ZWidth == 0.) ? 0. : 1. / ZWidth;
	cells = new list < Primitive * >[XVoxels * YVoxels * ZVoxels];

	addPrimitivesToVoxels(expandedPrimitives);
}

void GridAccelerator::addPrimitivesToVoxels(const vector <
											Primitive * >&primitives)
{
	for (u_int i = 0; i < primitives.size(); ++i) {

		BBox primBounds = primitives[i]->BoundWorldSpace();
		int x0 = max(x2v(primBounds.pMin.x), 0);
		int x1 = min(x2v(primBounds.pMax.x), XVoxels - 1);
		int y0 = max(y2v(primBounds.pMin.y), 0);
		int y1 = min(y2v(primBounds.pMax.y), YVoxels - 1);
		int z0 = max(z2v(primBounds.pMin.z), 0);
		int z1 = min(z2v(primBounds.pMax.z), ZVoxels - 1);

		for (int x = x0; x <= x1; ++x)
			for (int y = y0; y <= y1; ++y)
				for (int z = z0; z <= z1; ++z) {
					int offset = z * XVoxels * YVoxels + y * XVoxels + x;
					cells[offset].push_back(primitives[i]);
				}

	}
}

GridAccelerator::~GridAccelerator()
{
	delete[]cells;
}

bool GridAccelerator::IntersectClosest(const Ray & ray, Float mint,
									   Float * maxt, HitInfo * hitInfo)
{

	Float rayT = *maxt;
	if (bbox.Inside(ray(mint)))
		rayT = mint;
	else if (!bbox.IntersectP(ray, &rayT))
		return false;
	Point gridIntersect = ray(rayT);

	int x = x2v(gridIntersect.x);
	if (x == XVoxels)
		x--;
	Assert(x >= 0 && x < XVoxels);

	Float NextXCrossing, DeltaX;
	int StepX, OutX;
	if (fabs(ray.D.x) < RI_EPSILON) {

		NextXCrossing = INFINITY;
		DeltaX = 0;
		OutX = -1;

	} else if (ray.D.x > 0) {

		NextXCrossing = rayT + (v2x(x + 1) - gridIntersect.x) / ray.D.x;
		DeltaX = XWidth / ray.D.x;
		StepX = 1;
		OutX = XVoxels;

	} else {

		NextXCrossing = rayT + (v2x(x) - gridIntersect.x) / ray.D.x;
		DeltaX = -XWidth / ray.D.x;
		StepX = -1;
		OutX = -1;

	}

	int y = y2v(gridIntersect.y);
	if (y == YVoxels)
		y--;
	Float NextYCrossing, DeltaY;
	int StepY, OutY;

	Assert(y >= 0 && y < YVoxels);
	if (fabs(ray.D.y) < RI_EPSILON) {
		NextYCrossing = INFINITY;
		DeltaY = 0;
	} else if (ray.D.y < 0) {
		NextYCrossing = rayT + (v2y(y) - gridIntersect.y) / ray.D.y;
		DeltaY = -YWidth / ray.D.y;
		StepY = OutY = -1;
	} else {
		NextYCrossing = rayT + (v2y(y + 1) - gridIntersect.y) / ray.D.y;
		DeltaY = YWidth / ray.D.y;
		StepY = 1;
		OutY = YVoxels;
	}

	int z = z2v(gridIntersect.z);
	if (z == ZVoxels)
		z--;
	Float NextZCrossing, DeltaZ;
	int StepZ, OutZ;

	Assert(z >= 0 && z < ZVoxels);
	if (fabs(ray.D.z) < RI_EPSILON) {
		NextZCrossing = INFINITY;
		DeltaZ = 0;
	} else if (ray.D.z < 0) {
		NextZCrossing = rayT + (v2z(z) - gridIntersect.z) / ray.D.z;
		DeltaZ = -ZWidth / ray.D.z;
		StepZ = OutZ = -1;
	} else {
		NextZCrossing = rayT + (v2z(z + 1) - gridIntersect.z) / ray.D.z;
		DeltaZ = ZWidth / ray.D.z;
		StepZ = 1;
		OutZ = ZVoxels;
	}

	bool hitSomething = false;
	for (;;) {
		int offset = z * XVoxels * YVoxels + y * XVoxels + x;
		list < Primitive * >*primitiveList = cells + offset;
		if (primitiveList->size() > 0) {

			const Transform *lastTransform = NULL;
			Ray rayObj;
			list < Primitive * >::iterator iter = primitiveList->begin();
			while (iter != primitiveList->end()) {
				Primitive *prim = *iter;

				const Transform *primW2O =
					&(prim->attributes->WorldToObject);
				if (primW2O != lastTransform) {
					rayObj = (*primW2O) (ray);
					lastTransform = primW2O;
				}

				if (prim->IntersectClosest(rayObj, mint, maxt, hitInfo)) {
					hitSomething = true;
					if (!hitInfo)
						break;
				}

				++iter;
			}

		}

		if (NextXCrossing < NextYCrossing && NextXCrossing < NextZCrossing) {

			if (*maxt < NextXCrossing)
				break;
			x += StepX;
			if (x == OutX)
				break;
			NextXCrossing += DeltaX;

		} else if (NextZCrossing < NextYCrossing) {

			if (*maxt < NextZCrossing)
				break;
			z += StepZ;
			if (z == OutZ)
				break;
			NextZCrossing += DeltaZ;

		} else {

			if (*maxt < NextYCrossing)
				break;
			y += StepY;
			if (y == OutY)
				break;
			NextYCrossing += DeltaY;

		}

	}
	return hitSomething;

}

HBVNode::HBVNode(const pair < Primitive *, BBox * >&prim)
:	bounds(*prim.second)
{
	isLeaf = true;
	primitive = prim.first;
}

HBVNode::HBVNode(HBVNode * child1, HBVNode * child2)
:	bounds(Union(child1->bounds, child2->bounds))
{
	isLeaf = false;
	children[0] = child1;
	children[1] = child2;
}

HBVNode::~HBVNode()
{
	if (!isLeaf) {
		delete children[0];
		delete children[1];
	}
}

HBVAccelerator::HBVAccelerator(const vector < Primitive * >&primitives)
{
	root = BuildHBV(primitives);
}

HBVNode *HBVAccelerator::BuildHBV(const vector < Primitive * >&primitives)
{
	BBox *bboxes = new BBox[primitives.size()];
	vector<pair<Primitive *, BBox *> > prims;

	prims.reserve(primitives.size());
	for (u_int i = 0; i < primitives.size(); ++i) {
		bboxes[i] = primitives[i]->BoundWorldSpace();
		prims.push_back(make_pair(primitives[i], &bboxes[i]));
	}

	for (u_int j = 0; j < prims.size(); ++j) {
		int s = int (random() % prims.size());
		swap(prims[j], prims[s]);
	}

	HBVNode *ret = RecursiveBuild(prims);
	delete[]bboxes;
	return ret;
}



HBVNode *HBVAccelerator::RecursiveBuild(const vector<pair<Primitive *, BBox *> > &primitives) {

if (primitives.size() == 0)
	return NULL;
else if (primitives.size() == 1)
	return new HBVNode(primitives[0]);
else if (primitives.size() == 2)
	return new HBVNode(new HBVNode(primitives[0]),
					   new HBVNode(primitives[1]));
else {

		vector<pair<Primitive *, BBox *> > p1, p2;

	BBox bounds(*primitives[0].second);
	for (u_int i = 0; i < primitives.size(); ++i)
		bounds = Union(bounds, *primitives[i].second);

	int axis;
	Float dx = bounds.pMax.x - bounds.pMin.x;
	Float dy = bounds.pMax.y - bounds.pMin.y;
	Float dz = bounds.pMax.z - bounds.pMin.z;
	if (dx > dy) {
		if (dz > dz)
			axis = 0;
		else
			axis = 2;
	} else {
		if (dy > dz)
			axis = 1;
		else
			axis = 2;
	}

#if 0
		vector<pair<Primitive *, Point> > primPartition;
	primPartition.reserve(primitives.size());
	for (u_int i = 0; i < primitives.size(); ++i) {
		Point Pcenter = 0.5 * primitives[i].second->pMin +
			0.5 * primitives[i].second->pMax;
		primPartition.push_back(primitives[i].first, Pcenter);
	}
#endif
	Float pmid;
	switch (axis) {
	case 0:					// X

#if 0
		int mid = XPartition(primPartition, 0, primPartition.size());
		int i;
		for (i = 0; i < mid; ++i)
			p1.push_back(primitives[i]);
		for (; i < primPartition.size(); ++i)
			p2.push_back(primitives[i]);
#else
		pmid = (bounds.pMax.x + bounds.pMin.x) * 0.5;
		for (u_int i = 0; i < primitives.size(); ++i) {
			if (primitives[i].second->pMax.x < pmid) {
				p1.push_back(primitives[i]);
			} else if (primitives[i].second->pMin.x > pmid) {
				p2.push_back(primitives[i]);
			} else {
				if (RandomFloat() > 0.5)
					p1.push_back(primitives[i]);
				else
					p2.push_back(primitives[i]);
			}
		}
#endif

		break;
	case 1:					// Y

		pmid = (bounds.pMax.y + bounds.pMin.y) * 0.5;
		for (u_int i = 0; i < primitives.size(); ++i) {
			if (primitives[i].second->pMax.y < pmid) {
				p1.push_back(primitives[i]);
			} else if (primitives[i].second->pMin.y > pmid) {
				p2.push_back(primitives[i]);
			} else {
				if (RandomFloat() > 0.5)
					p1.push_back(primitives[i]);
				else
					p2.push_back(primitives[i]);
			}
		}

		break;
	case 2:					// Z

		pmid = (bounds.pMax.z + bounds.pMin.z) * 0.5;
		for (u_int i = 0; i < primitives.size(); ++i) {
			if (primitives[i].second->pMax.z < pmid) {
				p1.push_back(primitives[i]);
			} else if (primitives[i].second->pMin.z > pmid) {
				p2.push_back(primitives[i]);
			} else {
				if (RandomFloat() > 0.5)
					p1.push_back(primitives[i]);
				else
					p2.push_back(primitives[i]);
			}
		}

		break;
	}

	HBVNode *side1 = RecursiveBuild(p1);
	HBVNode *side2 = RecursiveBuild(p2);
	if (side1) {
		if (side2)
			return new HBVNode(side1, side2);
		else
			return side1;
	} else
		return side2;

}
}

int HBVAccelerator::XPartition(vector<pair<Primitive *, Point> > &prims, int start, int end)

{
//  int ranitem = random() % (end-start);
	Assert(1 == 0);
	return 0;
}

HBVAccelerator::~HBVAccelerator()
{
	delete root;
}

bool HBVAccelerator::IntersectClosest(const Ray & ray, Float mint,
									  Float * maxt, HitInfo * hit)
{
	return RecursiveIntersect(root, ray, mint, maxt, hit);
}

bool HBVAccelerator::RecursiveIntersect(HBVNode * &node, const Ray & ray,
										Float mint, Float * maxt,
										HitInfo * hit)
{
	Assert(node != NULL);

	Float hitt = *maxt;
	if (!node->bounds.Inside(ray(mint)) &&
		(!node->bounds.IntersectP(ray, &hitt) || hitt < mint))
		return false;


	if (node->isLeaf) {

		Ray rayo = (node->primitive->attributes->WorldToObject) (ray);
		if (node->primitive->CanIntersect())
			return node->primitive->IntersectClosest(rayo, mint, maxt,
													 hit);
		else {
			vector < Primitive * >refined;
			node->primitive->Refine(&refined);
			HBVNode *newNode = BuildHBV(refined);
			delete node;
			node = newNode;
			return RecursiveIntersect(node, ray, mint, maxt, hit);
		}

	} else {

		Assert(node->children[0] && node->children[1]);
		bool eitherHit =
			RecursiveIntersect(node->children[0], ray, mint, maxt, hit);
		if (eitherHit && !hit)
			// quick out for shadow rays
			return true;
		if (RecursiveIntersect(node->children[1], ray, mint, maxt, hit))
			eitherHit = true;
		return eitherHit;

	}
}
