diff options
| author | Yong He <yonghe@outlook.com> | 2023-02-21 12:51:46 -0800 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2023-02-21 12:51:46 -0800 |
| commit | 0ef7aa85d3a6b2ff1d6b25576b4d9eff188c1a6a (patch) | |
| tree | fd8f2b6e528e01a90cc2f34b2fe8ebf6cc5f97a9 /tests | |
| parent | 6bca0ec355aae2955c7de1cd16c2dc0dfe46f19c (diff) | |
Fix transposeCall. (#2669)
* Modify control-flow test case
* Update reverse-control-flow-3.slang
* Fix `transposeCall`.
* Fix.
---------
Co-authored-by: Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com>
Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'tests')
| -rw-r--r-- | tests/autodiff/reverse-control-flow-3.slang | 220 | ||||
| -rw-r--r-- | tests/autodiff/reverse-control-flow-3.slang.expected.txt | 5 | ||||
| -rw-r--r-- | tests/autodiff/reverse-inout-param-3.slang | 29 | ||||
| -rw-r--r-- | tests/autodiff/reverse-inout-param-3.slang.expected.txt | 7 |
4 files changed, 261 insertions, 0 deletions
diff --git a/tests/autodiff/reverse-control-flow-3.slang b/tests/autodiff/reverse-control-flow-3.slang new file mode 100644 index 000000000..e94f55b03 --- /dev/null +++ b/tests/autodiff/reverse-control-flow-3.slang @@ -0,0 +1,220 @@ +//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type + +//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBuffer + +RWStructuredBuffer<float> outputBuffer; + +struct PathState +{ + uint depth; + bool terminated; + + bool isHit() { return !terminated; } + bool isTerminated() { return terminated; } +}; + +struct PathResult : IDifferentiable +{ + float thp; + float L; +} +struct VisibilityQuery +{ + bool test(); +} + +struct ClosestHitQuery +{ + bool test(); +} +void generatePath(uint pathID, out PathState path) +{ + path.terminated = false; + path.depth = 0; +} + +[BackwardDifferentiable] +float lightEval(uint depth) +{ + if (depth == 1) + { + return 5.0f; + } + else + { + return 0.0f; + } +} + +struct MaterialParam : IDifferentiable +{ + float roughness; +} + +[BackwardDifferentiable] +MaterialParam getParam(uint id) +{ + MaterialParam p; + p.roughness = 0.5f; + return p; +} + +[ForwardDerivativeOf(getParam)] +DifferentialPair<MaterialParam> d_getParam(uint id) +{ + MaterialParam p; + p.roughness = 0.5f; + MaterialParam.Differential d; + d.roughness = 0.0f; + return diffPair(p, d); +} + +[BackwardDerivativeOf(getParam)] +void d_getParam(uint id, MaterialParam.Differential diff) +{ + outputBuffer[id] += diff.roughness; +} + + +[BackwardDifferentiable] +void updatePathThroughput(inout PathResult path, const float weight) +{ + path.thp *= weight; +} + +struct BSDFSample : IDifferentiable +{ + float val; +} + +[BackwardDifferentiable] +bool bsdfGGXSample(const MaterialParam bsdfParams, out BSDFSample result) +{ + result.val = bsdfParams.roughness; + return true; +} + +[BackwardDifferentiable] +bool generateScatterRay(const MaterialParam bsdfParams, inout PathState path, inout PathResult pathRes) +{ + BSDFSample result; + bool valid = bsdfGGXSample(bsdfParams, result); + return generateScatterRay(result, bsdfParams, path, pathRes, valid); +} + +/** Generates a new scatter ray using BSDF importance sampling. + \param[in] sd Shading data. + \param[in] mi Material instance at the shading point. + \param[in,out] path The path state. + \return True if a ray was generated, false otherwise. +*/ +[BackwardDifferentiable] +bool generateScatterRay(const BSDFSample bs, const MaterialParam bsdfParams, inout PathState path, inout PathResult pathRes, bool valid) +{ + if (valid) valid = generateScatterRay(bs, bsdfParams, path, pathRes); + return valid; +} + +/** Generates a new scatter ray given a valid BSDF sample. + \param[in] bs BSDF sample (assumed to be valid). + \param[in] sd Shading data. + \param[in] mi Material instance at the shading point. + \param[in,out] path The path state. + \return True if a ray was generated, false otherwise. +*/ +[BackwardDifferentiable] +bool generateScatterRay(const BSDFSample bs, const MaterialParam bsdfParams, inout PathState path, inout PathResult pathRes) +{ + updatePathThroughput(pathRes, bs.val); + return true; +} + +[BackwardDifferentiable] +void handleHit(inout PathState path, inout PathResult rs, inout VisibilityQuery vq) +{ + var param = getParam(0); + + bool lastVertex = param.roughness > 0.8; + if (lastVertex) + { + path.terminated = true; + return; + } + + generateScatterRay(param, path, rs); + + rs.L = rs.thp * lightEval(path.depth); + + // Decide on next hit + if (path.depth < 1) + path.terminated = false; + else + path.terminated = true; +} + +[BackwardDifferentiable] +float bsdfEval(const MaterialParam mparam) +{ + return mparam.roughness; +} + +[BackwardDifferentiable] +void nextHit(inout PathState path, inout PathResult rs, inout ClosestHitQuery cq) +{ + path.depth = path.depth + 1; +} + +[BackwardDifferentiable] +void handleMiss(inout PathState path, inout PathResult rs) +{ + rs.L = 0.0f; + path.terminated = true; +} + +[BackwardDifferentiable] +bool tracePath(uint pathID, out PathState path, inout PathResult pathRes) +{ + generatePath(pathID, path); + + float thp = pathRes.thp; + float L = pathRes.L; + + [ForceUnroll] + for (int i = 0; i < 3; ++i) + { + if (path.isHit()) + { + VisibilityQuery vq; + handleHit(path, pathRes, vq); + + if (path.isTerminated()) break; + + ClosestHitQuery chq; + nextHit(path, pathRes, chq); + } + else + { + handleMiss(path, pathRes); + } + } + + return true; +} + +[numthreads(1, 1, 1)] +void computeMain(uint3 dispatchThreadID: SV_DispatchThreadID) +{ + { + PathResult pathRes; + pathRes.L = 1.f; + pathRes.thp = 1.f; + + PathResult.Differential pathResD; + pathResD.L = 1.0f; + pathResD.thp = 0.f; + + var dpx = diffPair(pathRes, pathResD); + __bwd_diff(tracePath)(1, dpx); // Expect: 5.0 in outputBuffer[3] + } + +} diff --git a/tests/autodiff/reverse-control-flow-3.slang.expected.txt b/tests/autodiff/reverse-control-flow-3.slang.expected.txt new file mode 100644 index 000000000..f77c12531 --- /dev/null +++ b/tests/autodiff/reverse-control-flow-3.slang.expected.txt @@ -0,0 +1,5 @@ +type: float +5.000000 +0.000000 +0.000000 +0.000000
\ No newline at end of file diff --git a/tests/autodiff/reverse-inout-param-3.slang b/tests/autodiff/reverse-inout-param-3.slang new file mode 100644 index 000000000..8f98b1b40 --- /dev/null +++ b/tests/autodiff/reverse-inout-param-3.slang @@ -0,0 +1,29 @@ +//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type +//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type +//TEST(compute):COMPARE_COMPUTE_EX:-cpu -compute -output-using-type -shaderobj + +//TEST_INPUT:ubuffer(data=[0 0 0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer<float> outputBuffer; + +[BackwardDifferentiable] +void assign(inout float rs, float v) +{ + rs = v; +} + +[BackwardDifferentiable] +void f(inout float p, float x) +{ + assign(p, x); +} + +[numthreads(1, 1, 1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + var x = diffPair(5.0, 0.0); + var pp = diffPair(1.0, 3.0); + __bwd_diff(f)(pp, x); + + outputBuffer[0] = pp.p; // should be 1, since bwd_diff does not write back new primal val. + outputBuffer[1] = x.d; // 3 +}
\ No newline at end of file diff --git a/tests/autodiff/reverse-inout-param-3.slang.expected.txt b/tests/autodiff/reverse-inout-param-3.slang.expected.txt new file mode 100644 index 000000000..1d8606e6c --- /dev/null +++ b/tests/autodiff/reverse-inout-param-3.slang.expected.txt @@ -0,0 +1,7 @@ +type: float +1.000000 +3.000000 +0.000000 +0.000000 +0.000000 + |
