From c8e6a6452f4e531dca09152178bae2f9a2fb999a Mon Sep 17 00:00:00 2001 From: Yong He Date: Wed, 10 May 2023 09:11:36 -0700 Subject: Generate faster derivative for div by const operations. (#2877) * Generate faster derivative for div by const operations. * Increase `kMaxIterationsToAttempt` to 256. --------- Co-authored-by: Yong He --- source/slang/slang-ir-autodiff-fwd.cpp | 50 +++++++++++++++++++------ source/slang/slang-ir-loop-unroll.cpp | 2 +- source/slang/slang-ir-peephole.cpp | 68 ---------------------------------- source/slang/slang-ir-util.cpp | 68 ++++++++++++++++++++++++++++++++++ source/slang/slang-ir-util.h | 4 ++ 5 files changed, 111 insertions(+), 81 deletions(-) (limited to 'source') diff --git a/source/slang/slang-ir-autodiff-fwd.cpp b/source/slang/slang-ir-autodiff-fwd.cpp index e6bfc751c..38d0c0706 100644 --- a/source/slang/slang-ir-autodiff-fwd.cpp +++ b/source/slang/slang-ir-autodiff-fwd.cpp @@ -177,7 +177,10 @@ InstPair ForwardDiffTranscriber::transcribeBinaryArith(IRBuilder* builder, IRIns if (diffLeft || diffRight) { diffLeft = diffLeft ? diffLeft : getDifferentialZeroOfType(builder, primalLeft->getDataType()); + + bool diffRightIsZero = (diffRight == nullptr); diffRight = diffRight ? diffRight : getDifferentialZeroOfType(builder, primalRight->getDataType()); + diffRightIsZero = diffRightIsZero || isZero(diffRight); auto resultType = primalArith->getDataType(); auto origResultType = origArith->getDataType(); @@ -196,7 +199,7 @@ InstPair ForwardDiffTranscriber::transcribeBinaryArith(IRBuilder* builder, IRIns case kIROp_Mul: { auto diffLeftTimesRight = builder->emitMul(diffType, diffLeft, primalRight); - auto diffRightTimesLeft = builder->emitMul(diffType, primalLeft, diffRight); + auto diffRightTimesLeft = builder->emitMul(diffType, diffRight, primalLeft); builder->markInstAsDifferential(diffLeftTimesRight, resultType); builder->markInstAsDifferential(diffRightTimesLeft, resultType); @@ -215,20 +218,43 @@ InstPair ForwardDiffTranscriber::transcribeBinaryArith(IRBuilder* builder, IRIns } case kIROp_Div: { - auto diffLeftTimesRight = builder->emitMul(diffType, diffLeft, primalRight); - auto diffRightTimesLeft = builder->emitMul(diffType, primalLeft, diffRight); - auto diffSub = builder->emitSub(diffType, diffLeftTimesRight, diffRightTimesLeft); - builder->markInstAsDifferential(diffLeftTimesRight, resultType); - builder->markInstAsDifferential(diffRightTimesLeft, resultType); - builder->markInstAsDifferential(diffSub, resultType); + if (diffRightIsZero) + { + // Special case the dRight = 0 case here since it would be difficult + // to optimize out in the future. + IRInst* diff = nullptr; + if (auto constant = as(primalRight)) + { + diff = builder->emitMul( + diffType, + diffLeft, + builder->getFloatValue( + constant->getDataType(), 1.0 / constant->getValue())); + } + else + { + diff = builder->emitDiv(diffType, diffLeft, primalRight); + } + return InstPair(primalArith, diff); + } + else + { + auto diffLeftTimesRight = builder->emitMul(diffType, diffLeft, primalRight); + builder->markInstAsDifferential(diffLeftTimesRight, resultType); + + auto diffRightTimesLeft = builder->emitMul(diffType, primalLeft, diffRight); + builder->markInstAsDifferential(diffRightTimesLeft, resultType); - auto diffMul = builder->emitMul(primalRight->getFullType(), primalRight, primalRight); - builder->markInstAsPrimal(diffMul); + auto diffSub = builder->emitSub(diffType, diffLeftTimesRight, diffRightTimesLeft); + builder->markInstAsDifferential(diffSub, resultType); + auto diffMul = builder->emitMul(primalRight->getFullType(), primalRight, primalRight); + builder->markInstAsPrimal(diffMul); - auto diffDiv = builder->emitDiv(diffType, diffSub, diffMul); - builder->markInstAsDifferential(diffDiv, resultType); + auto diffDiv = builder->emitDiv(diffType, diffSub, diffMul); + builder->markInstAsDifferential(diffDiv, resultType); - return InstPair(primalArith, diffDiv); + return InstPair(primalArith, diffDiv); + } } default: getSink()->diagnose(origArith->sourceLoc, diff --git a/source/slang/slang-ir-loop-unroll.cpp b/source/slang/slang-ir-loop-unroll.cpp index f068eded4..a05700277 100644 --- a/source/slang/slang-ir-loop-unroll.cpp +++ b/source/slang/slang-ir-loop-unroll.cpp @@ -88,7 +88,7 @@ List collectBlocksInLoop(IRGlobalValueWithCode* func, IRLoop* loopIns static int _getLoopMaxIterationsToUnroll(IRLoop* loopInst) { - static constexpr int kMaxIterationsToAttempt = 100; + static constexpr int kMaxIterationsToAttempt = 256; auto forceUnrollDecor = loopInst->findDecoration(); if (!forceUnrollDecor) diff --git a/source/slang/slang-ir-peephole.cpp b/source/slang/slang-ir-peephole.cpp index f04012112..e07d1f9c4 100644 --- a/source/slang/slang-ir-peephole.cpp +++ b/source/slang/slang-ir-peephole.cpp @@ -112,74 +112,6 @@ struct PeepholeContext : InstPassBase return false; } - bool isZero(IRInst* inst) - { - switch (inst->getOp()) - { - case kIROp_IntLit: - return as(inst)->getValue() == 0; - case kIROp_FloatLit: - return as(inst)->getValue() == 0.0; - case kIROp_BoolLit: - return as(inst)->getValue() == false; - case kIROp_MakeVector: - case kIROp_MakeVectorFromScalar: - case kIROp_MakeMatrix: - case kIROp_MakeMatrixFromScalar: - case kIROp_MatrixReshape: - case kIROp_VectorReshape: - { - for (UInt i = 0; i < inst->getOperandCount(); i++) - { - if (!isZero(inst->getOperand(i))) - { - return false; - } - } - return true; - } - case kIROp_CastIntToFloat: - case kIROp_CastFloatToInt: - return isZero(inst->getOperand(0)); - default: - return false; - } - } - - bool isOne(IRInst* inst) - { - switch (inst->getOp()) - { - case kIROp_IntLit: - return as(inst)->getValue() == 1; - case kIROp_FloatLit: - return as(inst)->getValue() == 1.0; - case kIROp_BoolLit: - return as(inst)->getValue(); - case kIROp_MakeVector: - case kIROp_MakeVectorFromScalar: - case kIROp_MakeMatrix: - case kIROp_MakeMatrixFromScalar: - case kIROp_MatrixReshape: - case kIROp_VectorReshape: - { - for (UInt i = 0; i < inst->getOperandCount(); i++) - { - if (!isOne(inst->getOperand(i))) - { - return false; - } - } - return true; - } - case kIROp_CastIntToFloat: - case kIROp_CastFloatToInt: - return isOne(inst->getOperand(0)); - default: - return false; - } - } - bool tryOptimizeArithmeticInst(IRInst* inst) { bool allowUnsafeOptimizations = diff --git a/source/slang/slang-ir-util.cpp b/source/slang/slang-ir-util.cpp index 55c2b18c0..a978edc48 100644 --- a/source/slang/slang-ir-util.cpp +++ b/source/slang/slang-ir-util.cpp @@ -846,6 +846,74 @@ bool isGlobalOrUnknownMutableAddress(IRGlobalValueWithCode* parentFunc, IRInst* return (addrInstParent != parentFunc); } +bool isZero(IRInst* inst) +{ + switch (inst->getOp()) + { + case kIROp_IntLit: + return as(inst)->getValue() == 0; + case kIROp_FloatLit: + return as(inst)->getValue() == 0.0; + case kIROp_BoolLit: + return as(inst)->getValue() == false; + case kIROp_MakeVector: + case kIROp_MakeVectorFromScalar: + case kIROp_MakeMatrix: + case kIROp_MakeMatrixFromScalar: + case kIROp_MatrixReshape: + case kIROp_VectorReshape: + { + for (UInt i = 0; i < inst->getOperandCount(); i++) + { + if (!isZero(inst->getOperand(i))) + { + return false; + } + } + return true; + } + case kIROp_CastIntToFloat: + case kIROp_CastFloatToInt: + return isZero(inst->getOperand(0)); + default: + return false; + } +} + +bool isOne(IRInst* inst) +{ + switch (inst->getOp()) + { + case kIROp_IntLit: + return as(inst)->getValue() == 1; + case kIROp_FloatLit: + return as(inst)->getValue() == 1.0; + case kIROp_BoolLit: + return as(inst)->getValue(); + case kIROp_MakeVector: + case kIROp_MakeVectorFromScalar: + case kIROp_MakeMatrix: + case kIROp_MakeMatrixFromScalar: + case kIROp_MatrixReshape: + case kIROp_VectorReshape: + { + for (UInt i = 0; i < inst->getOperandCount(); i++) + { + if (!isOne(inst->getOperand(i))) + { + return false; + } + } + return true; + } + case kIROp_CastIntToFloat: + case kIROp_CastFloatToInt: + return isOne(inst->getOperand(0)); + default: + return false; + } +} + struct GenericChildrenMigrationContextImpl { IRCloneEnv cloneEnv; diff --git a/source/slang/slang-ir-util.h b/source/slang/slang-ir-util.h index 9fd6dd972..65d081f42 100644 --- a/source/slang/slang-ir-util.h +++ b/source/slang/slang-ir-util.h @@ -214,6 +214,10 @@ void removePhiArgs(IRInst* phiParam); int getParamIndexInBlock(IRParam* paramInst); bool isGlobalOrUnknownMutableAddress(IRGlobalValueWithCode* parentFunc, IRInst* inst); + +bool isZero(IRInst* inst); + +bool isOne(IRInst* inst); } #endif -- cgit v1.2.3