diff options
| author | Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> | 2024-11-18 16:34:03 -0500 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2024-11-18 13:34:03 -0800 |
| commit | ec5e019fa9732b99b75b2a3ca4f2ff5a7a3d2f33 (patch) | |
| tree | f314f5b16ad18dd3325a7c3a4228242d6d448752 /source/slang | |
| parent | 05903f708856a70d68bf41bbfb2b06620508dee0 (diff) | |
Add `IDifferentiablePtrType` support for arrays (#5576)
* Add `IDifferentiablePtrType` support for arrays
- Also fixes an issue with spirv-emit of constructors that contain references to global params
* Fix GLSL legalization for arrays of resource types
Diffstat (limited to 'source/slang')
| -rw-r--r-- | source/slang/diff.meta.slang | 6 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff.cpp | 7 | ||||
| -rw-r--r-- | source/slang/slang-ir-specialize-resources.cpp | 3 | ||||
| -rw-r--r-- | source/slang/slang-ir-spirv-legalize.cpp | 10 |
4 files changed, 22 insertions, 4 deletions
diff --git a/source/slang/diff.meta.slang b/source/slang/diff.meta.slang index 6042ff5cc..1200aef42 100644 --- a/source/slang/diff.meta.slang +++ b/source/slang/diff.meta.slang @@ -1241,6 +1241,12 @@ extension Array<T, N> : IDifferentiable } } +__generic<T : IDifferentiablePtrType, let N : int> +extension Array<T, N> : IDifferentiablePtrType +{ + typedef Array<T.Differential, N> Differential; +} + __generic<each T : IDifferentiable> extension Tuple<T> : IDifferentiable { diff --git a/source/slang/slang-ir-autodiff.cpp b/source/slang/slang-ir-autodiff.cpp index cb37b6242..5c05b0811 100644 --- a/source/slang/slang-ir-autodiff.cpp +++ b/source/slang/slang-ir-autodiff.cpp @@ -1275,7 +1275,9 @@ IRInst* DifferentiableTypeConformanceContext::buildArrayWitness( IRWitnessTable* table = nullptr; if (target == DiffConformanceKind::Value) { - SLANG_ASSERT(isDifferentiableValueType((IRType*)arrayType)); + if (!isDifferentiableValueType((IRType*)arrayType)) + return nullptr; + auto innerWitness = tryGetDifferentiableWitness( builder, as<IRArrayTypeBase>(arrayType)->getElementType(), @@ -1360,7 +1362,8 @@ IRInst* DifferentiableTypeConformanceContext::buildArrayWitness( } else if (target == DiffConformanceKind::Ptr) { - SLANG_ASSERT(isDifferentiablePtrType((IRType*)arrayType)); + if (!isDifferentiablePtrType((IRType*)arrayType)) + return nullptr; table = builder->createWitnessTable( sharedContext->differentiablePtrInterfaceType, diff --git a/source/slang/slang-ir-specialize-resources.cpp b/source/slang/slang-ir-specialize-resources.cpp index 56f468ac7..22cd9cb3f 100644 --- a/source/slang/slang-ir-specialize-resources.cpp +++ b/source/slang/slang-ir-specialize-resources.cpp @@ -1308,6 +1308,9 @@ bool specializeResourceUsage(CodeGenContext* codeGenContext, IRModule* irModule) bool isIllegalGLSLParameterType(IRType* type) { + if (auto arrayType = as<IRArrayTypeBase>(type)) + return isIllegalGLSLParameterType(arrayType->getElementType()); + if (as<IRParameterGroupType>(type)) return true; if (as<IRHLSLStructuredBufferTypeBase>(type)) diff --git a/source/slang/slang-ir-spirv-legalize.cpp b/source/slang/slang-ir-spirv-legalize.cpp index 4baa28d67..1de2edd4a 100644 --- a/source/slang/slang-ir-spirv-legalize.cpp +++ b/source/slang/slang-ir-spirv-legalize.cpp @@ -1465,15 +1465,21 @@ struct SPIRVLegalizationContext : public SourceEmitterBase void maybeHoistConstructInstToGlobalScope(IRInst* inst) { - // If all of the operands to this instruction are global, we can hoist - // this constructor to be a global too. This is important to make sure + // If all of the operands to this instruction are global, and are not global + // variables, we can hoist this constructor to be a global too. + // This is important to make sure // that vectors made of constant components end up being emitted as // constant vectors (using OpConstantComposite). UIndex opIndex = 0; for (auto operand = inst->getOperands(); opIndex < inst->getOperandCount(); operand++, opIndex++) + { if (operand->get()->getParent() != m_module->getModuleInst()) return; + + if (as<IRGlobalParam>(operand->get())) + return; + } inst->insertAtEnd(m_module->getModuleInst()); } |
