//DIAGNOSTIC_TEST:SIMPLE(filecheck=CHECK): -target spirv -entry computeMain -stage compute //TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBufferPrimal RWStructuredBuffer outputBufferPrimal; //TEST_INPUT:ubuffer(data=[1 2 3 4], stride=4):name=gradBuffer RWStructuredBuffer gradBuffer; struct BufferWithGrad { RWStructuredBuffer primal; RWStructuredBuffer grad; [Differentiable] void add(float value) { primal[0] = primal[0] + detach(value); } // check for diagnostic: // CHECK-DAG: ([[# @LINE+1]]): error 31158 [PrimalSubstituteOf(add)] void add_subst(float value) { } [BackwardDerivativeOf(add)] void add_bwd(inout DifferentialPair d) { d = diffPair(d.p, grad[0]); } } [Differentiable] void diffCall(BufferWithGrad result) { result.add(1.0f); } [shader("compute")] [numthreads(1, 1, 1)] void computeMain(uint3 dispatchThreadID: SV_DispatchThreadID) { BufferWithGrad bg = {outputBufferPrimal, gradBuffer}; diffCall(bg); bwd_diff(diffCall)(bg); }