#version 430

#include "common/constants.glsl"
#include "common/data/biomes.glsl"
#include "common/octohedral.glsl"

uniform sampler2DArray coloursTexture;
uniform sampler2DArray normalsTexture;
uniform sampler2D atlasTexture;
uniform vec3 lightDirection;
uniform vec3 diffuseColour;
uniform vec3 ambientColour;
uniform float diffuseAmount;
uniform float specularAmount;
uniform float ambientAmount;
uniform mat4 ciViewMatrixInverse;
uniform mat4 ciModelViewProjection;
uniform bool DEBUG_BILLBOARDS;
uniform int INTERPOLATED_VIEWS;
uniform bool ambientOcclusionEnabled;
uniform float AMBIENT_OCCLUSION_AMOUNT;

uniform vec3 PLANE_NORMALS[BIOME_BITMAP_TOTAL_SIZE];
uniform vec3 PLANE_UPS[BIOME_BITMAP_TOTAL_SIZE];
uniform vec3 PLANE_RIGHTS[BIOME_BITMAP_TOTAL_SIZE];

in vec4 fColour;
in float fOpacity;
in float fDiffuse;
in float fAmbient;
in float fRotation;
in vec3 centrePosition;
in vec3 rotatedWorldPosition;
flat in ivec3 views;
flat in vec3 viewRatios; 
flat in int type;
flat in float fScale;
flat in vec3 cameraEyePosition;

out vec4 _output;

vec3 rotate(vec3 v, vec3 axis, float angle) {
    float cosTheta = cos(angle);
    float sinTheta = sin(angle);

    return (v * cosTheta) +
           (cross(axis, v) * sinTheta) +
           (axis * dot(axis, v) * (1.0 - cosTheta));
                   
}

vec3 linePlaneIntersection(vec3 P0, vec3 d, vec3 planePoint, vec3 planeNormal) {
    float denominator = dot(planeNormal, d);
    // If denominator is zero, the line is parallel to the plane
    if (abs(denominator) < 1e-6) {
        return vec3(0.0);
    }
    float t = dot(planeNormal, planePoint - P0) / denominator;
  
    return P0 + t * d;
}

vec2 getPlaneCoordinates(int plane, vec3 intersection, vec3 centre, int type) {
	return vec2(0.5f, 0.5f) + 
		vec2(
			dot(intersection - centre, PLANE_RIGHTS[plane]) / (fScale * plantDefinitions[type].width), 
			dot(intersection - centre, PLANE_UPS[plane]) / (fScale * plantDefinitions[type].height)
		);
}

void main() {
	vec4 projectedOutput = vec4(0);
	vec3 projectedNormal = vec3(0);
	float depth = 0.f;
	float opacity = 0.f;
	float viewTotal = 0.f;
	vec3 viewRay = normalize(cameraEyePosition - rotatedWorldPosition);
	float ambientOcclusion = 1.f;

	for (int i=0; i<3; i++) {
		if (i >= INTERPOLATED_VIEWS) {
			continue;
		}
		int view = views[i];
		vec3 intersection = linePlaneIntersection(cameraEyePosition, viewRay, centrePosition, PLANE_NORMALS[view]);
		vec2 planeCoords = getPlaneCoordinates(view, intersection, centrePosition, type);
		if (planeCoords.x > 0 && planeCoords.x < 1 && planeCoords.y > 0 && planeCoords.y < 1) {
			planeCoords.y = 1.f - planeCoords.y;
			vec3 texCoords = vec3(
				plantDefinitions[type].rotations[view].start + (plantDefinitions[type].rotations[view].size * planeCoords), 
				plantDefinitions[type].rotations[view].layer
			);
			vec4 colour = texture(coloursTexture, texCoords);
			vec4 normal = texture(normalsTexture, texCoords);
			vec4 trunkColour = vec4(plantDefinitions[type].trunkColourRed, plantDefinitions[type].trunkColourGreen,
							  plantDefinitions[type].trunkColourBlue, 1.f);

			if (colour.a > 0.f || colour.x > 0.f) {
				if (fColour.a > 0.5f) {
					projectedOutput += (fColour * colour.x * colour.b * viewRatios[i]);
					projectedOutput += (trunkColour * colour.a * (1.f - colour.b) * viewRatios[i]);
					opacity += fOpacity * viewRatios[i];
					projectedNormal += getNormalFromOctahedralCoordinates(normal.zw) * colour.b * viewRatios[i];
					projectedNormal += getNormalFromOctahedralCoordinates(normal.xy) * (1.f - colour.b) * viewRatios[i];
					viewTotal += viewRatios[i];
				} else {
					projectedOutput += (trunkColour * colour.a) * viewRatios[i];
					opacity += fOpacity * viewRatios[i];
					projectedNormal += getNormalFromOctahedralCoordinates(normal.xy) * viewRatios[i];
					viewTotal += viewRatios[i];
				}
				if (ambientOcclusionEnabled) {
					ambientOcclusion -= ((1.f - colour[1]) * viewRatios[i] * AMBIENT_OCCLUSION_AMOUNT);
				}

				vec4 v_clip_coord = ciModelViewProjection * vec4(intersection, 1.0);
				float f_ndc_depth = v_clip_coord.z / v_clip_coord.w;
				depth += (f_ndc_depth * viewRatios[i]);				
			}
		}
	}

	if (projectedOutput.a == 0) {
		discard;
	} else {
		projectedNormal = normalize(rotate(-projectedNormal, vec3(0, 1, 0), -fRotation).xyz);
		float diffuse = max(dot(projectedNormal, lightDirection), 0);
		_output = vec4(
			(diffuseColour * projectedOutput.rgb * fDiffuse * diffuse * diffuseAmount * fOpacity) + 
			(ambientColour * projectedOutput.rgb * fAmbient * ambientAmount * fOpacity * ambientOcclusion),
			fOpacity * projectedOutput.a
		);

		gl_FragDepth = depth / viewTotal;
	}
}