summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorSai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com>2025-03-17 08:05:14 -0700
committerGitHub <noreply@github.com>2025-03-17 15:05:14 +0000
commit714ee76af46b96c32724f0d6edb159fddeffc6bf (patch)
tree3ac6fc10580acd4cf250f5439c8d88aa1457fb6e
parent98ff41989b04ce883e9dc9f4464c45290d30c560 (diff)
Fix crash when swizzling non-differentiable types (#6613)
* Fix crash when swizzling non-differentiable types * Update slang-ir-autodiff-fwd.cpp
-rw-r--r--source/slang/slang-ir-autodiff-fwd.cpp30
-rw-r--r--tests/autodiff/non-differentiable-swizzle.slang70
2 files changed, 89 insertions, 11 deletions
diff --git a/source/slang/slang-ir-autodiff-fwd.cpp b/source/slang/slang-ir-autodiff-fwd.cpp
index 0302d9ce7..e146ac3e0 100644
--- a/source/slang/slang-ir-autodiff-fwd.cpp
+++ b/source/slang/slang-ir-autodiff-fwd.cpp
@@ -955,20 +955,28 @@ InstPair ForwardDiffTranscriber::transcribeCall(IRBuilder* builder, IRCall* orig
InstPair ForwardDiffTranscriber::transcribeSwizzle(IRBuilder* builder, IRSwizzle* origSwizzle)
{
IRInst* primalSwizzle = maybeCloneForPrimalInst(builder, origSwizzle);
-
if (auto diffBase = lookupDiffInst(origSwizzle->getBase(), nullptr))
{
- List<IRInst*> swizzleIndices;
- for (UIndex ii = 0; ii < origSwizzle->getElementCount(); ii++)
- swizzleIndices.add(origSwizzle->getElementIndex(ii));
+ // `diffBase` may exist even if the type is non-differentiable (e.g. IRCall inst that
+ // creates other differentiable outputs).
+ //
+ // We'll check to see if we can get a differential for the type in order to determine
+ // whether to generate a differential swizzle inst.
+ //
+ if (auto diffType = differentiateType(builder, primalSwizzle->getDataType()))
+ {
+ List<IRInst*> swizzleIndices;
+ for (UIndex ii = 0; ii < origSwizzle->getElementCount(); ii++)
+ swizzleIndices.add(origSwizzle->getElementIndex(ii));
- return InstPair(
- primalSwizzle,
- builder->emitSwizzle(
- differentiateType(builder, primalSwizzle->getDataType()),
- diffBase,
- origSwizzle->getElementCount(),
- swizzleIndices.getBuffer()));
+ return InstPair(
+ primalSwizzle,
+ builder->emitSwizzle(
+ diffType,
+ diffBase,
+ origSwizzle->getElementCount(),
+ swizzleIndices.getBuffer()));
+ }
}
return InstPair(primalSwizzle, nullptr);
diff --git a/tests/autodiff/non-differentiable-swizzle.slang b/tests/autodiff/non-differentiable-swizzle.slang
new file mode 100644
index 000000000..879497423
--- /dev/null
+++ b/tests/autodiff/non-differentiable-swizzle.slang
@@ -0,0 +1,70 @@
+//TEST:SIMPLE(filecheck=CUDA): -target cuda -line-directive-mode none
+
+// This test just needs to compile without crashing.
+
+// CUDA: __device__ void s_bwd_myKernel_0
+
+[Differentiable]
+uint3 load_foo(int32_t g_idx, TensorView<int32_t> indices)
+{
+ return {
+ indices[uint2(g_idx, 0)],
+ indices[uint2(g_idx, 1)],
+ indices[uint2(g_idx, 2)]
+ };
+}
+
+[Differentiable]
+Triangle load_triangle(DiffTensorView vertices, uint3 index_set)
+{
+ float3[3] triangle = {
+ read_t3_float3(index_set.x, vertices),
+ read_t3_float3(index_set.y, vertices),
+ read_t3_float3(index_set.z, vertices)
+ };
+ return { triangle };
+}
+
+struct Triangle : IDifferentiable
+{
+ float3[3] verts;
+}
+
+[Differentiable]
+float3 read_t3_float3(uint32_t idx, DiffTensorView t3)
+{
+ return float3(t3[uint2(idx, 0)],
+ t3[uint2(idx, 1)],
+ t3[uint2(idx, 2)]);
+}
+
+[Differentiable] TriangleFoo load_triangle_foo(int32_t g_idx, DiffTensorView vertices, TensorView<int32_t> indices, DiffTensorView vertex_color)
+{
+ uint3 index_set = load_foo(g_idx, indices);
+ Triangle triangle = load_triangle(vertices, index_set);
+ float3 c0 = read_t3_float3(index_set.x, vertex_color);
+ float3 c1 = read_t3_float3(index_set.y, vertex_color);
+ float3 c2 = read_t3_float3(index_set.z, vertex_color);
+ TriangleFoo result = { triangle, 0.f, {c0, c1, c2} };
+ return result;
+}
+
+struct TriangleFoo : IDifferentiable
+{
+ Triangle triangle;
+ float density;
+ float3 vertex_color[3];
+};
+
+[AutoPyBindCUDA]
+[Differentiable]
+[CudaKernel]
+void myKernel(DiffTensorView vertices, TensorView<int32_t> indices, DiffTensorView vertex_color)
+{
+ if (cudaThreadIdx().x > 0)
+ return;
+
+ TriangleFoo.Differential dp_g = TriangleFoo.dzero();
+
+ bwd_diff(load_triangle_foo)(cudaThreadIdx().x, vertices, indices, vertex_color, dp_g);
+}