summaryrefslogtreecommitdiffstats
path: root/source/slang
diff options
context:
space:
mode:
authorSai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com>2024-11-18 16:34:03 -0500
committerGitHub <noreply@github.com>2024-11-18 13:34:03 -0800
commitec5e019fa9732b99b75b2a3ca4f2ff5a7a3d2f33 (patch)
treef314f5b16ad18dd3325a7c3a4228242d6d448752 /source/slang
parent05903f708856a70d68bf41bbfb2b06620508dee0 (diff)
Add `IDifferentiablePtrType` support for arrays (#5576)
* Add `IDifferentiablePtrType` support for arrays - Also fixes an issue with spirv-emit of constructors that contain references to global params * Fix GLSL legalization for arrays of resource types
Diffstat (limited to 'source/slang')
-rw-r--r--source/slang/diff.meta.slang6
-rw-r--r--source/slang/slang-ir-autodiff.cpp7
-rw-r--r--source/slang/slang-ir-specialize-resources.cpp3
-rw-r--r--source/slang/slang-ir-spirv-legalize.cpp10
4 files changed, 22 insertions, 4 deletions
diff --git a/source/slang/diff.meta.slang b/source/slang/diff.meta.slang
index 6042ff5cc..1200aef42 100644
--- a/source/slang/diff.meta.slang
+++ b/source/slang/diff.meta.slang
@@ -1241,6 +1241,12 @@ extension Array<T, N> : IDifferentiable
}
}
+__generic<T : IDifferentiablePtrType, let N : int>
+extension Array<T, N> : IDifferentiablePtrType
+{
+ typedef Array<T.Differential, N> Differential;
+}
+
__generic<each T : IDifferentiable>
extension Tuple<T> : IDifferentiable
{
diff --git a/source/slang/slang-ir-autodiff.cpp b/source/slang/slang-ir-autodiff.cpp
index cb37b6242..5c05b0811 100644
--- a/source/slang/slang-ir-autodiff.cpp
+++ b/source/slang/slang-ir-autodiff.cpp
@@ -1275,7 +1275,9 @@ IRInst* DifferentiableTypeConformanceContext::buildArrayWitness(
IRWitnessTable* table = nullptr;
if (target == DiffConformanceKind::Value)
{
- SLANG_ASSERT(isDifferentiableValueType((IRType*)arrayType));
+ if (!isDifferentiableValueType((IRType*)arrayType))
+ return nullptr;
+
auto innerWitness = tryGetDifferentiableWitness(
builder,
as<IRArrayTypeBase>(arrayType)->getElementType(),
@@ -1360,7 +1362,8 @@ IRInst* DifferentiableTypeConformanceContext::buildArrayWitness(
}
else if (target == DiffConformanceKind::Ptr)
{
- SLANG_ASSERT(isDifferentiablePtrType((IRType*)arrayType));
+ if (!isDifferentiablePtrType((IRType*)arrayType))
+ return nullptr;
table = builder->createWitnessTable(
sharedContext->differentiablePtrInterfaceType,
diff --git a/source/slang/slang-ir-specialize-resources.cpp b/source/slang/slang-ir-specialize-resources.cpp
index 56f468ac7..22cd9cb3f 100644
--- a/source/slang/slang-ir-specialize-resources.cpp
+++ b/source/slang/slang-ir-specialize-resources.cpp
@@ -1308,6 +1308,9 @@ bool specializeResourceUsage(CodeGenContext* codeGenContext, IRModule* irModule)
bool isIllegalGLSLParameterType(IRType* type)
{
+ if (auto arrayType = as<IRArrayTypeBase>(type))
+ return isIllegalGLSLParameterType(arrayType->getElementType());
+
if (as<IRParameterGroupType>(type))
return true;
if (as<IRHLSLStructuredBufferTypeBase>(type))
diff --git a/source/slang/slang-ir-spirv-legalize.cpp b/source/slang/slang-ir-spirv-legalize.cpp
index 4baa28d67..1de2edd4a 100644
--- a/source/slang/slang-ir-spirv-legalize.cpp
+++ b/source/slang/slang-ir-spirv-legalize.cpp
@@ -1465,15 +1465,21 @@ struct SPIRVLegalizationContext : public SourceEmitterBase
void maybeHoistConstructInstToGlobalScope(IRInst* inst)
{
- // If all of the operands to this instruction are global, we can hoist
- // this constructor to be a global too. This is important to make sure
+ // If all of the operands to this instruction are global, and are not global
+ // variables, we can hoist this constructor to be a global too.
+ // This is important to make sure
// that vectors made of constant components end up being emitted as
// constant vectors (using OpConstantComposite).
UIndex opIndex = 0;
for (auto operand = inst->getOperands(); opIndex < inst->getOperandCount();
operand++, opIndex++)
+ {
if (operand->get()->getParent() != m_module->getModuleInst())
return;
+
+ if (as<IRGlobalParam>(operand->get()))
+ return;
+ }
inst->insertAtEnd(m_module->getModuleInst());
}