summaryrefslogtreecommitdiffstats
path: root/source
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2023-03-14 09:31:51 -0700
committerGitHub <noreply@github.com>2023-03-14 09:31:51 -0700
commite291f60c6b083eaa74aed5307a6e9461274c1642 (patch)
treebde9b45a9e09ebbe173fae1821237b258a9ff800 /source
parenta911ca6e06ce41e403b80fe6054162393491c8ac (diff)
Support `fwd_diff(bwd_diff(f))`. (#2697)
* Support `fwd_diff(bwd_diff(f))`. * Fix. --------- Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source')
-rw-r--r--source/slang/slang-ir-addr-inst-elimination.cpp7
-rw-r--r--source/slang/slang-ir-autodiff-transcriber-base.cpp8
-rw-r--r--source/slang/slang-ir-autodiff-unzip.cpp7
-rw-r--r--source/slang/slang-ir-autodiff.cpp414
-rw-r--r--source/slang/slang-ir-inst-defs.h6
-rw-r--r--source/slang/slang-ir-insts.h29
-rw-r--r--source/slang/slang-ir-peephole.cpp13
-rw-r--r--source/slang/slang-ir-validate.cpp6
-rw-r--r--source/slang/slang-ir.cpp104
9 files changed, 580 insertions, 14 deletions
diff --git a/source/slang/slang-ir-addr-inst-elimination.cpp b/source/slang/slang-ir-addr-inst-elimination.cpp
index 42bc88106..6715f2c6a 100644
--- a/source/slang/slang-ir-addr-inst-elimination.cpp
+++ b/source/slang/slang-ir-addr-inst-elimination.cpp
@@ -150,10 +150,13 @@ struct AddressInstEliminationContext
for (auto use = addrInst->firstUse; use; )
{
+ auto nextUse = use->nextUse;
+
if (as<IRDecoration>(use->getUser()))
+ {
+ use = nextUse;
continue;
-
- auto nextUse = use->nextUse;
+ }
switch (use->getUser()->getOp())
{
diff --git a/source/slang/slang-ir-autodiff-transcriber-base.cpp b/source/slang/slang-ir-autodiff-transcriber-base.cpp
index 091e7f1ab..0f51a6c62 100644
--- a/source/slang/slang-ir-autodiff-transcriber-base.cpp
+++ b/source/slang/slang-ir-autodiff-transcriber-base.cpp
@@ -258,10 +258,6 @@ IRType* AutoDiffTranscriberBase::differentiateType(IRBuilder* builder, IRType* o
auto diffType = (IRType*)differentiableTypeConformanceContext.getDifferentialForType(builder, origType);
return (IRType*)findOrTranscribePrimalInst(builder, diffType);
}
- else if (origType->getOp() == kIROp_LookupWitness)
- {
- return (IRType*)findOrTranscribePrimalInst(builder, (IRInst*)primalType);
- }
return (IRType*)transcribe(builder, origType);
}
@@ -282,9 +278,7 @@ IRType* AutoDiffTranscriberBase::_differentiateTypeImpl(IRBuilder* builder, IRTy
{
case kIROp_Param:
if (as<IRTypeType>(primalType->getDataType()))
- return (IRType*)(differentiableTypeConformanceContext.getDifferentialForType(
- builder,
- (IRType*)primalType));
+ return differentiateType(builder, origType);
else if (as<IRWitnessTableType>(primalType->getDataType()))
return (IRType*)primalType;
diff --git a/source/slang/slang-ir-autodiff-unzip.cpp b/source/slang/slang-ir-autodiff-unzip.cpp
index 2347c7a8f..e01452972 100644
--- a/source/slang/slang-ir-autodiff-unzip.cpp
+++ b/source/slang/slang-ir-autodiff-unzip.cpp
@@ -307,7 +307,12 @@ struct ExtractPrimalFuncContext
IRCloneEnv cloneEnv;
fieldType = cloneInst(&cloneEnv, &genTypeBuilder, fieldType);
}
- return genTypeBuilder.createStructField(structType, structKey, (IRType*)fieldType);
+ auto structField = genTypeBuilder.createStructField(structType, structKey, (IRType*)fieldType);
+ if (auto diffFieldType = backwardPrimalTranscriber->differentiateType(&genTypeBuilder, (IRType*)fieldType))
+ {
+ genTypeBuilder.addIntermediateContextFieldDifferentialTypeDecoration(structField, diffFieldType);
+ }
+ return structField;
}
void storeInst(
diff --git a/source/slang/slang-ir-autodiff.cpp b/source/slang/slang-ir-autodiff.cpp
index edea3847d..4d22d9eed 100644
--- a/source/slang/slang-ir-autodiff.cpp
+++ b/source/slang/slang-ir-autodiff.cpp
@@ -482,6 +482,7 @@ void stripAutoDiffDecorationsFromChildren(IRInst* parent)
case kIROp_BackwardDerivativePrimalReturnDecoration:
case kIROp_AutoDiffOriginalValueDecoration:
case kIROp_UserDefinedBackwardDerivativeDecoration:
+ case kIROp_IntermediateContextFieldDifferentialTypeDecoration:
decor->removeAndDeallocate();
break;
default:
@@ -631,6 +632,8 @@ struct AutoDiffPass : public InstPassBase
{
List<IRInst*> args;
auto subBase = processIntermediateContextTypeBase(builder, spec->getBase());
+ if (!subBase)
+ return nullptr;
for (UInt a = 0; a < spec->getArgCount(); a++)
args.add(spec->getArg(a));
auto actualType = builder->emitSpecializeInst(
@@ -645,6 +648,9 @@ struct AutoDiffPass : public InstPassBase
auto inner = findGenericReturnVal(baseGeneric);
if (auto typeDecor = inner->findDecoration<IRBackwardDerivativeIntermediateTypeDecoration>())
{
+ if (!isTypeFullyDifferentiated(typeDecor->getBackwardDerivativeIntermediateType()))
+ return nullptr;
+
return typeDecor->getBackwardDerivativeIntermediateType();
}
}
@@ -652,6 +658,8 @@ struct AutoDiffPass : public InstPassBase
{
if (auto typeDecor = func->findDecoration<IRBackwardDerivativeIntermediateTypeDecoration>())
{
+ if (!isTypeFullyDifferentiated(typeDecor->getBackwardDerivativeIntermediateType()))
+ return nullptr;
return typeDecor->getBackwardDerivativeIntermediateType();
}
}
@@ -671,6 +679,10 @@ struct AutoDiffPass : public InstPassBase
bool lowerIntermediateContextType(IRBuilder* builder)
{
bool changed = false;
+ OrderedHashSet<IRInst*> loweredIntermediateTypes;
+
+ // Replace all `BackwardDiffIntermediateContextType` insts with the struct type
+ // that we generated during backward diff pass.
processAllInsts([&](IRInst* inst)
{
switch (inst->getOp())
@@ -685,6 +697,7 @@ struct AutoDiffPass : public InstPassBase
auto type = processIntermediateContextTypeBase(&subBuilder, baseFunc);
if (type)
{
+ loweredIntermediateTypes.Add(type);
inst->replaceUsesWith(type);
inst->removeAndDeallocate();
changed = true;
@@ -695,14 +708,400 @@ struct AutoDiffPass : public InstPassBase
break;
}
});
+
+ // Now we generate the differential type for the intermediate context type
+ // to allow higher order differentiation.
+ generateDifferentialImplementationForContextType(loweredIntermediateTypes);
return changed;
}
+ // Utility function for topology sorting the intermediate context types.
+ bool isIntermediateContextTypeReadyForProcess(OrderedHashSet<IRInst*>& contextTypes, OrderedHashSet<IRInst*>& sortedSet, IRInst* t)
+ {
+ if (!contextTypes.Contains(t))
+ return true;
+
+ switch (t->getOp())
+ {
+ case kIROp_StructType:
+ {
+ bool canAddNow = true;
+ for (auto f : as<IRStructType>(t)->getFields())
+ {
+ if (!isIntermediateContextTypeReadyForProcess(contextTypes, sortedSet, f->getFieldType()))
+ {
+ canAddNow = false;
+ break;
+ }
+ }
+ return canAddNow;
+ }
+ case kIROp_Specialize:
+ return isIntermediateContextTypeReadyForProcess(contextTypes, sortedSet, as<IRSpecialize>(t)->getBase());
+ case kIROp_Generic:
+ return isIntermediateContextTypeReadyForProcess(contextTypes, sortedSet, findGenericReturnVal(as<IRGeneric>(t)));
+ default:
+ return true;
+ }
+ }
+
+ struct IntermediateContextTypeDifferentialInfo
+ {
+ IRInst* diffType = nullptr;
+ IRInst* diffWitness = nullptr;
+ IRInst* diffDiffWitness = nullptr;
+ IRInst* zeroMethod = nullptr;
+ IRInst* addMethod = nullptr;
+ };
+
+ // Register the differential type for an intermediate context type to the derivative functions that uses the type.
+ void registerDiffContextType(
+ IRBuilder& builder,
+ IRDifferentiableTypeDictionaryDecoration* diffDecor,
+ OrderedDictionary<IRInst*, IntermediateContextTypeDifferentialInfo>& diffTypes,
+ IRInst* origType)
+ {
+ HashSet<IRInst*> registeredType;
+ for (auto entry : diffDecor->getChildren())
+ {
+ if (auto e = as<IRDifferentiableTypeDictionaryItem>(entry))
+ {
+ registeredType.Add(e->getOperand(0));
+ }
+ }
+ // Use a work list to recursively walk through all sub fields of the struct type.
+ List<IRInst*> wlist;
+ wlist.add(origType);
+ for (Index i = 0; i < wlist.getCount(); i++)
+ {
+ auto t = wlist[i];
+ IntermediateContextTypeDifferentialInfo diffInfo;
+ if (!diffTypes.TryGetValue(t, diffInfo))
+ continue;
+ if (registeredType.Add(t))
+ builder.addDifferentiableTypeEntry(diffDecor, t, diffInfo.diffWitness);
+ else
+ continue;
+
+ if (auto structType = as<IRStructType>(getResolvedInstForDecorations(t)))
+ {
+ for (auto f : structType->getFields())
+ {
+ wlist.add(f->getFieldType());
+ }
+ }
+ }
+ }
+
+ void generateDifferentialImplementationForContextType(OrderedHashSet<IRInst*>& contextTypes)
+ {
+ // First we are going to topology sort all intermediate context types.
+ OrderedHashSet<IRInst*> sortedContextTypes;
+ for (;;)
+ {
+ auto lastCount = sortedContextTypes.Count();
+ for (auto t : contextTypes)
+ {
+ if (sortedContextTypes.Contains(t))
+ continue;
+ // Have all dependent types been added yet?
+ if (isIntermediateContextTypeReadyForProcess(contextTypes, sortedContextTypes, t))
+ sortedContextTypes.Add(t);
+ }
+ if (lastCount == sortedContextTypes.Count())
+ break;
+ }
+
+ // After the types are sorted, we start to generate the differential type and IDifferentiable witnesses
+ // for them.
+
+ OrderedDictionary<IRInst*, IntermediateContextTypeDifferentialInfo> diffTypes;
+ IRBuilder builder(module);
+ for (auto t : sortedContextTypes)
+ {
+ if (t->getOp() == kIROp_Generic || t->getOp() == kIROp_StructType)
+ {
+ // For generics/struct types, we will generate a new generic/struct type representing the differntial.
+
+ SLANG_RELEASE_ASSERT(t->getParent() && t->getParent()->getOp() == kIROp_Module);
+ builder.setInsertBefore(t);
+ auto diffInfo = fillDifferentialTypeImplementation(diffTypes, t);
+ diffTypes[t] = diffInfo;
+ }
+ else if (auto specialize = as<IRSpecialize>(t))
+ {
+ // A specialize of a context type translates to a specialize of its differential type/witness.
+
+ IntermediateContextTypeDifferentialInfo baseInfo;
+ SLANG_RELEASE_ASSERT(diffTypes.TryGetValue(specialize->getBase(), baseInfo));
+ builder.setInsertBefore(t);
+ List<IRInst*> args;
+ for (UInt i = 0; i < specialize->getArgCount(); i++)
+ args.add(specialize->getArg(i));
+ IntermediateContextTypeDifferentialInfo info;
+ info.diffType = builder.emitSpecializeInst(
+ builder.getTypeKind(), baseInfo.diffType, (UInt)args.getCount(), args.getBuffer());
+ info.diffWitness = builder.emitSpecializeInst(
+ builder.getWitnessTableType(autodiffContext->differentiableInterfaceType),
+ baseInfo.diffWitness,
+ (UInt)args.getCount(),
+ args.getBuffer());
+ diffTypes[t] = info;
+ }
+ else
+ {
+ // If `t` is not a specialize, it'd better be processed by now.
+ // We currently don't support the `LookupInterfaceMethod` case, since it can't
+ // appear in a derivative function because we will only call the backward diff function without a intermediate-type
+ // via an interface.
+ SLANG_RELEASE_ASSERT(diffTypes.ContainsKey(t));
+ }
+ }
+
+ // Register the differential types into the conformance dictionaries of the functions that uses them.
+ for (auto t : diffTypes)
+ {
+ HashSet<IRFunc*> registeredFuncs;
+ for (auto use = t.Key->firstUse; use; use = use->nextUse)
+ {
+ auto parentFunc = getParentFunc(use->getUser());
+ if (!parentFunc)
+ continue;
+ if (!registeredFuncs.Add(parentFunc))
+ continue;
+ if (auto dictDecor = parentFunc->findDecoration<IRDifferentiableTypeDictionaryDecoration>())
+ {
+ registerDiffContextType(builder, dictDecor, diffTypes, t.Key);
+ }
+ }
+ }
+ }
+
+ IntermediateContextTypeDifferentialInfo fillDifferentialTypeImplementationForStruct(
+ OrderedDictionary<IRInst*, IntermediateContextTypeDifferentialInfo>& diffTypes,
+ IRStructType* originalType,
+ IRStructType* diffType)
+ {
+ IntermediateContextTypeDifferentialInfo result;
+ result.diffType = diffType;
+
+ IRBuilder builder(diffType);
+ builder.setInsertInto(diffType);
+
+ // Generate the fields for all differentiable members of the original struct type.
+ for (auto field : originalType->getFields())
+ {
+ IRInst* diffFieldType = nullptr;
+ if (auto diffDecor = field->findDecoration<IRIntermediateContextFieldDifferentialTypeDecoration>())
+ {
+ diffFieldType = diffDecor->getDifferentialType();
+ }
+ else
+ {
+ IntermediateContextTypeDifferentialInfo diffFieldTypeInfo;
+ diffTypes.TryGetValue(field->getDataType(), diffFieldTypeInfo);
+ diffFieldType = diffFieldTypeInfo.diffType;
+ }
+ if (diffFieldType)
+ {
+ IRBuilder keyBuilder = builder;
+ keyBuilder.setInsertBefore(maybeFindOuterGeneric(originalType));
+ auto diffKey = keyBuilder.createStructKey();
+ builder.createStructField(diffType, diffKey, (IRType*)diffFieldType);
+ builder.addDecoration(field->getKey(), kIROp_DerivativeMemberDecoration, diffKey);
+ builder.addDecoration(diffKey, kIROp_DerivativeMemberDecoration, diffKey);
+ }
+ }
+
+ builder.setInsertAfter(diffType);
+
+ // For now, we are going to structurally derive dadd and dzero methods for intermediate context types,
+ // because it is tricky for us to obtain the original witness tables for the fields at this point.
+ // This is inconsistent with how we are dealing with dadd and dzero methods via witness table lookup,
+ // and can lead to problems if the user defines any non-trivial dadd/dzero methods.
+ //
+ // TODO: we should consider rewrite this logic to be witness table lookup based, or simplify the entire
+ // type system and IR passes to always use structurally derived methods instead of user-provided
+ // methods.
+ IRInst* zeroMethod = nullptr;
+ {
+ auto zeroMethodType = builder.getFuncType(List<IRType*>(), diffType);
+ zeroMethod = builder.createFunc();
+ zeroMethod->setFullType(zeroMethodType);
+ result.zeroMethod = zeroMethod;
+ builder.setInsertInto(zeroMethod);
+ builder.emitBlock();
+ builder.emitReturn(builder.emitDefaultConstruct(diffType));
+ }
+
+ builder.setInsertAfter(zeroMethod);
+ IRInst* addMethod = nullptr;
+ {
+ List<IRType*> paramTypes;
+ paramTypes.add(diffType);
+ paramTypes.add(diffType);
+ auto addMethodType = builder.getFuncType(List<IRType*>(), diffType);
+ addMethod = builder.createFunc();
+ result.addMethod = addMethod;
+ addMethod->setFullType(addMethodType);
+ builder.setInsertInto(addMethod);
+ builder.emitBlock();
+ auto param1 = builder.emitParam(diffType);
+ auto param2 = builder.emitParam(diffType);
+ builder.emitReturn(builder.emitStructuralAdd(param1, param2));
+ }
+
+ builder.setInsertAfter(addMethod);
+ auto diffTypeIsDiffWitness = builder.createWitnessTable(autodiffContext->differentiableInterfaceType, diffType);
+ auto origTypeIsDiffWitness = builder.createWitnessTable(autodiffContext->differentiableInterfaceType, originalType);
+ result.diffWitness = origTypeIsDiffWitness;
+
+ builder.createWitnessTableEntry(origTypeIsDiffWitness, autodiffContext->differentialAssocTypeStructKey, diffType);
+ builder.createWitnessTableEntry(origTypeIsDiffWitness, autodiffContext->differentialAssocTypeWitnessStructKey, diffTypeIsDiffWitness);
+ builder.createWitnessTableEntry(origTypeIsDiffWitness, autodiffContext->zeroMethodStructKey, zeroMethod);
+ builder.createWitnessTableEntry(origTypeIsDiffWitness, autodiffContext->addMethodStructKey, addMethod);
+
+ builder.createWitnessTableEntry(diffTypeIsDiffWitness, autodiffContext->differentialAssocTypeStructKey, diffType);
+ builder.createWitnessTableEntry(diffTypeIsDiffWitness, autodiffContext->differentialAssocTypeWitnessStructKey, diffTypeIsDiffWitness);
+ builder.createWitnessTableEntry(diffTypeIsDiffWitness, autodiffContext->zeroMethodStructKey, zeroMethod);
+ builder.createWitnessTableEntry(diffTypeIsDiffWitness, autodiffContext->addMethodStructKey, addMethod);
+ return result;
+ }
+
+ IntermediateContextTypeDifferentialInfo fillDifferentialTypeImplementation(
+ OrderedDictionary<IRInst*, IntermediateContextTypeDifferentialInfo>& diffTypes,
+ IRInst* originalType)
+ {
+ if (originalType->getOp() == kIROp_StructType)
+ {
+ IRBuilder builder(originalType);
+ builder.setInsertBefore(originalType);
+ auto diffType = builder.createStructType();
+ return fillDifferentialTypeImplementationForStruct(
+ diffTypes,
+ as<IRStructType>(originalType),
+ as<IRStructType>(diffType));
+ }
+ else if (auto genType = as<IRGeneric>(originalType))
+ {
+ // For generics, we process the inner struct type as normal,
+ // and then hoist the additional insts we created from the generic.
+
+ auto structType = as<IRStructType>(findGenericReturnVal(genType));
+ SLANG_RELEASE_ASSERT(structType);
+
+ auto innerResult = fillDifferentialTypeImplementation(diffTypes, structType);
+ IRBuilder builder(originalType);
+ builder.setInsertBefore(originalType);
+
+ // Now we hoist the new values from the generic to form their independent generics.
+ IRInst* specInst = nullptr;
+ IntermediateContextTypeDifferentialInfo result;
+ if (innerResult.diffType)
+ result.diffType = hoistValueFromGeneric(builder, innerResult.diffType, specInst, true);
+ if (innerResult.zeroMethod)
+ {
+ hoistValueFromGeneric(builder, innerResult.zeroMethod->getFullType(), specInst, true);
+ result.zeroMethod = hoistValueFromGeneric(builder, innerResult.zeroMethod, specInst, true);
+ }
+ if (innerResult.addMethod)
+ {
+ hoistValueFromGeneric(builder, innerResult.addMethod->getFullType(), specInst, true);
+ result.addMethod = hoistValueFromGeneric(builder, innerResult.addMethod, specInst, true);
+ }
+ if (innerResult.diffDiffWitness)
+ result.diffDiffWitness = hoistValueFromGeneric(builder, innerResult.diffDiffWitness, specInst, true);
+ if (innerResult.diffWitness)
+ {
+ builder.setInsertBefore(innerResult.diffWitness);
+ List<IRInst*> args;
+ for (auto param : genType->getParams())
+ args.add(param);
+ as<IRWitnessTable>(innerResult.diffWitness)->setConcreteType((IRType*)builder.emitSpecializeInst(
+ builder.getTypeKind(), originalType, (UInt)args.getCount(), args.getBuffer()));
+ result.diffWitness = hoistValueFromGeneric(builder, innerResult.diffWitness, specInst, true);
+ }
+ return result;
+ }
+ return IntermediateContextTypeDifferentialInfo();
+ }
+
+ HashSet<IRInst*> fullyDifferentiatedInsts;
+
+ // Returns true if `type` is fully differentiated, i.e. does not have
+ // any unmaterialized intermediate context types.
+ bool isTypeFullyDifferentiated(IRInst* type)
+ {
+ if (fullyDifferentiatedInsts.Contains(type))
+ return true;
+ if (type->getOp() == kIROp_BackwardDiffIntermediateContextType)
+ return false;
+ if (auto structType = as<IRStructType>(type))
+ {
+ for (auto f : structType->getFields())
+ if (!isTypeFullyDifferentiated(f->getFieldType()))
+ return false;
+ }
+ else if (auto genType = as<IRGeneric>(type))
+ {
+ bool result = isTypeFullyDifferentiated(findGenericReturnVal(genType));
+ if (result)
+ fullyDifferentiatedInsts.Add(genType);
+ return result;
+ }
+ switch (type->getOp())
+ {
+ case kIROp_ArrayType:
+ case kIROp_UnsizedArrayType:
+ case kIROp_InOutType:
+ case kIROp_OutType:
+ case kIROp_PtrType:
+ case kIROp_DifferentialPairType:
+ case kIROp_DifferentialPairUserCodeType:
+ case kIROp_AttributedType:
+ for (UInt i = 0; i < type->getOperandCount(); i++)
+ if (!isTypeFullyDifferentiated(type->getOperand(i)))
+ return false;
+ default:
+ fullyDifferentiatedInsts.Add(type);
+ return true;
+ }
+ }
+
+ // Returns true if `func` is fully differentiated, i.e. does not have
+ // any differentiate insts.
+ bool isFullyDifferentiated(IRFunc* func)
+ {
+ if (fullyDifferentiatedInsts.Contains(func))
+ return true;
+
+ for (auto block : func->getBlocks())
+ {
+ for (auto ii : block->getChildren())
+ {
+ switch (ii->getOp())
+ {
+ case kIROp_ForwardDifferentiate:
+ case kIROp_BackwardDifferentiate:
+ case kIROp_BackwardDifferentiatePrimal:
+ case kIROp_BackwardDifferentiatePropagate:
+ case kIROp_BackwardDiffIntermediateContextType:
+ return false;
+ }
+ if (ii->getDataType() && !isTypeFullyDifferentiated(ii->getDataType()))
+ return false;
+ }
+ }
+ fullyDifferentiatedInsts.Add(func);
+ return true;
+ }
+
// Process all differentiate calls, and recursively generate code for forward and backward
// derivative functions.
//
bool processReferencedFunctions(IRBuilder* builder)
{
+ fullyDifferentiatedInsts.Clear();
bool hasChanges = false;
for (;;)
{
@@ -725,6 +1124,12 @@ struct AutoDiffPass : public InstPassBase
case kIROp_Func:
case kIROp_Specialize:
case kIROp_LookupWitness:
+ if (auto innerFunc = as<IRFunc>(getResolvedInstForDecorations(inst->getOperand(0))))
+ {
+ // Skip functions whose body still has a differentiate inst (higher order func).
+ if (!isFullyDifferentiated(innerFunc))
+ return;
+ }
autoDiffWorkList.add(inst);
break;
default:
@@ -845,6 +1250,11 @@ struct AutoDiffPass : public InstPassBase
if (!changed)
break;
+ if (lowerIntermediateContextType(builder))
+ {
+ hasChanges = true;
+ }
+
// We have done transcribing the functions, now it is time to demote all DifferentialPair types
// and their operations down to DifferentialPairUserCodeType and *UserCode operations so they
// can be treated just like normal types with no special semantics in future processing, and won't
@@ -854,10 +1264,6 @@ struct AutoDiffPass : public InstPassBase
hasChanges |= changed;
}
- if (lowerIntermediateContextType(builder))
- {
- hasChanges = true;
- }
return hasChanges;
}
diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h
index 28c682c91..bb8cfc378 100644
--- a/source/slang/slang-ir-inst-defs.h
+++ b/source/slang/slang-ir-inst-defs.h
@@ -320,6 +320,9 @@ INST(MakeOptionalValue, makeOptionalValue, 1, 0)
INST(MakeOptionalNone, makeOptionalNone, 1, 0)
INST(Call, call, 1, 0)
+// Structural addition of two values of the same type.
+INST(StructuralAdd, structuralAdd, 2, 0)
+
INST(RTTIObject, rtti_object, 0, 0)
INST(Alloca, alloca, 1, 0)
@@ -814,6 +817,9 @@ INST(HighLevelDeclDecoration, highLevelDecl, 1, 0)
/// forward-differentiated updateElement inst.
INST(PrimalElementTypeDecoration, primalElementType, 1, 0)
+ /// Used by the auto-diff pass to mark the differential type of an intermediate context field.
+ INST(IntermediateContextFieldDifferentialTypeDecoration, IntermediateContextFieldDifferentialTypeDecoration, 1, 0)
+
/// Used by the auto-diff pass to hold a reference to a
/// differential member of a type in its associated differential type.
INST(DerivativeMemberDecoration, derivativeMemberDecoration, 1, 0)
diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h
index f3181ecf7..43893bfe6 100644
--- a/source/slang/slang-ir-insts.h
+++ b/source/slang/slang-ir-insts.h
@@ -783,6 +783,20 @@ struct IRPrimalElementTypeDecoration : IRDecoration
IRInst* getPrimalElementType() { return getOperand(0); }
};
+struct IRIntermediateContextFieldDifferentialTypeDecoration : IRDecoration
+{
+ enum
+ {
+ kOp = kIROp_IntermediateContextFieldDifferentialTypeDecoration
+ };
+
+ IR_LEAF_ISA(IntermediateContextFieldDifferentialTypeDecoration)
+
+ IRInst* getDifferentialType() { return getOperand(0); }
+ IRInst* getDifferentialWitness() { return getOperand(1); }
+
+};
+
struct IRBackwardDifferentiableDecoration : IRDecoration
{
enum
@@ -2205,6 +2219,11 @@ struct IRWitnessTable : IRInst
return (IRType*) getOperand(0);
}
+ void setConcreteType(IRType* t)
+ {
+ return setOperand(0, t);
+ }
+
IR_LEAF_ISA(WitnessTable)
};
@@ -2867,6 +2886,7 @@ public:
IRInst* addPrimalValueStructKeyDecoration(IRInst* target, IRStructKey* key);
IRInst* addPrimalElementTypeDecoration(IRInst* target, IRInst* type);
+ IRInst* addIntermediateContextFieldDifferentialTypeDecoration(IRInst* target, IRInst* type);
// Add a differentiable type entry to the appropriate dictionary.
IRInst* addDifferentiableTypeEntry(IRInst* dictDecoration, IRInst* irType, IRInst* conformanceWitness);
@@ -2949,6 +2969,15 @@ public:
/// the inst.
IRInst* emitDefaultConstructRaw(IRType* type);
+ /// Emits appropriate inst for structurally adding two values of `type`.
+ /// If `fallback` is true, will emit `StructuralAdd` inst on unknown types.
+ /// Otherwise, returns nullptr if we can't materialize the inst.
+ IRInst* emitStructuralAdd(IRInst* val0, IRInst* val1, bool fallback = true);
+
+ /// Emits a raw `StructuralAdd` opcode without attempting to fold/materialize
+ /// the inst.
+ IRInst* emitStructuralAddRaw(IRInst* val0, IRInst* val1);
+
IRInst* emitCast(
IRType* type,
IRInst* value);
diff --git a/source/slang/slang-ir-peephole.cpp b/source/slang/slang-ir-peephole.cpp
index 5d5a41726..a5ec50b2c 100644
--- a/source/slang/slang-ir-peephole.cpp
+++ b/source/slang/slang-ir-peephole.cpp
@@ -633,6 +633,19 @@ struct PeepholeContext : InstPassBase
}
}
break;
+ case kIROp_StructuralAdd:
+ {
+ IRBuilder builder(module);
+ builder.setInsertBefore(inst);
+ // See if we can replace the generic add inst with concrete values.
+ if (auto newCtor = builder.emitStructuralAdd(inst->getOperand(0), inst->getOperand(1), false))
+ {
+ inst->replaceUsesWith(newCtor);
+ maybeRemoveOldInst(inst);
+ changed = true;
+ }
+ }
+ break;
case kIROp_Add:
case kIROp_Mul:
case kIROp_Sub:
diff --git a/source/slang/slang-ir-validate.cpp b/source/slang/slang-ir-validate.cpp
index d5c0aa432..55e0f0168 100644
--- a/source/slang/slang-ir-validate.cpp
+++ b/source/slang/slang-ir-validate.cpp
@@ -208,6 +208,12 @@ namespace Slang
if (allInGlobalScope)
return;
+ // Allow exceptions.
+ switch (inst->getOp())
+ {
+ case kIROp_DifferentiableTypeDictionaryItem:
+ return;
+ }
//
// We failed to find `operandParent` while walking the ancestors of `inst`,
// so something had gone wrong.
diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp
index 08c066f5d..f61e5a10e 100644
--- a/source/slang/slang-ir.cpp
+++ b/source/slang/slang-ir.cpp
@@ -3509,6 +3509,104 @@ namespace Slang
return nullptr;
}
+ IRInst* IRBuilder::emitStructuralAddRaw(IRInst* val0, IRInst* val1)
+ {
+ IRInst* args[2] = { val0, val1 };
+ return emitIntrinsicInst(val0->getFullType(), kIROp_StructuralAdd, 2, args);
+ }
+
+ IRInst* IRBuilder::emitStructuralAdd(IRInst* val0, IRInst* val1, bool fallback)
+ {
+ auto type = val0->getFullType();
+ SLANG_RELEASE_ASSERT(val0->getFullType() == val1->getFullType());
+ IRType* actualType = val0->getFullType();
+ for (;;)
+ {
+ if (auto attr = as<IRAttributedType>(actualType))
+ actualType = attr->getBaseType();
+ else if (auto rateQualified = as<IRRateQualifiedType>(actualType))
+ actualType = rateQualified->getValueType();
+ else
+ break;
+ }
+ if (as<IRBasicType>(actualType))
+ return emitAdd(type, val0, val1);
+
+ switch (actualType->getOp())
+ {
+ case kIROp_PtrType:
+ case kIROp_VectorType:
+ case kIROp_MatrixType:
+ return emitAdd(type, val0, val1);
+ case kIROp_TupleType:
+ {
+ List<IRInst*> elements;
+ auto tupleType = as<IRTupleType>(actualType);
+ for (UInt i = 0; i < tupleType->getOperandCount(); i++)
+ {
+ auto operand = tupleType->getOperand(i);
+ if (as<IRAttr>(operand))
+ break;
+ auto inner = emitStructuralAdd(
+ emitGetTupleElement((IRType*)operand, val0, i),
+ emitGetTupleElement((IRType*)operand, val1, i),
+ fallback);
+ if (!inner)
+ return nullptr;
+ elements.add(inner);
+ }
+ return emitMakeTuple(tupleType, elements);
+ }
+ case kIROp_StructType:
+ {
+ List<IRInst*> elements;
+ auto structType = as<IRStructType>(actualType);
+ for (auto field : structType->getFields())
+ {
+ auto fieldType = field->getFieldType();
+ auto inner = emitStructuralAdd(
+ emitFieldExtract(fieldType, val0, field->getKey()),
+ emitFieldExtract(fieldType, val1, field->getKey()),
+ fallback);
+ if (!inner)
+ return nullptr;
+ elements.add(inner);
+ }
+ return emitMakeStruct(type, elements);
+ }
+ case kIROp_ArrayType:
+ {
+ auto arrayType = as<IRArrayType>(actualType);
+ if (auto count = as<IRIntLit>(arrayType->getElementCount()))
+ {
+ auto elementType = arrayType->getElementType();
+ List<IRInst*> elements;
+ constexpr int maxCount = 4096;
+ if (count->getValue() > maxCount)
+ break;
+ for (IRIntegerValue i = 0; i < count->getValue(); i++)
+ {
+ auto index = getIntValue(getIntType(), i);
+ auto element = emitStructuralAdd(
+ emitElementExtract(elementType, val0, index),
+ emitElementExtract(elementType, val1, index),
+ fallback);
+ elements.add(element);
+ }
+ return emitMakeArray(type, elements.getCount(), elements.getBuffer());
+ }
+ break;
+ }
+ default:
+ break;
+ }
+ if (fallback)
+ {
+ return emitStructuralAddRaw(val0, val1);
+ }
+ return nullptr;
+ }
+
static int _getTypeStyleId(IRType* type)
{
if (auto vectorType = as<IRVectorType>(type))
@@ -3928,6 +4026,11 @@ namespace Slang
return addDecoration(target, kIROp_PrimalElementTypeDecoration, type);
}
+ IRInst* IRBuilder::addIntermediateContextFieldDifferentialTypeDecoration(IRInst* target, IRInst* type)
+ {
+ return addDecoration(target, kIROp_IntermediateContextFieldDifferentialTypeDecoration, type);
+ }
+
RefPtr<IRModule> IRModule::create(Session* session)
{
RefPtr<IRModule> module = new IRModule(session);
@@ -7028,6 +7131,7 @@ namespace Slang
case kIROp_Nop:
case kIROp_undefined:
case kIROp_DefaultConstruct:
+ case kIROp_StructuralAdd:
case kIROp_Specialize:
case kIROp_LookupWitness:
case kIROp_GetSequentialID: