summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2023-03-15 09:39:21 -0700
committerGitHub <noreply@github.com>2023-03-15 09:39:21 -0700
commitbf308241b54ae9c421a29aa5620da9fb3ec15245 (patch)
treeacf114b9e0677f6b6494b105130d7043b1be872b
parent176eaa9f7770ad81cbd71def8a1551d6237167bd (diff)
Properly implement differential witness of intermediate context type. (#2699)
* Properly implement differential witness of intermediate context type. * Modify test to include a loop. --------- Co-authored-by: Yong He <yhe@nvidia.com>
-rw-r--r--source/slang/slang-ir-autodiff-transcriber-base.cpp136
-rw-r--r--source/slang/slang-ir-autodiff-transcriber-base.h1
-rw-r--r--source/slang/slang-ir-autodiff-transpose.h1
-rw-r--r--source/slang/slang-ir-autodiff-unzip.cpp5
-rw-r--r--source/slang/slang-ir-autodiff.cpp54
-rw-r--r--source/slang/slang-ir-inst-defs.h3
-rw-r--r--source/slang/slang-ir-insts.h14
-rw-r--r--source/slang/slang-ir-peephole.cpp13
-rw-r--r--source/slang/slang-ir-util.cpp20
-rw-r--r--source/slang/slang-ir-util.h4
-rw-r--r--source/slang/slang-ir.cpp103
-rw-r--r--tests/autodiff/high-order-backward-diff-3.slang10
12 files changed, 210 insertions, 154 deletions
diff --git a/source/slang/slang-ir-autodiff-transcriber-base.cpp b/source/slang/slang-ir-autodiff-transcriber-base.cpp
index 0f51a6c62..e3ef357ee 100644
--- a/source/slang/slang-ir-autodiff-transcriber-base.cpp
+++ b/source/slang/slang-ir-autodiff-transcriber-base.cpp
@@ -172,25 +172,149 @@ IRInst* AutoDiffTranscriberBase::maybeCloneForPrimalInst(IRBuilder* builder, IRI
return primal;
}
+IRInst* _lookupWitness(IRBuilder* builder, IRInst* witness, IRInst* requirementKey);
+
// Get or construct `:IDifferentiable` conformance for a DifferentiablePair.
IRWitnessTable* AutoDiffTranscriberBase::getDifferentialPairWitness(IRBuilder* builder, IRInst* inOriginalDiffPairType, IRInst* inPrimalDiffPairType)
{
// Differentiate the pair type to get it's differential (which is itself a pair)
- auto diffDiffPairType = differentiateType(builder, (IRType*)inOriginalDiffPairType);
-
+ auto diffDiffPairType = (IRType*)differentiateType(builder, (IRType*)inOriginalDiffPairType);
+
+ auto addMethod = builder->createFunc();
+ auto zeroMethod = builder->createFunc();
+
auto table = builder->createWitnessTable(autoDiffSharedContext->differentiableInterfaceType, (IRType*)inPrimalDiffPairType);
// And place it in the synthesized witness table.
builder->createWitnessTableEntry(table, autoDiffSharedContext->differentialAssocTypeStructKey, diffDiffPairType);
+ builder->createWitnessTableEntry(table, autoDiffSharedContext->differentialAssocTypeWitnessStructKey, table);
+ builder->createWitnessTableEntry(table, autoDiffSharedContext->addMethodStructKey, addMethod);
+ builder->createWitnessTableEntry(table, autoDiffSharedContext->zeroMethodStructKey, zeroMethod);
+
+ bool isUserCodeType = as<IRDifferentialPairUserCodeType>(inOriginalDiffPairType) ? true : false;
- // Omit the method synthesis here, since we can just intercept those directly at `getXXMethodForType`.
+ // Fill in differential method implementations.
+ auto elementType = as<IRDifferentialPairTypeBase>(inPrimalDiffPairType)->getValueType();
+ auto innerWitness = as<IRDifferentialPairTypeBase>(inPrimalDiffPairType)->getWitness();
+ {
+ // Add method.
+ IRBuilder b = *builder;
+ b.setInsertInto(addMethod);
+ IRType* paramTypes[2] = { diffDiffPairType, diffDiffPairType };
+ addMethod->setFullType(b.getFuncType(2, paramTypes, diffDiffPairType));
+ b.emitBlock();
+ auto p0 = b.emitParam(diffDiffPairType);
+ auto p1 = b.emitParam(diffDiffPairType);
+
+ // Since we are already dealing with a DiffPair<T>.Differnetial type, we know that value type == diff type.
+ auto innerAdd = _lookupWitness(&b, innerWitness, autoDiffSharedContext->addMethodStructKey);
+ IRInst* argsPrimal[2] = {
+ isUserCodeType ? b.emitDifferentialPairGetPrimalUserCode(p0) : b.emitDifferentialPairGetPrimal(p0),
+ isUserCodeType ? b.emitDifferentialPairGetPrimalUserCode(p1) : b.emitDifferentialPairGetPrimal(p1) };
+ auto primalPart = b.emitCallInst(elementType, innerAdd, 2, argsPrimal);
+ IRInst* argsDiff[2] = {
+ isUserCodeType ? b.emitDifferentialPairGetDifferentialUserCode(elementType, p0) : b.emitDifferentialPairGetDifferential(elementType, p0),
+ isUserCodeType ? b.emitDifferentialPairGetDifferentialUserCode(elementType, p1) : b.emitDifferentialPairGetDifferential(elementType, p1)};
+ auto diffPart = b.emitCallInst(elementType, innerAdd, 2, argsDiff);
+ auto retVal =
+ isUserCodeType
+ ? b.emitMakeDifferentialPairUserCode(diffDiffPairType, primalPart, diffPart)
+ : b.emitMakeDifferentialPair(diffDiffPairType, primalPart, diffPart);
+ b.emitReturn(retVal);
+ }
+ {
+ // Zero method.
+ IRBuilder b = *builder;
+ b.setInsertInto(zeroMethod);
+ zeroMethod->setFullType(b.getFuncType(0, nullptr, diffDiffPairType));
+ b.emitBlock();
+ auto innerZero = _lookupWitness(&b, innerWitness, autoDiffSharedContext->zeroMethodStructKey);
+ auto zeroVal = b.emitCallInst(elementType, innerZero, 0, nullptr);
+ auto retVal =
+ isUserCodeType
+ ? b.emitMakeDifferentialPairUserCode(diffDiffPairType, zeroVal, zeroVal)
+ : b.emitMakeDifferentialPair(diffDiffPairType, zeroVal, zeroVal);
+ b.emitReturn(retVal);
+ }
+
// Record this in the context for future lookups
differentiableTypeConformanceContext.differentiableWitnessDictionary[(IRType*)inOriginalDiffPairType] = table;
return table;
}
+// Get or construct `:IDifferentiable` conformance for an Array.
+IRWitnessTable* AutoDiffTranscriberBase::getArrayWitness(IRBuilder* builder, IRInst* inOriginalArrayType, IRInst* inPrimalArrayType)
+{
+ // Differentiate the pair type to get it's differential (which is itself a pair)
+ auto diffArrayType = (IRType*)differentiateType(builder, (IRType*)inOriginalArrayType);
+
+ if (!diffArrayType)
+ return nullptr;
+
+ auto innerWitness = tryGetDifferentiableWitness(builder, as<IRArrayTypeBase>(inOriginalArrayType)->getElementType());
+
+ auto addMethod = builder->createFunc();
+ auto zeroMethod = builder->createFunc();
+
+ auto table = builder->createWitnessTable(autoDiffSharedContext->differentiableInterfaceType, (IRType*)inPrimalArrayType);
+
+ // And place it in the synthesized witness table.
+ builder->createWitnessTableEntry(table, autoDiffSharedContext->differentialAssocTypeStructKey, diffArrayType);
+ builder->createWitnessTableEntry(table, autoDiffSharedContext->differentialAssocTypeWitnessStructKey, table);
+ builder->createWitnessTableEntry(table, autoDiffSharedContext->addMethodStructKey, addMethod);
+ builder->createWitnessTableEntry(table, autoDiffSharedContext->zeroMethodStructKey, zeroMethod);
+
+ auto elementType = as<IRArrayTypeBase>(diffArrayType)->getElementType();
+
+ // Fill in differential method implementations.
+ {
+ // Add method.
+ IRBuilder b = *builder;
+ b.setInsertInto(addMethod);
+ IRType* paramTypes[2] = { diffArrayType, diffArrayType };
+ addMethod->setFullType(b.getFuncType(2, paramTypes, diffArrayType));
+ b.emitBlock();
+ auto p0 = b.emitParam(diffArrayType);
+ auto p1 = b.emitParam(diffArrayType);
+
+ // Since we are already dealing with a DiffPair<T>.Differnetial type, we know that value type == diff type.
+ auto innerAdd = _lookupWitness(&b, innerWitness, autoDiffSharedContext->addMethodStructKey);
+ auto resultVar = b.emitVar(diffArrayType);
+ IRBlock* loopBodyBlock = nullptr;
+ IRBlock* loopBreakBlock = nullptr;
+ auto loopCounter = emitLoopBlocks(&b, b.getIntValue(b.getIntType(), 0), as<IRArrayTypeBase>(diffArrayType)->getElementCount(), loopBodyBlock, loopBreakBlock);
+ b.setInsertBefore(loopBodyBlock->getTerminator());
+
+ IRInst* args[2] = {
+ b.emitElementExtract(p0, loopCounter),
+ b.emitElementExtract(p1, loopCounter) };
+ auto elementResult = b.emitCallInst(elementType, innerAdd, 2, args);
+ auto addr = b.emitElementAddress(resultVar, loopCounter);
+ b.emitStore(addr, elementResult);
+ b.setInsertInto(loopBreakBlock);
+ b.emitReturn(b.emitLoad(resultVar));
+ }
+ {
+ // Zero method.
+ IRBuilder b = *builder;
+ b.setInsertInto(zeroMethod);
+ zeroMethod->setFullType(b.getFuncType(0, nullptr, diffArrayType));
+ b.emitBlock();
+
+ auto innerZero = _lookupWitness(&b, innerWitness, autoDiffSharedContext->zeroMethodStructKey);
+ auto zeroVal = b.emitCallInst(elementType, innerZero, 0, nullptr);
+ auto retVal = b.emitMakeArrayFromElement(diffArrayType, zeroVal);
+ b.emitReturn(retVal);
+ }
+
+ // Record this in the context for future lookups
+ differentiableTypeConformanceContext.differentiableWitnessDictionary[(IRType*)inOriginalArrayType] = table;
+
+ return table;
+}
+
IRInst* AutoDiffTranscriberBase::tryGetDifferentiableWitness(IRBuilder* builder, IRInst* originalType)
{
IRInst* witness =
@@ -204,10 +328,14 @@ IRInst* AutoDiffTranscriberBase::tryGetDifferentiableWitness(IRBuilder* builder,
{
auto primalType = lookupPrimalInst(builder, originalType, nullptr);
SLANG_RELEASE_ASSERT(primalType);
- if (auto primalPairType = as<IRDifferentialPairType>(primalType))
+ if (auto primalPairType = as<IRDifferentialPairTypeBase>(primalType))
{
witness = getDifferentialPairWitness(builder, originalType, primalPairType);
}
+ else if (auto arrayType = as<IRArrayType>(primalType))
+ {
+ witness = getArrayWitness(builder, originalType, arrayType);
+ }
else if (auto extractExistential = as<IRExtractExistentialType>(originalType))
{
differentiateExtractExistentialType(builder, extractExistential, witness);
diff --git a/source/slang/slang-ir-autodiff-transcriber-base.h b/source/slang/slang-ir-autodiff-transcriber-base.h
index 47e568645..d5070689e 100644
--- a/source/slang/slang-ir-autodiff-transcriber-base.h
+++ b/source/slang/slang-ir-autodiff-transcriber-base.h
@@ -97,6 +97,7 @@ struct AutoDiffTranscriberBase
// Get or construct `:IDifferentiable` conformance for a DifferentiablePair.
IRWitnessTable* getDifferentialPairWitness(IRBuilder* builder, IRInst* inOriginalDiffPairType, IRInst* inPrimalDiffPairType);
+ IRWitnessTable* getArrayWitness(IRBuilder* builder, IRInst* inOriginalArrayType, IRInst* inPrimalArrayType);
IRInst* tryGetDifferentiableWitness(IRBuilder* builder, IRInst* originalType);
diff --git a/source/slang/slang-ir-autodiff-transpose.h b/source/slang/slang-ir-autodiff-transpose.h
index 1cd6a0e33..a92978817 100644
--- a/source/slang/slang-ir-autodiff-transpose.h
+++ b/source/slang/slang-ir-autodiff-transpose.h
@@ -2888,6 +2888,7 @@ struct DiffTransposePass
auto diffElementType = (IRType*)diffTypeContext.getDifferentialForType(builder, arrayType->getElementType());
SLANG_RELEASE_ASSERT(diffElementType);
auto arraySize = arrayType->getElementCount();
+
if (auto constArraySize = as<IRIntLit>(arraySize))
{
List<IRInst*> args;
diff --git a/source/slang/slang-ir-autodiff-unzip.cpp b/source/slang/slang-ir-autodiff-unzip.cpp
index e01452972..5b59416d4 100644
--- a/source/slang/slang-ir-autodiff-unzip.cpp
+++ b/source/slang/slang-ir-autodiff-unzip.cpp
@@ -308,9 +308,10 @@ struct ExtractPrimalFuncContext
fieldType = cloneInst(&cloneEnv, &genTypeBuilder, fieldType);
}
auto structField = genTypeBuilder.createStructField(structType, structKey, (IRType*)fieldType);
- if (auto diffFieldType = backwardPrimalTranscriber->differentiateType(&genTypeBuilder, (IRType*)fieldType))
+
+ if (auto witness = backwardPrimalTranscriber->tryGetDifferentiableWitness(&genTypeBuilder, (IRType*)fieldType))
{
- genTypeBuilder.addIntermediateContextFieldDifferentialTypeDecoration(structField, diffFieldType);
+ genTypeBuilder.addIntermediateContextFieldDifferentialTypeDecoration(structField, witness);
}
return structField;
}
diff --git a/source/slang/slang-ir-autodiff.cpp b/source/slang/slang-ir-autodiff.cpp
index 4d22d9eed..517b9e3ea 100644
--- a/source/slang/slang-ir-autodiff.cpp
+++ b/source/slang/slang-ir-autodiff.cpp
@@ -399,9 +399,7 @@ IRInst* DifferentiableTypeConformanceContext::lookUpConformanceForType(IRInst* t
IRInst* DifferentiableTypeConformanceContext::lookUpInterfaceMethod(IRBuilder* builder, IRType* origType, IRStructKey* key)
{
if (auto conformance = lookUpConformanceForType(origType))
- {
return _lookupWitness(builder, conformance, key);
- }
return nullptr;
}
@@ -889,40 +887,44 @@ struct AutoDiffPass : public InstPassBase
builder.setInsertInto(diffType);
// Generate the fields for all differentiable members of the original struct type.
+ struct FieldInfo
+ {
+ IRStructField* field;
+ IRInst* witness;
+ };
+ List<FieldInfo> diffFields;
+
for (auto field : originalType->getFields())
{
- IRInst* diffFieldType = nullptr;
+ IRInst* diffFieldWitness = nullptr;
if (auto diffDecor = field->findDecoration<IRIntermediateContextFieldDifferentialTypeDecoration>())
{
- diffFieldType = diffDecor->getDifferentialType();
+ diffFieldWitness = diffDecor->getDifferentialWitness();
}
else
{
IntermediateContextTypeDifferentialInfo diffFieldTypeInfo;
diffTypes.TryGetValue(field->getDataType(), diffFieldTypeInfo);
- diffFieldType = diffFieldTypeInfo.diffType;
+ diffFieldWitness = diffFieldTypeInfo.diffWitness;
}
- if (diffFieldType)
+ if (diffFieldWitness)
{
+ FieldInfo info;
IRBuilder keyBuilder = builder;
keyBuilder.setInsertBefore(maybeFindOuterGeneric(originalType));
auto diffKey = keyBuilder.createStructKey();
- builder.createStructField(diffType, diffKey, (IRType*)diffFieldType);
+ auto diffFieldType = _lookupWitness(&keyBuilder, diffFieldWitness, autodiffContext->differentialAssocTypeStructKey);
+ info.field = builder.createStructField(diffType, diffKey, (IRType*)diffFieldType);
+ info.witness = diffFieldWitness;
builder.addDecoration(field->getKey(), kIROp_DerivativeMemberDecoration, diffKey);
builder.addDecoration(diffKey, kIROp_DerivativeMemberDecoration, diffKey);
+ diffFields.add(info);
}
}
builder.setInsertAfter(diffType);
- // For now, we are going to structurally derive dadd and dzero methods for intermediate context types,
- // because it is tricky for us to obtain the original witness tables for the fields at this point.
- // This is inconsistent with how we are dealing with dadd and dzero methods via witness table lookup,
- // and can lead to problems if the user defines any non-trivial dadd/dzero methods.
- //
- // TODO: we should consider rewrite this logic to be witness table lookup based, or simplify the entire
- // type system and IR passes to always use structurally derived methods instead of user-provided
- // methods.
+ // Implement `dadd` and `dzero` methods.
IRInst* zeroMethod = nullptr;
{
auto zeroMethodType = builder.getFuncType(List<IRType*>(), diffType);
@@ -931,7 +933,14 @@ struct AutoDiffPass : public InstPassBase
result.zeroMethod = zeroMethod;
builder.setInsertInto(zeroMethod);
builder.emitBlock();
- builder.emitReturn(builder.emitDefaultConstruct(diffType));
+ List<IRInst*> fieldVals;
+ for (auto info : diffFields)
+ {
+ auto innerZeroMethod = _lookupWitness(&builder, info.witness, autodiffContext->zeroMethodStructKey);
+ IRInst* val = builder.emitCallInst(info.field->getFieldType(), innerZeroMethod, 0, nullptr);
+ fieldVals.add(val);
+ }
+ builder.emitReturn(builder.emitMakeStruct(diffType, fieldVals));
}
builder.setInsertAfter(zeroMethod);
@@ -948,7 +957,18 @@ struct AutoDiffPass : public InstPassBase
builder.emitBlock();
auto param1 = builder.emitParam(diffType);
auto param2 = builder.emitParam(diffType);
- builder.emitReturn(builder.emitStructuralAdd(param1, param2));
+ List<IRInst*> fieldVals;
+ for (auto info : diffFields)
+ {
+ auto innerAddMethod = _lookupWitness(&builder, info.witness, autodiffContext->addMethodStructKey);
+ IRInst* args[2] = {
+ builder.emitFieldExtract(info.field->getFieldType(), param1, info.field->getKey()),
+ builder.emitFieldExtract(info.field->getFieldType(), param2, info.field->getKey()),
+ };
+ IRInst* val = builder.emitCallInst(info.field->getFieldType(), innerAddMethod, 2, args);
+ fieldVals.add(val);
+ }
+ builder.emitReturn(builder.emitMakeStruct(diffType, fieldVals));
}
builder.setInsertAfter(addMethod);
diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h
index bb8cfc378..71d9315bd 100644
--- a/source/slang/slang-ir-inst-defs.h
+++ b/source/slang/slang-ir-inst-defs.h
@@ -320,9 +320,6 @@ INST(MakeOptionalValue, makeOptionalValue, 1, 0)
INST(MakeOptionalNone, makeOptionalNone, 1, 0)
INST(Call, call, 1, 0)
-// Structural addition of two values of the same type.
-INST(StructuralAdd, structuralAdd, 2, 0)
-
INST(RTTIObject, rtti_object, 0, 0)
INST(Alloca, alloca, 1, 0)
diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h
index 43893bfe6..0f5c36dcb 100644
--- a/source/slang/slang-ir-insts.h
+++ b/source/slang/slang-ir-insts.h
@@ -792,8 +792,7 @@ struct IRIntermediateContextFieldDifferentialTypeDecoration : IRDecoration
IR_LEAF_ISA(IntermediateContextFieldDifferentialTypeDecoration)
- IRInst* getDifferentialType() { return getOperand(0); }
- IRInst* getDifferentialWitness() { return getOperand(1); }
+ IRInst* getDifferentialWitness() { return getOperand(0); }
};
@@ -2886,7 +2885,7 @@ public:
IRInst* addPrimalValueStructKeyDecoration(IRInst* target, IRStructKey* key);
IRInst* addPrimalElementTypeDecoration(IRInst* target, IRInst* type);
- IRInst* addIntermediateContextFieldDifferentialTypeDecoration(IRInst* target, IRInst* type);
+ IRInst* addIntermediateContextFieldDifferentialTypeDecoration(IRInst* target, IRInst* witness);
// Add a differentiable type entry to the appropriate dictionary.
IRInst* addDifferentiableTypeEntry(IRInst* dictDecoration, IRInst* irType, IRInst* conformanceWitness);
@@ -2969,15 +2968,6 @@ public:
/// the inst.
IRInst* emitDefaultConstructRaw(IRType* type);
- /// Emits appropriate inst for structurally adding two values of `type`.
- /// If `fallback` is true, will emit `StructuralAdd` inst on unknown types.
- /// Otherwise, returns nullptr if we can't materialize the inst.
- IRInst* emitStructuralAdd(IRInst* val0, IRInst* val1, bool fallback = true);
-
- /// Emits a raw `StructuralAdd` opcode without attempting to fold/materialize
- /// the inst.
- IRInst* emitStructuralAddRaw(IRInst* val0, IRInst* val1);
-
IRInst* emitCast(
IRType* type,
IRInst* value);
diff --git a/source/slang/slang-ir-peephole.cpp b/source/slang/slang-ir-peephole.cpp
index a5ec50b2c..5d5a41726 100644
--- a/source/slang/slang-ir-peephole.cpp
+++ b/source/slang/slang-ir-peephole.cpp
@@ -633,19 +633,6 @@ struct PeepholeContext : InstPassBase
}
}
break;
- case kIROp_StructuralAdd:
- {
- IRBuilder builder(module);
- builder.setInsertBefore(inst);
- // See if we can replace the generic add inst with concrete values.
- if (auto newCtor = builder.emitStructuralAdd(inst->getOperand(0), inst->getOperand(1), false))
- {
- inst->replaceUsesWith(newCtor);
- maybeRemoveOldInst(inst);
- changed = true;
- }
- }
- break;
case kIROp_Add:
case kIROp_Mul:
case kIROp_Sub:
diff --git a/source/slang/slang-ir-util.cpp b/source/slang/slang-ir-util.cpp
index 254734965..de03a1661 100644
--- a/source/slang/slang-ir-util.cpp
+++ b/source/slang/slang-ir-util.cpp
@@ -554,6 +554,26 @@ IROp getSwapSideComparisonOp(IROp op)
}
}
+IRInst* emitLoopBlocks(IRBuilder* builder, IRInst* initVal, IRInst* finalVal, IRBlock*& loopBodyBlock, IRBlock*& loopBreakBlock)
+{
+ IRBuilder loopBuilder = *builder;
+ auto loopHeadBlock = loopBuilder.emitBlock();
+ loopBodyBlock = loopBuilder.emitBlock();
+ loopBreakBlock = loopBuilder.emitBlock();
+ auto loopContinueBlock = loopBuilder.emitBlock();
+ builder->emitLoop(loopHeadBlock, loopBreakBlock, loopHeadBlock, 1, &initVal);
+ loopBuilder.setInsertInto(loopHeadBlock);
+ auto loopParam = loopBuilder.emitParam(initVal->getFullType());
+ auto cmpResult = loopBuilder.emitLess(loopParam, finalVal);
+ loopBuilder.emitIfElse(cmpResult, loopBodyBlock, loopBreakBlock, loopBreakBlock);
+ loopBuilder.setInsertInto(loopBodyBlock);
+ loopBuilder.emitBranch(loopContinueBlock);
+ loopBuilder.setInsertInto(loopContinueBlock);
+ auto newParam = loopBuilder.emitAdd(loopParam->getFullType(), loopParam, loopBuilder.getIntValue(loopBuilder.getIntType(), 1));
+ loopBuilder.emitBranch(loopHeadBlock, 1, &newParam);
+ return loopParam;
+}
+
void setInsertBeforeOrdinaryInst(IRBuilder* builder, IRInst* inst)
{
if (as<IRParam>(inst))
diff --git a/source/slang/slang-ir-util.h b/source/slang/slang-ir-util.h
index 62156cad6..0989dee33 100644
--- a/source/slang/slang-ir-util.h
+++ b/source/slang/slang-ir-util.h
@@ -182,6 +182,10 @@ void setInsertBeforeOrdinaryInst(IRBuilder* builder, IRInst* inst);
// Set IRBuilder to insert after `inst`. If `inst` is a param, it will insert after the last param.
void setInsertAfterOrdinaryInst(IRBuilder* builder, IRInst* inst);
+// Emit a loop structure with a simple incrementing counter.
+// Returns the loop counter `IRParam`.
+IRInst* emitLoopBlocks(IRBuilder* builder, IRInst* initVal, IRInst* finalVal, IRBlock*& loopBodyBlock, IRBlock*& loopBreakBlock);
+
}
#endif
diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp
index f61e5a10e..9f877969a 100644
--- a/source/slang/slang-ir.cpp
+++ b/source/slang/slang-ir.cpp
@@ -3509,104 +3509,6 @@ namespace Slang
return nullptr;
}
- IRInst* IRBuilder::emitStructuralAddRaw(IRInst* val0, IRInst* val1)
- {
- IRInst* args[2] = { val0, val1 };
- return emitIntrinsicInst(val0->getFullType(), kIROp_StructuralAdd, 2, args);
- }
-
- IRInst* IRBuilder::emitStructuralAdd(IRInst* val0, IRInst* val1, bool fallback)
- {
- auto type = val0->getFullType();
- SLANG_RELEASE_ASSERT(val0->getFullType() == val1->getFullType());
- IRType* actualType = val0->getFullType();
- for (;;)
- {
- if (auto attr = as<IRAttributedType>(actualType))
- actualType = attr->getBaseType();
- else if (auto rateQualified = as<IRRateQualifiedType>(actualType))
- actualType = rateQualified->getValueType();
- else
- break;
- }
- if (as<IRBasicType>(actualType))
- return emitAdd(type, val0, val1);
-
- switch (actualType->getOp())
- {
- case kIROp_PtrType:
- case kIROp_VectorType:
- case kIROp_MatrixType:
- return emitAdd(type, val0, val1);
- case kIROp_TupleType:
- {
- List<IRInst*> elements;
- auto tupleType = as<IRTupleType>(actualType);
- for (UInt i = 0; i < tupleType->getOperandCount(); i++)
- {
- auto operand = tupleType->getOperand(i);
- if (as<IRAttr>(operand))
- break;
- auto inner = emitStructuralAdd(
- emitGetTupleElement((IRType*)operand, val0, i),
- emitGetTupleElement((IRType*)operand, val1, i),
- fallback);
- if (!inner)
- return nullptr;
- elements.add(inner);
- }
- return emitMakeTuple(tupleType, elements);
- }
- case kIROp_StructType:
- {
- List<IRInst*> elements;
- auto structType = as<IRStructType>(actualType);
- for (auto field : structType->getFields())
- {
- auto fieldType = field->getFieldType();
- auto inner = emitStructuralAdd(
- emitFieldExtract(fieldType, val0, field->getKey()),
- emitFieldExtract(fieldType, val1, field->getKey()),
- fallback);
- if (!inner)
- return nullptr;
- elements.add(inner);
- }
- return emitMakeStruct(type, elements);
- }
- case kIROp_ArrayType:
- {
- auto arrayType = as<IRArrayType>(actualType);
- if (auto count = as<IRIntLit>(arrayType->getElementCount()))
- {
- auto elementType = arrayType->getElementType();
- List<IRInst*> elements;
- constexpr int maxCount = 4096;
- if (count->getValue() > maxCount)
- break;
- for (IRIntegerValue i = 0; i < count->getValue(); i++)
- {
- auto index = getIntValue(getIntType(), i);
- auto element = emitStructuralAdd(
- emitElementExtract(elementType, val0, index),
- emitElementExtract(elementType, val1, index),
- fallback);
- elements.add(element);
- }
- return emitMakeArray(type, elements.getCount(), elements.getBuffer());
- }
- break;
- }
- default:
- break;
- }
- if (fallback)
- {
- return emitStructuralAddRaw(val0, val1);
- }
- return nullptr;
- }
-
static int _getTypeStyleId(IRType* type)
{
if (auto vectorType = as<IRVectorType>(type))
@@ -4026,9 +3928,9 @@ namespace Slang
return addDecoration(target, kIROp_PrimalElementTypeDecoration, type);
}
- IRInst* IRBuilder::addIntermediateContextFieldDifferentialTypeDecoration(IRInst* target, IRInst* type)
+ IRInst* IRBuilder::addIntermediateContextFieldDifferentialTypeDecoration(IRInst* target, IRInst* witness)
{
- return addDecoration(target, kIROp_IntermediateContextFieldDifferentialTypeDecoration, type);
+ return addDecoration(target, kIROp_IntermediateContextFieldDifferentialTypeDecoration, witness);
}
RefPtr<IRModule> IRModule::create(Session* session)
@@ -7131,7 +7033,6 @@ namespace Slang
case kIROp_Nop:
case kIROp_undefined:
case kIROp_DefaultConstruct:
- case kIROp_StructuralAdd:
case kIROp_Specialize:
case kIROp_LookupWitness:
case kIROp_GetSequentialID:
diff --git a/tests/autodiff/high-order-backward-diff-3.slang b/tests/autodiff/high-order-backward-diff-3.slang
index eb3866b96..100a9a1e0 100644
--- a/tests/autodiff/high-order-backward-diff-3.slang
+++ b/tests/autodiff/high-order-backward-diff-3.slang
@@ -14,14 +14,20 @@ struct A : IDifferentiable
[BackwardDifferentiable]
float f(A x)
{
- return x.x * x.x;
+ A rs;
+ rs.x = 1.0;
+ for (int i = 0; i < 2; i++)
+ rs.x = rs.x * x.x;
+ return rs.x;
}
[BackwardDifferentiable]
float outerF(A x)
{
A nx;
- nx.x = x.x * x.x;
+ nx.x = 1.0;
+ for (int i = 0; i < 2; i++)
+ nx.x = nx.x * x.x;
nx.nx = 2;//x.nx;
return f(nx);
}