diff options
Diffstat (limited to 'source')
| -rw-r--r-- | source/slang/slang-ir-autodiff-fwd.cpp | 11 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-transpose.h | 37 |
2 files changed, 48 insertions, 0 deletions
diff --git a/source/slang/slang-ir-autodiff-fwd.cpp b/source/slang/slang-ir-autodiff-fwd.cpp index a45a3abf9..04acad435 100644 --- a/source/slang/slang-ir-autodiff-fwd.cpp +++ b/source/slang/slang-ir-autodiff-fwd.cpp @@ -1329,9 +1329,12 @@ InstPair ForwardDiffTranscriber::transcribeInstImpl(IRBuilder* builder, IRInst* return transcribeSwizzle(builder, as<IRSwizzle>(origInst)); case kIROp_MakeTuple: + case kIROp_Neg: return transcribeByPassthrough(builder, origInst); + case kIROp_UpdateElement: return transcribeUpdateElement(builder, origInst); + case kIROp_unconditionalBranch: return transcribeControlFlow(builder, origInst); @@ -1346,6 +1349,7 @@ InstPair ForwardDiffTranscriber::transcribeInstImpl(IRBuilder* builder, IRInst* case kIROp_FieldExtract: case kIROp_FieldAddress: return transcribeFieldExtract(builder, origInst); + case kIROp_GetElement: case kIROp_GetElementPtr: return transcribeGetElement(builder, origInst); @@ -1361,12 +1365,15 @@ InstPair ForwardDiffTranscriber::transcribeInstImpl(IRBuilder* builder, IRInst* case kIROp_MakeDifferentialPair: return transcribeMakeDifferentialPair(builder, as<IRMakeDifferentialPair>(origInst)); + case kIROp_DifferentialPairGetPrimal: case kIROp_DifferentialPairGetDifferential: return transcribeDifferentialPairGetElement(builder, origInst); + case kIROp_ExtractExistentialValue: case kIROp_MakeExistential: return transcribeSingleOperandInst(builder, origInst); + case kIROp_ExtractExistentialType: { IRInst* witnessTable; @@ -1377,8 +1384,10 @@ InstPair ForwardDiffTranscriber::transcribeInstImpl(IRBuilder* builder, IRInst* } case kIROp_ExtractExistentialWitnessTable: return transcribeExtractExistentialWitnessTable(builder, origInst); + case kIROp_WrapExistential: return transcribeWrapExistential(builder, origInst); + case kIROp_undefined: return transcribeUndefined(builder, origInst); @@ -1387,8 +1396,10 @@ InstPair ForwardDiffTranscriber::transcribeInstImpl(IRBuilder* builder, IRInst* // so we treat this inst as non differentiable. // We can extend the frontend and IR with a separate op-code that can provide an explicit diff value. return trascribeNonDiffInst(builder, origInst); + case kIROp_StructKey: return InstPair(origInst, nullptr); + case kIROp_Unreachable: { auto unreachInst = builder->emitUnreachable(); diff --git a/source/slang/slang-ir-autodiff-transpose.h b/source/slang/slang-ir-autodiff-transpose.h index eef820804..d9b28ea3c 100644 --- a/source/slang/slang-ir-autodiff-transpose.h +++ b/source/slang/slang-ir-autodiff-transpose.h @@ -1247,6 +1247,8 @@ struct DiffTransposePass case kIROp_Add: case kIROp_Mul: case kIROp_Sub: + case kIROp_Div: + case kIROp_Neg: return transposeArithmetic(builder, fwdInst, revValue); case kIROp_Call: @@ -1927,6 +1929,41 @@ struct DiffTransposePass SLANG_ASSERT_FAILURE("Neither operand of a mul instruction is a differential inst"); } } + case kIROp_Div: + { + if (isDifferentialInst(fwdInst->getOperand(0))) + { + SLANG_RELEASE_ASSERT(!isDifferentialInst(fwdInst->getOperand(1))); + + // (Out = dA / B) -> (dA += dOut / B) + return TranspositionResult( + List<RevGradient>( + RevGradient( + fwdInst->getOperand(0), + builder->emitDiv(operandType, revValue, fwdInst->getOperand(1)), + fwdInst))); + } + { + SLANG_ASSERT_FAILURE("The first operand of a div inst must be a differential inst"); + } + } + case kIROp_Neg: + { + if (isDifferentialInst(fwdInst->getOperand(0))) + { + // (Out = -dA) -> (dA += -dOut) + return TranspositionResult( + List<RevGradient>( + RevGradient( + fwdInst->getOperand(0), + builder->emitNeg(operandType, revValue), + fwdInst))); + } + else + { + SLANG_ASSERT_FAILURE("Cannot transpose neg of a non-differentiable inst"); + } + } default: SLANG_ASSERT_FAILURE("Unhandled arithmetic"); |
