// AMD FidelityFX Super Resolution - Robust Contrast Adaptive Sharpening (RCAS)
// Based on FidelityFX-FSR from GPUOpen-Effects

//!BGFX EFFECT
//!VERSION 1
//!NAME FSR RCAS (Sharpening)
//!CATEGORY Sharpening
//!DESCRIPTION AMD FidelityFX Super Resolution 1.0 sharpening pass. RCAS applies contrast-adaptive sharpening with built-in noise reduction for cleaner results.
//!CAPABILITY FP16

//!PARAMETER
//!LABEL Sharpness
//!DESC Controls sharpening intensity. Higher values produce sharper output but may introduce artifacts. Default 0.87 provides balanced sharpening.
//!DEFAULT 0.87
//!MIN 0
//!MAX 1
//!STEP 0.01
float sharpness;

//!TEXTURE
Texture2D INPUT;

//!TEXTURE
//!WIDTH INPUT_WIDTH
//!HEIGHT INPUT_HEIGHT
Texture2D OUTPUT;

//!SAMPLER
//!FILTER POINT
SamplerState sam;

//!PASS 1
//!IN INPUT
//!OUT OUTPUT
//!BLOCK_SIZE 16
//!NUM_THREADS 64

// Helper macros for three-way min/max operations
#define min3(a, b, c) min(a, min(b, c))
#define max3(a, b, c) max(a, max(b, c))

// Sharpening limit to prevent unnatural results
#define FSR_RCAS_LIMIT (0.25-(1.0/16.0))

#ifdef BG_FP16

// FP16 optimized RCAS processing two pixels simultaneously
void FsrRcasHx2(
	out MF2 pixR,
	out MF2 pixG,
	out MF2 pixB,
	float3 b0, float3 d0, float3 e0, float3 f0, float3 h0,
	float3 b1, float3 d1, float3 e1, float3 f1, float3 h1,
	MF s
) {
	// Convert from Array of Structures to Structure of Arrays layout
	MF2 bR = MF2(b0.r, b1.r);
	MF2 bG = MF2(b0.g, b1.g);
	MF2 bB = MF2(b0.b, b1.b);
	MF2 dR = MF2(d0.r, d1.r);
	MF2 dG = MF2(d0.g, d1.g);
	MF2 dB = MF2(d0.b, d1.b);
	MF2 eR = MF2(e0.r, e1.r);
	MF2 eG = MF2(e0.g, e1.g);
	MF2 eB = MF2(e0.b, e1.b);
	MF2 fR = MF2(f0.r, f1.r);
	MF2 fG = MF2(f0.g, f1.g);
	MF2 fB = MF2(f0.b, f1.b);
	MF2 hR = MF2(h0.r, h1.r);
	MF2 hG = MF2(h0.g, h1.g);
	MF2 hB = MF2(h0.b, h1.b);

	// Compute approximate luma (scaled by 2 for efficiency)
	MF2 bL = bB * 0.5 + (bR * 0.5 + bG);
	MF2 dL = dB * 0.5 + (dR * 0.5 + dG);
	MF2 eL = eB * 0.5 + (eR * 0.5 + eG);
	MF2 fL = fB * 0.5 + (fR * 0.5 + fG);
	MF2 hL = hB * 0.5 + (hR * 0.5 + hG);

	// Noise detection based on local contrast
	MF2 nz = 0.25 * bL + 0.25 * dL + 0.25 * fL + 0.25 * hL - eL;
	nz = saturate(abs(nz) * rcp(max3(max3(bL, dL, eL), fL, hL) - min3(min3(bL, dL, eL), fL, hL)));
	nz = -0.5 * nz + 1.0;

	// Compute min/max of the cross-shaped neighborhood
	MF2 mn4R = min(min3(bR, dR, fR), hR);
	MF2 mn4G = min(min3(bG, dG, fG), hG);
	MF2 mn4B = min(min3(bB, dB, fB), hB);
	MF2 mx4R = max(min3(bR, dR, fR), hR);
	MF2 mx4G = max(min3(bG, dG, fG), hG);
	MF2 mx4B = max(min3(bB, dB, fB), hB);

	// Constants for peak range limiting
	MF2 peakC = MF2(1.0, -1.0 * 4.0);

	// Calculate adaptive sharpening strength per channel
	MF2 hitMinR = min(mn4R, eR) * rcp(4.0 * mx4R);
	MF2 hitMinG = min(mn4G, eG) * rcp(4.0 * mx4G);
	MF2 hitMinB = min(mn4B, eB) * rcp(4.0 * mx4B);
	MF2 hitMaxR = (peakC.x - max(mx4R, eR)) * rcp(4.0 * mn4R + peakC.y);
	MF2 hitMaxG = (peakC.x - max(mx4G, eG)) * rcp(4.0 * mn4G + peakC.y);
	MF2 hitMaxB = (peakC.x - max(mx4B, eB)) * rcp(4.0 * mn4B + peakC.y);
	MF2 lobeR = max(-hitMinR, hitMaxR);
	MF2 lobeG = max(-hitMinG, hitMaxG);
	MF2 lobeB = max(-hitMinB, hitMaxB);
	MF2 lobe = max(-FSR_RCAS_LIMIT, min(max3(lobeR, lobeG, lobeB), 0.0)) * s;

	// Apply noise-adaptive attenuation
	lobe *= nz;

	// Final sharpened output (uses medium precision RCP to avoid tonality shifts)
	MF2 rcpL = rcp(4.0 * lobe + 1.0);
	pixR = (lobe * bR + lobe * dR + lobe * hR + lobe * fR + eR) * rcpL;
	pixG = (lobe * bG + lobe * dG + lobe * hG + lobe * fG + eG) * rcpL;
	pixB = (lobe * bB + lobe * dB + lobe * hB + lobe * fB + eB) * rcpL;
}

#else

// 32-bit floating point RCAS implementation
float3 FsrRcasF(float3 b, float3 d, float3 e, float3 f, float3 h) {
	// Uses minimal 3x3 cross-shaped kernel:
	//    b
	//  d e f
	//    h

	// Extract color channel components
	float bR = b.r;
	float bG = b.g;
	float bB = b.b;
	float dR = d.r;
	float dG = d.g;
	float dB = d.b;
	float eR = e.r;
	float eG = e.g;
	float eB = e.b;
	float fR = f.r;
	float fG = f.g;
	float fB = f.b;
	float hR = h.r;
	float hG = h.g;
	float hB = h.b;

	// Approximate luma (scaled by 2)
	float bL = bB * 0.5 + (bR * 0.5 + bG);
	float dL = dB * 0.5 + (dR * 0.5 + dG);
	float eL = eB * 0.5 + (eR * 0.5 + eG);
	float fL = fB * 0.5 + (fR * 0.5 + fG);
	float hL = hB * 0.5 + (hR * 0.5 + hG);

	// Noise detection: compare average of neighbors to center
	float nz = 0.25 * bL + 0.25 * dL + 0.25 * fL + 0.25 * hL - eL;
	nz = saturate(abs(nz) * rcp(max3(max3(bL, dL, eL), fL, hL) - min3(min3(bL, dL, eL), fL, hL)));
	nz = -0.5 * nz + 1.0;

	// Find neighborhood bounds
	float mn4R = min(min3(bR, dR, fR), hR);
	float mn4G = min(min3(bG, dG, fG), hG);
	float mn4B = min(min3(bB, dB, fB), hB);
	float mx4R = max(max3(bR, dR, fR), hR);
	float mx4G = max(max3(bG, dG, fG), hG);
	float mx4B = max(max3(bB, dB, fB), hB);

	// Peak range constants
	float2 peakC = { 1.0, -1.0 * 4.0 };

	// Compute per-channel sharpening limits (requires high precision RCP)
	float hitMinR = min(mn4R, eR) * rcp(4.0 * mx4R);
	float hitMinG = min(mn4G, eG) * rcp(4.0 * mx4G);
	float hitMinB = min(mn4B, eB) * rcp(4.0 * mx4B);
	float hitMaxR = (peakC.x - max(mx4R, eR)) * rcp(4.0 * mn4R + peakC.y);
	float hitMaxG = (peakC.x - max(mx4G, eG)) * rcp(4.0 * mn4G + peakC.y);
	float hitMaxB = (peakC.x - max(mx4B, eB)) * rcp(4.0 * mn4B + peakC.y);
	float lobeR = max(-hitMinR, hitMaxR);
	float lobeG = max(-hitMinG, hitMaxG);
	float lobeB = max(-hitMinB, hitMaxB);
	float lobe = max(-FSR_RCAS_LIMIT, min(max3(lobeR, lobeG, lobeB), 0)) * sharpness;

	// Reduce sharpening in noisy regions
	lobe *= nz;

	// Apply sharpening filter (medium precision RCP prevents visible tonality changes)
	float rcpL = rcp(4.0 * lobe + 1.0);
	float3 c = {
		(lobe * bR + lobe * dR + lobe * hR + lobe * fR + eR) * rcpL,
		(lobe * bG + lobe * dG + lobe * hG + lobe * fG + eG) * rcpL,
		(lobe * bB + lobe * dB + lobe * hB + lobe * fB + eB) * rcpL
	};

	return c;
}

#endif

// Compute shader entry point
void Pass1(uint2 blockStart, uint3 threadId) {
	uint2 gxy = blockStart + (TileSwizzle8x8(threadId.x) << 1);

	const uint2 outputSize = GetOutputSize();
	if (gxy.x >= outputSize.x || gxy.y >= outputSize.y) {
		return;
	}

	// Load 4x4 neighborhood for processing 2x2 output block
	float3 src[4][4];
	[unroll]
	for (uint i = 1; i < 3; ++i) {
		[unroll]
		for (uint j = 0; j < 4; ++j) {
			src[i][j] = INPUT.Load(int3(gxy.x + i - 1, gxy.y + j - 1, 0)).rgb;
		}
	}

	// Load edge samples
	src[0][1] = INPUT.Load(int3(gxy.x - 1, gxy.y, 0)).rgb;
	src[0][2] = INPUT.Load(int3(gxy.x - 1, gxy.y + 1, 0)).rgb;
	src[3][1] = INPUT.Load(int3(gxy.x + 2, gxy.y, 0)).rgb;
	src[3][2] = INPUT.Load(int3(gxy.x + 2, gxy.y + 1, 0)).rgb;

#ifdef BG_FP16
	MF2 pixR, pixG, pixB;
	const MF s = (MF)sharpness;

	// Process first two pixels horizontally
	FsrRcasHx2(pixR, pixG, pixB, src[1][0], src[0][1], src[1][1], src[2][1], src[1][2], src[2][0], src[1][1], src[2][1], src[3][1], src[2][2], s);
	OUTPUT[gxy] = MF4(pixR.x, pixG.x, pixB.x, 1);
	++gxy.x;
	OUTPUT[gxy] = MF4(pixR.y, pixG.y, pixB.y, 1);

	// Process second two pixels (next row)
	FsrRcasHx2(pixR, pixG, pixB, src[2][1], src[1][2], src[2][2], src[3][2], src[2][3], src[1][1], src[0][2], src[1][2], src[2][2], src[1][3], s);
	++gxy.y;
	OUTPUT[gxy] = MF4(pixR.x, pixG.x, pixB.x, 1);
	--gxy.x;
	OUTPUT[gxy] = MF4(pixR.y, pixG.y, pixB.y, 1);
#else
	// Process 2x2 block with bounds checking
	OUTPUT[gxy] = float4(FsrRcasF(src[1][0], src[0][1], src[1][1], src[2][1], src[1][2]), 1);

	++gxy.x;
	if (gxy.x < outputSize.x && gxy.y < outputSize.y) {
		OUTPUT[gxy] = float4(FsrRcasF(src[2][0], src[1][1], src[2][1], src[3][1], src[2][2]), 1);
	}

	++gxy.y;
	if (gxy.x < outputSize.x && gxy.y < outputSize.y) {
		OUTPUT[gxy] = float4(FsrRcasF(src[2][1], src[1][2], src[2][2], src[3][2], src[2][3]), 1);
	}

	--gxy.x;
	if (gxy.x < outputSize.x && gxy.y < outputSize.y) {
		OUTPUT[gxy] = float4(FsrRcasF(src[1][1], src[0][2], src[1][2], src[2][2], src[1][3]), 1);
	}
#endif
}
