// Compute shader for the first step of the GPU water flow simulation
#version 430

layout(local_size_x = 8, local_size_y = 8, local_size_z = 1) in;

layout (binding = 0) uniform sampler2DRect water0Read;
layout (binding = 1) uniform sampler2DRect water1Read;
layout (binding = 2) uniform sampler2DRect water2Read;
layout (binding = 3) uniform sampler2DRect water3Read;
layout (binding = 4) uniform sampler2DRect flood0Read;

layout (binding = 0, rgba32f) uniform restrict image2DRect water0Write;
layout (binding = 1, rgba32f) uniform restrict image2DRect water1Write;
layout (binding = 2, rgba32f) uniform restrict image2DRect water2Write;
layout (binding = 3, rgba16) uniform restrict image2DRect flood0Write;

uniform float TIMESTEP;
uniform float RAIN_STRENGTH;
uniform float FLOW_PERCENT;
uniform float SLOPE_OFFSET;
uniform float ONE_DIRECTION_BIAS;

#include "common/noise.glsl"
#include "common/data/water.glsl"
#include "common/util.glsl"

void sort(inout vec4 _input, inout ivec4 _directions) {
	for (int i=1; i<=3; i++) {
		for (int j=3; j>=i; j--) {
			if (_input[j] >= _input[j - 1]) {
				float temp = _input[j];
				_input[j] = _input[j - 1];
				_input[j - 1] = temp;

				int _temp = _directions[j];
				_directions[j] = _directions[j - 1];
				_directions[j - 1] = _temp;
			}
		}
	}
}

void main() {
	ivec2 dataPos = ivec2(gl_GlobalInvocationID.xy);

	WaterData data = getWaterData(dataPos, true, true, true);
	data.primaryFlowDirection = 0;
	vec4 oldFlows = data.flows;
	vec4 floodFlows = texelFetch(flood0Read, dataPos);

	WaterData neighbours[4];
	neighbours[0] = getWaterData(dataPos + ivec2(0, -1), false, false, true);
	neighbours[1] = getWaterData(dataPos + ivec2(1, 0), false, false, true);
	neighbours[2] = getWaterData(dataPos + ivec2(0, 1), false, false, true);
	neighbours[3] = getWaterData(dataPos + ivec2(-1, 0), false, false, true);

	float largestFlow = -1.0;
	vec4 heightDifferences;
	ivec4 directions = ivec4(0, 1, 2, 3);
	float totalFlow = 0.f, totalFloodFlow = 0.f;

	/* Iterate over each neighbour to calculate the flow amount */
	for (int i=0; i<4; i++) {
		float heightDifference = (max(-4.f, data.landHeight) + data.waterHeight) - (max(-4.f, neighbours[i].landHeight) + neighbours[i].waterHeight);
		float floodDifference = (max(-4.f, data.landHeight) + data.floodHeight) - (max(-4.f, neighbours[i].landHeight) + neighbours[i].floodHeight);
		float flow = (TIMESTEP * heightDifference);
		float floodFlow = floodDifference * 0.1f;
		heightDifferences[i] = heightDifference;

		if (flow > largestFlow) {
			largestFlow = flow;
			data.primaryFlowDirection = i;
		}

		data.flows[i] = max((data.flows[i] * FLOW_PERCENT) + flow, 0.f);
		floodFlows[i] = max((floodFlows[i] * FLOW_PERCENT) + max(floodFlow, 0.f), 0.f);
		totalFlow += data.flows[i];
		totalFloodFlow += floodFlows[i];
	}

	sort(heightDifferences, directions);
	float heightDifferenceBetweenFirstAndSecond = heightDifferences[0] - heightDifferences[1];

	/* If the lowest neighbour can handle all the flow without the height reaching the level
	   of any of the neighbours, only flow that direction */
	if ((heightDifferenceBetweenFirstAndSecond * ONE_DIRECTION_BIAS) > data.waterHeight) {
		totalFlow = 0.f;

		/* Switch to the less favourable flow direction with probability proportional to the ratio
			of its height difference compared to the primary height difference */
		float primaryFlowPreference = rand(dataPos + vec2(directions[0] * 7, directions[0] * 13));
		float secondaryFlowPreference = rand(dataPos + vec2(directions[1] * 2, directions[1] * 49)) * (heightDifferences[1]) / (heightDifferences[0]);
		if (secondaryFlowPreference > primaryFlowPreference) {
			data.primaryFlowDirection = directions[1];
		}

		for (int i=0; i<4; i++) {
			/* Reset all other flows back to what they were previously */
			if (i != data.primaryFlowDirection) {
				data.flows[i] = max(oldFlows[i] * FLOW_PERCENT, 0.f);
			}
			totalFlow += data.flows[i];
		}
	}

	/* If the desired out flows exceeds the water available, scale the flows */
	float flowThisStep = totalFlow * TIMESTEP;
	float flowScaling = 0.f, floodScaling = 1.f;
	if (totalFlow > 0.f) {
		float ratio = min(1.f, data.waterHeight / (flowThisStep * 1.4f));
		float diffRatio = clamp(heightDifferences[0] / (flowThisStep * 1.4f), 0.f, 1.f);
		flowScaling = min(ratio, diffRatio);
	}
	float floodAvailable = data.floodHeight;
	if (totalFloodFlow > 0.f) {
		floodScaling = min(1.f, floodAvailable / (totalFloodFlow * 1.f));
	}

	for (int i=0; i<4; i++) {
		data.flows[i] = data.flows[i] * flowScaling;
		floodFlows[i] = floodFlows[i] * floodScaling;
	}

	vec4 output1, output2, output3, output4;
	outputWaterData(data, output1, output2, output3, output4);
	imageStore(water0Write, dataPos, output1);
	imageStore(water1Write, dataPos, output2);
	imageStore(water2Write, dataPos, output3);
	imageStore(flood0Write, dataPos, floodFlows + vec4(0.f, 0.f, 0.f, 0.f));
}