// Fragment shader for the GPU water flow simulation pass that calculates position offsets
#version 430

layout(local_size_x = 8, local_size_y = 8, local_size_z = 1) in;

layout (binding = 0) uniform sampler2DRect waterOffset0Read;
layout (binding = 1) uniform sampler2DRect water0Read;
layout (binding = 2) uniform sampler2DRect water1Read;
layout (binding = 3) uniform sampler2DRect water2Read;
layout (binding = 4) uniform sampler2DRect water3Read;

layout (binding = 0, rgba32f) uniform restrict image2DRect waterOffset0Write;
layout (binding = 1, rgba8) uniform restrict image3D normals0Write;
layout (binding = 2, rgba32f) uniform restrict image2DRect debugOutput0Write;

uniform ivec2 water_size;
uniform int water_resolution;
uniform bool WRITE_DEBUG;
uniform bool OFFSET_ENHANCEMENTS;
uniform float TRIANGLE_SPACING_CUTOFF;
uniform float TRIANGLE_SPACING_SPEED;
uniform float RIVER_THRESHOLD;
uniform bool NEW_TRIANGLE_SWITCHING;

#include "common/data/water.glsl"

vec2 projectPointOntoLine(vec2 lineStart, vec2 lineEnd, vec2 point) {
    vec2 lineVec = lineEnd - lineStart;
    vec2 pointVec = point - lineStart;
    float projLength = dot(pointVec, lineVec) / length(lineVec);
    return lineStart + projLength * normalize(lineVec);
}

// Triangles with the third point implied as being the current point we're operating on
struct Triangle {
	vec2 point1;
	vec2 point2;
};

vec3 calculateNormal(vec3 p1, vec3 p2, vec3 p3) {
    return normalize(cross(p2 - p1, p3 - p1));
}

void main() {

	ivec2 dataPos = ivec2(gl_GlobalInvocationID.xy);
	vec4 output1, debugOutput = vec4(0);

	WaterData waterData[4];
	waterData[0] = getWaterData(dataPos + ivec2(0, 0), false, false, false);
	waterData[1] = getWaterData(dataPos + ivec2(1, 0), false, false, false);
	waterData[2] = getWaterData(dataPos + ivec2(0, 1), false, false, false);
	waterData[3] = getWaterData(dataPos + ivec2(1, 1), false, false, false);

	vec4 water2Data = texelFetch(water2Read, dataPos);
	vec4 waterOffset0Data = texelFetch(waterOffset0Read, dataPos);
	float flipTrianglesCounter = waterOffset0Data[3];

	// find all inflows
	vec4 neighbourOther[4];
	neighbourOther[0] = texelFetch(water2Read, dataPos + ivec2(0, -1));
	neighbourOther[1] = texelFetch(water2Read, dataPos + ivec2(1, 0));
	neighbourOther[2] = texelFetch(water2Read, dataPos + ivec2(0, 1));
	neighbourOther[3] = texelFetch(water2Read, dataPos + ivec2(-1, 0));

	vec4 neighbourFlows[4];
	neighbourFlows[0] = texelFetch(water1Read, dataPos + ivec2(0, -1));
	neighbourFlows[1] = texelFetch(water1Read, dataPos + ivec2(1, 0));
	neighbourFlows[2] = texelFetch(water1Read, dataPos + ivec2(0, 1));
	neighbourFlows[3] = texelFetch(water1Read, dataPos + ivec2(-1, 0));

	vec4 neighbourOffsets[4];
	neighbourOffsets[0] = texelFetch(waterOffset0Read, dataPos + ivec2(0, -1));
	neighbourOffsets[1] = texelFetch(waterOffset0Read, dataPos + ivec2(1, 0));
	neighbourOffsets[2] = texelFetch(waterOffset0Read, dataPos + ivec2(0, 1));
	neighbourOffsets[3] = texelFetch(waterOffset0Read, dataPos + ivec2(-1, 0));

	// figure out whether to flip the ordering of triangle rendering based on where water is flowing
	float flipTriangles = 0.0;
	float previousFlipTriangles = waterOffset0Data[2];
	if (
		(waterData[0].waterHeight > RIVER_THRESHOLD && waterData[1].waterHeight > RIVER_THRESHOLD && 
		 waterData[3].waterHeight > RIVER_THRESHOLD && waterData[2].waterHeight < RIVER_THRESHOLD) ||
		(waterData[0].waterHeight > RIVER_THRESHOLD && waterData[2].waterHeight > RIVER_THRESHOLD && 
		 waterData[3].waterHeight > RIVER_THRESHOLD && waterData[1].waterHeight < RIVER_THRESHOLD)) {
		flipTriangles = 1.0;
	}

	// Flip triangles near the ocean to smooth out the coastline
	float _min = min(min(waterData[0].landHeight, waterData[1].landHeight), min(waterData[2].landHeight, waterData[3].landHeight));
	if (_min < 0 && (_min == waterData[1].landHeight || _min == waterData[2].landHeight)) {
		flipTriangles = 1.0;
	}
	if (flipTriangles != previousFlipTriangles) {
		if (flipTrianglesCounter < 0) {
			flipTrianglesCounter = 10.0;
		} else {
			flipTrianglesCounter = flipTrianglesCounter - 1.0;
			flipTriangles = previousFlipTriangles;	
		}
	} else {
		flipTrianglesCounter = 10.0;
	}

	// find biggest inflow?
	float biggest = 0.001f;
	// float biggest = 0.01f;
	int biggestAt = -1;
	float speed = 30.0;

	if (neighbourOther[0].x == 2 && neighbourFlows[0].z > biggest) {
		biggest = neighbourFlows[0].z;
		biggestAt = 0;
	}
	if (neighbourOther[1].x == 3 && neighbourFlows[1].w > biggest) {
		biggest = neighbourFlows[1].w;
		biggestAt = 1;
	}
	if (neighbourOther[2].x == 0 && neighbourFlows[2].x > biggest) {
		biggest = neighbourFlows[2].x;
		biggestAt = 2;
	}
	if (neighbourOther[3].x == 1 && neighbourFlows[3].y > biggest) {
		biggest = neighbourFlows[3].y;
		biggestAt = 3;
	}

	const float maxDeviation = 0.4f;
	vec2 perps[4];
	perps[0] = vec2(1, 1);
	perps[1] = vec2(-1, 1);
	perps[2] = vec2(-1, -1);
	perps[3] = vec2(1, -1);

	if (biggestAt != -1 && dataPos.x >= 1 && dataPos.y >= 1 && dataPos.x < water_size.x - 1 && dataPos.y < water_size.y - 1) {
		int outflowDirection = int(water2Data[0]);
		// find offset of inflow
		vec2 inflowOffset = neighbourOffsets[biggestAt].xy;

		// find our offset
		vec2 offset = waterOffset0Data.xy;

		// find offest of outflow
		vec2 outflowOffset = neighbourOffsets[outflowDirection].xy;

		// calculate angle between inflow and outflow
		vec2 inflowVector = vec2(0,0);
		vec2 outflowVector = vec2(0, 0);

		if (biggestAt == 0) { inflowVector = vec2(0, 1); }
		if (biggestAt == 1) { inflowVector = vec2(-1, 0); }
		if (biggestAt == 2) { inflowVector = vec2(0, -1); }
		if (biggestAt == 3) { inflowVector = vec2(1, 0); }
		inflowVector = inflowVector - inflowOffset;
		inflowVector = inflowVector + offset;

		if (outflowDirection == 0) { outflowVector = vec2(0, -1); }
		if (outflowDirection == 1) { outflowVector = vec2(1, 0); }
		if (outflowDirection == 2) { outflowVector = vec2(0, 1); }
		if (outflowDirection == 3) { outflowVector = vec2(-1, 0); }
		outflowVector = outflowVector - offset;
		outflowVector = outflowVector + outflowOffset;
		vec2 overallVector = inflowVector + outflowVector;

		float mag = 1.0 - (dot(inflowVector, outflowVector) / (length(inflowVector) * length(outflowVector)));
		float side = (overallVector.x * inflowVector.y) - (inflowVector.x * overallVector.y);
		vec2 perpendicular = vec2(-overallVector.y, overallVector.x);
		if (side > 0) {
			perpendicular = vec2(overallVector.y, -overallVector.x);
		}

		output1 = vec4(
			clamp(waterOffset0Data.x + ((perpendicular.x / (2000.f / speed)) * mag), -maxDeviation, maxDeviation), 
			clamp(waterOffset0Data.y + ((perpendicular.y / (2000.f / speed)) * mag), -maxDeviation, maxDeviation), 
			flipTriangles, flipTrianglesCounter
		);
	} else {
		output1 = vec4(waterOffset0Data.x * 0.9f, waterOffset0Data.y * 0.9f, flipTriangles, flipTrianglesCounter);
	}

	int numTriangles = 0;
	Triangle triangles[8]; // 8 = most triangles we could be part of

	vec4 topLeftOffset = texelFetch(waterOffset0Read, dataPos + ivec2(-1, -1));
	vec4 topRightOffset = texelFetch(waterOffset0Read, dataPos + ivec2(1, -1));
	vec4 bottomLeftOffset = texelFetch(waterOffset0Read, dataPos + ivec2(-1, 1));
	vec4 bottomRightOffset = texelFetch(waterOffset0Read, dataPos + ivec2(1, 1));

	vec2 addedForce = vec2(0, 0);

	if (topLeftOffset[2] == 1.f) {
		triangles[numTriangles++] = Triangle(vec2(-1, -1) + topLeftOffset.xy, vec2(0, -1) + neighbourOffsets[0].xy);
		triangles[numTriangles++] = Triangle(vec2(-1, -1) + topLeftOffset.xy, vec2(-1, 0) + neighbourOffsets[3].xy);
	} else {
		triangles[numTriangles++] = Triangle(vec2(-1, 0) + neighbourOffsets[3].xy, vec2(0, -1) + neighbourOffsets[0].xy);
	}

	if (neighbourOffsets[0][2] == 1.f) {
		triangles[numTriangles++] = Triangle(vec2(0, -1) + neighbourOffsets[0].xy, vec2(1, 0) + neighbourOffsets[1].xy);
	} else {
		triangles[numTriangles++] = Triangle(vec2(0, -1) + neighbourOffsets[0].xy, vec2(1, -1) + topRightOffset.xy);
		triangles[numTriangles++] = Triangle(vec2(1, -1) + topRightOffset.xy, vec2(1, 0) + neighbourOffsets[1].xy);
	}

	if (waterOffset0Data[2] == 1.f) {
		triangles[numTriangles++] = Triangle(vec2(1, 0) + neighbourOffsets[1].xy, vec2(1, 1) + bottomRightOffset.xy);
		triangles[numTriangles++] = Triangle(vec2(1, 1) + bottomRightOffset.xy, vec2(0, 1) + neighbourOffsets[2].xy);
	} else {
		triangles[numTriangles++] = Triangle(vec2(1, 0) + neighbourOffsets[1].xy, vec2(0, 1) + neighbourOffsets[2].xy);
	}

	if (neighbourOffsets[3][2] == 1.f) {
		triangles[numTriangles++] = Triangle(vec2(-1, 0) + neighbourOffsets[3].xy, vec2(0, 1) + neighbourOffsets[2].xy);
	} else {
		triangles[numTriangles++] = Triangle(vec2(-1, 0) + neighbourOffsets[3].xy, vec2(-1, 1) + bottomLeftOffset.xy);
		triangles[numTriangles++] = Triangle(vec2(-1, 1) + bottomLeftOffset.xy, vec2(0, 1) + neighbourOffsets[2].xy);
	}

	for (int i=0; i<numTriangles; i++) {
		vec2 projectedPoint = projectPointOntoLine(triangles[i].point1, triangles[i].point2, output1.xy);
		// should operate in squared distance to make everything more efficient
		float perpendicularDistance = distance(projectedPoint, output1.xy);
		if (perpendicularDistance < TRIANGLE_SPACING_CUTOFF) {
			addedForce += (output1.xy - projectedPoint) * TRIANGLE_SPACING_SPEED * (0.1f + (TRIANGLE_SPACING_CUTOFF - perpendicularDistance));
		}
	}
	output1.xy += addedForce;

	vec3 normals[2];

	/* for now, everything looks better if we don't calculate the actual normals with water offset taken into account
		but rather the normals without the offset vertices.  This might change later.
		It's less accurate, but it sidesteps issues from some triangles being much steeper or flatter than they really should be */

	// if (output1[2] == 0.0) {
	// 	normals[0] = calculateNormal(
	// 		(vec3(output1.x, 0, output1.y) * water_resolution) + vec3(0, waterData[0].landHeight, 0), 
	// 		(vec3(neighbourOffsets[2].x, 0, 1 + neighbourOffsets[2].y) * water_resolution) + vec3(0, waterData[2].landHeight, 0),
	// 		(vec3(1 + neighbourOffsets[1].x, 0, neighbourOffsets[1].y) * water_resolution) + vec3(0, waterData[1].landHeight, 0)
	// 	);
	// 	normals[1] = calculateNormal(
	// 		(vec3(1 + neighbourOffsets[1].x, 0, neighbourOffsets[1].y) * water_resolution) + vec3(0, waterData[1].landHeight, 0),
	// 		(vec3(neighbourOffsets[2].x, 0, 1 + neighbourOffsets[2].y) * water_resolution) + vec3(0, waterData[2].landHeight, 0),
	// 		(vec3(1 + bottomRightOffset.x, 0, 1 + bottomRightOffset.y) * water_resolution) + vec3(0, waterData[3].landHeight, 0)
	// 	);
	// } else {
	// 	normals[0] = calculateNormal(
	// 		(vec3(output1.x, 0, output1.y) * water_resolution) + vec3(0, waterData[0].landHeight, 0), 
	// 		(vec3(1 + bottomRightOffset.x, 0, 1 + bottomRightOffset.y) * water_resolution) + vec3(0, waterData[3].landHeight, 0),
	// 		(vec3(1 + neighbourOffsets[1].x, 0, neighbourOffsets[1].y) * water_resolution) + vec3(0, waterData[1].landHeight, 0)
	// 	);
	// 	normals[1] = calculateNormal(
	// 		(vec3(output1.x, 0, output1.y) * water_resolution) + vec3(0, waterData[0].landHeight, 0), 
	// 		(vec3(neighbourOffsets[2].x, 0, 1 + neighbourOffsets[2].y) * water_resolution) + vec3(0, waterData[2].landHeight, 0),
	// 		(vec3(1 + bottomRightOffset.x, 0, 1 + bottomRightOffset.y) * water_resolution) + vec3(0, waterData[3].landHeight, 0)
	// 	);
	// }

	if (output1[2] == 0.0) {
		normals[0] = calculateNormal(
			(vec3(0, 0, 0) * water_resolution) + vec3(0, waterData[0].landHeight, 0), 
			(vec3(0, 0, 1) * water_resolution) + vec3(0, waterData[2].landHeight, 0),
			(vec3(1, 0, 0) * water_resolution) + vec3(0, waterData[1].landHeight, 0)
		);
		normals[1] = calculateNormal(
			(vec3(1, 0, 0) * water_resolution) + vec3(0, waterData[1].landHeight, 0),
			(vec3(0, 0, 1) * water_resolution) + vec3(0, waterData[2].landHeight, 0),
			(vec3(1, 0, 1) * water_resolution) + vec3(0, waterData[3].landHeight, 0)
		);
	} else {
		normals[0] = calculateNormal(
			(vec3(0, 0, 0) * water_resolution) + vec3(0, waterData[0].landHeight, 0), 
			(vec3(1, 0, 1) * water_resolution) + vec3(0, waterData[3].landHeight, 0),
			(vec3(1, 0, 0) * water_resolution) + vec3(0, waterData[1].landHeight, 0)
		);
		normals[1] = calculateNormal(
			(vec3(0, 0, 0) * water_resolution) + vec3(0, waterData[0].landHeight, 0), 
			(vec3(0, 0, 1) * water_resolution) + vec3(0, waterData[2].landHeight, 0),
			(vec3(1, 0, 1) * water_resolution) + vec3(0, waterData[3].landHeight, 0)
		);
	}

	imageStore(waterOffset0Write, dataPos, output1);
	imageStore(normals0Write, ivec3(dataPos.x, dataPos.y, 0), vec4(normals[0], 0));
	imageStore(normals0Write, ivec3(dataPos.x, dataPos.y, 1), vec4(normals[1], 0));

	if (WRITE_DEBUG) {
		debugOutput[0] = normals[0].y;
		debugOutput[1] = normals[1].y;
		debugOutput[2] = -999.f;
		imageStore(debugOutput0Write, dataPos, debugOutput);
	}
}