#version 430 core

uniform int water_resolution;
uniform float ACCUMULATION_RATE;
uniform float ABLATION_RATE;
uniform float FLOW_RATE;
uniform float EROSION;
uniform float SEDIMENT_LIMIT;
uniform float DISPLAY_SLOPE_RATIO;
uniform bool WRITE_DEBUG;
uniform ivec2 DISPLAY_CHUNK_DIMENSIONS;
uniform float RAINFALL_NORMALISED_TO_SIM_CONVERSION;
uniform bool dayProgressed;
uniform float LAVA_MELT_RATE;
uniform float TIMESTEP;
uniform float TEMPERATURE_SLOPE_MULTIPLIER;

layout(local_size_x = 8, local_size_y = 8, local_size_z = 1) in;

layout (binding = 0) uniform sampler2DRect ice0Read;
layout (binding = 1) uniform sampler2DRect ice1Read;
layout (binding = 2) uniform sampler2DRect ice2Read;
layout (binding = 3) uniform usampler2DRect sentinals0Read;
layout (binding = 4) uniform sampler2DRect water0Read;
layout (binding = 5) uniform sampler2DRect temperature0Read;
layout (binding = 6) uniform sampler2DRect rainfall2Read;
layout (binding = 7) uniform sampler2DRect rockProperties0Read;
layout (binding = 8) uniform sampler2DRect lava0Read;

layout (binding = 0, rgba32f) uniform restrict image2DRect ice0Write;
layout (binding = 1, rg32f) uniform restrict image2DRect ice1Write;
layout (binding = 2, rg8) uniform restrict image2DRect ice2Write;
layout (binding = 3, rg8) uniform restrict image2DRect iceActivation0Write;
layout (binding = 4, r32ui) uniform restrict uimage2DRect sentinals0Write;
layout (binding = 5, rgba32f) uniform restrict image2DRect debugOutput0Write;
layout (binding = 6, rgba32f) uniform restrict image2DRect debugOutput1Write;

const float eps = 1e-6;
const float Gamma_d = 7.26e-5;
const float Gamma_s = 3.27;

#include "common/temperature.glsl"

float phi(float r) {
    const float b = 2;
    return max(0, max(min(b*r, 1), min(r, b)));
}

float diffusivityWithSliding(float grad_s, float h_p, float h_m, float s_p, float s_m) {
    float D_p   = h_p*h_p*h_p * (Gamma_d*h_p*h_p + Gamma_s) * grad_s * FLOW_RATE;
    float D_m   = h_m*h_m*h_m * (Gamma_d*h_m*h_m + Gamma_s) * grad_s * FLOW_RATE;
    float D_min = min(D_p, D_m);
    float D_max = max(D_p, D_m);
    if (s_p <= s_m && h_m <= h_p) return D_min;
    if (s_p <= s_m && h_m >  h_p) return D_max;
    if (s_p >  s_m && h_m <= h_p) return D_max;
    if (s_p >  s_m && h_m >  h_p) return D_min;
    return 0;
}

void main() {
    ivec2 dataPos = ivec2(gl_GlobalInvocationID.xy);
    
    float dx2 = water_resolution*water_resolution;
    float dy2 = water_resolution*water_resolution;
    vec4 previousIceValues = texelFetch(ice0Read, dataPos);
    vec4 ice2Data = texelFetch(ice2Read, dataPos);
    vec4 lava0Data = texelFetch(lava0Read, dataPos);

    float sediment = previousIceValues[1];
    vec4 neighbours[8];
    neighbours[0] = texelFetch(ice0Read, dataPos + ivec2(1, 0));
    neighbours[1] = texelFetch(ice0Read, dataPos + ivec2(-1, 0));
    neighbours[2] = texelFetch(ice0Read, dataPos + ivec2(0, 1));
    neighbours[3] = texelFetch(ice0Read, dataPos + ivec2(0, -1));
    neighbours[4] = texelFetch(ice0Read, dataPos + ivec2(-1, -1));
    neighbours[5] = texelFetch(ice0Read, dataPos + ivec2(1, -1));
    neighbours[6] = texelFetch(ice0Read, dataPos + ivec2(1, 1));
    neighbours[7] = texelFetch(ice0Read, dataPos + ivec2(-1, 1));

    float neighboursLandHeight[8];
    neighboursLandHeight[0] = texelFetch(water0Read, dataPos + ivec2(1, 0))[0];
    neighboursLandHeight[1] = texelFetch(water0Read, dataPos + ivec2(-1, 0))[0];
    neighboursLandHeight[2] = texelFetch(water0Read, dataPos + ivec2(0, 1))[0];
    neighboursLandHeight[3] = texelFetch(water0Read, dataPos + ivec2(0, -1))[0];
    neighboursLandHeight[4] = texelFetch(water0Read, dataPos + ivec2(-1, -1))[0];
    neighboursLandHeight[5] = texelFetch(water0Read, dataPos + ivec2(1, -1))[0];
    neighboursLandHeight[6] = texelFetch(water0Read, dataPos + ivec2(1, 1))[0];
    neighboursLandHeight[7] = texelFetch(water0Read, dataPos + ivec2(-1, 1))[0];

    float h     = previousIceValues[0];
    float h_ip  = neighbours[0][0];
    float h_ipp = texelFetch(ice0Read, dataPos + ivec2(2, 0))[0];
    float h_im  = neighbours[1][0];
    float h_imm = texelFetch(ice0Read, dataPos + ivec2(-2, 0))[0];
    float h_jp  = neighbours[2][0];
    float h_jpp = texelFetch(ice0Read, dataPos + ivec2(0, 2))[0];
    float h_jm  = neighbours[3][0];
    float h_jmm = texelFetch(ice0Read, dataPos + ivec2(0, -2))[0];

    vec4 land = texelFetch(water0Read, dataPos);
    // access bedrock positions and compute surface
    float z    = land[0];
    float s    = h    + z;

    float _s = z;
    float s_ip = h_ip + neighboursLandHeight[0];
    float s_im = h_im + neighboursLandHeight[1];
    float s_jp = h_jp + neighboursLandHeight[2]; 
    float s_jm = h_jm + neighboursLandHeight[3];
    
    // compute downstream to upstream ice thickness ratio
    float r_i  = (h    - h_im) /(h_ip  - h    + eps);
    float r_ip = (h_ip - h)    /(h_ipp - h_ip + eps);
    float r_im = (h_im - h_imm)/(h     - h_im + eps);
    
    float r_j  = (h    - h_jm) /(h_jp  - h    + eps);
    float r_jp = (h_jp - h)    /(h_jpp - h_jp + eps);
    float r_jm = (h_jm - h_jmm)/(h     - h_jm + eps);
    
    // ice thickness at cell boundary (staggered grid)
    // up = +1/2, dn = -1/2
    float h_iup_m = h    + 0.5 * phi(r_i)  * (h_ip  - h);
    float h_iup_p = h_ip - 0.5 * phi(r_ip) * (h_ipp - h_ip);
    float h_idn_m = h_im + 0.5 * phi(r_im) * (h     - h_im);
    float h_idn_p = h    - 0.5 * phi(r_i)  * (h_ip  - h);
    
    float h_jup_m = h    + 0.5 * phi(r_j)  * (h_jp  - h);
    float h_jup_p = h_jp - 0.5 * phi(r_jp) * (h_jpp - h_jp);
    float h_jdn_m = h_jm + 0.5 * phi(r_jm) * (h     - h_jm);
    float h_jdn_p = h    - 0.5 * phi(r_j)  * (h_jp  - h);
    
    float s_diff_ip = abs(s_ip - s);
    float s_diff_im = abs(s_im - s);
    float s_diff_jp = abs(s_jp - s);
    float s_diff_jm = abs(s_jm - s);

    float maxIceSlope = max(max(s_diff_ip, s_diff_im), max(s_diff_jp, s_diff_jm));
    float maxLandSlope = max(
        max(abs(land[0] - neighboursLandHeight[0]), abs(land[0] - neighboursLandHeight[1])), 
        max(abs(land[0] - neighboursLandHeight[2]), abs(land[0] - neighboursLandHeight[3]))
    );
    float maxNeighbourIce = max(max(h_ip, h_im), max(h_jp, h_jm));
    float lowestNeighbourDisplayIce = 999999.f;
    for (int i=0; i<8; i++) {
        float _height = neighbours[i][0] + neighboursLandHeight[i]; 
        if (neighbours[i][0] > 0.f && _height < lowestNeighbourDisplayIce) {
            lowestNeighbourDisplayIce = _height;
        }
    }

    // slope gradients
    float grad_s_iup = (s_diff_ip * s_diff_ip)/dy2;
    float grad_s_idn = (s_diff_im * s_diff_im)/dy2;
    float grad_s_jup = (s_diff_jp * s_diff_jp)/dx2;
    float grad_s_jdn = (s_diff_jm * s_diff_jm)/dx2;
    
    float xAcross = (neighboursLandHeight[0] - neighboursLandHeight[1]) / 2;
    float yAcross = (neighboursLandHeight[2] - neighboursLandHeight[3]) / 2;
    float xOut = (_s - (neighboursLandHeight[1] + xAcross)) / water_resolution;
    float yOut = (_s - (neighboursLandHeight[3] + yAcross)) / water_resolution;

    // diffusivities at the 4 cell boundaries
    float D_iup = diffusivityWithSliding(grad_s_iup, h_iup_p, h_iup_m, s_ip, s);
    float D_idn = diffusivityWithSliding(grad_s_idn, h_idn_p, h_idn_m, s, s_im);
    float D_jup = diffusivityWithSliding(grad_s_jup, h_jup_p, h_jup_m, s_jp, s);
    float D_jdn = diffusivityWithSliding(grad_s_jdn, h_jdn_p, h_jdn_m, s, s_jm);    

    // flux q divergence
    vec4 directionalFlows;
    directionalFlows[0] = D_iup*(s_ip - s); 
    directionalFlows[1] = D_idn*(s_im - s);
    directionalFlows[2] = D_jup*(s_jp - s);
    directionalFlows[3] = D_jdn*(s_jm - s);

    float div_q_i = (directionalFlows[0] + directionalFlows[1])/dy2;
    float div_q_j = (directionalFlows[2] + directionalFlows[3])/dx2;
    float previousMaxChange = float(texelFetch(sentinals0Read, ivec2(0, 0))[0]);
    const float MAX_TIMESTEP = 0.01f;
    float timestep = clamp(1.f / previousMaxChange, 0.f, MAX_TIMESTEP) * TIMESTEP;
    float totalIceFlow = timestep * (div_q_i + div_q_j);
    float maxAllowedChange = max(h / 30.f, 1.f);
    float maxChange = abs(div_q_i + div_q_j) / maxAllowedChange;
    if (maxChange > (previousMaxChange / 2.f)) {
        imageAtomicMax(sentinals0Write, ivec2(0, 0), uint(maxChange));
    }
    totalIceFlow = clamp(totalIceFlow, -2.f, 2.f);

    float previousSediment = sediment;
    vec4 sedimentFlows;
    if (h > 0.01f) {
        for (int i=0; i<4; i++) {
            if (directionalFlows[i] < 0.f) {
                sedimentFlows[i] = clamp(timestep * directionalFlows[i] / (h * dy2), -timestep, timestep);
                sediment += (sedimentFlows[i] * previousSediment);
            } else {
                sedimentFlows[i] = clamp(timestep * directionalFlows[i] / (neighbours[i][0] * dx2), -timestep, timestep);
                sediment += (sedimentFlows[i] * neighbours[i][1]);
            }
        }  
    } else {
        sediment = 0.f;
    }
    
    float xVelocity = (abs(directionalFlows[0]) + abs(directionalFlows[1])) / (2 * (h + eps)) * 0.001f;
    float yVelocity = (abs(directionalFlows[2]) + abs(directionalFlows[3])) / (2 * (h + eps)) * 0.001f; 
    xVelocity = clamp(xVelocity, 0.f, 1.f);
    yVelocity = clamp(yVelocity, 0.f, 1.f);

    float xAcrossFactor = (1.f + abs(yAcross / water_resolution) * 1.f), xStickingOutFactor = (1.f + clamp(yOut * 2.f, -1.f, 1.f));
    float yAcrossFactor = (1.f + abs(xAcross / water_resolution) * 1.f), yStickingOutFactor = (1.f + clamp(xOut * 2.f, -1.f, 1.f));

    float xErosion = (xVelocity * xVelocity * EROSION * min(1.f, h)) * xAcrossFactor * xStickingOutFactor * TIMESTEP;
    float yErosion = (yVelocity * yVelocity * EROSION * min(1.f, h)) * yAcrossFactor * yStickingOutFactor * TIMESTEP;
    float erosion = clamp(xErosion + yErosion, -0.01f, 0.01f) * clamp((SEDIMENT_LIMIT - sediment) / SEDIMENT_LIMIT, 0.f, 1.f);
    sediment = clamp(sediment, 0, SEDIMENT_LIMIT * 2.f);
    if (land[0] < 1.f) {
        erosion = min(erosion, 0.f);
    }

    Temperature temperature  = normalisedToTemperature(texelFetch(temperature0Read, dataPos).xy);
    vec2 rainfall = texelFetch(rainfall2Read, dataPos).xy;
    float snowfall = 0.f;
    float snowfallAccumulationPercentage = clamp(1.f - (temperature.percentageYearAboveZeroCelsius * 2.f), 0.f, 1.f);
    if (temperature.meanCelsius < -5.f) {
        snowfall = snowfallAccumulationPercentage * ACCUMULATION_RATE * rainfall[0];
    } else {
        snowfall = -(temperature.meanCelsius + 5.f) * ABLATION_RATE;
    }
    snowfall *= (clamp(timestep, MAX_TIMESTEP / 10.f, MAX_TIMESTEP) * TIMESTEP);
    float meltRatio = 1.f;
    if (lava0Data[0] > 0.f) {
        snowfall = -LAVA_MELT_RATE;
    }

    float sedimentDropped = (sediment * (clamp(-snowfall / h, 0.f, 1.f))) * TIMESTEP;
    erosion -= sedimentDropped;

    meltRatio = meltRatio * TIMESTEP;
    float change = (snowfall + totalIceFlow) * TIMESTEP;
    float melt = clamp(-snowfall, 0.f, h) * meltRatio;
    
    melt = melt * clamp((temperature.currentCelsius + 10.f) / 20.f, 0.f, 1.f);

    if (TIMESTEP > 0.f) {
        /* We're using a spare attribute here that is only 8 bits, rather than 32 bit float
           So instead of storing how much snow should later melt, we just store the count of
           days that it snowed, and estimate from that how much snow should melt */
        if (temperature.currentCelsius < 0.f) {
            if (dayProgressed) {
                ice2Data[1] += 1 / 255.f; 
            }
        } else {
            if (ice2Data[1] > 0.f) {
                if (dayProgressed) {
                    ice2Data[1] = max(ice2Data[1] - 4 / 255.f, 0.f);
                }
                melt += rainfall[1] * RAINFALL_NORMALISED_TO_SIM_CONVERSION * 4.f;
            }
        }
    }

    float newIceHeight = max(h + change, 0);
    imageStore(ice0Write, dataPos, vec4(newIceHeight, max(sediment + erosion, 0.f), melt / ACCUMULATION_RATE, erosion));
    vec4 existingDisplayHeight = texelFetch(ice1Read, dataPos);

    vec4 rockProperties0Data = texelFetch(rockProperties0Read, dataPos);
    bool snowable = (maxLandSlope < (9.f + (max(0.f, -temperature.currentCelsius) / TEMPERATURE_SLOPE_MULTIPLIER) ) || existingDisplayHeight[0] > 0.f || rockProperties0Data[1] >= 1.f) 
        && rainfall[1] > 0.02f && lava0Data[0] == 0.f;
    // bool inSnow = snowable && temperature.currentCelsius < 0.f && (temperature.meanCelsius > 0.f || snowfallAccumulationPercentage < 0.9f);
    bool inSnow = snowable && temperature.currentCelsius < 0.f;
    ice2Data[0] = inSnow ? 1.f : 0.f;

    newIceHeight = (newIceHeight + h) / 2.f;
    if (newIceHeight < (maxIceSlope * DISPLAY_SLOPE_RATIO) || newIceHeight < 1.f) {
        newIceHeight = min(lowestNeighbourDisplayIce - land[0] - 4.f, -4.f);
    } else {
        newIceHeight = newIceHeight - (maxIceSlope * DISPLAY_SLOPE_RATIO);
    }

    if ((existingDisplayHeight[0] < 0.f) != (newIceHeight < 0.f)) {
        existingDisplayHeight[1] -= 1.f;
        if (existingDisplayHeight[1] <= 0.f) {
            existingDisplayHeight[0] = min(newIceHeight, 0.1f);
            existingDisplayHeight[1] = 256.f;
        }
    } else {
        existingDisplayHeight[0] = mix(existingDisplayHeight[0], newIceHeight, 0.1f);
        existingDisplayHeight[1] = 256.f;
    }

    imageStore(ice1Write, dataPos, existingDisplayHeight);
    imageStore(ice2Write, dataPos, ice2Data);

    if (existingDisplayHeight[0] > 0.f) {
        imageStore(iceActivation0Write, dataPos / 128, vec4(1.f, 0.f, 0.f, 0.f));
    }

    if (WRITE_DEBUG) {
        imageStore(debugOutput0Write, dataPos, vec4(temperature.percentageYearAboveZeroCelsius, yErosion, xVelocity, yVelocity));
        imageStore(debugOutput1Write, dataPos, directionalFlows);
    }
}