From ec5e019fa9732b99b75b2a3ca4f2ff5a7a3d2f33 Mon Sep 17 00:00:00 2001 From: Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> Date: Mon, 18 Nov 2024 16:34:03 -0500 Subject: Add `IDifferentiablePtrType` support for arrays (#5576) * Add `IDifferentiablePtrType` support for arrays - Also fixes an issue with spirv-emit of constructors that contain references to global params * Fix GLSL legalization for arrays of resource types --- source/slang/diff.meta.slang | 6 +++ source/slang/slang-ir-autodiff.cpp | 7 ++- source/slang/slang-ir-specialize-resources.cpp | 3 ++ source/slang/slang-ir-spirv-legalize.cpp | 10 ++++- tests/autodiff/diff-ptr-type-array.slang | 59 ++++++++++++++++++++++++++ 5 files changed, 81 insertions(+), 4 deletions(-) create mode 100644 tests/autodiff/diff-ptr-type-array.slang diff --git a/source/slang/diff.meta.slang b/source/slang/diff.meta.slang index 6042ff5cc..1200aef42 100644 --- a/source/slang/diff.meta.slang +++ b/source/slang/diff.meta.slang @@ -1241,6 +1241,12 @@ extension Array : IDifferentiable } } +__generic +extension Array : IDifferentiablePtrType +{ + typedef Array Differential; +} + __generic extension Tuple : IDifferentiable { diff --git a/source/slang/slang-ir-autodiff.cpp b/source/slang/slang-ir-autodiff.cpp index cb37b6242..5c05b0811 100644 --- a/source/slang/slang-ir-autodiff.cpp +++ b/source/slang/slang-ir-autodiff.cpp @@ -1275,7 +1275,9 @@ IRInst* DifferentiableTypeConformanceContext::buildArrayWitness( IRWitnessTable* table = nullptr; if (target == DiffConformanceKind::Value) { - SLANG_ASSERT(isDifferentiableValueType((IRType*)arrayType)); + if (!isDifferentiableValueType((IRType*)arrayType)) + return nullptr; + auto innerWitness = tryGetDifferentiableWitness( builder, as(arrayType)->getElementType(), @@ -1360,7 +1362,8 @@ IRInst* DifferentiableTypeConformanceContext::buildArrayWitness( } else if (target == DiffConformanceKind::Ptr) { - SLANG_ASSERT(isDifferentiablePtrType((IRType*)arrayType)); + if (!isDifferentiablePtrType((IRType*)arrayType)) + return nullptr; table = builder->createWitnessTable( sharedContext->differentiablePtrInterfaceType, diff --git a/source/slang/slang-ir-specialize-resources.cpp b/source/slang/slang-ir-specialize-resources.cpp index 56f468ac7..22cd9cb3f 100644 --- a/source/slang/slang-ir-specialize-resources.cpp +++ b/source/slang/slang-ir-specialize-resources.cpp @@ -1308,6 +1308,9 @@ bool specializeResourceUsage(CodeGenContext* codeGenContext, IRModule* irModule) bool isIllegalGLSLParameterType(IRType* type) { + if (auto arrayType = as(type)) + return isIllegalGLSLParameterType(arrayType->getElementType()); + if (as(type)) return true; if (as(type)) diff --git a/source/slang/slang-ir-spirv-legalize.cpp b/source/slang/slang-ir-spirv-legalize.cpp index 4baa28d67..1de2edd4a 100644 --- a/source/slang/slang-ir-spirv-legalize.cpp +++ b/source/slang/slang-ir-spirv-legalize.cpp @@ -1465,15 +1465,21 @@ struct SPIRVLegalizationContext : public SourceEmitterBase void maybeHoistConstructInstToGlobalScope(IRInst* inst) { - // If all of the operands to this instruction are global, we can hoist - // this constructor to be a global too. This is important to make sure + // If all of the operands to this instruction are global, and are not global + // variables, we can hoist this constructor to be a global too. + // This is important to make sure // that vectors made of constant components end up being emitted as // constant vectors (using OpConstantComposite). UIndex opIndex = 0; for (auto operand = inst->getOperands(); opIndex < inst->getOperandCount(); operand++, opIndex++) + { if (operand->get()->getParent() != m_module->getModuleInst()) return; + + if (as(operand->get())) + return; + } inst->insertAtEnd(m_module->getModuleInst()); } diff --git a/tests/autodiff/diff-ptr-type-array.slang b/tests/autodiff/diff-ptr-type-array.slang new file mode 100644 index 000000000..30e6fe963 --- /dev/null +++ b/tests/autodiff/diff-ptr-type-array.slang @@ -0,0 +1,59 @@ +//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK): -shaderobj -output-using-type +//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK): -vk -shaderobj -output-using-type +//DISABLE_TEST(compute):COMPARE_COMPUTE:-wgpu + +//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer outputBuffer; + +// ----- MyPtrType definition ----- +struct MyPtrType : IDifferentiablePtrType +{ + typealias Differential = MyPtrType; + + RWStructuredBuffer buffer; + uint offset; + + float load(uint idx) { return buffer[offset + idx]; } + void accumulate(uint idx, float value) { buffer[offset + idx] += value; } +} + +[BackwardDerivative(load_bwd)] +float load(MyPtrType[2] b, uint idx) +{ + return b[1].load(idx); +} + +void load_bwd(DifferentialPtrPair b, uint idx, float grad) +{ + b.d[1].accumulate(idx, grad); +} + +// ------ +[Differentiable] +float reduce(MyPtrType a) +{ + return load( { a, a }, 0) + load( { a, a }, 1); +} + +[Differentiable] +float test(MyPtrType b) +{ + return reduce(b); +} + +[numthreads(1, 1, 1)] +void computeMain(uint id: SV_DispatchThreadID) +{ + outputBuffer[0] = 1; // CHECK: 1 + outputBuffer[1] = 2; // CHECK: 2 + + // Denote the first two elements in the buffer as the primal buffer and the last two elements + // for the derivative. + var b = DifferentialPtrPair( { outputBuffer, 0 }, { outputBuffer, 2 }); + + bwd_diff(test)(b, 1.5f); + + // Check locations [2] and [3] in the buffer + // CHECK: 1.5 + // CHECK: 1.5 +} -- cgit v1.2.3