diff options
| author | Yong He <yonghe@outlook.com> | 2023-03-15 09:39:21 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2023-03-15 09:39:21 -0700 |
| commit | bf308241b54ae9c421a29aa5620da9fb3ec15245 (patch) | |
| tree | acf114b9e0677f6b6494b105130d7043b1be872b | |
| parent | 176eaa9f7770ad81cbd71def8a1551d6237167bd (diff) | |
Properly implement differential witness of intermediate context type. (#2699)
* Properly implement differential witness of intermediate context type.
* Modify test to include a loop.
---------
Co-authored-by: Yong He <yhe@nvidia.com>
| -rw-r--r-- | source/slang/slang-ir-autodiff-transcriber-base.cpp | 136 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-transcriber-base.h | 1 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-transpose.h | 1 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-unzip.cpp | 5 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff.cpp | 54 | ||||
| -rw-r--r-- | source/slang/slang-ir-inst-defs.h | 3 | ||||
| -rw-r--r-- | source/slang/slang-ir-insts.h | 14 | ||||
| -rw-r--r-- | source/slang/slang-ir-peephole.cpp | 13 | ||||
| -rw-r--r-- | source/slang/slang-ir-util.cpp | 20 | ||||
| -rw-r--r-- | source/slang/slang-ir-util.h | 4 | ||||
| -rw-r--r-- | source/slang/slang-ir.cpp | 103 | ||||
| -rw-r--r-- | tests/autodiff/high-order-backward-diff-3.slang | 10 |
12 files changed, 210 insertions, 154 deletions
diff --git a/source/slang/slang-ir-autodiff-transcriber-base.cpp b/source/slang/slang-ir-autodiff-transcriber-base.cpp index 0f51a6c62..e3ef357ee 100644 --- a/source/slang/slang-ir-autodiff-transcriber-base.cpp +++ b/source/slang/slang-ir-autodiff-transcriber-base.cpp @@ -172,25 +172,149 @@ IRInst* AutoDiffTranscriberBase::maybeCloneForPrimalInst(IRBuilder* builder, IRI return primal; } +IRInst* _lookupWitness(IRBuilder* builder, IRInst* witness, IRInst* requirementKey); + // Get or construct `:IDifferentiable` conformance for a DifferentiablePair. IRWitnessTable* AutoDiffTranscriberBase::getDifferentialPairWitness(IRBuilder* builder, IRInst* inOriginalDiffPairType, IRInst* inPrimalDiffPairType) { // Differentiate the pair type to get it's differential (which is itself a pair) - auto diffDiffPairType = differentiateType(builder, (IRType*)inOriginalDiffPairType); - + auto diffDiffPairType = (IRType*)differentiateType(builder, (IRType*)inOriginalDiffPairType); + + auto addMethod = builder->createFunc(); + auto zeroMethod = builder->createFunc(); + auto table = builder->createWitnessTable(autoDiffSharedContext->differentiableInterfaceType, (IRType*)inPrimalDiffPairType); // And place it in the synthesized witness table. builder->createWitnessTableEntry(table, autoDiffSharedContext->differentialAssocTypeStructKey, diffDiffPairType); + builder->createWitnessTableEntry(table, autoDiffSharedContext->differentialAssocTypeWitnessStructKey, table); + builder->createWitnessTableEntry(table, autoDiffSharedContext->addMethodStructKey, addMethod); + builder->createWitnessTableEntry(table, autoDiffSharedContext->zeroMethodStructKey, zeroMethod); + + bool isUserCodeType = as<IRDifferentialPairUserCodeType>(inOriginalDiffPairType) ? true : false; - // Omit the method synthesis here, since we can just intercept those directly at `getXXMethodForType`. + // Fill in differential method implementations. + auto elementType = as<IRDifferentialPairTypeBase>(inPrimalDiffPairType)->getValueType(); + auto innerWitness = as<IRDifferentialPairTypeBase>(inPrimalDiffPairType)->getWitness(); + { + // Add method. + IRBuilder b = *builder; + b.setInsertInto(addMethod); + IRType* paramTypes[2] = { diffDiffPairType, diffDiffPairType }; + addMethod->setFullType(b.getFuncType(2, paramTypes, diffDiffPairType)); + b.emitBlock(); + auto p0 = b.emitParam(diffDiffPairType); + auto p1 = b.emitParam(diffDiffPairType); + + // Since we are already dealing with a DiffPair<T>.Differnetial type, we know that value type == diff type. + auto innerAdd = _lookupWitness(&b, innerWitness, autoDiffSharedContext->addMethodStructKey); + IRInst* argsPrimal[2] = { + isUserCodeType ? b.emitDifferentialPairGetPrimalUserCode(p0) : b.emitDifferentialPairGetPrimal(p0), + isUserCodeType ? b.emitDifferentialPairGetPrimalUserCode(p1) : b.emitDifferentialPairGetPrimal(p1) }; + auto primalPart = b.emitCallInst(elementType, innerAdd, 2, argsPrimal); + IRInst* argsDiff[2] = { + isUserCodeType ? b.emitDifferentialPairGetDifferentialUserCode(elementType, p0) : b.emitDifferentialPairGetDifferential(elementType, p0), + isUserCodeType ? b.emitDifferentialPairGetDifferentialUserCode(elementType, p1) : b.emitDifferentialPairGetDifferential(elementType, p1)}; + auto diffPart = b.emitCallInst(elementType, innerAdd, 2, argsDiff); + auto retVal = + isUserCodeType + ? b.emitMakeDifferentialPairUserCode(diffDiffPairType, primalPart, diffPart) + : b.emitMakeDifferentialPair(diffDiffPairType, primalPart, diffPart); + b.emitReturn(retVal); + } + { + // Zero method. + IRBuilder b = *builder; + b.setInsertInto(zeroMethod); + zeroMethod->setFullType(b.getFuncType(0, nullptr, diffDiffPairType)); + b.emitBlock(); + auto innerZero = _lookupWitness(&b, innerWitness, autoDiffSharedContext->zeroMethodStructKey); + auto zeroVal = b.emitCallInst(elementType, innerZero, 0, nullptr); + auto retVal = + isUserCodeType + ? b.emitMakeDifferentialPairUserCode(diffDiffPairType, zeroVal, zeroVal) + : b.emitMakeDifferentialPair(diffDiffPairType, zeroVal, zeroVal); + b.emitReturn(retVal); + } + // Record this in the context for future lookups differentiableTypeConformanceContext.differentiableWitnessDictionary[(IRType*)inOriginalDiffPairType] = table; return table; } +// Get or construct `:IDifferentiable` conformance for an Array. +IRWitnessTable* AutoDiffTranscriberBase::getArrayWitness(IRBuilder* builder, IRInst* inOriginalArrayType, IRInst* inPrimalArrayType) +{ + // Differentiate the pair type to get it's differential (which is itself a pair) + auto diffArrayType = (IRType*)differentiateType(builder, (IRType*)inOriginalArrayType); + + if (!diffArrayType) + return nullptr; + + auto innerWitness = tryGetDifferentiableWitness(builder, as<IRArrayTypeBase>(inOriginalArrayType)->getElementType()); + + auto addMethod = builder->createFunc(); + auto zeroMethod = builder->createFunc(); + + auto table = builder->createWitnessTable(autoDiffSharedContext->differentiableInterfaceType, (IRType*)inPrimalArrayType); + + // And place it in the synthesized witness table. + builder->createWitnessTableEntry(table, autoDiffSharedContext->differentialAssocTypeStructKey, diffArrayType); + builder->createWitnessTableEntry(table, autoDiffSharedContext->differentialAssocTypeWitnessStructKey, table); + builder->createWitnessTableEntry(table, autoDiffSharedContext->addMethodStructKey, addMethod); + builder->createWitnessTableEntry(table, autoDiffSharedContext->zeroMethodStructKey, zeroMethod); + + auto elementType = as<IRArrayTypeBase>(diffArrayType)->getElementType(); + + // Fill in differential method implementations. + { + // Add method. + IRBuilder b = *builder; + b.setInsertInto(addMethod); + IRType* paramTypes[2] = { diffArrayType, diffArrayType }; + addMethod->setFullType(b.getFuncType(2, paramTypes, diffArrayType)); + b.emitBlock(); + auto p0 = b.emitParam(diffArrayType); + auto p1 = b.emitParam(diffArrayType); + + // Since we are already dealing with a DiffPair<T>.Differnetial type, we know that value type == diff type. + auto innerAdd = _lookupWitness(&b, innerWitness, autoDiffSharedContext->addMethodStructKey); + auto resultVar = b.emitVar(diffArrayType); + IRBlock* loopBodyBlock = nullptr; + IRBlock* loopBreakBlock = nullptr; + auto loopCounter = emitLoopBlocks(&b, b.getIntValue(b.getIntType(), 0), as<IRArrayTypeBase>(diffArrayType)->getElementCount(), loopBodyBlock, loopBreakBlock); + b.setInsertBefore(loopBodyBlock->getTerminator()); + + IRInst* args[2] = { + b.emitElementExtract(p0, loopCounter), + b.emitElementExtract(p1, loopCounter) }; + auto elementResult = b.emitCallInst(elementType, innerAdd, 2, args); + auto addr = b.emitElementAddress(resultVar, loopCounter); + b.emitStore(addr, elementResult); + b.setInsertInto(loopBreakBlock); + b.emitReturn(b.emitLoad(resultVar)); + } + { + // Zero method. + IRBuilder b = *builder; + b.setInsertInto(zeroMethod); + zeroMethod->setFullType(b.getFuncType(0, nullptr, diffArrayType)); + b.emitBlock(); + + auto innerZero = _lookupWitness(&b, innerWitness, autoDiffSharedContext->zeroMethodStructKey); + auto zeroVal = b.emitCallInst(elementType, innerZero, 0, nullptr); + auto retVal = b.emitMakeArrayFromElement(diffArrayType, zeroVal); + b.emitReturn(retVal); + } + + // Record this in the context for future lookups + differentiableTypeConformanceContext.differentiableWitnessDictionary[(IRType*)inOriginalArrayType] = table; + + return table; +} + IRInst* AutoDiffTranscriberBase::tryGetDifferentiableWitness(IRBuilder* builder, IRInst* originalType) { IRInst* witness = @@ -204,10 +328,14 @@ IRInst* AutoDiffTranscriberBase::tryGetDifferentiableWitness(IRBuilder* builder, { auto primalType = lookupPrimalInst(builder, originalType, nullptr); SLANG_RELEASE_ASSERT(primalType); - if (auto primalPairType = as<IRDifferentialPairType>(primalType)) + if (auto primalPairType = as<IRDifferentialPairTypeBase>(primalType)) { witness = getDifferentialPairWitness(builder, originalType, primalPairType); } + else if (auto arrayType = as<IRArrayType>(primalType)) + { + witness = getArrayWitness(builder, originalType, arrayType); + } else if (auto extractExistential = as<IRExtractExistentialType>(originalType)) { differentiateExtractExistentialType(builder, extractExistential, witness); diff --git a/source/slang/slang-ir-autodiff-transcriber-base.h b/source/slang/slang-ir-autodiff-transcriber-base.h index 47e568645..d5070689e 100644 --- a/source/slang/slang-ir-autodiff-transcriber-base.h +++ b/source/slang/slang-ir-autodiff-transcriber-base.h @@ -97,6 +97,7 @@ struct AutoDiffTranscriberBase // Get or construct `:IDifferentiable` conformance for a DifferentiablePair. IRWitnessTable* getDifferentialPairWitness(IRBuilder* builder, IRInst* inOriginalDiffPairType, IRInst* inPrimalDiffPairType); + IRWitnessTable* getArrayWitness(IRBuilder* builder, IRInst* inOriginalArrayType, IRInst* inPrimalArrayType); IRInst* tryGetDifferentiableWitness(IRBuilder* builder, IRInst* originalType); diff --git a/source/slang/slang-ir-autodiff-transpose.h b/source/slang/slang-ir-autodiff-transpose.h index 1cd6a0e33..a92978817 100644 --- a/source/slang/slang-ir-autodiff-transpose.h +++ b/source/slang/slang-ir-autodiff-transpose.h @@ -2888,6 +2888,7 @@ struct DiffTransposePass auto diffElementType = (IRType*)diffTypeContext.getDifferentialForType(builder, arrayType->getElementType()); SLANG_RELEASE_ASSERT(diffElementType); auto arraySize = arrayType->getElementCount(); + if (auto constArraySize = as<IRIntLit>(arraySize)) { List<IRInst*> args; diff --git a/source/slang/slang-ir-autodiff-unzip.cpp b/source/slang/slang-ir-autodiff-unzip.cpp index e01452972..5b59416d4 100644 --- a/source/slang/slang-ir-autodiff-unzip.cpp +++ b/source/slang/slang-ir-autodiff-unzip.cpp @@ -308,9 +308,10 @@ struct ExtractPrimalFuncContext fieldType = cloneInst(&cloneEnv, &genTypeBuilder, fieldType); } auto structField = genTypeBuilder.createStructField(structType, structKey, (IRType*)fieldType); - if (auto diffFieldType = backwardPrimalTranscriber->differentiateType(&genTypeBuilder, (IRType*)fieldType)) + + if (auto witness = backwardPrimalTranscriber->tryGetDifferentiableWitness(&genTypeBuilder, (IRType*)fieldType)) { - genTypeBuilder.addIntermediateContextFieldDifferentialTypeDecoration(structField, diffFieldType); + genTypeBuilder.addIntermediateContextFieldDifferentialTypeDecoration(structField, witness); } return structField; } diff --git a/source/slang/slang-ir-autodiff.cpp b/source/slang/slang-ir-autodiff.cpp index 4d22d9eed..517b9e3ea 100644 --- a/source/slang/slang-ir-autodiff.cpp +++ b/source/slang/slang-ir-autodiff.cpp @@ -399,9 +399,7 @@ IRInst* DifferentiableTypeConformanceContext::lookUpConformanceForType(IRInst* t IRInst* DifferentiableTypeConformanceContext::lookUpInterfaceMethod(IRBuilder* builder, IRType* origType, IRStructKey* key) { if (auto conformance = lookUpConformanceForType(origType)) - { return _lookupWitness(builder, conformance, key); - } return nullptr; } @@ -889,40 +887,44 @@ struct AutoDiffPass : public InstPassBase builder.setInsertInto(diffType); // Generate the fields for all differentiable members of the original struct type. + struct FieldInfo + { + IRStructField* field; + IRInst* witness; + }; + List<FieldInfo> diffFields; + for (auto field : originalType->getFields()) { - IRInst* diffFieldType = nullptr; + IRInst* diffFieldWitness = nullptr; if (auto diffDecor = field->findDecoration<IRIntermediateContextFieldDifferentialTypeDecoration>()) { - diffFieldType = diffDecor->getDifferentialType(); + diffFieldWitness = diffDecor->getDifferentialWitness(); } else { IntermediateContextTypeDifferentialInfo diffFieldTypeInfo; diffTypes.TryGetValue(field->getDataType(), diffFieldTypeInfo); - diffFieldType = diffFieldTypeInfo.diffType; + diffFieldWitness = diffFieldTypeInfo.diffWitness; } - if (diffFieldType) + if (diffFieldWitness) { + FieldInfo info; IRBuilder keyBuilder = builder; keyBuilder.setInsertBefore(maybeFindOuterGeneric(originalType)); auto diffKey = keyBuilder.createStructKey(); - builder.createStructField(diffType, diffKey, (IRType*)diffFieldType); + auto diffFieldType = _lookupWitness(&keyBuilder, diffFieldWitness, autodiffContext->differentialAssocTypeStructKey); + info.field = builder.createStructField(diffType, diffKey, (IRType*)diffFieldType); + info.witness = diffFieldWitness; builder.addDecoration(field->getKey(), kIROp_DerivativeMemberDecoration, diffKey); builder.addDecoration(diffKey, kIROp_DerivativeMemberDecoration, diffKey); + diffFields.add(info); } } builder.setInsertAfter(diffType); - // For now, we are going to structurally derive dadd and dzero methods for intermediate context types, - // because it is tricky for us to obtain the original witness tables for the fields at this point. - // This is inconsistent with how we are dealing with dadd and dzero methods via witness table lookup, - // and can lead to problems if the user defines any non-trivial dadd/dzero methods. - // - // TODO: we should consider rewrite this logic to be witness table lookup based, or simplify the entire - // type system and IR passes to always use structurally derived methods instead of user-provided - // methods. + // Implement `dadd` and `dzero` methods. IRInst* zeroMethod = nullptr; { auto zeroMethodType = builder.getFuncType(List<IRType*>(), diffType); @@ -931,7 +933,14 @@ struct AutoDiffPass : public InstPassBase result.zeroMethod = zeroMethod; builder.setInsertInto(zeroMethod); builder.emitBlock(); - builder.emitReturn(builder.emitDefaultConstruct(diffType)); + List<IRInst*> fieldVals; + for (auto info : diffFields) + { + auto innerZeroMethod = _lookupWitness(&builder, info.witness, autodiffContext->zeroMethodStructKey); + IRInst* val = builder.emitCallInst(info.field->getFieldType(), innerZeroMethod, 0, nullptr); + fieldVals.add(val); + } + builder.emitReturn(builder.emitMakeStruct(diffType, fieldVals)); } builder.setInsertAfter(zeroMethod); @@ -948,7 +957,18 @@ struct AutoDiffPass : public InstPassBase builder.emitBlock(); auto param1 = builder.emitParam(diffType); auto param2 = builder.emitParam(diffType); - builder.emitReturn(builder.emitStructuralAdd(param1, param2)); + List<IRInst*> fieldVals; + for (auto info : diffFields) + { + auto innerAddMethod = _lookupWitness(&builder, info.witness, autodiffContext->addMethodStructKey); + IRInst* args[2] = { + builder.emitFieldExtract(info.field->getFieldType(), param1, info.field->getKey()), + builder.emitFieldExtract(info.field->getFieldType(), param2, info.field->getKey()), + }; + IRInst* val = builder.emitCallInst(info.field->getFieldType(), innerAddMethod, 2, args); + fieldVals.add(val); + } + builder.emitReturn(builder.emitMakeStruct(diffType, fieldVals)); } builder.setInsertAfter(addMethod); diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h index bb8cfc378..71d9315bd 100644 --- a/source/slang/slang-ir-inst-defs.h +++ b/source/slang/slang-ir-inst-defs.h @@ -320,9 +320,6 @@ INST(MakeOptionalValue, makeOptionalValue, 1, 0) INST(MakeOptionalNone, makeOptionalNone, 1, 0) INST(Call, call, 1, 0) -// Structural addition of two values of the same type. -INST(StructuralAdd, structuralAdd, 2, 0) - INST(RTTIObject, rtti_object, 0, 0) INST(Alloca, alloca, 1, 0) diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h index 43893bfe6..0f5c36dcb 100644 --- a/source/slang/slang-ir-insts.h +++ b/source/slang/slang-ir-insts.h @@ -792,8 +792,7 @@ struct IRIntermediateContextFieldDifferentialTypeDecoration : IRDecoration IR_LEAF_ISA(IntermediateContextFieldDifferentialTypeDecoration) - IRInst* getDifferentialType() { return getOperand(0); } - IRInst* getDifferentialWitness() { return getOperand(1); } + IRInst* getDifferentialWitness() { return getOperand(0); } }; @@ -2886,7 +2885,7 @@ public: IRInst* addPrimalValueStructKeyDecoration(IRInst* target, IRStructKey* key); IRInst* addPrimalElementTypeDecoration(IRInst* target, IRInst* type); - IRInst* addIntermediateContextFieldDifferentialTypeDecoration(IRInst* target, IRInst* type); + IRInst* addIntermediateContextFieldDifferentialTypeDecoration(IRInst* target, IRInst* witness); // Add a differentiable type entry to the appropriate dictionary. IRInst* addDifferentiableTypeEntry(IRInst* dictDecoration, IRInst* irType, IRInst* conformanceWitness); @@ -2969,15 +2968,6 @@ public: /// the inst. IRInst* emitDefaultConstructRaw(IRType* type); - /// Emits appropriate inst for structurally adding two values of `type`. - /// If `fallback` is true, will emit `StructuralAdd` inst on unknown types. - /// Otherwise, returns nullptr if we can't materialize the inst. - IRInst* emitStructuralAdd(IRInst* val0, IRInst* val1, bool fallback = true); - - /// Emits a raw `StructuralAdd` opcode without attempting to fold/materialize - /// the inst. - IRInst* emitStructuralAddRaw(IRInst* val0, IRInst* val1); - IRInst* emitCast( IRType* type, IRInst* value); diff --git a/source/slang/slang-ir-peephole.cpp b/source/slang/slang-ir-peephole.cpp index a5ec50b2c..5d5a41726 100644 --- a/source/slang/slang-ir-peephole.cpp +++ b/source/slang/slang-ir-peephole.cpp @@ -633,19 +633,6 @@ struct PeepholeContext : InstPassBase } } break; - case kIROp_StructuralAdd: - { - IRBuilder builder(module); - builder.setInsertBefore(inst); - // See if we can replace the generic add inst with concrete values. - if (auto newCtor = builder.emitStructuralAdd(inst->getOperand(0), inst->getOperand(1), false)) - { - inst->replaceUsesWith(newCtor); - maybeRemoveOldInst(inst); - changed = true; - } - } - break; case kIROp_Add: case kIROp_Mul: case kIROp_Sub: diff --git a/source/slang/slang-ir-util.cpp b/source/slang/slang-ir-util.cpp index 254734965..de03a1661 100644 --- a/source/slang/slang-ir-util.cpp +++ b/source/slang/slang-ir-util.cpp @@ -554,6 +554,26 @@ IROp getSwapSideComparisonOp(IROp op) } } +IRInst* emitLoopBlocks(IRBuilder* builder, IRInst* initVal, IRInst* finalVal, IRBlock*& loopBodyBlock, IRBlock*& loopBreakBlock) +{ + IRBuilder loopBuilder = *builder; + auto loopHeadBlock = loopBuilder.emitBlock(); + loopBodyBlock = loopBuilder.emitBlock(); + loopBreakBlock = loopBuilder.emitBlock(); + auto loopContinueBlock = loopBuilder.emitBlock(); + builder->emitLoop(loopHeadBlock, loopBreakBlock, loopHeadBlock, 1, &initVal); + loopBuilder.setInsertInto(loopHeadBlock); + auto loopParam = loopBuilder.emitParam(initVal->getFullType()); + auto cmpResult = loopBuilder.emitLess(loopParam, finalVal); + loopBuilder.emitIfElse(cmpResult, loopBodyBlock, loopBreakBlock, loopBreakBlock); + loopBuilder.setInsertInto(loopBodyBlock); + loopBuilder.emitBranch(loopContinueBlock); + loopBuilder.setInsertInto(loopContinueBlock); + auto newParam = loopBuilder.emitAdd(loopParam->getFullType(), loopParam, loopBuilder.getIntValue(loopBuilder.getIntType(), 1)); + loopBuilder.emitBranch(loopHeadBlock, 1, &newParam); + return loopParam; +} + void setInsertBeforeOrdinaryInst(IRBuilder* builder, IRInst* inst) { if (as<IRParam>(inst)) diff --git a/source/slang/slang-ir-util.h b/source/slang/slang-ir-util.h index 62156cad6..0989dee33 100644 --- a/source/slang/slang-ir-util.h +++ b/source/slang/slang-ir-util.h @@ -182,6 +182,10 @@ void setInsertBeforeOrdinaryInst(IRBuilder* builder, IRInst* inst); // Set IRBuilder to insert after `inst`. If `inst` is a param, it will insert after the last param. void setInsertAfterOrdinaryInst(IRBuilder* builder, IRInst* inst); +// Emit a loop structure with a simple incrementing counter. +// Returns the loop counter `IRParam`. +IRInst* emitLoopBlocks(IRBuilder* builder, IRInst* initVal, IRInst* finalVal, IRBlock*& loopBodyBlock, IRBlock*& loopBreakBlock); + } #endif diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp index f61e5a10e..9f877969a 100644 --- a/source/slang/slang-ir.cpp +++ b/source/slang/slang-ir.cpp @@ -3509,104 +3509,6 @@ namespace Slang return nullptr; } - IRInst* IRBuilder::emitStructuralAddRaw(IRInst* val0, IRInst* val1) - { - IRInst* args[2] = { val0, val1 }; - return emitIntrinsicInst(val0->getFullType(), kIROp_StructuralAdd, 2, args); - } - - IRInst* IRBuilder::emitStructuralAdd(IRInst* val0, IRInst* val1, bool fallback) - { - auto type = val0->getFullType(); - SLANG_RELEASE_ASSERT(val0->getFullType() == val1->getFullType()); - IRType* actualType = val0->getFullType(); - for (;;) - { - if (auto attr = as<IRAttributedType>(actualType)) - actualType = attr->getBaseType(); - else if (auto rateQualified = as<IRRateQualifiedType>(actualType)) - actualType = rateQualified->getValueType(); - else - break; - } - if (as<IRBasicType>(actualType)) - return emitAdd(type, val0, val1); - - switch (actualType->getOp()) - { - case kIROp_PtrType: - case kIROp_VectorType: - case kIROp_MatrixType: - return emitAdd(type, val0, val1); - case kIROp_TupleType: - { - List<IRInst*> elements; - auto tupleType = as<IRTupleType>(actualType); - for (UInt i = 0; i < tupleType->getOperandCount(); i++) - { - auto operand = tupleType->getOperand(i); - if (as<IRAttr>(operand)) - break; - auto inner = emitStructuralAdd( - emitGetTupleElement((IRType*)operand, val0, i), - emitGetTupleElement((IRType*)operand, val1, i), - fallback); - if (!inner) - return nullptr; - elements.add(inner); - } - return emitMakeTuple(tupleType, elements); - } - case kIROp_StructType: - { - List<IRInst*> elements; - auto structType = as<IRStructType>(actualType); - for (auto field : structType->getFields()) - { - auto fieldType = field->getFieldType(); - auto inner = emitStructuralAdd( - emitFieldExtract(fieldType, val0, field->getKey()), - emitFieldExtract(fieldType, val1, field->getKey()), - fallback); - if (!inner) - return nullptr; - elements.add(inner); - } - return emitMakeStruct(type, elements); - } - case kIROp_ArrayType: - { - auto arrayType = as<IRArrayType>(actualType); - if (auto count = as<IRIntLit>(arrayType->getElementCount())) - { - auto elementType = arrayType->getElementType(); - List<IRInst*> elements; - constexpr int maxCount = 4096; - if (count->getValue() > maxCount) - break; - for (IRIntegerValue i = 0; i < count->getValue(); i++) - { - auto index = getIntValue(getIntType(), i); - auto element = emitStructuralAdd( - emitElementExtract(elementType, val0, index), - emitElementExtract(elementType, val1, index), - fallback); - elements.add(element); - } - return emitMakeArray(type, elements.getCount(), elements.getBuffer()); - } - break; - } - default: - break; - } - if (fallback) - { - return emitStructuralAddRaw(val0, val1); - } - return nullptr; - } - static int _getTypeStyleId(IRType* type) { if (auto vectorType = as<IRVectorType>(type)) @@ -4026,9 +3928,9 @@ namespace Slang return addDecoration(target, kIROp_PrimalElementTypeDecoration, type); } - IRInst* IRBuilder::addIntermediateContextFieldDifferentialTypeDecoration(IRInst* target, IRInst* type) + IRInst* IRBuilder::addIntermediateContextFieldDifferentialTypeDecoration(IRInst* target, IRInst* witness) { - return addDecoration(target, kIROp_IntermediateContextFieldDifferentialTypeDecoration, type); + return addDecoration(target, kIROp_IntermediateContextFieldDifferentialTypeDecoration, witness); } RefPtr<IRModule> IRModule::create(Session* session) @@ -7131,7 +7033,6 @@ namespace Slang case kIROp_Nop: case kIROp_undefined: case kIROp_DefaultConstruct: - case kIROp_StructuralAdd: case kIROp_Specialize: case kIROp_LookupWitness: case kIROp_GetSequentialID: diff --git a/tests/autodiff/high-order-backward-diff-3.slang b/tests/autodiff/high-order-backward-diff-3.slang index eb3866b96..100a9a1e0 100644 --- a/tests/autodiff/high-order-backward-diff-3.slang +++ b/tests/autodiff/high-order-backward-diff-3.slang @@ -14,14 +14,20 @@ struct A : IDifferentiable [BackwardDifferentiable] float f(A x) { - return x.x * x.x; + A rs; + rs.x = 1.0; + for (int i = 0; i < 2; i++) + rs.x = rs.x * x.x; + return rs.x; } [BackwardDifferentiable] float outerF(A x) { A nx; - nx.x = x.x * x.x; + nx.x = 1.0; + for (int i = 0; i < 2; i++) + nx.x = nx.x * x.x; nx.nx = 2;//x.nx; return f(nx); } |
