summaryrefslogtreecommitdiffstats
path: root/source
diff options
context:
space:
mode:
authorSai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com>2023-08-17 14:45:13 -0400
committerGitHub <noreply@github.com>2023-08-17 14:45:13 -0400
commit945409c4c6871c18aad24086c594cc66b5913733 (patch)
tree41eed63f115971d82875e23acbec77d78be4cf3a /source
parent216fc18661fd6e05053b4cc864396e6017e85b04 (diff)
Initial support for differentiating existential types (#3111)
* Merge * WIP: Complete auto-diff logic for existential types * Revert "Add compiler option for generating representative hash" This reverts commit 13b09ef4621e73844c96d64d9c111a8ed0d45aae. * More fixes for fwd-mode AD on existential types * Add anyValueSize inference pass * Fix checking of `Differential.Differential==Differential` * In-progress: infer any-value-size for existential types * Existentials now work in forward-mode * Overhaul handling of existential AD types. Fwd-mode works, reverse-mode requires front-end changes * Reverse-mode now works on existentials * Cleanup * Remove diff rules for create existential object for now * Revert treat-as-differentiable changes * Fixes * More fixes * Cleanup * more cleanup * signed/unsigned * Revert "Cleanup" This reverts commit e4f7d71f07bb207736f90708961eeecd09a1b652. * Cleanup (again) * Remove public/export/keep-alive on null differential after AD pass * Minor fix * Update dictionary accessors * Keep export decoration * More fixes + Support for `kIROp_PackAnyValue` * Merge upstream * Update expected-failure.txt
Diffstat (limited to 'source')
-rw-r--r--source/slang/diff.meta.slang24
-rw-r--r--source/slang/slang-ast-decl.h3
-rw-r--r--source/slang/slang-check-decl.cpp22
-rw-r--r--source/slang/slang-check-expr.cpp6
-rwxr-xr-xsource/slang/slang-compiler.h2
-rw-r--r--source/slang/slang-emit.cpp1
-rw-r--r--source/slang/slang-ir-any-value-inference.cpp231
-rw-r--r--source/slang/slang-ir-any-value-inference.h13
-rw-r--r--source/slang/slang-ir-any-value-marshalling.cpp41
-rw-r--r--source/slang/slang-ir-autodiff-fwd.cpp57
-rw-r--r--source/slang/slang-ir-autodiff-fwd.h2
-rw-r--r--source/slang/slang-ir-autodiff-primal-hoist.cpp232
-rw-r--r--source/slang/slang-ir-autodiff-rev.cpp2
-rw-r--r--source/slang/slang-ir-autodiff-transcriber-base.cpp133
-rw-r--r--source/slang/slang-ir-autodiff-transcriber-base.h9
-rw-r--r--source/slang/slang-ir-autodiff-transpose.h186
-rw-r--r--source/slang/slang-ir-autodiff-unzip.cpp2
-rw-r--r--source/slang/slang-ir-autodiff-unzip.h5
-rw-r--r--source/slang/slang-ir-autodiff.cpp344
-rw-r--r--source/slang/slang-ir-autodiff.h51
-rw-r--r--source/slang/slang-ir-clone.cpp10
-rw-r--r--source/slang/slang-ir-generics-lowering-context.h1
-rw-r--r--source/slang/slang-ir-inst-defs.h7
-rw-r--r--source/slang/slang-ir-insts.h19
-rw-r--r--source/slang/slang-ir-lower-generics.cpp6
-rw-r--r--source/slang/slang-ir-lower-reinterpret.cpp6
-rw-r--r--source/slang/slang-ir-ssa.cpp20
-rw-r--r--source/slang/slang-ir.cpp11
-rw-r--r--source/slang/slang-lower-to-ir.cpp31
29 files changed, 1304 insertions, 173 deletions
diff --git a/source/slang/diff.meta.slang b/source/slang/diff.meta.slang
index ce0e72d34..423b6bfd0 100644
--- a/source/slang/diff.meta.slang
+++ b/source/slang/diff.meta.slang
@@ -25,6 +25,30 @@ attribute_syntax [DerivativeMember(memberName)] : DerivativeMemberAttribute;
__attributeTarget(FunctionDeclBase)
attribute_syntax [NoDiffThis] : NoDiffThisAttribute;
+// A 'none-type' that acts as a run-time sentinel for zero differentials.
+public struct NullDifferential : IDifferentiable
+{
+ // for now, we'll use at least one field to make sure the type is non-empty
+ uint dummy;
+ typedef NullDifferential Differential;
+
+ [Differentiable]
+ [ForceInline]
+ static Differential dzero() { return { 0 }; }
+
+ [Differentiable]
+ [ForceInline]
+ static Differential dadd(Differential, Differential) { return { 0 }; }
+
+ [Differentiable]
+ [ForceInline]
+ static Differential dmul<T: __BuiltinRealType>(T, Differential) { return { 0 }; }
+};
+
+// Existential check for null differential type
+__intrinsic_op($(kIROp_IsDifferentialNull))
+bool isDifferentialNull(IDifferentiable obj);
+
/// Represents a GPU view of a tensor.
__generic<T>
__magic_type(TensorViewType)
diff --git a/source/slang/slang-ast-decl.h b/source/slang/slang-ast-decl.h
index 8266d77c7..553a5c26f 100644
--- a/source/slang/slang-ast-decl.h
+++ b/source/slang/slang-ast-decl.h
@@ -554,6 +554,9 @@ class DerivativeRequirementDecl : public FunctionDeclBase
// The original requirement decl.
Decl* originalRequirementDecl = nullptr;
+
+ // Type to use for 'ThisType'
+ Type* diffThisType;
};
// A reference to a synthesized decl representing a differentiable function requirement, this decl will
diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp
index cb3db9e39..f25821dac 100644
--- a/source/slang/slang-check-decl.cpp
+++ b/source/slang/slang-check-decl.cpp
@@ -1656,8 +1656,7 @@ namespace Slang
RequirementWitness witnessValue;
auto requirementDecl = m_astBuilder->getSharedASTBuilder()->findBuiltinRequirementDecl(BuiltinRequirementKind::DifferentialType);
if (!inheritanceDecl->witnessTable->getRequirementDictionary().tryGetValue(requirementDecl, witnessValue))
- return;
-
+ return;
// A type used as differential type must have itself as its own differential type.
if (witnessValue.getFlavor() != RequirementWitness::Flavor::val)
return;
@@ -5781,6 +5780,16 @@ namespace Slang
interfaceDecl->members.add(reqDecl);
reqDecl->parentDecl = interfaceDecl;
+ if (!decl->hasModifier<NoDiffThisAttribute>())
+ {
+ // Build decl-ref-type from interface.
+ auto interfaceType = DeclRefType::create(getASTBuilder(), makeDeclRef(interfaceDecl));
+
+ // If the interface is differentiable, make the this type a pair.
+ if (tryGetDifferentialType(getASTBuilder(), interfaceType))
+ reqDecl->diffThisType = getDifferentialPairType(interfaceType);
+ }
+
auto reqRef = m_astBuilder->create<DerivativeRequirementReferenceDecl>();
reqRef->referencedDecl = reqDecl;
reqRef->parentDecl = decl;
@@ -5800,6 +5809,15 @@ namespace Slang
setFuncTypeIntoRequirementDecl(reqDecl, diffFuncType);
interfaceDecl->members.add(reqDecl);
reqDecl->parentDecl = interfaceDecl;
+ if (!decl->hasModifier<NoDiffThisAttribute>())
+ {
+ // Build decl-ref-type from interface.
+ auto interfaceType = DeclRefType::create(getASTBuilder(), makeDeclRef(interfaceDecl));
+
+ // If the interface is differentiable, make the this type a pair.
+ if (tryGetDifferentialType(getASTBuilder(), interfaceType))
+ reqDecl->diffThisType = getDifferentialPairType(interfaceType);
+ }
auto reqRef = m_astBuilder->create<DerivativeRequirementReferenceDecl>();
reqRef->referencedDecl = reqDecl;
diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp
index 05cb6262b..3d2f81edb 100644
--- a/source/slang/slang-check-expr.cpp
+++ b/source/slang/slang-check-expr.cpp
@@ -500,6 +500,12 @@ namespace Slang
// Don't synthesize for ThisType.
if (as<ThisTypeDecl>(subType->getDeclRef().getDecl()))
return nullptr;
+
+ // If the inner most subtype is itself an associated type, then we're dealing
+ // with an abstract type. There's not need to synthesize anythin at this point.
+ //
+ if (as<AssocTypeDecl>(subType->getDeclRef().getDecl()))
+ return nullptr;
// If we reach here, we are expecting a synthesized decl defined in `subType`.
// Instead of returning a DeclRefExpr to the requirement decl, we synthesize a placeholder decl
diff --git a/source/slang/slang-compiler.h b/source/slang/slang-compiler.h
index f29dc8dae..d87b755c7 100755
--- a/source/slang/slang-compiler.h
+++ b/source/slang/slang-compiler.h
@@ -2671,7 +2671,7 @@ namespace Slang
virtual SLANG_NO_THROW void SLANG_MCALL setDebugInfoFormat(SlangDebugInfoFormat format) SLANG_OVERRIDE;
virtual SLANG_NO_THROW void SLANG_MCALL setReportDownstreamTime(bool value) SLANG_OVERRIDE;
virtual SLANG_NO_THROW void SLANG_MCALL setReportPerfBenchmark(bool value) SLANG_OVERRIDE;
-
+
void setHLSLToVulkanLayoutOptions(int targetIndex, HLSLToVulkanLayoutOptions* vulkanLayoutOptions);
EndToEndCompileRequest(
diff --git a/source/slang/slang-emit.cpp b/source/slang/slang-emit.cpp
index 82f2a8fd3..6521b05ba 100644
--- a/source/slang/slang-emit.cpp
+++ b/source/slang/slang-emit.cpp
@@ -5,6 +5,7 @@
#include "../compiler-core/slang-name.h"
+#include "slang-ir-any-value-inference.h"
#include "slang-ir-bind-existentials.h"
#include "slang-ir-byte-address-legalize.h"
#include "slang-ir-collect-global-uniforms.h"
diff --git a/source/slang/slang-ir-any-value-inference.cpp b/source/slang/slang-ir-any-value-inference.cpp
new file mode 100644
index 000000000..eb4aa670f
--- /dev/null
+++ b/source/slang/slang-ir-any-value-inference.cpp
@@ -0,0 +1,231 @@
+#include "slang-ir-any-value-inference.h"
+
+#include "slang-ir-generics-lowering-context.h"
+#include "slang-ir.h"
+#include "slang-ir-insts.h"
+#include "slang-ir-util.h"
+#include "slang-ir-layout.h"
+#include "../core/slang-func-ptr.h"
+
+namespace Slang
+{
+
+ void _findDependenciesOfTypeInSet(IRType* type, HashSet<IRInterfaceType*>& targetSet, List<IRInterfaceType*>& result)
+ {
+ switch (type->getOp())
+ {
+ case kIROp_InterfaceType:
+ {
+ auto interfaceType = cast<IRInterfaceType>(type);
+ if (targetSet.contains(interfaceType))
+ {
+ result.add(interfaceType);
+ return;
+ }
+ }
+ break;
+ case kIROp_StructType:
+ {
+ auto structType = cast<IRStructType>(type);
+ for (auto field : structType->getFields())
+ {
+ _findDependenciesOfTypeInSet(field->getFieldType(), targetSet, result);
+ }
+ }
+ break;
+ default:
+ {
+ for (UInt i = 0; i < type->getOperandCount(); i++)
+ {
+ if (auto operandType = as<IRType>(type->getOperand(i)))
+ _findDependenciesOfTypeInSet(operandType, targetSet, result);
+ }
+ }
+ break;
+ }
+ }
+
+ List<IRInterfaceType*> findDependenciesOfTypeInSet(IRType* type, HashSet<IRInterfaceType*> targetSet)
+ {
+ List<IRInterfaceType*> result;
+ _findDependenciesOfTypeInSet(type, targetSet, result);
+
+ return result;
+ }
+
+ void _sortTopologically(
+ IRInterfaceType* interfaceType,
+ HashSet<IRInterfaceType*>& visited,
+ List<IRInterfaceType*>& sortedInterfaceTypes,
+ const Func<HashSet<IRInterfaceType*>, IRInterfaceType*>& getDependencies)
+ {
+ if (visited.contains(interfaceType))
+ return;
+
+ visited.add(interfaceType);
+
+ for (auto dependency : getDependencies(interfaceType))
+ {
+ _sortTopologically(dependency, visited, sortedInterfaceTypes, getDependencies);
+ }
+
+ sortedInterfaceTypes.add(interfaceType);
+ }
+
+ List<IRInterfaceType*> sortTopologically(
+ HashSet<IRInterfaceType*> interfaceTypes,
+ const Func<HashSet<IRInterfaceType*>, IRInterfaceType*>& getDependencies)
+ {
+ List<IRInterfaceType*> sortedInterfaceTypes;
+ HashSet<IRInterfaceType*> visited;
+ for (auto interfaceType : interfaceTypes)
+ {
+ _sortTopologically(interfaceType, visited, sortedInterfaceTypes, getDependencies);
+ }
+ return sortedInterfaceTypes;
+ }
+
+ void inferAnyValueSizeWhereNecessary(
+ IRModule* module)
+ {
+ // Go through the global insts and collect all interface types.
+ // For each interface type, infer its any-value-size, by looking up
+ // all witness tables whose conformance type matches the interface type.
+ // then using _calcNaturalSizeAndAlignment to find the max size.
+ //
+ // Note: we only infer any-value-size for interface types that are used
+ // as a generic type parameter, because we don't want to infer any-value-size
+ // for interface types that are used as a witness table type.
+ //
+
+ HashSet<IRInst*> implementedInterfaces;
+ // Add all interface type that are implemented by at least one type to a set.
+ for (auto inst : module->getGlobalInsts())
+ {
+ if (inst->getOp() == kIROp_WitnessTable)
+ {
+ auto interfaceType = cast<IRWitnessTableType>(inst->getDataType())->getConformanceType();
+ implementedInterfaces.add(interfaceType);
+ }
+ }
+
+ // Collect all interface types that require inference.
+ HashSet<IRInterfaceType*> interfaceTypes;
+ for (auto inst : module->getGlobalInsts())
+ {
+ if (inst->getOp() == kIROp_InterfaceType)
+ {
+ auto interfaceType = cast<IRInterfaceType>(inst);
+
+ // Do not infer anything for COM interfaces.
+ if (isComInterfaceType((IRType*)interfaceType))
+ continue;
+
+ // Also skip builtin types.
+ if (interfaceType->findDecoration<IRBuiltinDecoration>())
+ continue;
+
+ // If the interface already has an explicit any-value-size, don't infer anything.
+ if (interfaceType->findDecoration<IRAnyValueSizeDecoration>())
+ continue;
+
+ // Skip interfaces that are not implemented by any type.
+ if (!implementedInterfaces.contains(interfaceType))
+ continue;
+
+ interfaceTypes.add(interfaceType);
+ }
+ }
+
+ Dictionary<IRInterfaceType*, List<IRInst*>> mapInterfaceToImplementations;
+
+ // Collect all concrete types that conform to this interface type.
+ for (auto interfaceType : interfaceTypes)
+ {
+ IRWitnessTableType* witnessTableType = nullptr;
+ // Find witness table type corresponding to this interface.
+ for (auto use = interfaceType->firstUse; use; use = use->nextUse)
+ {
+ if (auto _witnessTableType = as<IRWitnessTableType>(use->getUser()))
+ {
+ if (_witnessTableType->getConformanceType() == interfaceType && _witnessTableType->hasUses())
+ {
+ witnessTableType = _witnessTableType;
+ break;
+ }
+ }
+ }
+
+ // If we hit this case, we have an interface without any conforming implementations.
+ // This case should be handled before this point.
+ //
+ SLANG_ASSERT(witnessTableType);
+
+ List<IRInst*> implList;
+
+ // Walk through all the uses of this witness table type to find the witness tables.
+ for (auto use = witnessTableType->firstUse; use; use = use->nextUse)
+ {
+ auto witnessTable = as<IRWitnessTable>(use->getUser());
+ if (!witnessTable || witnessTable->getDataType() != witnessTableType)
+ continue;
+
+ auto concreteImpl = witnessTable->getConcreteType();
+
+ // Only consider implementations at the top-level (ignore those nested
+ // in generics)
+ //
+ if (concreteImpl->getParent() == module->getModuleInst())
+ implList.add(concreteImpl);
+ }
+
+ mapInterfaceToImplementations.add(interfaceType, implList);
+ }
+
+ Dictionary<IRInterfaceType*, HashSet<IRInterfaceType*>> interfaceDependencyMap;
+
+ // Collect dependencies for each interface.
+ for (auto interfaceType : interfaceTypes)
+ {
+ HashSet<IRInterfaceType*> dependencySet;
+ for (auto impl : mapInterfaceToImplementations[interfaceType])
+ {
+ auto dependencies = findDependenciesOfTypeInSet((IRType*)impl, interfaceTypes);
+ for (auto dependency : dependencies)
+ dependencySet.add(dependency);
+ }
+ interfaceDependencyMap.add(interfaceType, dependencySet);
+ }
+
+ // Sort the interface types in topological order.
+ // This is necessary because we need to infer the any-value-size of an interface type
+ // before we infer the any-value-size of an interface type that depends on it.
+ //
+ List<IRInterfaceType*> sortedInterfaceTypes = sortTopologically(interfaceTypes, [&](IRInterfaceType* interfaceType)
+ {
+ return interfaceDependencyMap[interfaceType];
+ });
+
+ for (auto interfaceType : sortedInterfaceTypes)
+ {
+ IRIntegerValue maxAnyValueSize = -1;
+ for (auto implType : mapInterfaceToImplementations[interfaceType])
+ {
+ IRSizeAndAlignment sizeAndAlignment;
+ getNaturalSizeAndAlignment((IRType*)implType, &sizeAndAlignment);
+
+ maxAnyValueSize = Math::Max(maxAnyValueSize, sizeAndAlignment.size);
+ }
+
+ // Should not encounter interface types without any conforming implementations.
+ SLANG_ASSERT(maxAnyValueSize >= 0);
+
+ // If we found a max size, add an any-value-size decoration to the interface type.
+ if (maxAnyValueSize >= 0)
+ {
+ IRBuilder builder(module);
+ builder.addAnyValueSizeDecoration(interfaceType, maxAnyValueSize);
+ }
+ }
+ }
+}; \ No newline at end of file
diff --git a/source/slang/slang-ir-any-value-inference.h b/source/slang/slang-ir-any-value-inference.h
new file mode 100644
index 000000000..eb202d626
--- /dev/null
+++ b/source/slang/slang-ir-any-value-inference.h
@@ -0,0 +1,13 @@
+// slang-ir-any-value-inference.h
+#pragma once
+
+#include "../core/slang-common.h"
+#include "slang-ir.h"
+#include "slang-ir-insts.h"
+#include "slang-compiler.h"
+
+namespace Slang
+{
+ void inferAnyValueSizeWhereNecessary(
+ IRModule* module);
+}
diff --git a/source/slang/slang-ir-any-value-marshalling.cpp b/source/slang/slang-ir-any-value-marshalling.cpp
index 79aea9011..ed7818dbf 100644
--- a/source/slang/slang-ir-any-value-marshalling.cpp
+++ b/source/slang/slang-ir-any-value-marshalling.cpp
@@ -1,5 +1,6 @@
#include "slang-ir-any-value-marshalling.h"
+#include "../core/slang-math.h"
#include "slang-ir-generics-lowering-context.h"
#include "slang-ir.h"
#include "slang-ir-insts.h"
@@ -782,6 +783,46 @@ namespace Slang
auto size = SharedGenericsLoweringContext::getInterfaceAnyValueSize(interfaceType, interfaceType->sourceLoc);
return alignUp(offset, 4) + alignUp((SlangInt)size, 4);
}
+ case kIROp_LookupWitness:
+ {
+ auto witnessTableVal = type->getOperand(0);
+ auto key = type->getOperand(1);
+ IRType* assocType = nullptr;
+ if (auto witnessTableType = as<IRWitnessTableTypeBase>(witnessTableVal->getDataType()))
+ {
+ auto interfaceType = as<IRInterfaceType>(witnessTableType->getConformanceType());
+
+ // Walk through interface operands to find a match, the result should be an
+ // associated type entry.
+ //
+ for (UIndex ii = 0; ii < interfaceType->getOperandCount(); ii++)
+ {
+ auto entry = cast<IRInterfaceRequirementEntry>(interfaceType->getOperand(ii));
+ if (entry->getRequirementKey() == key &&
+ as<IRAssociatedType>(entry->getRequirementVal()))
+ {
+ assocType = (IRType*)entry->getRequirementVal();
+ break;
+ }
+ }
+ }
+
+ if (!assocType)
+ return -1;
+
+ IRIntegerValue anyValueSize = kInvalidAnyValueSize;
+ for (UInt i = 0; i < assocType->getOperandCount(); i++)
+ {
+ anyValueSize = Math::Min(
+ anyValueSize,
+ SharedGenericsLoweringContext::getInterfaceAnyValueSize(assocType->getOperand(i), type->sourceLoc));
+ }
+
+ if (anyValueSize == kInvalidAnyValueSize)
+ return -1;
+
+ return alignUp(offset, 4) + alignUp((SlangInt)anyValueSize, 4);
+ }
default:
if (as<IRTextureTypeBase>(type) || as<IRSamplerStateTypeBase>(type))
{
diff --git a/source/slang/slang-ir-autodiff-fwd.cpp b/source/slang/slang-ir-autodiff-fwd.cpp
index c17d7d5c4..2662498ed 100644
--- a/source/slang/slang-ir-autodiff-fwd.cpp
+++ b/source/slang/slang-ir-autodiff-fwd.cpp
@@ -142,6 +142,23 @@ InstPair ForwardDiffTranscriber::transcribeUndefined(IRBuilder* builder, IRInst*
return InstPair(primalVal, nullptr);
}
+InstPair ForwardDiffTranscriber::transcribeReinterpret(IRBuilder* builder, IRInst* origInst)
+{
+ auto primalVal = maybeCloneForPrimalInst(builder, origInst);
+
+ IRInst* diffVal = nullptr;
+
+ if (IRType* const diffType = differentiateType(builder, origInst->getFullType()))
+ {
+ if (auto diffOperand = findOrTranscribeDiffInst(builder, origInst->getOperand(0)))
+ {
+ diffVal = builder->emitReinterpret(diffType, diffOperand);
+ }
+ }
+
+ return InstPair(primalVal, diffVal);
+}
+
InstPair ForwardDiffTranscriber::transcribeVar(IRBuilder* builder, IRVar* origVar)
{
if (IRType* diffType = differentiateType(builder, origVar->getDataType()->getValueType()))
@@ -230,10 +247,12 @@ InstPair ForwardDiffTranscriber::transcribeBinaryArith(IRBuilder* builder, IRIns
diffLeft,
builder->getFloatValue(
constant->getDataType(), 1.0 / constant->getValue()));
+ builder->markInstAsDifferential(diff, resultType);
}
else
{
diff = builder->emitDiv(diffType, diffLeft, primalRight);
+ builder->markInstAsDifferential(diff, resultType);
}
return InstPair(primalArith, diff);
}
@@ -247,6 +266,7 @@ InstPair ForwardDiffTranscriber::transcribeBinaryArith(IRBuilder* builder, IRIns
auto diffSub = builder->emitSub(diffType, diffLeftTimesRight, diffRightTimesLeft);
builder->markInstAsDifferential(diffSub, resultType);
+
auto diffMul = builder->emitMul(primalRight->getFullType(), primalRight, primalRight);
builder->markInstAsPrimal(diffMul);
@@ -661,7 +681,7 @@ InstPair ForwardDiffTranscriber::transcribeCall(IRBuilder* builder, IRCall* orig
auto pairPtrType = as<IRPtrTypeBase>(pairType);
auto pairValType = as<IRDifferentialPairType>(
pairPtrType ? pairPtrType->getValueType() : pairType);
- auto diffType = differentiableTypeConformanceContext.getDifferentialTypeFromDiffPairType(&argBuilder, pairValType);
+ auto diffType = differentiableTypeConformanceContext.getDiffTypeFromPairType(&argBuilder, pairValType);
if (auto ptrParamType = as<IRPtrTypeBase>(diffParamType))
{
// Create temp var to pass in/out arguments.
@@ -698,6 +718,7 @@ InstPair ForwardDiffTranscriber::transcribeCall(IRBuilder* builder, IRCall* orig
{
auto newDiffVal = afterBuilder.emitDifferentialPairGetDifferential((IRType*)diffType, newVal);
afterBuilder.markInstAsDifferential(newDiffVal, pairValType->getValueType());
+
auto storeInst = afterBuilder.emitStore(diffArg, newDiffVal);
afterBuilder.markInstAsDifferential(storeInst, pairValType->getValueType());
}
@@ -1389,16 +1410,19 @@ InstPair ForwardDiffTranscriber::transcribeMakeExistential(IRBuilder* builder, I
SLANG_RELEASE_ASSERT(primalInterfaceType);
// If the interface type of the existential is differentiable, we emit a make existential
- // of IDifferentiable interface type and the witness table of the original type's conformance
+ // of IDifferentiable.Differential type and the witness table of the original type's conformance
// to IDifferentiable.
//
- if (auto differentialWitnessTable = tryExtractConformanceFromInterfaceType(
+ if (auto differentialWitnessTable = differentiableTypeConformanceContext.tryExtractConformanceFromInterfaceType(
builder, primalInterfaceType, (IRWitnessTable*)primalWitnessTable))
{
if (auto diffBase = findOrTranscribeDiffInst(builder, origBase))
{
+ auto differentialAssociatedType = differentiateType(builder, primalInterfaceType);
+ SLANG_ASSERT(differentialAssociatedType);
+
diffResult = builder->emitMakeExistential(
- autoDiffSharedContext->differentiableInterfaceType,
+ differentialAssociatedType,
diffBase,
differentialWitnessTable);
}
@@ -1735,6 +1759,7 @@ InstPair ForwardDiffTranscriber::transcribeFunc(IRBuilder* inBuilder, IRFunc* pr
{
auto diffVal = builder.emitLoad(writeBack.value.differential);
builder.markInstAsDifferential(diffVal, primalVal->getFullType());
+
valToStore = builder.emitMakeDifferentialPair(cast<IRPtrTypeBase>(param->getFullType())->getValueType(),
primalVal, diffVal);
builder.markInstAsMixedDifferential(valToStore, valToStore->getFullType());
@@ -1867,6 +1892,9 @@ InstPair ForwardDiffTranscriber::transcribeInstImpl(IRBuilder* builder, IRInst*
case kIROp_ExtractExistentialValue:
return transcribeSingleOperandInst(builder, origInst);
+
+ case kIROp_PackAnyValue:
+ return transcribeSingleOperandInst(builder, origInst);
case kIROp_MakeExistential:
return transcribeMakeExistential(builder, as<IRMakeExistential>(origInst));
@@ -1874,10 +1902,16 @@ InstPair ForwardDiffTranscriber::transcribeInstImpl(IRBuilder* builder, IRInst*
case kIROp_ExtractExistentialType:
{
IRInst* witnessTable;
+ auto diffType = differentiateExtractExistentialType(
+ builder, as<IRExtractExistentialType>(origInst), witnessTable);
+
+ // Mark types as primal since they are not transposable.
+ if (diffType)
+ builder->markInstAsPrimal(diffType);
+
return InstPair(
maybeCloneForPrimalInst(builder, origInst),
- differentiateExtractExistentialType(
- builder, as<IRExtractExistentialType>(origInst), witnessTable));
+ diffType);
}
case kIROp_ExtractExistentialWitnessTable:
return transcribeExtractExistentialWitnessTable(builder, origInst);
@@ -1890,6 +1924,9 @@ InstPair ForwardDiffTranscriber::transcribeInstImpl(IRBuilder* builder, IRInst*
case kIROp_undefined:
return transcribeUndefined(builder, origInst);
+
+ case kIROp_Reinterpret:
+ return transcribeReinterpret(builder, origInst);
// Differentiable insts that should have been lowered in a previous pass.
case kIROp_SwizzledStore:
@@ -1901,7 +1938,6 @@ InstPair ForwardDiffTranscriber::transcribeInstImpl(IRBuilder* builder, IRInst*
SLANG_RELEASE_ASSERT(lookupDiffInst(swizzledStore->getDest(), nullptr) == nullptr);
return transcribeNonDiffInst(builder, swizzledStore);
}
-
// Known non-differentiable insts.
case kIROp_Not:
case kIROp_BitAnd:
@@ -1918,12 +1954,10 @@ InstPair ForwardDiffTranscriber::transcribeInstImpl(IRBuilder* builder, IRInst*
case kIROp_RWStructuredBufferLoadStatus:
case kIROp_RWStructuredBufferStore:
case kIROp_RWStructuredBufferGetElementPtr:
- case kIROp_Reinterpret:
case kIROp_IsType:
case kIROp_ImageSubscript:
case kIROp_ImageLoad:
case kIROp_ImageStore:
- case kIROp_PackAnyValue:
case kIROp_UnpackAnyValue:
case kIROp_GetNativePtr:
case kIROp_CastIntToFloat:
@@ -1936,6 +1970,11 @@ InstPair ForwardDiffTranscriber::transcribeInstImpl(IRBuilder* builder, IRInst*
// A call to createDynamicObject<T>(arbitraryData) cannot provide a diff value,
// so we treat this inst as non differentiable.
// We can extend the frontend and IR with a separate op-code that can provide an explicit diff value.
+ //
+ // However, we can't skip this instruction since it also produces a _type_ which may be used by
+ // other differentiable instructions. Therefore, we'll create another existential object but with
+ // a dzero() for it's value.
+ //
case kIROp_CreateExistentialObject:
return transcribeNonDiffInst(builder, origInst);
diff --git a/source/slang/slang-ir-autodiff-fwd.h b/source/slang/slang-ir-autodiff-fwd.h
index 4edb9301a..8d8d65c10 100644
--- a/source/slang/slang-ir-autodiff-fwd.h
+++ b/source/slang/slang-ir-autodiff-fwd.h
@@ -88,6 +88,8 @@ struct ForwardDiffTranscriber : AutoDiffTranscriberBase
InstPair transcribeDefaultConstruct(IRBuilder* builder, IRInst* origInst);
+ InstPair transcribeReinterpret(IRBuilder* builder, IRInst* origInst);
+
virtual IRFuncType* differentiateFunctionType(IRBuilder* builder, IRInst* func, IRFuncType* funcType) override;
void generateTrivialFwdDiffFunc(IRFunc* primalFunc, IRFunc* diffFunc);
diff --git a/source/slang/slang-ir-autodiff-primal-hoist.cpp b/source/slang/slang-ir-autodiff-primal-hoist.cpp
index 6ccf7caf4..ebf7a9484 100644
--- a/source/slang/slang-ir-autodiff-primal-hoist.cpp
+++ b/source/slang/slang-ir-autodiff-primal-hoist.cpp
@@ -3,6 +3,7 @@
#include "slang-ir-autodiff-region.h"
#include "slang-ir-simplify-cfg.h"
#include "slang-ir-util.h"
+#include "../core/slang-func-ptr.h"
#include "slang-ir.h"
namespace Slang
@@ -1087,8 +1088,12 @@ IRVar* emitIndexedLocalVar(
IRType* baseType,
const List<IndexTrackingInfo>& defBlockIndices)
{
+ // Cannot store pointers. Case should have been handled by now.
SLANG_RELEASE_ASSERT(!as<IRPtrTypeBase>(baseType));
+ // Cannot store types. Case should have been handled by now.
+ SLANG_RELEASE_ASSERT(!as<IRTypeType>(baseType));
+
IRBuilder varBuilder(varBlock->getModule());
varBuilder.setInsertBefore(varBlock->getFirstOrdinaryInst());
@@ -1242,23 +1247,112 @@ static int getInstRegionNestLevel(
return (int)result;
}
+
+struct UseChain
+{
+ List<IRUse*> chain;
+ static List<UseChain> from(
+ IRUse* baseUse,
+ Func<bool, IRUse*> isRelevantUse,
+ Func<bool, IRInst*> passthroughInst)
+ {
+ IRInst* inst = baseUse->getUser();
+
+ // Base case 1: we hit a relevant use, return a single-element chain.
+ if (isRelevantUse(baseUse))
+ {
+ UseChain baseUseChain;
+ baseUseChain.chain.add(baseUse);
+
+ return List<UseChain>(UseChain(baseUseChain));
+ }
+
+ // Base case 2: we hit an irrelevant use that is not also a passthrough.
+ // so stop here.
+ if (!passthroughInst(inst))
+ {
+ return List<UseChain>();
+ }
+
+ // Recurse.
+ List<UseChain> result;
+ for (auto use = inst->firstUse; use; use = use->nextUse)
+ {
+ List<UseChain> innerChain = from(use, isRelevantUse, passthroughInst);
+
+ for (auto& useChain : innerChain)
+ {
+ useChain.chain.add(baseUse);
+ result.add(useChain);
+ }
+ }
+
+ return result;
+ }
+
+ void replace(IRBuilder* builder, IRInst* inst)
+ {
+ SLANG_ASSERT(chain.getCount() > 0);
+
+ // Simple case: if there is only one use, then we can just replace it.
+ if (chain.getCount() == 1)
+ {
+ builder->replaceOperand(chain.getLast(), inst);
+ chain.clear();
+ return;
+ }
+
+ IRCloneEnv env;
+
+ // Pop the last use, which is the base use that needs to be replaced.
+ auto baseUse = chain.getLast();
+ chain.removeLast();
+
+ // Ensure that replacement inst is set as mapping for the baseUse.
+ env.mapOldValToNew[baseUse->get()] = inst;
+
+ auto lastInstInChain = inst;
+
+ IRBuilder chainBuilder(builder->getModule());
+ setInsertAfterOrdinaryInst(&chainBuilder, inst);
+
+ // Clone the rest of the chain.
+ for (auto& use : chain)
+ {
+ lastInstInChain = cloneInst(&env, &chainBuilder, use->getUser());
+ }
+
+ // Replace the base use.
+ builder->replaceOperand(baseUse, lastInstInChain);
+
+ chain.clear();
+ }
+
+ IRInst* getUser() const
+ {
+ SLANG_ASSERT(chain.getCount() > 0);
+ return chain.getLast()->getUser();
+ }
+};
+
+
// Trim defBlockIndices based on the indices of out of scope uses.
//
static List<IndexTrackingInfo> maybeTrimIndices(
const List<IndexTrackingInfo>& defBlockIndices,
const Dictionary<IRBlock*, List<IndexTrackingInfo>>& indexedBlockInfo,
- const List<IRUse*>& outOfScopeUses)
+ const List<UseChain>& outOfScopeUses)
{
// Go through uses, lookup the defBlockIndices, and remove any indices if they
// are not present in any of the uses. (This is sort of slow...)
//
List<IndexTrackingInfo> result;
- for (auto& index : defBlockIndices)
+ for (const auto& index : defBlockIndices)
{
bool found = false;
- for (auto& use : outOfScopeUses)
+ for (const auto& use : outOfScopeUses)
{
- auto useInst = use->getUser();
+ auto useInst = use.getUser();
auto useBlock = useInst->getParent();
auto useBlockIndices = indexedBlockInfo.getValue(as<IRBlock>(useBlock));
if (useBlockIndices.contains(index))
@@ -1273,6 +1367,18 @@ static List<IndexTrackingInfo> maybeTrimIndices(
return result;
}
+bool canInstBeStored(IRInst* inst)
+{
+ // Cannot store insts whose value is a type or a witness table.
+ // These insts get lowered to target-specific logic, and cannot be
+ // stored into variables or context structs as normal values.
+ //
+ if (as<IRTypeType>(inst->getDataType()) || as<IRWitnessTableType>(inst->getDataType()))
+ return false;
+
+ return true;
+}
+
/// Legalizes all accesses to primal insts from recompute and diff blocks.
///
@@ -1352,8 +1458,19 @@ RefPtr<HoistedPrimalsInfo> ensurePrimalAvailability(
{
SLANG_ASSERT(!isDifferentialBlock(defaultVarBlock));
- for (auto instToStore : instSet)
+ List<IRInst*> workList;
+ for (auto inst : instSet)
+ workList.add(inst);
+
+ HashSet<IRInst*> seenInstSet;
+ while (workList.getCount() != 0)
{
+ auto instToStore = workList.getLast();
+ workList.removeLast();
+
+ if (seenInstSet.contains(instToStore))
+ continue;
+
IRBlock* defBlock = nullptr;
if (const auto ptrInst = as<IRPtrTypeBase>(instToStore->getDataType()))
{
@@ -1367,45 +1484,61 @@ RefPtr<HoistedPrimalsInfo> ensurePrimalAvailability(
SLANG_RELEASE_ASSERT(defBlock);
- List<IRUse*> outOfScopeUses;
+ List<UseChain> outOfScopeUses;
for (auto use = instToStore->firstUse; use;)
{
auto nextUse = use->nextUse;
- // Only consider uses in differential blocks.
- // This method is not responsible for other blocks.
- //
- IRBlock* userBlock = getBlock(use->getUser());
- if (isDifferentialOrRecomputeBlock(userBlock))
+ // Lambda to check if a use is relevant.
+ auto isRelevantUse = [&](IRUse* use)
{
- if (!domTree->dominates(defBlock, userBlock))
- {
- outOfScopeUses.add(use);
- }
- else if (!areIndicesSubsetOf(indexedBlockInfo[defBlock], indexedBlockInfo[userBlock]))
- {
- outOfScopeUses.add(use);
- }
- else if (getInstRegionNestLevel(indexedBlockInfo, defBlock, instToStore) > 0 &&
- !isDifferentialOrRecomputeBlock(defBlock))
+ // Only consider uses in differential blocks.
+ // This method is not responsible for other blocks.
+ //
+ IRBlock* userBlock = getBlock(use->getUser());
+ if (isDifferentialOrRecomputeBlock(userBlock))
{
- outOfScopeUses.add(use);
- }
- else if (as<IRPtrTypeBase>(instToStore->getDataType()) &&
- !isDifferentialOrRecomputeBlock(defBlock))
- {
- outOfScopeUses.add(use);
+ if (!domTree->dominates(defBlock, userBlock))
+ {
+ return true;
+ }
+ else if (!areIndicesSubsetOf(indexedBlockInfo[defBlock], indexedBlockInfo[userBlock]))
+ {
+ return true;
+ }
+ else if (getInstRegionNestLevel(indexedBlockInfo, defBlock, instToStore) > 0 &&
+ !isDifferentialOrRecomputeBlock(defBlock))
+ {
+ return true;
+ }
+ else if (as<IRPtrTypeBase>(instToStore->getDataType()) &&
+ !isDifferentialOrRecomputeBlock(defBlock))
+ {
+ return true;
+ }
}
- }
+ return false;
+ };
+
+ // Lambda to check if an inst is transparent. We lookup uses 'through' transparent
+ // insts recursively.
+ //
+ auto isPassthroughInst = [&](IRInst* inst)
+ {
+ return !canInstBeStored(inst);
+ };
+
+ List<UseChain> useChains = UseChain::from(use, isRelevantUse, isPassthroughInst);
+ outOfScopeUses.addRange(useChains);
use = nextUse;
}
if (outOfScopeUses.getCount() == 0)
{
-
if (!isRecomputeInst)
processedStoreSet.add(instToStore);
+ seenInstSet.add(instToStore);
continue;
}
@@ -1457,9 +1590,9 @@ RefPtr<HoistedPrimalsInfo> ensurePrimalAvailability(
for (auto use : outOfScopeUses)
{
- setInsertBeforeOrdinaryInst(&builder, getInstInBlock(use->getUser()));
+ setInsertBeforeOrdinaryInst(&builder, getInstInBlock(use.getUser()));
- List<IndexTrackingInfo>& useBlockIndices = indexedBlockInfo[getBlock(use->getUser())];
+ List<IndexTrackingInfo>& useBlockIndices = indexedBlockInfo[getBlock(use.getUser())];
IRInst* loadAddr = emitIndexedLoadAddressForVar(
&builder,
@@ -1467,12 +1600,37 @@ RefPtr<HoistedPrimalsInfo> ensurePrimalAvailability(
defBlock,
defBlockIndices,
useBlockIndices);
- builder.replaceOperand(use, loadAddr);
+ use.replace(&builder, loadAddr);
}
if (!isRecomputeInst)
processedStoreSet.add(localVar);
}
+ else if (!canInstBeStored(instToStore))
+ {
+ // We won't actually process these insts here. Instead we'll
+ // simply make sure that their operands are either already present
+ // in the worklist or add them to the worklist for legalization.
+ //
+
+ List<IRInst*> pendingOperands;
+ for (UIndex ii = 0; ii < instToStore->getOperandCount(); ii++)
+ {
+ auto operand = instToStore->getOperand(ii);
+ if (!instSet.contains(operand) && !seenInstSet.contains(operand))
+ {
+ if(getBlock(operand) &&
+ (getBlock(operand)->getParent() == getBlock(instToStore)->getParent()))
+ pendingOperands.add(operand);
+ }
+ }
+
+ if (pendingOperands.getCount() > 0)
+ {
+ for (Index ii = pendingOperands.getCount() - 1; ii >= 0; --ii)
+ workList.add(pendingOperands[ii]);
+ }
+ }
else
{
// Handle the special case of loop counters.
@@ -1495,16 +1653,18 @@ RefPtr<HoistedPrimalsInfo> ensurePrimalAvailability(
for (auto use : outOfScopeUses)
{
- List<IndexTrackingInfo> useBlockIndices = indexedBlockInfo[getBlock(use->getUser())];
- setInsertBeforeOrdinaryInst(&builder, getInstInBlock(use->getUser()));
- builder.replaceOperand(
- use,
+ List<IndexTrackingInfo> useBlockIndices = indexedBlockInfo[getBlock(use.getUser())];
+ setInsertBeforeOrdinaryInst(&builder, getInstInBlock(use.getUser()));
+ use.replace(
+ &builder,
loadIndexedValue(&builder, localVar, defBlock, defBlockIndices, useBlockIndices));
}
if (!isRecomputeInst)
processedStoreSet.add(localVar);
}
+
+ seenInstSet.add(instToStore);
}
};
diff --git a/source/slang/slang-ir-autodiff-rev.cpp b/source/slang/slang-ir-autodiff-rev.cpp
index 532fb88ac..8d7582373 100644
--- a/source/slang/slang-ir-autodiff-rev.cpp
+++ b/source/slang/slang-ir-autodiff-rev.cpp
@@ -913,7 +913,7 @@ namespace Slang
{
primalType = diffPairType->getValueType();
diffType = (IRType*)differentiableTypeConformanceContext
- .getDifferentialTypeFromDiffPairType(builder, diffPairType);
+ .getDiffTypeFromPairType(builder, diffPairType);
}
// Now we handle each combination of parameter direction x differentiability.
diff --git a/source/slang/slang-ir-autodiff-transcriber-base.cpp b/source/slang/slang-ir-autodiff-transcriber-base.cpp
index 0a9ff51a4..24e26f943 100644
--- a/source/slang/slang-ir-autodiff-transcriber-base.cpp
+++ b/source/slang/slang-ir-autodiff-transcriber-base.cpp
@@ -369,17 +369,30 @@ IRType* AutoDiffTranscriberBase::getOrCreateDiffPairType(IRBuilder* builder, IRI
autoDiffSharedContext->differentialAssocTypeWitnessStructKey);
}
}
+
+ // Obtain the witness that primalType conforms to IDifferentiable.
if (!witness)
witness = tryGetDifferentiableWitness(builder, originalType);
SLANG_RELEASE_ASSERT(witness);
- return builder->getDifferentialPairType(
+ auto pairType = builder->getDifferentialPairType(
(IRType*)primalType,
witness);
+
+ return pairType;
}
IRType* AutoDiffTranscriberBase::differentiateType(IRBuilder* builder, IRType* origType)
{
+ // Special-case for differentiable existential types.
+ if (as<IRInterfaceType>(origType) || as<IRAssociatedType>(origType))
+ {
+ if (differentiableTypeConformanceContext.lookUpConformanceForType(origType))
+ return autoDiffSharedContext->differentiableInterfaceType;
+ else
+ return nullptr;
+ }
+
auto primalType = lookupPrimalInst(builder, origType, origType);
if (primalType->getOp() == kIROp_Param &&
primalType->getParent() && primalType->getParent()->getParent() &&
@@ -482,72 +495,17 @@ IRType* AutoDiffTranscriberBase::_differentiateTypeImpl(IRBuilder* builder, IRTy
}
}
-// Given an interface type, return the lookup path from a witness table of `type` to a witness table of `IDifferentiable`.
-static bool _findDifferentiableInterfaceLookupPathImpl(
- HashSet<IRInst*>& processedTypes,
- IRInterfaceType* idiffType,
- IRInterfaceType* type,
- List<IRInterfaceRequirementEntry*>& currentPath)
-{
- if (processedTypes.contains(type))
- return false;
- processedTypes.add(type);
-
- List<IRInterfaceRequirementEntry*> lookupKeyPath;
- for (UInt i = 0; i < type->getOperandCount(); i++)
- {
- auto entry = as<IRInterfaceRequirementEntry>(type->getOperand(i));
- if (!entry) continue;
- if (auto wt = as<IRWitnessTableTypeBase>(entry->getRequirementVal()))
- {
- currentPath.add(entry);
- if (wt->getConformanceType() == idiffType)
- {
- return true;
- }
- else if (auto subInterfaceType = as<IRInterfaceType>(wt->getConformanceType()))
- {
- if (_findDifferentiableInterfaceLookupPathImpl(processedTypes, idiffType, subInterfaceType, currentPath))
- return true;
- }
- currentPath.removeLast();
- }
- }
- return false;
-}
-
-List<IRInterfaceRequirementEntry*> AutoDiffTranscriberBase::findDifferentiableInterfaceLookupPath(
- IRInterfaceType* idiffType,
- IRInterfaceType* type)
-{
- List<IRInterfaceRequirementEntry*> currentPath;
- HashSet<IRInst*> processedTypes;
- _findDifferentiableInterfaceLookupPathImpl(processedTypes, idiffType, type, currentPath);
- return currentPath;
-}
-
-IRInst* AutoDiffTranscriberBase::tryExtractConformanceFromInterfaceType(
- IRBuilder* builder,
- IRInterfaceType* interfaceType,
- IRWitnessTable* witnessTable)
+bool AutoDiffTranscriberBase::isExistentialType(IRType *type)
{
- SLANG_RELEASE_ASSERT(interfaceType);
-
- List<IRInterfaceRequirementEntry*> lookupKeyPath = findDifferentiableInterfaceLookupPath(
- autoDiffSharedContext->differentiableInterfaceType, interfaceType);
-
- IRInst* differentialTypeWitness = witnessTable;
- if (lookupKeyPath.getCount())
+ switch (type->getOp())
{
- // `interfaceType` does conform to `IDifferentiable`.
- for (auto node : lookupKeyPath)
- {
- differentialTypeWitness = builder->emitLookupInterfaceMethodInst((IRType*)node->getRequirementVal(), differentialTypeWitness, node->getRequirementKey());
- }
- return differentialTypeWitness;
+ case kIROp_ExtractExistentialType:
+ case kIROp_InterfaceType:
+ case kIROp_AssociatedType:
+ return true;
+ default:
+ return false;
}
-
- return nullptr;
}
InstPair AutoDiffTranscriberBase::transcribeExtractExistentialWitnessTable(IRBuilder* builder, IRInst* origInst)
@@ -569,7 +527,7 @@ InstPair AutoDiffTranscriberBase::transcribeExtractExistentialWitnessTable(IRBui
if (!interfaceType)
return InstPair(primalResult, nullptr);
- if (auto differentialWitnessTable = tryExtractConformanceFromInterfaceType(
+ if (auto differentialWitnessTable = differentiableTypeConformanceContext.tryExtractConformanceFromInterfaceType(
builder, interfaceType, (IRWitnessTable*)primalResult))
{
// `interfaceType` does conform to `IDifferentiable`.
@@ -630,7 +588,7 @@ IRType* AutoDiffTranscriberBase::differentiateExtractExistentialType(IRBuilder*
auto interfaceType = as<IRInterfaceType>(unwrapAttributedType(origType->getOperand(0)->getDataType()));
if (!interfaceType)
return nullptr;
- List<IRInterfaceRequirementEntry*> lookupKeyPath = findDifferentiableInterfaceLookupPath(
+ List<IRInterfaceRequirementEntry*> lookupKeyPath = differentiableTypeConformanceContext.findDifferentiableInterfaceLookupPath(
autoDiffSharedContext->differentiableInterfaceType, interfaceType);
if (lookupKeyPath.getCount())
@@ -737,6 +695,13 @@ InstPair AutoDiffTranscriberBase::transcribeLookupInterfaceMethod(IRBuilder* bui
(IRType*)primalDiffType,
primal,
autoDiffSharedContext->differentialAssocTypeWitnessStructKey);
+
+ // Mark both as primal since we're working with types
+ // (which don't need transposing)
+ //
+ builder->markInstAsPrimal(primalDiffType);
+ builder->markInstAsPrimal(diffWitness);
+
return InstPair(primal, diffWitness);
}
}
@@ -762,12 +727,31 @@ InstPair AutoDiffTranscriberBase::transcribeLookupInterfaceMethod(IRBuilder* bui
// result, it's useful to have a method to generate zero literals of any (arithmetic) type.
// The current implementation requires that types are defined linearly.
//
-IRInst* AutoDiffTranscriberBase::getDifferentialZeroOfType(IRBuilder* builder, IRType* originalType)
+IRInst* AutoDiffTranscriberBase::getDifferentialZeroOfType(
+ IRBuilder* builder, IRType* originalType)
{
originalType = (IRType*)unwrapAttributedType(originalType);
auto primalType = (IRType*)lookupPrimalInst(builder, originalType);
if (auto diffType = differentiateType(builder, originalType))
{
+ IRInst* diffWitnessTable = nullptr;
+ IRType* diffOuterType = nullptr;
+ if (isExistentialType(diffType))
+ {
+ // Emit null differential & pack it into an IDifferentiable existential.
+
+ auto nullDiffValue = differentiableTypeConformanceContext.emitNullDifferential(builder);
+ builder->markInstAsDifferential(nullDiffValue, autoDiffSharedContext->nullDifferentialStructType);
+
+ auto nullDiffExistential = builder->emitMakeExistential(
+ diffType,
+ nullDiffValue,
+ autoDiffSharedContext->nullDifferentialWitness);
+ builder->markInstAsDifferential(nullDiffExistential, primalType);
+
+ return nullDiffExistential;
+ }
+
switch (diffType->getOp())
{
case kIROp_DifferentialPairType:
@@ -812,7 +796,7 @@ IRInst* AutoDiffTranscriberBase::getDifferentialZeroOfType(IRBuilder* builder, I
// zero method from the same witness table.
auto wt = lookupInterface->getWitnessTable();
zeroMethod = builder->emitLookupInterfaceMethodInst(builder->getFuncType(List<IRType*>(), diffType), wt, autoDiffSharedContext->zeroMethodStructKey);
- builder->markInstAsDifferential(zeroMethod);
+ builder->markInstAsPrimal(zeroMethod);
}
else
{
@@ -825,7 +809,18 @@ IRInst* AutoDiffTranscriberBase::getDifferentialZeroOfType(IRBuilder* builder, I
auto callInst = builder->emitCallInst((IRType*)diffType, zeroMethod, emptyArgList);
builder->markInstAsDifferential(callInst, primalType);
- return callInst;
+ if (diffOuterType && isExistentialType(diffOuterType))
+ {
+ // Need to wrap the result back into an existential.
+ auto existentialZero = builder->emitMakeExistential(
+ diffOuterType,
+ callInst,
+ diffWitnessTable);
+ builder->markInstAsDifferential(existentialZero, primalType);
+ return existentialZero;
+ }
+ else
+ return callInst;
}
else
{
diff --git a/source/slang/slang-ir-autodiff-transcriber-base.h b/source/slang/slang-ir-autodiff-transcriber-base.h
index d6b2ea9ff..e9acbcd99 100644
--- a/source/slang/slang-ir-autodiff-transcriber-base.h
+++ b/source/slang/slang-ir-autodiff-transcriber-base.h
@@ -87,13 +87,8 @@ struct AutoDiffTranscriberBase
IRInst* maybeCloneForPrimalInst(IRBuilder* builder, IRInst* inst);
- List<IRInterfaceRequirementEntry*> findDifferentiableInterfaceLookupPath(
- IRInterfaceType* idiffType, IRInterfaceType* type);
-
InstPair transcribeExtractExistentialWitnessTable(IRBuilder* builder, IRInst* origInst);
- IRInst* tryExtractConformanceFromInterfaceType(IRBuilder* builder, IRInterfaceType* type, IRWitnessTable* WitnessTable);
-
void maybeMigrateDifferentiableDictionaryFromDerivativeFunc(IRBuilder* builder, IRInst* origFunc);
// Get or construct `:IDifferentiable` conformance for a DifferentiablePair.
@@ -141,6 +136,10 @@ struct AutoDiffTranscriberBase
IRType* _differentiateTypeImpl(IRBuilder* builder, IRType* origType);
+ bool isExistentialType(IRType* type);
+
+ void _markInstAsDifferential(IRBuilder* builder, IRInst* diffInst, IRInst* primalInst = nullptr);
+
virtual IRFuncType* differentiateFunctionType(IRBuilder* builder, IRInst* func, IRFuncType* funcType) = 0;
// Create an empty func to represent the transcribed func of `origFunc`.
diff --git a/source/slang/slang-ir-autodiff-transpose.h b/source/slang/slang-ir-autodiff-transpose.h
index dad4ab192..bcebd2108 100644
--- a/source/slang/slang-ir-autodiff-transpose.h
+++ b/source/slang/slang-ir-autodiff-transpose.h
@@ -583,15 +583,19 @@ struct DiffTransposePass
auto nextInst = inst->getNextInst();
if (auto varInst = as<IRVar>(inst))
{
- if (auto diffDecor = varInst->findDecoration<IRDifferentialInstDecoration>())
+ if (isDifferentialInst(varInst) && tryGetPrimalTypeFromDiffInst(varInst))
{
- if (auto ptrPrimalType = as<IRPtrTypeBase>(diffDecor->getPrimalType()))
+ if (auto ptrPrimalType = as<IRPtrTypeBase>(tryGetPrimalTypeFromDiffInst(varInst)))
{
varInst->insertAtEnd(firstRevDiffBlock);
auto dzero = emitDZeroOfDiffInstType(&builder, ptrPrimalType->getValueType());
builder.emitStore(varInst, dzero);
}
+ else
+ {
+ SLANG_UNEXPECTED("Expected an pointer-typed differential variable.");
+ }
}
}
inst = nextInst;
@@ -1139,21 +1143,15 @@ struct DiffTransposePass
// Normal differentiable input parameter will become an inout DiffPair parameter
// in the propagate func. The split logic has already prepared the initial value
// to pass in. We need to define a temp variable with this initial value and pass
- // in the temp variable as argument to the inout parameter.
+ // in the temp variable as argument to the inout parameter.
auto makePairArg = as<IRMakeDifferentialPair>(arg);
SLANG_RELEASE_ASSERT(makePairArg);
auto pairType = as<IRDifferentialPairType>(arg->getDataType());
auto var = builder->emitVar(arg->getDataType());
-
- auto diffType = (IRType*)diffTypeContext.getDiffTypeFromPairType(builder, pairType);
- auto zeroMethod = diffTypeContext.getDiffZeroMethodFromPairType(builder, pairType);
- SLANG_ASSERT(zeroMethod);
- auto diffZero = builder->emitCallInst(
- diffType,
- zeroMethod,
- List<IRInst*>());
+
+ auto diffZero = emitDZeroOfDiffInstType(builder, pairType->getValueType());
// Initialize this var to (arg.primal, 0).
builder->emitStore(
@@ -1484,6 +1482,18 @@ struct DiffTransposePass
case kIROp_FloatCast:
return transposeFloatCast(builder, fwdInst, revValue);
+ case kIROp_MakeExistential:
+ return transposeMakeExistential(builder, fwdInst, revValue);
+
+ case kIROp_ExtractExistentialValue:
+ return transposeExtractExistentialValue(builder, fwdInst, revValue);
+
+ case kIROp_Reinterpret:
+ return transposeReinterpret(builder, fwdInst, revValue);
+
+ case kIROp_PackAnyValue:
+ return transposePackAnyValue(builder, fwdInst, revValue);
+
case kIROp_LoadReverseGradient:
case kIROp_ReverseGradientDiffPairRef:
case kIROp_DefaultConstruct:
@@ -1495,7 +1505,6 @@ struct DiffTransposePass
case kIROp_Switch:
case kIROp_LookupWitness:
case kIROp_ExtractExistentialType:
- case kIROp_ExtractExistentialValue:
case kIROp_ExtractExistentialWitnessTable:
{
// Ignore. transposeBlock() should take care of adding the
@@ -1574,7 +1583,7 @@ struct DiffTransposePass
if (auto diffPairType = as<IRDifferentialPairType>(revVal->getDataType()))
{
revVal = builder->emitDifferentialPairGetDifferential(
- (IRType*)diffTypeContext.getDifferentialTypeFromDiffPairType(
+ (IRType*)diffTypeContext.getDiffTypeFromPairType(
builder, diffPairType),
revVal);
}
@@ -1992,6 +2001,110 @@ struct DiffTransposePass
fwdInst)));
}
+ TranspositionResult transposeMakeExistential(IRBuilder* builder, IRInst* fwdInst, IRInst* revValue)
+ {
+ auto isExistentialType = [&](IRInst* type) -> bool
+ {
+ switch (type->getOp())
+ {
+ case kIROp_ExtractExistentialType:
+ case kIROp_LookupWitness:
+ return true;
+ default:
+ return false;
+ }
+ };
+
+ auto diffType = fwdInst->getOperand(0)->getDataType();
+ if (isExistentialType(diffType))
+ {
+ // (A:IDiff = MakeExistential(B, W)) -> (dB: T += ExtractExistentialValue(dW))
+ return TranspositionResult(
+ List<RevGradient>(
+ RevGradient(
+ RevGradient::Flavor::Simple,
+ fwdInst->getOperand(0),
+ builder->emitExtractExistentialValue(
+ fwdInst->getOperand(0)->getDataType(),
+ revValue),
+ fwdInst)));
+ }
+ else
+ {
+ // We have a concrete type.
+ // (A:IDiff = MakeExistential(B, W)) ->
+ // (dB: T += ExtractExistentialValue(Reinterpret(dW)))
+ auto diffValInDiffType = builder->emitReinterpret(
+ diffType,
+ builder->emitExtractExistentialValue(
+ builder->emitExtractExistentialType(revValue),
+ revValue));
+
+ return TranspositionResult(
+ List<RevGradient>(
+ RevGradient(
+ RevGradient::Flavor::Simple,
+ fwdInst->getOperand(0),
+ diffValInDiffType,
+ fwdInst)));
+ }
+ }
+
+ TranspositionResult transposeExtractExistentialValue(IRBuilder* builder, IRInst* fwdInst, IRInst* revValue)
+ {
+ auto primalType = tryGetPrimalTypeFromDiffInst(fwdInst);
+ SLANG_ASSERT(primalType);
+
+ // If we reach this point, revValue must be a differentiable type.
+ auto revTypeWitness = diffTypeContext.tryGetDifferentiableWitness(
+ builder,
+ primalType);
+ SLANG_ASSERT(revTypeWitness);
+
+ auto baseExistential = fwdInst->getOperand(0);
+
+ // (dA = ExtractExistentialValue(dB)) -> (dB += MakeExistential(T, A, ExtractExistentialWitness(B)))
+ return TranspositionResult(
+ List<RevGradient>(
+ RevGradient(
+ RevGradient::Flavor::Simple,
+ baseExistential,
+ builder->emitMakeExistential(
+ baseExistential->getDataType(),
+ revValue,
+ revTypeWitness),
+ fwdInst)));
+ }
+
+ TranspositionResult transposeReinterpret(IRBuilder* builder, IRInst* fwdInst, IRInst* revValue)
+ {
+ // (A = reinterpret<T, U>(B)) -> (dB += reinterpret<U, T>(dA))
+ return TranspositionResult(
+ List<RevGradient>(
+ RevGradient(
+ RevGradient::Flavor::Simple,
+ fwdInst->getOperand(0),
+ builder->emitReinterpret(
+ fwdInst->getOperand(0)->getDataType(),
+ revValue),
+ fwdInst)));
+ }
+
+
+ TranspositionResult transposePackAnyValue(IRBuilder* builder, IRInst* fwdInst, IRInst* revValue)
+ {
+ // (A = packAnyValue<T, U>(B)) -> (dB += unpackAnyValue<U, T>(dA))
+ return TranspositionResult(
+ List<RevGradient>(
+ RevGradient(
+ RevGradient::Flavor::Simple,
+ fwdInst->getOperand(0),
+ builder->emitUnpackAnyValue(
+ fwdInst->getOperand(0)->getDataType(),
+ revValue),
+ fwdInst)));
+ }
+
// Gather all reverse-mode gradients for a Load inst, aggregate them and store them in the ptr.
//
void accumulateGradientsForLoad(IRBuilder* builder, IRLoad* revLoad)
@@ -2681,13 +2794,18 @@ struct DiffTransposePass
{
// Look for differential inst decoration.
if (auto diffInstDecoration = diffInst->findDecoration<IRDifferentialInstDecoration>())
- {
return diffInstDecoration->getPrimalType();
- }
- else
- {
- return nullptr;
- }
+
+ return nullptr;
+ }
+
+ IRInst* tryGetWitnessFromDiffInst(IRInst* diffInst)
+ {
+ // Look for differential inst decoration.
+ if (auto diffInstDecoration = diffInst->findDecoration<IRDifferentialInstDecoration>())
+ return diffInstDecoration->getWitness();
+
+ return nullptr;
}
IRInst* emitDZeroOfDiffInstType(IRBuilder* builder, IRType* primalType)
@@ -2709,6 +2827,16 @@ struct DiffTransposePass
auto diffDiffPairType = builder->getDifferentialPairUserCodeType(diffType, diffWitness);
return builder->emitMakeDifferentialPairUserCode(diffDiffPairType, primalZero, diffZero);
}
+ else if (as<IRInterfaceType>(primalType) || as<IRAssociatedType>(primalType))
+ {
+ // Pack a null value into an existential type.
+ auto existentialZero = builder->emitMakeExistential(
+ autodiffContext->differentiableInterfaceType,
+ diffTypeContext.emitNullDifferential(builder),
+ autodiffContext->nullDifferentialWitness);
+
+ return existentialZero;
+ }
auto zeroMethod = diffTypeContext.getZeroMethodForType(builder, primalType);
@@ -2720,6 +2848,19 @@ struct DiffTransposePass
zeroMethod,
List<IRInst*>());
}
+
+ IRInst* emitDAddForExistentialType(IRBuilder* builder, IRType* primalType, IRInst* op1, IRInst* op2)
+ {
+ auto existentialDAddFunc = diffTypeContext.getOrCreateExistentialDAddMethod();
+
+ // Should exist.
+ SLANG_ASSERT(existentialDAddFunc);
+
+ return builder->emitCallInst(
+ (IRType*)diffTypeContext.getDifferentialForType(builder, primalType),
+ existentialDAddFunc,
+ List<IRInst*>({ op1, op2 }));
+ }
IRInst* emitDAddOfDiffInstType(IRBuilder* builder, IRType* primalType, IRInst* op1, IRInst* op2)
{
@@ -2764,6 +2905,13 @@ struct DiffTransposePass
auto diffDiffPairType = builder->getDifferentialPairUserCodeType(diffType, diffWitness);
return builder->emitMakeDifferentialPairUserCode(diffDiffPairType, primal, diff);
}
+ else if (as<IRInterfaceType>(primalType) || as<IRAssociatedType>(primalType))
+ {
+ // If our type is existential, we need to handle the case where
+ // one or both of our operands are null-type.
+ //
+ return emitDAddForExistentialType(builder, primalType, op1, op2);
+ }
auto addMethod = diffTypeContext.getAddMethodForType(builder, primalType);
diff --git a/source/slang/slang-ir-autodiff-unzip.cpp b/source/slang/slang-ir-autodiff-unzip.cpp
index 95ad0d921..2857424f9 100644
--- a/source/slang/slang-ir-autodiff-unzip.cpp
+++ b/source/slang/slang-ir-autodiff-unzip.cpp
@@ -417,7 +417,7 @@ IRFunc* DiffUnzipPass::extractPrimalFunc(
inst->getFullType(),
intermediateVar,
structKeyDecor->getStructKey());
- iuse->set(val);
+ builder.replaceOperand(iuse, val);
}
}
instsToRemove.add(inst);
diff --git a/source/slang/slang-ir-autodiff-unzip.h b/source/slang/slang-ir-autodiff-unzip.h
index 4846fc840..c57dc300f 100644
--- a/source/slang/slang-ir-autodiff-unzip.h
+++ b/source/slang/slang-ir-autodiff-unzip.h
@@ -88,7 +88,7 @@ struct DiffUnzipPass
}
if (auto pairType = as<IRDifferentialPairType>(type))
{
- IRInst* diffType = diffTypeContext.getDifferentialTypeFromDiffPairType(builder, pairType);
+ IRInst* diffType = diffTypeContext.getDiffTypeFromPairType(builder, pairType);
if (as<IRPtrTypeBase>(primalParam->getFullType()))
diffType = builder->getPtrType(primalParam->getFullType()->getOp(), (IRType*)diffType);
auto primalRef = builder->emitPrimalParamRef(primalParam);
@@ -286,7 +286,8 @@ struct DiffUnzipPass
if (auto fwdPairResultType = as<IRDifferentialPairType>(mixedDecoration->getPairType()))
{
primalType = fwdPairResultType->getValueType();
- diffType = (IRType*)diffTypeContext.getDifferentialForType(&globalBuilder, primalType);
+ diffType = (IRType*)diffTypeContext.getDiffTypeFromPairType(&globalBuilder, fwdPairResultType);
+ SLANG_ASSERT(diffType);
resultType = fwdPairResultType;
}
diff --git a/source/slang/slang-ir-autodiff.cpp b/source/slang/slang-ir-autodiff.cpp
index cb710ac6b..645662caa 100644
--- a/source/slang/slang-ir-autodiff.cpp
+++ b/source/slang/slang-ir-autodiff.cpp
@@ -35,6 +35,15 @@ IRInst* _lookupWitness(IRBuilder* builder, IRInst* witness, IRInst* requirementK
return entry->getSatisfyingVal();
}
}
+ else if (auto interfaceType = as<IRInterfaceType>(witness))
+ {
+ for (UIndex ii = 0; ii < interfaceType->getOperandCount(); ii++)
+ {
+ auto entry = cast<IRInterfaceRequirementEntry>(interfaceType->getOperand(ii));
+ if (entry->getRequirementKey() == requirementKey)
+ return entry->getRequirementVal();
+ }
+ }
else
{
return builder->emitLookupInterfaceMethodInst(
@@ -47,8 +56,17 @@ IRInst* _lookupWitness(IRBuilder* builder, IRInst* witness, IRInst* requirementK
static IRInst* _getDiffTypeFromPairType(AutoDiffSharedContext*sharedContext, IRBuilder* builder, IRDifferentialPairTypeBase* type)
{
- auto witnessTable = type->getWitness();
- return _lookupWitness(builder, witnessTable, sharedContext->differentialAssocTypeStructKey);
+ auto witness = type->getWitness();
+ SLANG_RELEASE_ASSERT(witness);
+
+ // Special case when the primal type is an InterfaceType/AssociatedType
+ if (as<IRInterfaceType>(type->getValueType()) || as<IRAssociatedType>(type->getValueType()))
+ {
+ // The differential type is the IDifferentiable interface type.
+ return sharedContext->differentiableInterfaceType;
+ }
+
+ return _lookupWitness(builder, witness, sharedContext->differentialAssocTypeStructKey);
}
static IRInst* _getDiffTypeWitnessFromPairType(AutoDiffSharedContext* sharedContext, IRBuilder* builder, IRDifferentialPairTypeBase* type)
@@ -332,6 +350,8 @@ AutoDiffSharedContext::AutoDiffSharedContext(IRModuleInst* inModuleInst)
zeroMethodStructKey = findZeroMethodStructKey();
addMethodStructKey = findAddMethodStructKey();
mulMethodStructKey = findMulMethodStructKey();
+ nullDifferentialStructType = findNullDifferentialStructType();
+ nullDifferentialWitness = findNullDifferentialWitness();
if (differentialAssocTypeStructKey)
isInterfaceAvailable = true;
@@ -362,6 +382,47 @@ IRInst* AutoDiffSharedContext::findDifferentiableInterface()
return nullptr;
}
+IRStructType* AutoDiffSharedContext::findNullDifferentialStructType()
+{
+ if (auto module = as<IRModuleInst>(moduleInst))
+ {
+ for (auto globalInst : module->getGlobalInsts())
+ {
+ // TODO: Also a particularly dangerous way to look for a struct...
+ if (auto structType = as<IRStructType>(globalInst))
+ {
+ if (auto decor = structType->findDecoration<IRNameHintDecoration>())
+ {
+ if (decor->getName() == toSlice("NullDifferential"))
+ {
+ return structType;
+ }
+ }
+ }
+ }
+ }
+ return nullptr;
+}
+
+IRInst* AutoDiffSharedContext::findNullDifferentialWitness()
+{
+ if (auto module = as<IRModuleInst>(moduleInst))
+ {
+ for (auto globalInst : module->getGlobalInsts())
+ {
+ if (auto witnessTable = as<IRWitnessTable>(globalInst))
+ {
+ if (witnessTable->getConformanceType() == differentiableInterfaceType
+ && witnessTable->getConcreteType() == nullDifferentialStructType)
+ return witnessTable;
+ }
+ }
+ }
+
+ return nullptr;
+}
+
+
IRStructKey* AutoDiffSharedContext::getIDifferentiableStructKeyAtIndex(UInt index)
{
if (as<IRModuleInst>(moduleInst) && differentiableInterfaceType)
@@ -442,11 +503,9 @@ IRInst* DifferentiableTypeConformanceContext::lookUpInterfaceMethod(IRBuilder* b
}
IRInst* DifferentiableTypeConformanceContext::getDifferentialTypeFromDiffPairType(
- IRBuilder* builder, IRDifferentialPairTypeBase* diffPairType)
+ IRBuilder*, IRDifferentialPairTypeBase*)
{
- auto witness = diffPairType->getWitness();
- SLANG_RELEASE_ASSERT(witness);
- return _lookupWitness(builder, witness, sharedContext->differentialAssocTypeStructKey);
+ SLANG_UNIMPLEMENTED_X("");
}
IRInst* DifferentiableTypeConformanceContext::getDiffTypeFromPairType(IRBuilder* builder, IRDifferentialPairTypeBase* type)
@@ -471,6 +530,189 @@ IRInst* DifferentiableTypeConformanceContext::getDiffAddMethodFromPairType(IRBui
return _lookupWitness(builder, witnessTable, sharedContext->addMethodStructKey);
}
+IRInst *DifferentiableTypeConformanceContext::tryExtractConformanceFromInterfaceType(IRBuilder *builder, IRInterfaceType *interfaceType, IRWitnessTable *witnessTable)
+{
+ SLANG_RELEASE_ASSERT(interfaceType);
+
+ List<IRInterfaceRequirementEntry*> lookupKeyPath = findDifferentiableInterfaceLookupPath(
+ sharedContext->differentiableInterfaceType, interfaceType);
+
+ IRInst* differentialTypeWitness = witnessTable;
+ if (lookupKeyPath.getCount())
+ {
+ // `interfaceType` does conform to `IDifferentiable`.
+ for (auto node : lookupKeyPath)
+ {
+ differentialTypeWitness = builder->emitLookupInterfaceMethodInst((IRType*)node->getRequirementVal(), differentialTypeWitness, node->getRequirementKey());
+ // Lookup insts are always primal values.
+ builder->markInstAsPrimal(differentialTypeWitness);
+ }
+ return differentialTypeWitness;
+ }
+
+ return nullptr;
+}
+
+// Given an interface type, return the lookup path from a witness table of `type` to a witness table of `IDifferentiable`.
+static bool _findDifferentiableInterfaceLookupPathImpl(
+ HashSet<IRInst*>& processedTypes,
+ IRInterfaceType* idiffType,
+ IRInterfaceType* type,
+ List<IRInterfaceRequirementEntry*>& currentPath)
+{
+ if (processedTypes.contains(type))
+ return false;
+ processedTypes.add(type);
+
+ List<IRInterfaceRequirementEntry*> lookupKeyPath;
+ for (UInt i = 0; i < type->getOperandCount(); i++)
+ {
+ auto entry = as<IRInterfaceRequirementEntry>(type->getOperand(i));
+ if (!entry) continue;
+ if (auto wt = as<IRWitnessTableTypeBase>(entry->getRequirementVal()))
+ {
+ currentPath.add(entry);
+ if (wt->getConformanceType() == idiffType)
+ {
+ return true;
+ }
+ else if (auto subInterfaceType = as<IRInterfaceType>(wt->getConformanceType()))
+ {
+ if (_findDifferentiableInterfaceLookupPathImpl(processedTypes, idiffType, subInterfaceType, currentPath))
+ return true;
+ }
+ currentPath.removeLast();
+ }
+ }
+ return false;
+}
+
+List<IRInterfaceRequirementEntry *> DifferentiableTypeConformanceContext::findDifferentiableInterfaceLookupPath(IRInterfaceType *idiffType, IRInterfaceType *type)
+{
+ List<IRInterfaceRequirementEntry*> currentPath;
+ HashSet<IRInst*> processedTypes;
+ _findDifferentiableInterfaceLookupPathImpl(processedTypes, idiffType, type, currentPath);
+ return currentPath;
+}
+
+IRFunc *DifferentiableTypeConformanceContext::getOrCreateExistentialDAddMethod()
+{
+ if (this->existentialDAddFunc)
+ return this->existentialDAddFunc;
+
+ SLANG_ASSERT(sharedContext->differentiableInterfaceType);
+ SLANG_ASSERT(sharedContext->nullDifferentialWitness);
+
+ auto builder = IRBuilder(this->sharedContext->moduleInst);
+
+ existentialDAddFunc = builder.createFunc();
+ existentialDAddFunc->setFullType(builder.getFuncType(
+ List<IRType*>({
+ sharedContext->differentiableInterfaceType,
+ sharedContext->differentiableInterfaceType,
+ }),
+ sharedContext->differentiableInterfaceType));
+
+ builder.setInsertInto(existentialDAddFunc);
+ auto entryBlock = builder.emitBlock();
+
+ builder.setInsertInto(entryBlock);
+
+ // Insert parameters.
+ auto aObj = builder.emitParam(sharedContext->differentiableInterfaceType);
+ auto bObj = builder.emitParam(sharedContext->differentiableInterfaceType);
+
+ // Check if a.type == null_differential.type
+ auto aObjWitnessIsNull = builder.emitIsDifferentialNull(aObj);
+
+ // If aObjWitnessTable is null, return bObj.
+ auto aObjWitnessIsNullBlock = builder.emitBlock();
+ builder.setInsertInto(aObjWitnessIsNullBlock);
+ builder.emitReturn(bObj);
+
+ auto aObjWitnessIsNotNullBlock = builder.emitBlock();
+ builder.setInsertInto(aObjWitnessIsNotNullBlock);
+
+ // Check if b.type == null_differential.type
+ auto bObjWitnessIsNull = builder.emitIsDifferentialNull(bObj);
+
+ // If bObjWitnessTable is null, return aObj.
+ auto bObjWitnessIsNullBlock = builder.emitBlock();
+ builder.setInsertInto(bObjWitnessIsNullBlock);
+ builder.emitReturn(aObj);
+
+ auto bObjWitnessIsNotNullBlock = builder.emitBlock();
+
+ // Emit aObj.type::dadd(aObj.val, bObj.val)
+ //
+ // Important: we're looking up dadd on the differential type, and
+ // not the primal type. This assumes that the two methods are identical,
+ // which (mathematically) they should be.
+ //
+ auto concreteDiffTypeWitnessTable = builder.emitExtractExistentialWitnessTable(aObj);
+
+ // Extract func type from the witness table type.
+ IRFuncType* dAddFuncType = nullptr;
+ for (UIndex ii = 0; ii < sharedContext->differentiableInterfaceType->getOperandCount(); ii++)
+ {
+ auto entry = cast<IRInterfaceRequirementEntry>(sharedContext->differentiableInterfaceType->getOperand(ii));
+ if (entry->getRequirementKey() == sharedContext->addMethodStructKey)
+ {
+ dAddFuncType = cast<IRFuncType>(entry->getRequirementVal());
+ break;
+ }
+ }
+
+ SLANG_ASSERT(dAddFuncType);
+
+ auto dAddMethod = builder.emitLookupInterfaceMethodInst(
+ dAddFuncType,
+ concreteDiffTypeWitnessTable,
+ sharedContext->addMethodStructKey);
+
+ // Call
+ auto dAddResult = builder.emitCallInst(
+ dAddFuncType->getResultType(),
+ dAddMethod,
+ List<IRInst*>({
+ builder.emitExtractExistentialValue(dAddFuncType->getParamType(0), aObj),
+ builder.emitExtractExistentialValue(dAddFuncType->getParamType(1), bObj)}));
+
+ // Wrap result in existential.
+ auto existentialDiffType = builder.emitMakeExistential(
+ sharedContext->differentiableInterfaceType,
+ dAddResult,
+ concreteDiffTypeWitnessTable);
+
+ builder.emitReturn(existentialDiffType);
+
+ // Emit an unreachable block to act as the after block.
+ auto unreachableBlock = builder.emitBlock();
+ builder.setInsertInto(unreachableBlock);
+ builder.emitUnreachable();
+
+ // Link up conditional blocks.
+ builder.setInsertInto(entryBlock);
+ builder.emitIfElse(
+ aObjWitnessIsNull,
+ aObjWitnessIsNullBlock,
+ aObjWitnessIsNotNullBlock,
+ unreachableBlock);
+
+ builder.setInsertInto(aObjWitnessIsNotNullBlock);
+ builder.emitIfElse(
+ bObjWitnessIsNull,
+ bObjWitnessIsNullBlock,
+ bObjWitnessIsNotNullBlock,
+ unreachableBlock);
+
+ builder.addNameHintDecoration(existentialDAddFunc, UnownedStringSlice("__existential_dadd"));
+ builder.addBackwardDifferentiableDecoration(existentialDAddFunc);
+
+ this->existentialDAddFunc = existentialDAddFunc;
+ return existentialDAddFunc;
+}
+
void DifferentiableTypeConformanceContext::buildGlobalWitnessDictionary()
{
for (auto globalInst : sharedContext->moduleInst->getChildren())
@@ -745,9 +987,20 @@ IRInst* DifferentiableTypeConformanceContext::getArrayWitness(IRBuilder* builder
return table;
}
-IRInst* DifferentiableTypeConformanceContext::getExtractExistensialTypeWitness(IRBuilder*, IRExtractExistentialType*)
+IRInst* DifferentiableTypeConformanceContext::getExtractExistensialTypeWitness(
+ IRBuilder* builder,
+ IRExtractExistentialType* extractExistentialType)
{
- SLANG_UNIMPLEMENTED_X("TODO: Implement");
+ // Check that the type's base is differentiable
+ if (differentiateType(builder, extractExistentialType->getOperand(0)->getDataType()))
+ {
+ return tryExtractConformanceFromInterfaceType(
+ builder,
+ cast<IRInterfaceType>(extractExistentialType->getOperand(0)->getDataType()),
+ (IRWitnessTable*)builder->emitExtractExistentialWitnessTable(extractExistentialType->getOperand(0)));
+ }
+
+ return nullptr;
}
@@ -1761,6 +2014,71 @@ void removeDetachInsts(IRModule* module)
pass.processModule();
}
+struct LowerNullCheckPass : InstPassBase
+{
+ LowerNullCheckPass(IRModule* module, AutoDiffSharedContext* context) :
+ InstPassBase(module), context(context)
+ {
+ }
+ void processModule()
+ {
+ List<IRInst*> nullCheckInsts;
+ processInstsOfType<IRIsDifferentialNull>(kIROp_IsDifferentialNull, [&](IRIsDifferentialNull* isDiffNullInst)
+ {
+ IRBuilder builder(module);
+ builder.setInsertBefore(isDiffNullInst);
+
+ // Extract existential type from the operand.
+ auto operand = isDiffNullInst->getBase();
+ auto operandConcreteWitness = builder.emitExtractExistentialWitnessTable(operand);
+ auto witnessID = builder.emitGetSequentialIDInst(operandConcreteWitness);
+
+ auto nullDiffWitnessTable = context->nullDifferentialWitness;
+ auto nullDiffWitnessID = builder.emitGetSequentialIDInst(nullDiffWitnessTable);
+
+ // Compare the concrete type with the null differential witness table.
+ auto isDiffNull = builder.emitEql(witnessID, nullDiffWitnessID);
+
+ isDiffNullInst->replaceUsesWith(isDiffNull);
+ nullCheckInsts.add(isDiffNullInst);
+ });
+
+ for (auto nullCheckInst : nullCheckInsts)
+ {
+ nullCheckInst->removeAndDeallocate();
+ }
+ }
+
+ private:
+ AutoDiffSharedContext* context;
+};
+
+void lowerNullCheckInsts(IRModule* module, AutoDiffSharedContext* context)
+{
+ LowerNullCheckPass pass(module, context);
+ pass.processModule();
+}
+
+void releaseNullDifferentialType(AutoDiffSharedContext* context)
+{
+ if (auto nullStruct = context->nullDifferentialStructType)
+ {
+ if (auto publicDecoration = nullStruct->findDecoration<IRPublicDecoration>())
+ publicDecoration->removeAndDeallocate();
+ if (auto keepAliveDecoration = nullStruct->findDecoration<IRKeepAliveDecoration>())
+ keepAliveDecoration->removeAndDeallocate();
+ }
+
+ if (auto nullWitness = context->nullDifferentialWitness)
+ {
+ if (auto publicDecoration = nullWitness->findDecoration<IRPublicDecoration>())
+ publicDecoration->removeAndDeallocate();
+ if (auto keepAliveDecoration = nullWitness->findDecoration<IRKeepAliveDecoration>())
+ keepAliveDecoration->removeAndDeallocate();
+ }
+
+}
+
bool finalizeAutoDiffPass(IRModule* module)
{
bool modified = false;
@@ -1777,17 +2095,25 @@ bool finalizeAutoDiffPass(IRModule* module)
removeDetachInsts(module);
+ lowerNullCheckInsts(module, &autodiffContext);
+
stripNoDiffTypeAttribute(module);
// Remove auto-diff related decorations.
stripAutoDiffDecorations(module);
+ // Remove keep-alive decorations from null-differential type
+ // so it can be DCE'd if unused.
+ //
+ releaseNullDifferentialType(&autodiffContext);
+
return modified;
}
IRBlock* getBlock(IRInst* inst)
{
- SLANG_RELEASE_ASSERT(inst);
+ if (!inst)
+ return nullptr;
if (auto block = as<IRBlock>(inst))
return block;
diff --git a/source/slang/slang-ir-autodiff.h b/source/slang/slang-ir-autodiff.h
index fdbf5c65e..be51fba6f 100644
--- a/source/slang/slang-ir-autodiff.h
+++ b/source/slang/slang-ir-autodiff.h
@@ -92,6 +92,16 @@ struct AutoDiffSharedContext
IRStructKey* mulMethodStructKey = nullptr;
+ // Refernce to NullDifferential struct type. These are used
+ // as sentinel values for uninitialized existential (interface-typed)
+ // differentials.
+ //
+ IRStructType* nullDifferentialStructType = nullptr;
+
+ // Reference to the NullDifferential : IDifferentiable witness.
+ //
+ IRInst* nullDifferentialWitness = nullptr;
+
// Modules that don't use differentiable types
// won't have the IDifferentiable interface type available.
@@ -109,6 +119,10 @@ private:
IRInst* findDifferentiableInterface();
+ IRStructType *findNullDifferentialStructType();
+
+ IRInst *findNullDifferentialWitness();
+
IRStructKey* findDifferentialTypeStructKey()
{
return getIDifferentiableStructKeyAtIndex(0);
@@ -144,9 +158,17 @@ struct DifferentiableTypeConformanceContext
IRGlobalValueWithCode* parentFunc = nullptr;
OrderedDictionary<IRType*, IRInst*> differentiableWitnessDictionary;
+ IRFunc* existentialDAddFunc = nullptr;
+
DifferentiableTypeConformanceContext(AutoDiffSharedContext* shared)
: sharedContext(shared)
- {}
+ {
+ // Populate dictionary with null differential type.
+ if (sharedContext->nullDifferentialStructType)
+ differentiableWitnessDictionary.add(
+ sharedContext->nullDifferentialStructType,
+ sharedContext->nullDifferentialWitness);
+ }
void setFunc(IRGlobalValueWithCode* func);
@@ -181,6 +203,15 @@ struct DifferentiableTypeConformanceContext
IRInst* getDiffAddMethodFromPairType(IRBuilder* builder, IRDifferentialPairTypeBase* type);
+ IRInst* tryExtractConformanceFromInterfaceType(
+ IRBuilder* builder,
+ IRInterfaceType* interfaceType,
+ IRWitnessTable* witnessTable);
+
+ List<IRInterfaceRequirementEntry*> findDifferentiableInterfaceLookupPath(
+ IRInterfaceType* idiffType,
+ IRInterfaceType* type);
+
// Lookup and return the 'Differential' type declared in the concrete type
// in order to conform to the IDifferentiable interface.
// Note that inside a generic block, this will be a witness table lookup instruction
@@ -190,6 +221,13 @@ struct DifferentiableTypeConformanceContext
{
switch (origType->getOp())
{
+ case kIROp_InterfaceType:
+ {
+ if (isDifferentiableType(origType))
+ return this->sharedContext->differentiableInterfaceType;
+ else
+ return nullptr;
+ }
case kIROp_ArrayType:
{
auto diffElementType = (IRType*)getDifferentialForType(
@@ -249,6 +287,17 @@ struct DifferentiableTypeConformanceContext
auto result = lookUpInterfaceMethod(builder, origType, sharedContext->addMethodStructKey);
return result;
}
+
+ IRInst* emitNullDifferential(IRBuilder* builder)
+ {
+ return builder->emitCallInst(
+ sharedContext->nullDifferentialStructType,
+ getZeroMethodForType(builder, sharedContext->nullDifferentialStructType),
+ List<IRInst*>());
+ }
+
+ IRFunc* getOrCreateExistentialDAddMethod();
+
};
struct DifferentialPairTypeBuilder
diff --git a/source/slang/slang-ir-clone.cpp b/source/slang/slang-ir-clone.cpp
index 9258e511f..43b60d6ed 100644
--- a/source/slang/slang-ir-clone.cpp
+++ b/source/slang/slang-ir-clone.cpp
@@ -267,7 +267,15 @@ IRInst* cloneInst(
env, builder, oldInst);
env->mapOldValToNew.add(oldInst, newInst);
-
+
+ // For hoistable insts, its possible that the cloned inst is the same
+ // as the original inst.
+ // Skip the decoration/children cloning in that case (which will end up
+ // in an infinite loop)
+ //
+ if (newInst == oldInst)
+ return newInst;
+
cloneInstDecorationsAndChildren(
env, builder->getModule(), oldInst, newInst);
diff --git a/source/slang/slang-ir-generics-lowering-context.h b/source/slang/slang-ir-generics-lowering-context.h
index c8a7be3ee..509df6a33 100644
--- a/source/slang/slang-ir-generics-lowering-context.h
+++ b/source/slang/slang-ir-generics-lowering-context.h
@@ -33,7 +33,6 @@ namespace Slang
Dictionary<IRInterfaceType*, IRInterfaceType*> loweredInterfaceTypes;
Dictionary<IRInterfaceType*, IRInterfaceType*> mapLoweredInterfaceToOriginal;
-
// Dictionaries for interface type requirement key-value lookups.
// Used by `findInterfaceRequirementVal`.
Dictionary<IRInterfaceType*, Dictionary<IRInst*, IRInst*>> mapInterfaceRequirementKeyValue;
diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h
index f110c07e7..a8fdd8202 100644
--- a/source/slang/slang-ir-inst-defs.h
+++ b/source/slang/slang-ir-inst-defs.h
@@ -362,6 +362,9 @@ INST(PrimalParamRef, PrimalParamRef, 1, 0)
// to represent a reference to an inout parameter for use in the back-prop part of the computation.
INST(DiffParamRef, DiffParamRef, 1, 0)
+// Check that the value is a differential null value.
+INST(IsDifferentialNull, IsDifferentialNull, 1, 0)
+
INST(FieldExtract, get_field, 2, 0)
INST(FieldAddress, get_field_addr, 2, 0)
@@ -935,8 +938,8 @@ INST(WrapExistential, wrapExistential, 1, 0)
INST(GetValueFromBoundInterface, getValueFromBoundInterface, 1, 0)
INST(ExtractExistentialValue, extractExistentialValue, 1, 0)
-INST(ExtractExistentialType, extractExistentialType, 1, 0)
-INST(ExtractExistentialWitnessTable, extractExistentialWitnessTable, 1, 0)
+INST(ExtractExistentialType, extractExistentialType, 1, HOISTABLE)
+INST(ExtractExistentialWitnessTable, extractExistentialWitnessTable, 1, HOISTABLE)
INST(ExtractTaggedUnionTag, extractTaggedUnionTag, 1, 0)
INST(ExtractTaggedUnionPayload, extractTaggedUnionPayload, 1, 0)
diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h
index dade0e2f4..adfcac7fd 100644
--- a/source/slang/slang-ir-insts.h
+++ b/source/slang/slang-ir-insts.h
@@ -819,6 +819,7 @@ struct IRDifferentialInstDecoration : IRAutodiffInstDecoration
IRType* getPrimalType() { return (IRType*)(getOperand(0)); }
IRInst* getPrimalInst() { return getOperand(1); }
+ IRInst* getWitness() { return getOperand(2); }
};
struct IRPrimalInstDecoration : IRAutodiffInstDecoration
@@ -1018,6 +1019,17 @@ struct IRBackwardDifferentiate : IRInst
IR_LEAF_ISA(BackwardDifferentiate)
};
+struct IRIsDifferentialNull : IRInst
+{
+ enum
+ {
+ kOp = kIROp_IsDifferentialNull
+ };
+ IRInst* getBase() { return getOperand(0); }
+
+ IR_LEAF_ISA(IsDifferentialNull)
+};
+
// Retrieves the primal substitution function for the given function.
struct IRPrimalSubstitute : IRInst
{
@@ -3223,6 +3235,7 @@ public:
IRInst* emitBackwardDifferentiatePropagateInst(IRType* type, IRInst* baseFn);
IRInst* emitPrimalSubstituteInst(IRType* type, IRInst* baseFn);
IRInst* emitDetachDerivative(IRType* type, IRInst* value);
+ IRInst* emitIsDifferentialNull(IRInst* value);
IRInst* emitDispatchKernelInst(IRType* type, IRInst* baseFn, IRInst* threadGroupSize, IRInst* dispatchSize, Int argCount, IRInst* const* inArgs);
IRInst* emitCudaKernelLaunch(IRInst* baseFn, IRInst* gridDim, IRInst* blockDim, IRInst* argsArray, IRInst* cudaStream);
@@ -4177,6 +4190,12 @@ public:
addDecoration(value, kIROp_DifferentialInstDecoration, primalType, primalInst);
}
+ void markInstAsDifferential(IRInst* value, IRType* primalType, IRInst* primalInst, IRInst* witnessTable)
+ {
+ IRInst* args[] = { primalType, primalInst, witnessTable };
+ addDecoration(value, kIROp_DifferentialInstDecoration, args, 3);
+ }
+
void addCOMWitnessDecoration(IRInst* value, IRInst* witnessTable)
{
addDecoration(value, kIROp_COMWitnessDecoration, &witnessTable, 1);
diff --git a/source/slang/slang-ir-lower-generics.cpp b/source/slang/slang-ir-lower-generics.cpp
index f535b97d2..19f713548 100644
--- a/source/slang/slang-ir-lower-generics.cpp
+++ b/source/slang/slang-ir-lower-generics.cpp
@@ -2,6 +2,7 @@
#include "slang-ir-lower-generics.h"
#include "slang-ir-any-value-marshalling.h"
+#include "slang-ir-any-value-inference.h"
#include "slang-ir-augment-make-existential.h"
#include "slang-ir-generics-lowering-context.h"
#include "slang-ir-lower-existential.h"
@@ -15,7 +16,10 @@
#include "slang-ir-witness-table-wrapper.h"
#include "slang-ir-ssa-simplification.h"
#include "slang-ir-util.h"
+#include "slang-ir-layout.h"
+
#include "../core/slang-performance-profiler.h"
+#include "../core/slang-func-ptr.h"
namespace Slang
{
@@ -213,6 +217,8 @@ namespace Slang
checkTypeConformanceExists(&sharedContext);
+ inferAnyValueSizeWhereNecessary(module);
+
// Replace all `makeExistential` insts with `makeExistentialWithRTTI`
// before making any other changes. This is necessary because a parameter of
// generic type will be lowered into `AnyValueType`, and after that we can no longer
diff --git a/source/slang/slang-ir-lower-reinterpret.cpp b/source/slang/slang-ir-lower-reinterpret.cpp
index 7575c8f12..689cc8505 100644
--- a/source/slang/slang-ir-lower-reinterpret.cpp
+++ b/source/slang/slang-ir-lower-reinterpret.cpp
@@ -3,6 +3,7 @@
#include "slang-ir-insts.h"
#include "slang-ir-layout.h"
#include "slang-ir-any-value-marshalling.h"
+#include "slang-ir-any-value-inference.h"
namespace Slang
{
@@ -84,6 +85,11 @@ struct ReinterpretLoweringContext
void lowerReinterpret(TargetRequest* targetReq, IRModule* module, DiagnosticSink* sink)
{
+ // Before processing reinterpret insts, ensure that existential types without
+ // user-defined sizes have inferred sizes where possible.
+ //
+ inferAnyValueSizeWhereNecessary(module);
+
ReinterpretLoweringContext context;
context.module = module;
context.targetReq = targetReq;
diff --git a/source/slang/slang-ir-ssa.cpp b/source/slang/slang-ir-ssa.cpp
index 18eba677e..730943bf8 100644
--- a/source/slang/slang-ir-ssa.cpp
+++ b/source/slang/slang-ir-ssa.cpp
@@ -412,7 +412,7 @@ PhiInfo* addPhi(
{
valueType = context->getBuilder()->getRateQualifiedType(rate, valueType);
}
- IRParam* phi = builder->createParam(valueType);
+ IRParam* phi = builder->emitParam(valueType);
cloneRelevantDecorations(var, phi);
RefPtr<PhiInfo> phiInfo = new PhiInfo();
@@ -503,6 +503,7 @@ IRInst* tryRemoveTrivialPhi(
// replace uses of the phi (including its possible uses
// of itself) with the unique non-phi value.
phi->replaceUsesWith(same);
+ phi->removeAndDeallocate();
// Clear out the operands to the phi, since they won't
// actually get used in the program any more.
@@ -849,11 +850,12 @@ void processBlock(
// leave them as-is, or replace them with a value
// that we look up with local/global value numbering
- IRInst* next = nullptr;
- for (auto ii = block->getFirstInst(); ii; ii = next)
- {
- next = ii->getNextInst();
+ List<IRInst*> workList;
+ for (auto ii = block->getFirstInst(); ii; ii = ii->getNextInst())
+ workList.add(ii);
+ for (auto& ii : workList)
+ {
// Any new instructions we create to represent
// the new value will get inserted before whatever
// instruction we are working with.
@@ -1117,6 +1119,14 @@ bool constructSSA(ConstructSSAContext* context)
{
auto blockInfo = *context->blockInfos.tryGetValue(bb);
+ // First remove phis from their parent blocks.
+ for (auto phiInfo : blockInfo->phis)
+ if (!phiInfo->replacement)
+ phiInfo->phi->removeFromParent();
+
+ // Then, add them back in a consistent order, and add predecessor
+ // args in the same order.
+ //
for (auto phiInfo : blockInfo->phis)
{
// If we replaced this phi with another value,
diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp
index c666ccc08..8d36c2e86 100644
--- a/source/slang/slang-ir.cpp
+++ b/source/slang/slang-ir.cpp
@@ -3268,6 +3268,17 @@ namespace Slang
return inst;
}
+ IRInst *IRBuilder::emitIsDifferentialNull(IRInst *value)
+ {
+ auto inst = createInst<IRIsDifferentialNull>(
+ this,
+ kIROp_IsDifferentialNull,
+ getBoolType(),
+ value);
+ addInst(inst);
+ return inst;
+ }
+
IRInst* IRBuilder::emitBackwardDifferentiateInst(IRType* type, IRInst* baseFn)
{
auto inst = createInst<IRBackwardDifferentiate>(
diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp
index 1ed79fbe3..489a89287 100644
--- a/source/slang/slang-lower-to-ir.cpp
+++ b/source/slang/slang-lower-to-ir.cpp
@@ -2813,6 +2813,16 @@ void collectParameterLists(
auto noDiffAttr = context->astBuilder->getNoDiffModifierVal();
thisType = context->astBuilder->getModifiedType(thisType, 1, &noDiffAttr);
}
+ else if (auto fwdDerivDeclRef = declRef.as<ForwardDerivativeRequirementDecl>())
+ {
+ thisType = fwdDerivDeclRef.getDecl()->diffThisType;
+ }
+ else if (auto bwdDerivDeclRef = declRef.as<BackwardDerivativeRequirementDecl>())
+ {
+ thisType = bwdDerivDeclRef.getDecl()->diffThisType;
+ innerThisParamDirection = kParameterDirection_InOut;
+ }
+
addThisParameter(innerThisParamDirection, thisType, ioParameterLists);
}
}
@@ -7235,7 +7245,7 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
}
auto assocType = context->irBuilder->getAssociatedType(
constraintInterfaces.getArrayView().arrayView);
- context->setValue(decl, assocType);
+ context->setValue(decl, assocType);
return LoweredValInfo::simple(assocType);
}
@@ -8446,14 +8456,6 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
addNameHint(subContext, irFunc, decl);
addLinkageDecoration(subContext, irFunc, decl);
- if (decl->body)
- {
- if (auto differentialAttr = decl->findModifier<DifferentiableAttribute>())
- {
- lowerDifferentiableAttribute(subContext, irFunc, differentialAttr);
- }
- }
-
// Always force inline diff setter accessor to prevent downstream compiler from complaining
// fields are not fully initialized for the first `inout` parameter.
if (as<SetterDecl>(decl))
@@ -8927,6 +8929,17 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
getBuilder()->addDecoration(irFunc, kIROp_PreferRecomputeDecoration);
}
}
+
+ if (auto diffAttr = decl->findModifier<DifferentiableAttribute>())
+ {
+ if (decl->body)
+ {
+ subContext->irBuilder->setInsertInto(irFunc->getParent());
+ lowerDifferentiableAttribute(subContext, irFunc, diffAttr);
+ subContext->irBuilder->setInsertInto(irFunc);
+ }
+ }
+
// For convenience, ensure that any additional global
// values that were emitted while outputting the function
// body appear before the function itself in the list