summaryrefslogtreecommitdiffstats
path: root/source
diff options
context:
space:
mode:
Diffstat (limited to 'source')
-rw-r--r--source/slang/slang-ast-expr.h17
-rw-r--r--source/slang/slang-check-expr.cpp2
-rw-r--r--source/slang/slang-ir-check-differentiability.cpp20
-rw-r--r--source/slang/slang-ir-inst-defs.h7
-rw-r--r--source/slang/slang-ir-insts.h20
-rw-r--r--source/slang/slang-lower-to-ir.cpp21
-rw-r--r--source/slang/slang-parser.cpp1
7 files changed, 77 insertions, 11 deletions
diff --git a/source/slang/slang-ast-expr.h b/source/slang/slang-ast-expr.h
index c441e1b9b..07bf2f033 100644
--- a/source/slang/slang-ast-expr.h
+++ b/source/slang/slang-ast-expr.h
@@ -541,6 +541,23 @@ class TreatAsDifferentiableExpr : public Expr
Expr* innerExpr;
Scope* scope;
+
+ enum Flavor
+ {
+ /// Represents a no_diff wrapper over
+ /// a non-differentiable method.
+ /// i.e. no_diff(fn(...))
+ ///
+ NoDiff,
+
+ /// Represents a call to a method that
+ /// is either marked differentiable, or has
+ /// a user-defined derivative in scope.
+ ///
+ Differentiable
+ };
+
+ Flavor flavor;
};
/// A type expression of the form `__TaggedUnion(A, ...)`.
diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp
index 76af3694f..3c90c3ed8 100644
--- a/source/slang/slang-check-expr.cpp
+++ b/source/slang/slang-check-expr.cpp
@@ -2265,6 +2265,7 @@ namespace Slang
{
maybeRegisterDifferentiableType(m_astBuilder, arg->type.type);
}
+
if (auto calleeExpr = as<DeclRefExpr>(checkedInvokeExpr->functionExpr))
{
if (auto calleeDecl = as<FunctionDeclBase>(calleeExpr->declRef.getDecl()))
@@ -2279,6 +2280,7 @@ namespace Slang
newFuncExpr->type = checkedInvokeExpr->type;
newFuncExpr->innerExpr = checkedInvokeExpr;
newFuncExpr->loc = checkedInvokeExpr->loc;
+ newFuncExpr->flavor = TreatAsDifferentiableExpr::Flavor::Differentiable;
checkedExpr = newFuncExpr;
}
else
diff --git a/source/slang/slang-ir-check-differentiability.cpp b/source/slang/slang-ir-check-differentiability.cpp
index d4b93be5e..3207e0729 100644
--- a/source/slang/slang-ir-check-differentiability.cpp
+++ b/source/slang/slang-ir-check-differentiability.cpp
@@ -71,6 +71,15 @@ public:
return false;
}
+ bool shouldTreatCallAsDifferentiable(IRInst* callInst)
+ {
+ SLANG_ASSERT(as<IRCall>(callInst));
+
+ return (
+ callInst->findDecoration<IRTreatCallAsDifferentiableDecoration>() ||
+ callInst->findDecoration<IRDifferentiableCallDecoration>());
+ }
+
bool isDifferentiableFunc(IRInst* func, DifferentiableLevel level)
{
switch (func->getOp())
@@ -300,7 +309,7 @@ public:
case kIROp_FloatLit:
return true;
case kIROp_Call:
- return inst->findDecoration<IRTreatAsDifferentiableDecoration>() ||
+ return shouldTreatCallAsDifferentiable(inst) ||
isDifferentiableFunc(as<IRCall>(inst)->getCallee(), requiredDiffLevel) && isDifferentiableType(diffTypeContext, inst->getFullType());
case kIROp_Load:
// We don't have more knowledge on whether diff is available at the destination address.
@@ -330,7 +339,7 @@ public:
case kIROp_DetachDerivative:
return false;
case kIROp_Call:
- if (inst->findDecoration<IRTreatAsDifferentiableDecoration>())
+ if (shouldTreatCallAsDifferentiable(inst))
return false;
return isDifferentiableFunc(as<IRCall>(inst)->getCallee(), requiredDiffLevel) &&
isDifferentiableType(diffTypeContext, inst->getFullType());
@@ -451,7 +460,8 @@ public:
// If inst's type is differentiable, and it is in expectDiffInstWorkList,
// then some user is expecting the result of the call to produce a derivative.
// In this case we need to issue a diagnostic.
- if (isDifferentiableType(diffTypeContext, inst->getFullType()))
+ if (isDifferentiableType(diffTypeContext, inst->getFullType()) &&
+ !isDifferentiableFunc(call->getCallee(), requiredDiffLevel))
{
sink->diagnose(
inst,
@@ -490,9 +500,7 @@ public:
case kIROp_Call:
{
auto callInst = as<IRCall>(inst);
- if (callInst->findDecoration<IRTreatAsDifferentiableDecoration>())
- continue;
- if (!isDifferentiableFunc(callInst->getCallee(), DifferentiableLevel::Forward))
+ if (callInst->findDecoration<IRTreatCallAsDifferentiableDecoration>())
continue;
auto calleeFuncType = as<IRFuncType>(callInst->getCallee()->getFullType());
if (!calleeFuncType) continue;
diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h
index 60f07c17f..231ae6dbe 100644
--- a/source/slang/slang-ir-inst-defs.h
+++ b/source/slang/slang-ir-inst-defs.h
@@ -867,9 +867,14 @@ INST(HighLevelDeclDecoration, highLevelDecl, 1, 0)
/// differential member of a type in its associated differential type.
INST(DerivativeMemberDecoration, derivativeMemberDecoration, 1, 0)
- /// Treat a function as differentiable function, or an IRCall as a call to a differentiable function.
+ /// Treat a function as differentiable function
INST(TreatAsDifferentiableDecoration, treatAsDifferentiableDecoration, 0, 0)
+ /// Treat a call to arbitrary function as a differentiable call.
+ INST(TreatCallAsDifferentiableDecoration, treatCallAsDifferentiableDecoration, 0, 0)
+
+ /// Mark a call as explicitly calling a differentiable function.
+ INST(DifferentiableCallDecoration, differentiableCallDecoration, 0, 0)
/// Hint that the result from a call to the decorated function should be stored in backward prop function.
INST(PreferCheckpointDecoration, PreferCheckpointDecoration, 0, 0)
diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h
index f933697a2..417b39da3 100644
--- a/source/slang/slang-ir-insts.h
+++ b/source/slang/slang-ir-insts.h
@@ -899,6 +899,26 @@ struct IRTreatAsDifferentiableDecoration : IRDecoration
IR_LEAF_ISA(TreatAsDifferentiableDecoration)
};
+// Mark a call as explicitly calling a differentiable function.
+struct IRDifferentiableCallDecoration : IRDecoration
+{
+ enum
+ {
+ kOp = kIROp_DifferentiableCallDecoration
+ };
+ IR_LEAF_ISA(DifferentiableCallDecoration)
+};
+
+// Treat a call to a non-differentiable function as a differentiable call.
+struct IRTreatCallAsDifferentiableDecoration : IRDecoration
+{
+ enum
+ {
+ kOp = kIROp_TreatCallAsDifferentiableDecoration
+ };
+ IR_LEAF_ISA(TreatCallAsDifferentiableDecoration)
+};
+
struct IRDerivativeMemberDecoration : IRDecoration
{
enum
diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp
index e49da00c2..0773226d1 100644
--- a/source/slang/slang-lower-to-ir.cpp
+++ b/source/slang/slang-lower-to-ir.cpp
@@ -3311,7 +3311,14 @@ struct ExprLoweringVisitorBase : ExprVisitor<Derived, LoweredValInfo>
// to handle that case might be.
//
if (as<IRCall>(materializedVal.val))
- getBuilder()->addDecoration(materializedVal.val, kIROp_TreatAsDifferentiableDecoration);
+ {
+ if (expr->flavor == TreatAsDifferentiableExpr::Flavor::NoDiff)
+ getBuilder()->addDecoration(materializedVal.val, kIROp_TreatCallAsDifferentiableDecoration);
+ else if (expr->flavor == TreatAsDifferentiableExpr::Flavor::Differentiable)
+ getBuilder()->addDecoration(materializedVal.val, kIROp_DifferentiableCallDecoration);
+ else
+ SLANG_UNEXPECTED("Unknown TreatAsDifferentiableExpr::Flavor");
+ }
innerInst = getSimpleVal(context, materializedVal);
@@ -3324,13 +3331,19 @@ struct ExprLoweringVisitorBase : ExprVisitor<Derived, LoweredValInfo>
}
else
{
- SLANG_ASSERT("TreatAsDifferentiableExpr on non-simple l-values not properly defined.");
+ SLANG_UNEXPECTED("TreatAsDifferentiableExpr on non-simple l-values not properly defined.");
}
}
else
{
- if (as<IRCall>(baseVal.val))
- getBuilder()->addDecoration(baseVal.val, kIROp_TreatAsDifferentiableDecoration);
+ if (auto callInst = as<IRCall>(baseVal.val))
+ if (expr->flavor == TreatAsDifferentiableExpr::Flavor::NoDiff)
+ getBuilder()->addDecoration(callInst, kIROp_TreatCallAsDifferentiableDecoration);
+ else if (expr->flavor == TreatAsDifferentiableExpr::Flavor::Differentiable)
+ getBuilder()->addDecoration(callInst, kIROp_DifferentiableCallDecoration);
+ else
+ SLANG_UNEXPECTED("Unknown TreatAsDifferentiableExpr::Flavor");
+
innerInst = baseVal.val;
}
diff --git a/source/slang/slang-parser.cpp b/source/slang/slang-parser.cpp
index b8310451c..b0af5378c 100644
--- a/source/slang/slang-parser.cpp
+++ b/source/slang/slang-parser.cpp
@@ -5298,6 +5298,7 @@ namespace Slang
auto noDiffExpr = parser->astBuilder->create<TreatAsDifferentiableExpr>();
noDiffExpr->innerExpr = parser->ParseLeafExpression();
noDiffExpr->scope = parser->currentScope;
+ noDiffExpr->flavor = TreatAsDifferentiableExpr::Flavor::NoDiff;
return noDiffExpr;
}