diff options
Diffstat (limited to 'source')
| -rw-r--r-- | source/slang/slang-ast-expr.h | 17 | ||||
| -rw-r--r-- | source/slang/slang-check-expr.cpp | 2 | ||||
| -rw-r--r-- | source/slang/slang-ir-check-differentiability.cpp | 20 | ||||
| -rw-r--r-- | source/slang/slang-ir-inst-defs.h | 7 | ||||
| -rw-r--r-- | source/slang/slang-ir-insts.h | 20 | ||||
| -rw-r--r-- | source/slang/slang-lower-to-ir.cpp | 21 | ||||
| -rw-r--r-- | source/slang/slang-parser.cpp | 1 |
7 files changed, 77 insertions, 11 deletions
diff --git a/source/slang/slang-ast-expr.h b/source/slang/slang-ast-expr.h index c441e1b9b..07bf2f033 100644 --- a/source/slang/slang-ast-expr.h +++ b/source/slang/slang-ast-expr.h @@ -541,6 +541,23 @@ class TreatAsDifferentiableExpr : public Expr Expr* innerExpr; Scope* scope; + + enum Flavor + { + /// Represents a no_diff wrapper over + /// a non-differentiable method. + /// i.e. no_diff(fn(...)) + /// + NoDiff, + + /// Represents a call to a method that + /// is either marked differentiable, or has + /// a user-defined derivative in scope. + /// + Differentiable + }; + + Flavor flavor; }; /// A type expression of the form `__TaggedUnion(A, ...)`. diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp index 76af3694f..3c90c3ed8 100644 --- a/source/slang/slang-check-expr.cpp +++ b/source/slang/slang-check-expr.cpp @@ -2265,6 +2265,7 @@ namespace Slang { maybeRegisterDifferentiableType(m_astBuilder, arg->type.type); } + if (auto calleeExpr = as<DeclRefExpr>(checkedInvokeExpr->functionExpr)) { if (auto calleeDecl = as<FunctionDeclBase>(calleeExpr->declRef.getDecl())) @@ -2279,6 +2280,7 @@ namespace Slang newFuncExpr->type = checkedInvokeExpr->type; newFuncExpr->innerExpr = checkedInvokeExpr; newFuncExpr->loc = checkedInvokeExpr->loc; + newFuncExpr->flavor = TreatAsDifferentiableExpr::Flavor::Differentiable; checkedExpr = newFuncExpr; } else diff --git a/source/slang/slang-ir-check-differentiability.cpp b/source/slang/slang-ir-check-differentiability.cpp index d4b93be5e..3207e0729 100644 --- a/source/slang/slang-ir-check-differentiability.cpp +++ b/source/slang/slang-ir-check-differentiability.cpp @@ -71,6 +71,15 @@ public: return false; } + bool shouldTreatCallAsDifferentiable(IRInst* callInst) + { + SLANG_ASSERT(as<IRCall>(callInst)); + + return ( + callInst->findDecoration<IRTreatCallAsDifferentiableDecoration>() || + callInst->findDecoration<IRDifferentiableCallDecoration>()); + } + bool isDifferentiableFunc(IRInst* func, DifferentiableLevel level) { switch (func->getOp()) @@ -300,7 +309,7 @@ public: case kIROp_FloatLit: return true; case kIROp_Call: - return inst->findDecoration<IRTreatAsDifferentiableDecoration>() || + return shouldTreatCallAsDifferentiable(inst) || isDifferentiableFunc(as<IRCall>(inst)->getCallee(), requiredDiffLevel) && isDifferentiableType(diffTypeContext, inst->getFullType()); case kIROp_Load: // We don't have more knowledge on whether diff is available at the destination address. @@ -330,7 +339,7 @@ public: case kIROp_DetachDerivative: return false; case kIROp_Call: - if (inst->findDecoration<IRTreatAsDifferentiableDecoration>()) + if (shouldTreatCallAsDifferentiable(inst)) return false; return isDifferentiableFunc(as<IRCall>(inst)->getCallee(), requiredDiffLevel) && isDifferentiableType(diffTypeContext, inst->getFullType()); @@ -451,7 +460,8 @@ public: // If inst's type is differentiable, and it is in expectDiffInstWorkList, // then some user is expecting the result of the call to produce a derivative. // In this case we need to issue a diagnostic. - if (isDifferentiableType(diffTypeContext, inst->getFullType())) + if (isDifferentiableType(diffTypeContext, inst->getFullType()) && + !isDifferentiableFunc(call->getCallee(), requiredDiffLevel)) { sink->diagnose( inst, @@ -490,9 +500,7 @@ public: case kIROp_Call: { auto callInst = as<IRCall>(inst); - if (callInst->findDecoration<IRTreatAsDifferentiableDecoration>()) - continue; - if (!isDifferentiableFunc(callInst->getCallee(), DifferentiableLevel::Forward)) + if (callInst->findDecoration<IRTreatCallAsDifferentiableDecoration>()) continue; auto calleeFuncType = as<IRFuncType>(callInst->getCallee()->getFullType()); if (!calleeFuncType) continue; diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h index 60f07c17f..231ae6dbe 100644 --- a/source/slang/slang-ir-inst-defs.h +++ b/source/slang/slang-ir-inst-defs.h @@ -867,9 +867,14 @@ INST(HighLevelDeclDecoration, highLevelDecl, 1, 0) /// differential member of a type in its associated differential type. INST(DerivativeMemberDecoration, derivativeMemberDecoration, 1, 0) - /// Treat a function as differentiable function, or an IRCall as a call to a differentiable function. + /// Treat a function as differentiable function INST(TreatAsDifferentiableDecoration, treatAsDifferentiableDecoration, 0, 0) + /// Treat a call to arbitrary function as a differentiable call. + INST(TreatCallAsDifferentiableDecoration, treatCallAsDifferentiableDecoration, 0, 0) + + /// Mark a call as explicitly calling a differentiable function. + INST(DifferentiableCallDecoration, differentiableCallDecoration, 0, 0) /// Hint that the result from a call to the decorated function should be stored in backward prop function. INST(PreferCheckpointDecoration, PreferCheckpointDecoration, 0, 0) diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h index f933697a2..417b39da3 100644 --- a/source/slang/slang-ir-insts.h +++ b/source/slang/slang-ir-insts.h @@ -899,6 +899,26 @@ struct IRTreatAsDifferentiableDecoration : IRDecoration IR_LEAF_ISA(TreatAsDifferentiableDecoration) }; +// Mark a call as explicitly calling a differentiable function. +struct IRDifferentiableCallDecoration : IRDecoration +{ + enum + { + kOp = kIROp_DifferentiableCallDecoration + }; + IR_LEAF_ISA(DifferentiableCallDecoration) +}; + +// Treat a call to a non-differentiable function as a differentiable call. +struct IRTreatCallAsDifferentiableDecoration : IRDecoration +{ + enum + { + kOp = kIROp_TreatCallAsDifferentiableDecoration + }; + IR_LEAF_ISA(TreatCallAsDifferentiableDecoration) +}; + struct IRDerivativeMemberDecoration : IRDecoration { enum diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp index e49da00c2..0773226d1 100644 --- a/source/slang/slang-lower-to-ir.cpp +++ b/source/slang/slang-lower-to-ir.cpp @@ -3311,7 +3311,14 @@ struct ExprLoweringVisitorBase : ExprVisitor<Derived, LoweredValInfo> // to handle that case might be. // if (as<IRCall>(materializedVal.val)) - getBuilder()->addDecoration(materializedVal.val, kIROp_TreatAsDifferentiableDecoration); + { + if (expr->flavor == TreatAsDifferentiableExpr::Flavor::NoDiff) + getBuilder()->addDecoration(materializedVal.val, kIROp_TreatCallAsDifferentiableDecoration); + else if (expr->flavor == TreatAsDifferentiableExpr::Flavor::Differentiable) + getBuilder()->addDecoration(materializedVal.val, kIROp_DifferentiableCallDecoration); + else + SLANG_UNEXPECTED("Unknown TreatAsDifferentiableExpr::Flavor"); + } innerInst = getSimpleVal(context, materializedVal); @@ -3324,13 +3331,19 @@ struct ExprLoweringVisitorBase : ExprVisitor<Derived, LoweredValInfo> } else { - SLANG_ASSERT("TreatAsDifferentiableExpr on non-simple l-values not properly defined."); + SLANG_UNEXPECTED("TreatAsDifferentiableExpr on non-simple l-values not properly defined."); } } else { - if (as<IRCall>(baseVal.val)) - getBuilder()->addDecoration(baseVal.val, kIROp_TreatAsDifferentiableDecoration); + if (auto callInst = as<IRCall>(baseVal.val)) + if (expr->flavor == TreatAsDifferentiableExpr::Flavor::NoDiff) + getBuilder()->addDecoration(callInst, kIROp_TreatCallAsDifferentiableDecoration); + else if (expr->flavor == TreatAsDifferentiableExpr::Flavor::Differentiable) + getBuilder()->addDecoration(callInst, kIROp_DifferentiableCallDecoration); + else + SLANG_UNEXPECTED("Unknown TreatAsDifferentiableExpr::Flavor"); + innerInst = baseVal.val; } diff --git a/source/slang/slang-parser.cpp b/source/slang/slang-parser.cpp index b8310451c..b0af5378c 100644 --- a/source/slang/slang-parser.cpp +++ b/source/slang/slang-parser.cpp @@ -5298,6 +5298,7 @@ namespace Slang auto noDiffExpr = parser->astBuilder->create<TreatAsDifferentiableExpr>(); noDiffExpr->innerExpr = parser->ParseLeafExpression(); noDiffExpr->scope = parser->currentScope; + noDiffExpr->flavor = TreatAsDifferentiableExpr::Flavor::NoDiff; return noDiffExpr; } |
