From 714ee76af46b96c32724f0d6edb159fddeffc6bf Mon Sep 17 00:00:00 2001 From: Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> Date: Mon, 17 Mar 2025 08:05:14 -0700 Subject: Fix crash when swizzling non-differentiable types (#6613) * Fix crash when swizzling non-differentiable types * Update slang-ir-autodiff-fwd.cpp --- source/slang/slang-ir-autodiff-fwd.cpp | 30 +++++++---- tests/autodiff/non-differentiable-swizzle.slang | 70 +++++++++++++++++++++++++ 2 files changed, 89 insertions(+), 11 deletions(-) create mode 100644 tests/autodiff/non-differentiable-swizzle.slang diff --git a/source/slang/slang-ir-autodiff-fwd.cpp b/source/slang/slang-ir-autodiff-fwd.cpp index 0302d9ce7..e146ac3e0 100644 --- a/source/slang/slang-ir-autodiff-fwd.cpp +++ b/source/slang/slang-ir-autodiff-fwd.cpp @@ -955,20 +955,28 @@ InstPair ForwardDiffTranscriber::transcribeCall(IRBuilder* builder, IRCall* orig InstPair ForwardDiffTranscriber::transcribeSwizzle(IRBuilder* builder, IRSwizzle* origSwizzle) { IRInst* primalSwizzle = maybeCloneForPrimalInst(builder, origSwizzle); - if (auto diffBase = lookupDiffInst(origSwizzle->getBase(), nullptr)) { - List swizzleIndices; - for (UIndex ii = 0; ii < origSwizzle->getElementCount(); ii++) - swizzleIndices.add(origSwizzle->getElementIndex(ii)); + // `diffBase` may exist even if the type is non-differentiable (e.g. IRCall inst that + // creates other differentiable outputs). + // + // We'll check to see if we can get a differential for the type in order to determine + // whether to generate a differential swizzle inst. + // + if (auto diffType = differentiateType(builder, primalSwizzle->getDataType())) + { + List swizzleIndices; + for (UIndex ii = 0; ii < origSwizzle->getElementCount(); ii++) + swizzleIndices.add(origSwizzle->getElementIndex(ii)); - return InstPair( - primalSwizzle, - builder->emitSwizzle( - differentiateType(builder, primalSwizzle->getDataType()), - diffBase, - origSwizzle->getElementCount(), - swizzleIndices.getBuffer())); + return InstPair( + primalSwizzle, + builder->emitSwizzle( + diffType, + diffBase, + origSwizzle->getElementCount(), + swizzleIndices.getBuffer())); + } } return InstPair(primalSwizzle, nullptr); diff --git a/tests/autodiff/non-differentiable-swizzle.slang b/tests/autodiff/non-differentiable-swizzle.slang new file mode 100644 index 000000000..879497423 --- /dev/null +++ b/tests/autodiff/non-differentiable-swizzle.slang @@ -0,0 +1,70 @@ +//TEST:SIMPLE(filecheck=CUDA): -target cuda -line-directive-mode none + +// This test just needs to compile without crashing. + +// CUDA: __device__ void s_bwd_myKernel_0 + +[Differentiable] +uint3 load_foo(int32_t g_idx, TensorView indices) +{ + return { + indices[uint2(g_idx, 0)], + indices[uint2(g_idx, 1)], + indices[uint2(g_idx, 2)] + }; +} + +[Differentiable] +Triangle load_triangle(DiffTensorView vertices, uint3 index_set) +{ + float3[3] triangle = { + read_t3_float3(index_set.x, vertices), + read_t3_float3(index_set.y, vertices), + read_t3_float3(index_set.z, vertices) + }; + return { triangle }; +} + +struct Triangle : IDifferentiable +{ + float3[3] verts; +} + +[Differentiable] +float3 read_t3_float3(uint32_t idx, DiffTensorView t3) +{ + return float3(t3[uint2(idx, 0)], + t3[uint2(idx, 1)], + t3[uint2(idx, 2)]); +} + +[Differentiable] TriangleFoo load_triangle_foo(int32_t g_idx, DiffTensorView vertices, TensorView indices, DiffTensorView vertex_color) +{ + uint3 index_set = load_foo(g_idx, indices); + Triangle triangle = load_triangle(vertices, index_set); + float3 c0 = read_t3_float3(index_set.x, vertex_color); + float3 c1 = read_t3_float3(index_set.y, vertex_color); + float3 c2 = read_t3_float3(index_set.z, vertex_color); + TriangleFoo result = { triangle, 0.f, {c0, c1, c2} }; + return result; +} + +struct TriangleFoo : IDifferentiable +{ + Triangle triangle; + float density; + float3 vertex_color[3]; +}; + +[AutoPyBindCUDA] +[Differentiable] +[CudaKernel] +void myKernel(DiffTensorView vertices, TensorView indices, DiffTensorView vertex_color) +{ + if (cudaThreadIdx().x > 0) + return; + + TriangleFoo.Differential dp_g = TriangleFoo.dzero(); + + bwd_diff(load_triangle_foo)(cudaThreadIdx().x, vertices, indices, vertex_color, dp_g); +} -- cgit v1.2.3