summaryrefslogtreecommitdiffstats
path: root/source
diff options
context:
space:
mode:
Diffstat (limited to 'source')
-rw-r--r--source/slang/slang-ir-autodiff-fwd.cpp11
-rw-r--r--source/slang/slang-ir-autodiff-transpose.h37
2 files changed, 48 insertions, 0 deletions
diff --git a/source/slang/slang-ir-autodiff-fwd.cpp b/source/slang/slang-ir-autodiff-fwd.cpp
index a45a3abf9..04acad435 100644
--- a/source/slang/slang-ir-autodiff-fwd.cpp
+++ b/source/slang/slang-ir-autodiff-fwd.cpp
@@ -1329,9 +1329,12 @@ InstPair ForwardDiffTranscriber::transcribeInstImpl(IRBuilder* builder, IRInst*
return transcribeSwizzle(builder, as<IRSwizzle>(origInst));
case kIROp_MakeTuple:
+ case kIROp_Neg:
return transcribeByPassthrough(builder, origInst);
+
case kIROp_UpdateElement:
return transcribeUpdateElement(builder, origInst);
+
case kIROp_unconditionalBranch:
return transcribeControlFlow(builder, origInst);
@@ -1346,6 +1349,7 @@ InstPair ForwardDiffTranscriber::transcribeInstImpl(IRBuilder* builder, IRInst*
case kIROp_FieldExtract:
case kIROp_FieldAddress:
return transcribeFieldExtract(builder, origInst);
+
case kIROp_GetElement:
case kIROp_GetElementPtr:
return transcribeGetElement(builder, origInst);
@@ -1361,12 +1365,15 @@ InstPair ForwardDiffTranscriber::transcribeInstImpl(IRBuilder* builder, IRInst*
case kIROp_MakeDifferentialPair:
return transcribeMakeDifferentialPair(builder, as<IRMakeDifferentialPair>(origInst));
+
case kIROp_DifferentialPairGetPrimal:
case kIROp_DifferentialPairGetDifferential:
return transcribeDifferentialPairGetElement(builder, origInst);
+
case kIROp_ExtractExistentialValue:
case kIROp_MakeExistential:
return transcribeSingleOperandInst(builder, origInst);
+
case kIROp_ExtractExistentialType:
{
IRInst* witnessTable;
@@ -1377,8 +1384,10 @@ InstPair ForwardDiffTranscriber::transcribeInstImpl(IRBuilder* builder, IRInst*
}
case kIROp_ExtractExistentialWitnessTable:
return transcribeExtractExistentialWitnessTable(builder, origInst);
+
case kIROp_WrapExistential:
return transcribeWrapExistential(builder, origInst);
+
case kIROp_undefined:
return transcribeUndefined(builder, origInst);
@@ -1387,8 +1396,10 @@ InstPair ForwardDiffTranscriber::transcribeInstImpl(IRBuilder* builder, IRInst*
// so we treat this inst as non differentiable.
// We can extend the frontend and IR with a separate op-code that can provide an explicit diff value.
return trascribeNonDiffInst(builder, origInst);
+
case kIROp_StructKey:
return InstPair(origInst, nullptr);
+
case kIROp_Unreachable:
{
auto unreachInst = builder->emitUnreachable();
diff --git a/source/slang/slang-ir-autodiff-transpose.h b/source/slang/slang-ir-autodiff-transpose.h
index eef820804..d9b28ea3c 100644
--- a/source/slang/slang-ir-autodiff-transpose.h
+++ b/source/slang/slang-ir-autodiff-transpose.h
@@ -1247,6 +1247,8 @@ struct DiffTransposePass
case kIROp_Add:
case kIROp_Mul:
case kIROp_Sub:
+ case kIROp_Div:
+ case kIROp_Neg:
return transposeArithmetic(builder, fwdInst, revValue);
case kIROp_Call:
@@ -1927,6 +1929,41 @@ struct DiffTransposePass
SLANG_ASSERT_FAILURE("Neither operand of a mul instruction is a differential inst");
}
}
+ case kIROp_Div:
+ {
+ if (isDifferentialInst(fwdInst->getOperand(0)))
+ {
+ SLANG_RELEASE_ASSERT(!isDifferentialInst(fwdInst->getOperand(1)));
+
+ // (Out = dA / B) -> (dA += dOut / B)
+ return TranspositionResult(
+ List<RevGradient>(
+ RevGradient(
+ fwdInst->getOperand(0),
+ builder->emitDiv(operandType, revValue, fwdInst->getOperand(1)),
+ fwdInst)));
+ }
+ {
+ SLANG_ASSERT_FAILURE("The first operand of a div inst must be a differential inst");
+ }
+ }
+ case kIROp_Neg:
+ {
+ if (isDifferentialInst(fwdInst->getOperand(0)))
+ {
+ // (Out = -dA) -> (dA += -dOut)
+ return TranspositionResult(
+ List<RevGradient>(
+ RevGradient(
+ fwdInst->getOperand(0),
+ builder->emitNeg(operandType, revValue),
+ fwdInst)));
+ }
+ else
+ {
+ SLANG_ASSERT_FAILURE("Cannot transpose neg of a non-differentiable inst");
+ }
+ }
default:
SLANG_ASSERT_FAILURE("Unhandled arithmetic");