#include "martiniglass.h"

MartiniGlassContour::MartiniGlassContour( const Transform &o2w, Float _scale, 
    int _numPoints, Float * _contour ) : Shape( o2w ) {

    numPoints = _numPoints;

    contour = new ContourPoint[numPoints];

    Float maxZ = 0;
    
    for ( int i = 0; i < numPoints; i++ ) {
	contour[i].x = _contour[2*i];
	contour[i].z = _contour[2*i + 1];
	maxZ = max( maxZ, contour[i].z );
    }	

    // Scale the contour points by the amount requested 

    for ( int i = 0; i < numPoints; i++ ) {
	contour[i].x *= _scale;
	contour[i].z *= _scale;
    }

    ConstructStrips();

    ConstructBoundingBoxes();
}


BBox
MartiniGlassContour::Bound() const {
    return *bbox;
}


void
MartiniGlassContour::ConstructStrips() {
    Point p1, p2;
    Normal n1, n2;
    Normal nTmp1, nTmp2;
    Float d1, d2;
    Float dx1, dz1;
    Float dx2, dz2; 

    strip = new (Strip*)[numPoints-1];

    for ( int i = 0; i < numPoints-1; i++ ) {

	if ( i == 0 ) {
	    dx1 = contour[1].x - contour[0].x;
	    dz1 = contour[1].z - contour[0].z;

	    n1 = Normal( 0, 0, 0 );
	    if ( dz1 > 0 )
		n1.x = -1.0;
	    else
		n1.x = 1.0;
	    n1.z = -dx1 * n1.x / dz1;
	    n1.Normalize();
	}
	else {
	    dx1 = contour[i].x - contour[i-1].x;
	    dz1 = contour[i].z - contour[i-1].z;
	    d1 = sqrt( dx1*dx1 + dz1*dz1 );
	    dx2 = contour[i+1].x - contour[i].x;
	    dz2 = contour[i+1].z - contour[i].z;
	    d2 = sqrt( dx2*dx2 + dz2*dz2 );

	    nTmp1 = Normal( 0, 0, 0 );
	    if ( dz1 > 0 ) 
		nTmp1.x = -1.0;
	    else
		nTmp1.x = 1.0;
	    nTmp1.z = -dx1* nTmp1.x / dz1;
	    nTmp1.Normalize();
	    
	    nTmp2 = Normal( 0, 0, 0 );
	    if ( dz2 > 0 )
		nTmp2.x = -1.0;
	    else
		nTmp2.x = 1.0;
	    nTmp2.z = -dx2 * nTmp2.x / dz2;
	    nTmp2.Normalize();	

	    // Interpolate the two normals according to area/distance between 
	    // contour points

	    n1 = (d1 * nTmp1 + d2 * nTmp2) / (d1 + d2);
	    n1.Normalize();
	}
	    
	if ( i == numPoints - 2 ) {
	    dx1 = contour[i+1].x - contour[i].x;
	    dz1 = contour[i+1].z - contour[i].z;

	    n2 = Normal( 0, 0, 0 );
	    if ( dz1 > 0 )
		n2.x = -1.0;
	    else
		n2.x = 1.0;
	    n2.z = -dx1 * n2.x / dz1;
	    n2.Normalize();
	}
	else {
	    dx1 = contour[i+1].x - contour[i].x;
	    dz1 = contour[i+1].z - contour[i].z;
	    d1 = sqrt( dx1*dx1 + dz1*dz1 );
	    dx2 = contour[i+2].x - contour[i+1].x;
	    dz2 = contour[i+2].z - contour[i+1].z;
	    d2 = sqrt( dx2*dx2 + dz2*dz2 );

	    nTmp1 = Normal( 0, 0, 0 );
	    if ( dz1 > 0 ) 
		nTmp1.x = -1.0;
	    else
		nTmp1.x = 1.0;
	    nTmp1.z = -dx1* nTmp1.x / dz1;
	    nTmp1.Normalize();
	    
	    nTmp2 = Normal( 0, 0, 0 );
	    if ( dz2 > 0 )
		nTmp2.x = -1.0;
	    else
		nTmp2.x = 1.0;
	    nTmp2.z = -dx2 * nTmp2.x / dz2;
	    nTmp2.Normalize();	

	    // Interpolate the two normals according to area/distance between 
	    // contour points

	    n2 = (d1 * nTmp1 + d2 * nTmp2) / (d1 + d2);
	    n2.Normalize();
	}

	p1 = Point( contour[i].x, 0, contour[i].z );
	p2 = Point( contour[i+1].x, 0, contour[i+1].z );
	strip[i] = new Strip( p1, p2, n1, n2 );
    }
}

void
MartiniGlassContour::ConstructBoundingBoxes() {
    bbox = new BBox( strip[0]->Bound() );

    for ( int i = 1; i < numPoints-1; i++ ) 
	*bbox = Union( *bbox, strip[i]->Bound() );
}

bool
MartiniGlassContour::Intersect( const Ray &r, DifferentialGeometry *dg ) const {

    bool intersectFound = false;
    
    Ray ray = WorldToObject( r );

    if ( bbox->IntersectP( ray ) ) {
	ray.maxt = r.maxt;
	for ( int i = 0; i < numPoints-1; i++ )
	    if ( strip[i]->Bound().IntersectP( ray ) ) {
		ray.maxt = r.maxt;
		if ( strip[i]->Intersect( ray, dg ) ) {
		    intersectFound = true;
		    r.maxt = ray.maxt;
		}
	    }
    }

    if ( intersectFound ) {
	*dg = ObjectToWorld( *dg );
    }

    return intersectFound;
}

MartiniGlassContour::Strip::Strip( Point & _p1, Point & _p2, 
    Normal & _n1, Normal & _n2 ) {

    p1 = _p1;
    p2 = _p2;
    n1 = _n1;
    n2 = _n2;

    if ( p1.x > p2.x )
	radius = p1.x;
    else
	radius = p2.x;

    apexZ = p2.z - p2.x * (p2.z - p1.z) / (p2.x - p1.x);
    zMin = min( p1.z, p2.z );
    zMax = max( p1.z, p2.z );

    if ( apexZ <= zMin )
	height = zMax - apexZ;
    else
	height = apexZ - zMin;
/*   
    printf( "p1 : (%f, %f, %f), p2 : (%f, %f %f)\n", p1.x, p1.y, p1.z, p2.x, p2.y, p2.z );
    printf( "apexZ : %f height : %f\n", apexZ, height );
*/
    Point b1 = Point( -radius, -radius, zMin );
    Point b2 = Point( radius, radius, zMax );
    bbox = new BBox( b1, b2 );

    assert( height > 0 );
}

BBox
MartiniGlassContour::Strip::Bound() const {
    return *bbox;
}

bool
MartiniGlassContour::Strip::Intersect( const Ray &r, DifferentialGeometry *dg ) const {
	double u, v;
	Ray ray = r;
	double k = radius / height;
	k = k*k;
	double A = ray.D.x * ray.D.x +
			  ray.D.y * ray.D.y -
		  k * ray.D.z * ray.D.z;
	double B = 2 * (ray.D.x * ray.O.x +
				   ray.D.y * ray.O.y -
			   k * ray.D.z * (ray.O.z-apexZ) );
	double C = ray.O.x * ray.O.x +
			  ray.O.y * ray.O.y -
		  k * (ray.O.z - apexZ) * (ray.O.z-apexZ);
	
	double discrim = B * B - 4. * A * C;
	if (discrim < 0.) return false;
	double rootDiscrim = sqrt(discrim);
	double q;
	if (B < 0) q = -.5 * (B - rootDiscrim);
	else       q = -.5 * (B + rootDiscrim);
	Float t0 = q / A, t1 = C / q;
	if (t0 > t1) swap(t0, t1);
	if (t0 > ray.maxt || t1 < ray.mint)
		return false;
	Float thit = t0;
	if (t0 < ray.mint) thit = t1;
	if (thit > ray.maxt) return false;

	Point Phit = ray(thit);
	//u = 1. - (atan2(Phit.y, -Phit.x) + M_PI) / (2*M_PI);
        u = atan2(Phit.y, Phit.x);
        if ( u < 0 )
            u = 2*M_PI + u;
        u /= 2*M_PI;
	v = fabs( Phit.z - apexZ ) / height;

	if (Phit.z < zMin || Phit.z > zMax ) {
		if (thit == t1) return false;
		thit = t1;
		if (t1 > ray.maxt) return false;
		Phit = ray(thit);
		Point Phit = ray(thit);
	//	u = 1. - (atan2(Phit.y, -Phit.x) + M_PI) / (2*M_PI);
                u = atan2(Phit.y, Phit.x);
                if ( u < 0 )
                    u = 2*M_PI + u;
                u /= 2*M_PI;
		v = fabs( Phit.z - apexZ ) / height;
		if (Phit.z < zMin || Phit.z > zMax )
			return false;
	}

	Normal interpN = (( Phit.z - p2.z ) * n1 + ( p1.z - Phit.z ) * n2) / 
	    (p1.z - p2.z);
	interpN.Normalize();
	Normal N( interpN.x * cos( u*2*M_PI ), interpN.x * sin( u*2*M_PI ), interpN.z );

	Vector S, T;
	Float denom = sqrt(Phit.x*Phit.x + Phit.y*Phit.y);
	if (denom == 0) {
		S = Vector( 1,0,0 );
	} else  {
		S = Vector( - Phit.y / denom, Phit.x / denom );
	}
	T = Cross( Vector(N), S );
	*dg = DifferentialGeometry(Phit, N, S, T, u, v);
	r.maxt = thit;
	return true;
}
