summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2023-02-10 18:46:57 -0800
committerGitHub <noreply@github.com>2023-02-10 18:46:57 -0800
commitaec57d849ae20a305d08348cf543d19eabc2e2d6 (patch)
treeafac620a888d27ee1000b036c4ab8c3773180af3
parent6e7b424953ae6732d4863e887e7e452396095d71 (diff)
Fix several autodiff bugs. (#2643)
-rw-r--r--source/slang/slang-ir-autodiff-fwd.cpp10
-rw-r--r--source/slang/slang-ir-autodiff-unzip.cpp4
-rw-r--r--tests/autodiff/bsdf/bsdf-auto-rev.slang98
-rw-r--r--tests/autodiff/bsdf/bsdf-sample.slang46
-rw-r--r--tests/autodiff/bsdf/bsdf-sample.slang.expected.txt11
-rw-r--r--tests/autodiff/reverse-struct-out.slang50
-rw-r--r--tests/autodiff/reverse-struct-out.slang.expected.txt6
7 files changed, 220 insertions, 5 deletions
diff --git a/source/slang/slang-ir-autodiff-fwd.cpp b/source/slang/slang-ir-autodiff-fwd.cpp
index fca34f9a2..7782bd39c 100644
--- a/source/slang/slang-ir-autodiff-fwd.cpp
+++ b/source/slang/slang-ir-autodiff-fwd.cpp
@@ -148,7 +148,7 @@ InstPair ForwardDiffTranscriber::transcribeBinaryArith(IRBuilder* builder, IRIns
builder->markInstAsDifferential(diffRightTimesLeft, resultType);
builder->markInstAsDifferential(diffSub, resultType);
- auto diffMul = builder->emitMul(resultType, primalRight, primalRight);
+ auto diffMul = builder->emitMul(primalRight->getFullType(), primalRight, primalRight);
builder->markInstAsPrimal(diffMul);
auto diffDiv = builder->emitDiv(diffType, diffSub, diffMul);
@@ -877,6 +877,14 @@ InstPair ForwardDiffTranscriber::transcribeUpdateElement(IRBuilder* builder, IRI
diffBase, diffAccessChain, diffVal);
builder->addPrimalElementTypeDecoration(diffUpdateElement, primalElementType);
}
+ else
+ {
+ auto primalElementType = primalVal->getDataType();
+ auto zeroElementDiff = getDifferentialZeroOfType(builder, primalElementType);
+ diffUpdateElement = builder->emitUpdateElement(
+ diffBase, diffAccessChain, zeroElementDiff);
+ builder->addPrimalElementTypeDecoration(diffUpdateElement, primalElementType);
+ }
}
}
return InstPair(primalUpdateField, diffUpdateElement);
diff --git a/source/slang/slang-ir-autodiff-unzip.cpp b/source/slang/slang-ir-autodiff-unzip.cpp
index 25f6c3964..6aaa40baf 100644
--- a/source/slang/slang-ir-autodiff-unzip.cpp
+++ b/source/slang/slang-ir-autodiff-unzip.cpp
@@ -532,10 +532,6 @@ IRFunc* DiffUnzipPass::extractPrimalFunc(
stripTempDecorations(func);
- // Run simplification to DCE unnecessary insts.
- eliminateDeadCode(func);
- eliminateDeadCode(primalFunc);
-
return primalFunc;
}
} // namespace Slang
diff --git a/tests/autodiff/bsdf/bsdf-auto-rev.slang b/tests/autodiff/bsdf/bsdf-auto-rev.slang
new file mode 100644
index 000000000..c2ba434f8
--- /dev/null
+++ b/tests/autodiff/bsdf/bsdf-auto-rev.slang
@@ -0,0 +1,98 @@
+//TEST_IGNORE_FILE:
+
+struct ShadingData
+{
+ float3 V;
+ float3 N;
+ float3 T;
+ float3 B;
+
+ float3 fromLocal(float3 v)
+ {
+ return T * v.x + B * v.y + N * v.z;
+ }
+
+ float3 toLocal(float3 v)
+ {
+ return float3(dot(v, T), dot(v, B), dot(v, N));
+ }
+};
+
+struct Auto_Bwd_ScatterSample : IDifferentiable
+{
+ float3 wo;
+ float pdf;
+ float3 weight;
+};
+
+struct Auto_Bwd_BSDFParameters : IDifferentiable
+{
+ float3 albedo;
+ float roughness;
+};
+
+[BackwardDifferentiable]
+void bsdfGGXSample(in ShadingData sd, in Auto_Bwd_BSDFParameters params, out Auto_Bwd_ScatterSample result)
+{
+ float3 wiLocal = no_diff(sd.toLocal(sd.V));
+ float2 u = float2(0.8, 0.3);
+
+ // Taken from Rendering.Materials.Microfacet. Follows the Walter et al. EGSR07 BTDF paper
+ float alphaSqr = params.roughness * params.roughness;
+ float phi = u.y * (2 * 3.1415926);
+ float tanThetaSqr = alphaSqr * u.x / (1 - u.x);
+ float cosTheta = 1 / sqrt(1 + tanThetaSqr);
+ float r = sqrt(max(1 - cosTheta * cosTheta, 0));
+
+ float3 hLocal = float3(cos(phi) * r, sin(phi) * r, cosTheta); // half-vector local space
+ float wiDotH = dot(wiLocal, hLocal);
+ float3 woLocal = 2 * hLocal * wiDotH - wiLocal; // outgoing vector local space
+
+ float pdf = bsdfGGXPDF(hLocal, params) / (4.f * wiDotH);
+ result.wo = no_diff(sd.fromLocal(woLocal)); // wo to world.
+ result.pdf = detach(pdf);
+ result.weight = evalGGXDivByPDF(wiLocal, woLocal, hLocal, params) * pdf / detach(pdf);
+}
+
+[BackwardDifferentiable]
+float3 F(float3 f0, float3 f90, float cosTheta)
+{
+ return f0 + (f90 - f0) * pow(max(1 - cosTheta, 0.f), 5.f);
+}
+
+[BackwardDifferentiable]
+float evalLambdaGGX(float alphaSqr, float cosTheta)
+{
+ float cosThetaSqr = cosTheta * cosTheta;
+ float tanThetaSqr = max(1 - cosThetaSqr, 0) / cosThetaSqr;
+ return 0.5 * (-1 + sqrt(1 + alphaSqr * tanThetaSqr));
+}
+
+[BackwardDifferentiable]
+float G(float alpha, float cosThetaI, float cosThetaO)
+{
+ float alphaSqr = alpha * alpha;
+ float lambdaI = evalLambdaGGX(alphaSqr, cosThetaI);
+ float lambdaO = evalLambdaGGX(alphaSqr, cosThetaO);
+ return 1.0 / (1 + lambdaI + lambdaO);
+}
+
+[BackwardDifferentiable]
+float3 evalGGXDivByPDF(in float3 wi, in float3 wo, in float3 h, in Auto_Bwd_BSDFParameters params)
+{
+ const float3 F0Color = params.albedo;
+ let F90Color = float3(1.0, 1.0, 1.0);
+ return F(F0Color, F90Color, dot(wi, h)) * G(params.roughness, wi.z, wo.z) * dot(wi, h) / (wi.z * h.z);
+}
+
+[BackwardDifferentiable]
+float bsdfGGXPDF(in float3 hLocal, in Auto_Bwd_BSDFParameters params)
+{
+ float cosTheta = hLocal.z;
+
+ float alpha = params.roughness;
+ float a2 = alpha * alpha;
+ float d = ((cosTheta * a2 - cosTheta) * cosTheta + 1);
+
+ return (a2 / (d * d * 3.1415926)) * cosTheta;
+}
diff --git a/tests/autodiff/bsdf/bsdf-sample.slang b/tests/autodiff/bsdf/bsdf-sample.slang
new file mode 100644
index 000000000..8a9508791
--- /dev/null
+++ b/tests/autodiff/bsdf/bsdf-sample.slang
@@ -0,0 +1,46 @@
+//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type
+//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type
+
+//TEST_INPUT:ubuffer(data=[0 0 0 0 0 0 0 0 0 0], stride=4):out,name=outputBuffer
+RWStructuredBuffer<float> outputBuffer;
+
+__exported import bsdf_auto_rev;
+
+[numthreads(1, 1, 1)]
+void computeMain(uint3 dispatchThreadID: SV_DispatchThreadID)
+{
+ ShadingData sd;
+ sd.N = float3(0, 0, 1);
+ sd.T = float3(1, 0, 0);
+ sd.B = float3(0, 1, 0);
+ sd.V = normalize(float3(0.3, 0.5, 0.8));
+
+ {
+ DifferentialPair<Auto_Bwd_BSDFParameters> dp_params = DifferentialPair<Auto_Bwd_BSDFParameters>(
+ { float3(0.9, 0.6, 0.4), 0.1 },
+ { float3(0, 0, 0), 0 });
+
+ Auto_Bwd_ScatterSample.Differential dOut = { float3(0, 0, 0), 0, float3(1, 0, 0) };
+ __bwd_diff(bsdfGGXSample)(sd, dp_params, dOut);
+
+ outputBuffer[0] = dp_params.d.albedo[0];
+ outputBuffer[1] = dp_params.d.albedo[1];
+ outputBuffer[2] = dp_params.d.albedo[2];
+ outputBuffer[3] = dp_params.d.roughness;
+ }
+
+ {
+ DifferentialPair<Auto_Bwd_BSDFParameters> dp_params = DifferentialPair<Auto_Bwd_BSDFParameters>(
+ { float3(0.9, 0.6, 0.4), 0.1 },
+ { float3(0, 0, 0), 1.0 });
+ DifferentialPair<Auto_Bwd_ScatterSample> dp_result;
+ __fwd_diff(bsdfGGXSample)(sd, dp_params, dp_result);
+
+ outputBuffer[4] = dp_result.p.weight[0];
+ outputBuffer[5] = dp_result.p.weight[1];
+ outputBuffer[6] = dp_result.p.weight[2];
+ outputBuffer[7] = dp_result.d.weight[0];
+ outputBuffer[8] = dp_result.d.weight[1];
+ outputBuffer[9] = dp_result.d.weight[2];
+ }
+}
diff --git a/tests/autodiff/bsdf/bsdf-sample.slang.expected.txt b/tests/autodiff/bsdf/bsdf-sample.slang.expected.txt
new file mode 100644
index 000000000..e2e25558f
--- /dev/null
+++ b/tests/autodiff/bsdf/bsdf-sample.slang.expected.txt
@@ -0,0 +1,11 @@
+type: float
+1.093531
+0.000000
+0.000000
+-18.207390
+0.984221
+0.656162
+0.437456
+-18.207394
+-12.138763
+-8.093011 \ No newline at end of file
diff --git a/tests/autodiff/reverse-struct-out.slang b/tests/autodiff/reverse-struct-out.slang
new file mode 100644
index 000000000..af2f8becf
--- /dev/null
+++ b/tests/autodiff/reverse-struct-out.slang
@@ -0,0 +1,50 @@
+
+//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type
+//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type
+
+//TEST_INPUT:ubuffer(data=[0 0 0 0 0], stride=4):out,name=outputBuffer
+RWStructuredBuffer<float> outputBuffer;
+
+struct A : IDifferentiable
+{
+ float x;
+ float y;
+};
+
+struct B : IDifferentiable
+{
+ float x;
+ float y;
+};
+
+typedef DifferentialPair<A> dpA;
+
+float id(float x)
+{
+ return x;
+}
+
+[BackwardDifferentiable]
+void f(A input, out B rs)
+{
+ rs.x = input.x * input.x;
+ // Derivative of rs.x should still propagate through this no_diff call.
+ rs.y = no_diff id(input.y);
+}
+
+[numthreads(1, 1, 1)]
+void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
+{
+ {
+ A a = {3.0, 2.0};
+ A.Differential azero = {0.0, 0.0};
+
+ dpA dpa = dpA(a, azero);
+
+ B.Differential dout = {1.0, 1.0};
+
+ __bwd_diff(f)(dpa, dout);
+ outputBuffer[0] = dpa.d.x; // Expect: 6
+ outputBuffer[1] = dpa.d.y; // Expect: 0
+ }
+}
diff --git a/tests/autodiff/reverse-struct-out.slang.expected.txt b/tests/autodiff/reverse-struct-out.slang.expected.txt
new file mode 100644
index 000000000..f5ad0d81f
--- /dev/null
+++ b/tests/autodiff/reverse-struct-out.slang.expected.txt
@@ -0,0 +1,6 @@
+type: float
+6.000000
+0.000000
+0.000000
+0.000000
+0.000000