diff options
| author | Yong He <yonghe@outlook.com> | 2023-05-10 09:11:36 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2023-05-10 09:11:36 -0700 |
| commit | c8e6a6452f4e531dca09152178bae2f9a2fb999a (patch) | |
| tree | dab0f646dc520d2a187f64885e7b7fe152b49f5e /source | |
| parent | ddebd60853b3f34bfd8e89de804fd15808abf75d (diff) | |
Generate faster derivative for div by const operations. (#2877)
* Generate faster derivative for div by const operations.
* Increase `kMaxIterationsToAttempt` to 256.
---------
Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source')
| -rw-r--r-- | source/slang/slang-ir-autodiff-fwd.cpp | 50 | ||||
| -rw-r--r-- | source/slang/slang-ir-loop-unroll.cpp | 2 | ||||
| -rw-r--r-- | source/slang/slang-ir-peephole.cpp | 68 | ||||
| -rw-r--r-- | source/slang/slang-ir-util.cpp | 68 | ||||
| -rw-r--r-- | source/slang/slang-ir-util.h | 4 |
5 files changed, 111 insertions, 81 deletions
diff --git a/source/slang/slang-ir-autodiff-fwd.cpp b/source/slang/slang-ir-autodiff-fwd.cpp index e6bfc751c..38d0c0706 100644 --- a/source/slang/slang-ir-autodiff-fwd.cpp +++ b/source/slang/slang-ir-autodiff-fwd.cpp @@ -177,7 +177,10 @@ InstPair ForwardDiffTranscriber::transcribeBinaryArith(IRBuilder* builder, IRIns if (diffLeft || diffRight) { diffLeft = diffLeft ? diffLeft : getDifferentialZeroOfType(builder, primalLeft->getDataType()); + + bool diffRightIsZero = (diffRight == nullptr); diffRight = diffRight ? diffRight : getDifferentialZeroOfType(builder, primalRight->getDataType()); + diffRightIsZero = diffRightIsZero || isZero(diffRight); auto resultType = primalArith->getDataType(); auto origResultType = origArith->getDataType(); @@ -196,7 +199,7 @@ InstPair ForwardDiffTranscriber::transcribeBinaryArith(IRBuilder* builder, IRIns case kIROp_Mul: { auto diffLeftTimesRight = builder->emitMul(diffType, diffLeft, primalRight); - auto diffRightTimesLeft = builder->emitMul(diffType, primalLeft, diffRight); + auto diffRightTimesLeft = builder->emitMul(diffType, diffRight, primalLeft); builder->markInstAsDifferential(diffLeftTimesRight, resultType); builder->markInstAsDifferential(diffRightTimesLeft, resultType); @@ -215,20 +218,43 @@ InstPair ForwardDiffTranscriber::transcribeBinaryArith(IRBuilder* builder, IRIns } case kIROp_Div: { - auto diffLeftTimesRight = builder->emitMul(diffType, diffLeft, primalRight); - auto diffRightTimesLeft = builder->emitMul(diffType, primalLeft, diffRight); - auto diffSub = builder->emitSub(diffType, diffLeftTimesRight, diffRightTimesLeft); - builder->markInstAsDifferential(diffLeftTimesRight, resultType); - builder->markInstAsDifferential(diffRightTimesLeft, resultType); - builder->markInstAsDifferential(diffSub, resultType); + if (diffRightIsZero) + { + // Special case the dRight = 0 case here since it would be difficult + // to optimize out in the future. + IRInst* diff = nullptr; + if (auto constant = as<IRFloatLit>(primalRight)) + { + diff = builder->emitMul( + diffType, + diffLeft, + builder->getFloatValue( + constant->getDataType(), 1.0 / constant->getValue())); + } + else + { + diff = builder->emitDiv(diffType, diffLeft, primalRight); + } + return InstPair(primalArith, diff); + } + else + { + auto diffLeftTimesRight = builder->emitMul(diffType, diffLeft, primalRight); + builder->markInstAsDifferential(diffLeftTimesRight, resultType); + + auto diffRightTimesLeft = builder->emitMul(diffType, primalLeft, diffRight); + builder->markInstAsDifferential(diffRightTimesLeft, resultType); - auto diffMul = builder->emitMul(primalRight->getFullType(), primalRight, primalRight); - builder->markInstAsPrimal(diffMul); + auto diffSub = builder->emitSub(diffType, diffLeftTimesRight, diffRightTimesLeft); + builder->markInstAsDifferential(diffSub, resultType); + auto diffMul = builder->emitMul(primalRight->getFullType(), primalRight, primalRight); + builder->markInstAsPrimal(diffMul); - auto diffDiv = builder->emitDiv(diffType, diffSub, diffMul); - builder->markInstAsDifferential(diffDiv, resultType); + auto diffDiv = builder->emitDiv(diffType, diffSub, diffMul); + builder->markInstAsDifferential(diffDiv, resultType); - return InstPair(primalArith, diffDiv); + return InstPair(primalArith, diffDiv); + } } default: getSink()->diagnose(origArith->sourceLoc, diff --git a/source/slang/slang-ir-loop-unroll.cpp b/source/slang/slang-ir-loop-unroll.cpp index f068eded4..a05700277 100644 --- a/source/slang/slang-ir-loop-unroll.cpp +++ b/source/slang/slang-ir-loop-unroll.cpp @@ -88,7 +88,7 @@ List<IRBlock*> collectBlocksInLoop(IRGlobalValueWithCode* func, IRLoop* loopIns static int _getLoopMaxIterationsToUnroll(IRLoop* loopInst) { - static constexpr int kMaxIterationsToAttempt = 100; + static constexpr int kMaxIterationsToAttempt = 256; auto forceUnrollDecor = loopInst->findDecoration<IRForceUnrollDecoration>(); if (!forceUnrollDecor) diff --git a/source/slang/slang-ir-peephole.cpp b/source/slang/slang-ir-peephole.cpp index f04012112..e07d1f9c4 100644 --- a/source/slang/slang-ir-peephole.cpp +++ b/source/slang/slang-ir-peephole.cpp @@ -112,74 +112,6 @@ struct PeepholeContext : InstPassBase return false; } - bool isZero(IRInst* inst) - { - switch (inst->getOp()) - { - case kIROp_IntLit: - return as<IRIntLit>(inst)->getValue() == 0; - case kIROp_FloatLit: - return as<IRFloatLit>(inst)->getValue() == 0.0; - case kIROp_BoolLit: - return as<IRBoolLit>(inst)->getValue() == false; - case kIROp_MakeVector: - case kIROp_MakeVectorFromScalar: - case kIROp_MakeMatrix: - case kIROp_MakeMatrixFromScalar: - case kIROp_MatrixReshape: - case kIROp_VectorReshape: - { - for (UInt i = 0; i < inst->getOperandCount(); i++) - { - if (!isZero(inst->getOperand(i))) - { - return false; - } - } - return true; - } - case kIROp_CastIntToFloat: - case kIROp_CastFloatToInt: - return isZero(inst->getOperand(0)); - default: - return false; - } - } - - bool isOne(IRInst* inst) - { - switch (inst->getOp()) - { - case kIROp_IntLit: - return as<IRIntLit>(inst)->getValue() == 1; - case kIROp_FloatLit: - return as<IRFloatLit>(inst)->getValue() == 1.0; - case kIROp_BoolLit: - return as<IRBoolLit>(inst)->getValue(); - case kIROp_MakeVector: - case kIROp_MakeVectorFromScalar: - case kIROp_MakeMatrix: - case kIROp_MakeMatrixFromScalar: - case kIROp_MatrixReshape: - case kIROp_VectorReshape: - { - for (UInt i = 0; i < inst->getOperandCount(); i++) - { - if (!isOne(inst->getOperand(i))) - { - return false; - } - } - return true; - } - case kIROp_CastIntToFloat: - case kIROp_CastFloatToInt: - return isOne(inst->getOperand(0)); - default: - return false; - } - } - bool tryOptimizeArithmeticInst(IRInst* inst) { bool allowUnsafeOptimizations = diff --git a/source/slang/slang-ir-util.cpp b/source/slang/slang-ir-util.cpp index 55c2b18c0..a978edc48 100644 --- a/source/slang/slang-ir-util.cpp +++ b/source/slang/slang-ir-util.cpp @@ -846,6 +846,74 @@ bool isGlobalOrUnknownMutableAddress(IRGlobalValueWithCode* parentFunc, IRInst* return (addrInstParent != parentFunc); } +bool isZero(IRInst* inst) +{ + switch (inst->getOp()) + { + case kIROp_IntLit: + return as<IRIntLit>(inst)->getValue() == 0; + case kIROp_FloatLit: + return as<IRFloatLit>(inst)->getValue() == 0.0; + case kIROp_BoolLit: + return as<IRBoolLit>(inst)->getValue() == false; + case kIROp_MakeVector: + case kIROp_MakeVectorFromScalar: + case kIROp_MakeMatrix: + case kIROp_MakeMatrixFromScalar: + case kIROp_MatrixReshape: + case kIROp_VectorReshape: + { + for (UInt i = 0; i < inst->getOperandCount(); i++) + { + if (!isZero(inst->getOperand(i))) + { + return false; + } + } + return true; + } + case kIROp_CastIntToFloat: + case kIROp_CastFloatToInt: + return isZero(inst->getOperand(0)); + default: + return false; + } +} + +bool isOne(IRInst* inst) +{ + switch (inst->getOp()) + { + case kIROp_IntLit: + return as<IRIntLit>(inst)->getValue() == 1; + case kIROp_FloatLit: + return as<IRFloatLit>(inst)->getValue() == 1.0; + case kIROp_BoolLit: + return as<IRBoolLit>(inst)->getValue(); + case kIROp_MakeVector: + case kIROp_MakeVectorFromScalar: + case kIROp_MakeMatrix: + case kIROp_MakeMatrixFromScalar: + case kIROp_MatrixReshape: + case kIROp_VectorReshape: + { + for (UInt i = 0; i < inst->getOperandCount(); i++) + { + if (!isOne(inst->getOperand(i))) + { + return false; + } + } + return true; + } + case kIROp_CastIntToFloat: + case kIROp_CastFloatToInt: + return isOne(inst->getOperand(0)); + default: + return false; + } +} + struct GenericChildrenMigrationContextImpl { IRCloneEnv cloneEnv; diff --git a/source/slang/slang-ir-util.h b/source/slang/slang-ir-util.h index 9fd6dd972..65d081f42 100644 --- a/source/slang/slang-ir-util.h +++ b/source/slang/slang-ir-util.h @@ -214,6 +214,10 @@ void removePhiArgs(IRInst* phiParam); int getParamIndexInBlock(IRParam* paramInst); bool isGlobalOrUnknownMutableAddress(IRGlobalValueWithCode* parentFunc, IRInst* inst); + +bool isZero(IRInst* inst); + +bool isOne(IRInst* inst); } #endif |
