summaryrefslogtreecommitdiffstats
path: root/source
diff options
context:
space:
mode:
Diffstat (limited to 'source')
-rw-r--r--source/slang/slang-ir-autodiff-fwd.cpp50
-rw-r--r--source/slang/slang-ir-loop-unroll.cpp2
-rw-r--r--source/slang/slang-ir-peephole.cpp68
-rw-r--r--source/slang/slang-ir-util.cpp68
-rw-r--r--source/slang/slang-ir-util.h4
5 files changed, 111 insertions, 81 deletions
diff --git a/source/slang/slang-ir-autodiff-fwd.cpp b/source/slang/slang-ir-autodiff-fwd.cpp
index e6bfc751c..38d0c0706 100644
--- a/source/slang/slang-ir-autodiff-fwd.cpp
+++ b/source/slang/slang-ir-autodiff-fwd.cpp
@@ -177,7 +177,10 @@ InstPair ForwardDiffTranscriber::transcribeBinaryArith(IRBuilder* builder, IRIns
if (diffLeft || diffRight)
{
diffLeft = diffLeft ? diffLeft : getDifferentialZeroOfType(builder, primalLeft->getDataType());
+
+ bool diffRightIsZero = (diffRight == nullptr);
diffRight = diffRight ? diffRight : getDifferentialZeroOfType(builder, primalRight->getDataType());
+ diffRightIsZero = diffRightIsZero || isZero(diffRight);
auto resultType = primalArith->getDataType();
auto origResultType = origArith->getDataType();
@@ -196,7 +199,7 @@ InstPair ForwardDiffTranscriber::transcribeBinaryArith(IRBuilder* builder, IRIns
case kIROp_Mul:
{
auto diffLeftTimesRight = builder->emitMul(diffType, diffLeft, primalRight);
- auto diffRightTimesLeft = builder->emitMul(diffType, primalLeft, diffRight);
+ auto diffRightTimesLeft = builder->emitMul(diffType, diffRight, primalLeft);
builder->markInstAsDifferential(diffLeftTimesRight, resultType);
builder->markInstAsDifferential(diffRightTimesLeft, resultType);
@@ -215,20 +218,43 @@ InstPair ForwardDiffTranscriber::transcribeBinaryArith(IRBuilder* builder, IRIns
}
case kIROp_Div:
{
- auto diffLeftTimesRight = builder->emitMul(diffType, diffLeft, primalRight);
- auto diffRightTimesLeft = builder->emitMul(diffType, primalLeft, diffRight);
- auto diffSub = builder->emitSub(diffType, diffLeftTimesRight, diffRightTimesLeft);
- builder->markInstAsDifferential(diffLeftTimesRight, resultType);
- builder->markInstAsDifferential(diffRightTimesLeft, resultType);
- builder->markInstAsDifferential(diffSub, resultType);
+ if (diffRightIsZero)
+ {
+ // Special case the dRight = 0 case here since it would be difficult
+ // to optimize out in the future.
+ IRInst* diff = nullptr;
+ if (auto constant = as<IRFloatLit>(primalRight))
+ {
+ diff = builder->emitMul(
+ diffType,
+ diffLeft,
+ builder->getFloatValue(
+ constant->getDataType(), 1.0 / constant->getValue()));
+ }
+ else
+ {
+ diff = builder->emitDiv(diffType, diffLeft, primalRight);
+ }
+ return InstPair(primalArith, diff);
+ }
+ else
+ {
+ auto diffLeftTimesRight = builder->emitMul(diffType, diffLeft, primalRight);
+ builder->markInstAsDifferential(diffLeftTimesRight, resultType);
+
+ auto diffRightTimesLeft = builder->emitMul(diffType, primalLeft, diffRight);
+ builder->markInstAsDifferential(diffRightTimesLeft, resultType);
- auto diffMul = builder->emitMul(primalRight->getFullType(), primalRight, primalRight);
- builder->markInstAsPrimal(diffMul);
+ auto diffSub = builder->emitSub(diffType, diffLeftTimesRight, diffRightTimesLeft);
+ builder->markInstAsDifferential(diffSub, resultType);
+ auto diffMul = builder->emitMul(primalRight->getFullType(), primalRight, primalRight);
+ builder->markInstAsPrimal(diffMul);
- auto diffDiv = builder->emitDiv(diffType, diffSub, diffMul);
- builder->markInstAsDifferential(diffDiv, resultType);
+ auto diffDiv = builder->emitDiv(diffType, diffSub, diffMul);
+ builder->markInstAsDifferential(diffDiv, resultType);
- return InstPair(primalArith, diffDiv);
+ return InstPair(primalArith, diffDiv);
+ }
}
default:
getSink()->diagnose(origArith->sourceLoc,
diff --git a/source/slang/slang-ir-loop-unroll.cpp b/source/slang/slang-ir-loop-unroll.cpp
index f068eded4..a05700277 100644
--- a/source/slang/slang-ir-loop-unroll.cpp
+++ b/source/slang/slang-ir-loop-unroll.cpp
@@ -88,7 +88,7 @@ List<IRBlock*> collectBlocksInLoop(IRGlobalValueWithCode* func, IRLoop* loopIns
static int _getLoopMaxIterationsToUnroll(IRLoop* loopInst)
{
- static constexpr int kMaxIterationsToAttempt = 100;
+ static constexpr int kMaxIterationsToAttempt = 256;
auto forceUnrollDecor = loopInst->findDecoration<IRForceUnrollDecoration>();
if (!forceUnrollDecor)
diff --git a/source/slang/slang-ir-peephole.cpp b/source/slang/slang-ir-peephole.cpp
index f04012112..e07d1f9c4 100644
--- a/source/slang/slang-ir-peephole.cpp
+++ b/source/slang/slang-ir-peephole.cpp
@@ -112,74 +112,6 @@ struct PeepholeContext : InstPassBase
return false;
}
- bool isZero(IRInst* inst)
- {
- switch (inst->getOp())
- {
- case kIROp_IntLit:
- return as<IRIntLit>(inst)->getValue() == 0;
- case kIROp_FloatLit:
- return as<IRFloatLit>(inst)->getValue() == 0.0;
- case kIROp_BoolLit:
- return as<IRBoolLit>(inst)->getValue() == false;
- case kIROp_MakeVector:
- case kIROp_MakeVectorFromScalar:
- case kIROp_MakeMatrix:
- case kIROp_MakeMatrixFromScalar:
- case kIROp_MatrixReshape:
- case kIROp_VectorReshape:
- {
- for (UInt i = 0; i < inst->getOperandCount(); i++)
- {
- if (!isZero(inst->getOperand(i)))
- {
- return false;
- }
- }
- return true;
- }
- case kIROp_CastIntToFloat:
- case kIROp_CastFloatToInt:
- return isZero(inst->getOperand(0));
- default:
- return false;
- }
- }
-
- bool isOne(IRInst* inst)
- {
- switch (inst->getOp())
- {
- case kIROp_IntLit:
- return as<IRIntLit>(inst)->getValue() == 1;
- case kIROp_FloatLit:
- return as<IRFloatLit>(inst)->getValue() == 1.0;
- case kIROp_BoolLit:
- return as<IRBoolLit>(inst)->getValue();
- case kIROp_MakeVector:
- case kIROp_MakeVectorFromScalar:
- case kIROp_MakeMatrix:
- case kIROp_MakeMatrixFromScalar:
- case kIROp_MatrixReshape:
- case kIROp_VectorReshape:
- {
- for (UInt i = 0; i < inst->getOperandCount(); i++)
- {
- if (!isOne(inst->getOperand(i)))
- {
- return false;
- }
- }
- return true;
- }
- case kIROp_CastIntToFloat:
- case kIROp_CastFloatToInt:
- return isOne(inst->getOperand(0));
- default:
- return false;
- }
- }
-
bool tryOptimizeArithmeticInst(IRInst* inst)
{
bool allowUnsafeOptimizations =
diff --git a/source/slang/slang-ir-util.cpp b/source/slang/slang-ir-util.cpp
index 55c2b18c0..a978edc48 100644
--- a/source/slang/slang-ir-util.cpp
+++ b/source/slang/slang-ir-util.cpp
@@ -846,6 +846,74 @@ bool isGlobalOrUnknownMutableAddress(IRGlobalValueWithCode* parentFunc, IRInst*
return (addrInstParent != parentFunc);
}
+bool isZero(IRInst* inst)
+{
+ switch (inst->getOp())
+ {
+ case kIROp_IntLit:
+ return as<IRIntLit>(inst)->getValue() == 0;
+ case kIROp_FloatLit:
+ return as<IRFloatLit>(inst)->getValue() == 0.0;
+ case kIROp_BoolLit:
+ return as<IRBoolLit>(inst)->getValue() == false;
+ case kIROp_MakeVector:
+ case kIROp_MakeVectorFromScalar:
+ case kIROp_MakeMatrix:
+ case kIROp_MakeMatrixFromScalar:
+ case kIROp_MatrixReshape:
+ case kIROp_VectorReshape:
+ {
+ for (UInt i = 0; i < inst->getOperandCount(); i++)
+ {
+ if (!isZero(inst->getOperand(i)))
+ {
+ return false;
+ }
+ }
+ return true;
+ }
+ case kIROp_CastIntToFloat:
+ case kIROp_CastFloatToInt:
+ return isZero(inst->getOperand(0));
+ default:
+ return false;
+ }
+}
+
+bool isOne(IRInst* inst)
+{
+ switch (inst->getOp())
+ {
+ case kIROp_IntLit:
+ return as<IRIntLit>(inst)->getValue() == 1;
+ case kIROp_FloatLit:
+ return as<IRFloatLit>(inst)->getValue() == 1.0;
+ case kIROp_BoolLit:
+ return as<IRBoolLit>(inst)->getValue();
+ case kIROp_MakeVector:
+ case kIROp_MakeVectorFromScalar:
+ case kIROp_MakeMatrix:
+ case kIROp_MakeMatrixFromScalar:
+ case kIROp_MatrixReshape:
+ case kIROp_VectorReshape:
+ {
+ for (UInt i = 0; i < inst->getOperandCount(); i++)
+ {
+ if (!isOne(inst->getOperand(i)))
+ {
+ return false;
+ }
+ }
+ return true;
+ }
+ case kIROp_CastIntToFloat:
+ case kIROp_CastFloatToInt:
+ return isOne(inst->getOperand(0));
+ default:
+ return false;
+ }
+}
+
struct GenericChildrenMigrationContextImpl
{
IRCloneEnv cloneEnv;
diff --git a/source/slang/slang-ir-util.h b/source/slang/slang-ir-util.h
index 9fd6dd972..65d081f42 100644
--- a/source/slang/slang-ir-util.h
+++ b/source/slang/slang-ir-util.h
@@ -214,6 +214,10 @@ void removePhiArgs(IRInst* phiParam);
int getParamIndexInBlock(IRParam* paramInst);
bool isGlobalOrUnknownMutableAddress(IRGlobalValueWithCode* parentFunc, IRInst* inst);
+
+bool isZero(IRInst* inst);
+
+bool isOne(IRInst* inst);
}
#endif