summaryrefslogtreecommitdiffstats
path: root/source/slang
diff options
context:
space:
mode:
Diffstat (limited to 'source/slang')
-rw-r--r--source/slang/slang-ast-expr.h8
-rw-r--r--source/slang/slang-check-conversion.cpp21
-rw-r--r--source/slang/slang-check-expr.cpp7
-rw-r--r--source/slang/slang-check-impl.h2
-rw-r--r--source/slang/slang-lower-to-ir.cpp10
5 files changed, 48 insertions, 0 deletions
diff --git a/source/slang/slang-ast-expr.h b/source/slang/slang-ast-expr.h
index c07f7f5b9..f9b8831f1 100644
--- a/source/slang/slang-ast-expr.h
+++ b/source/slang/slang-ast-expr.h
@@ -544,6 +544,13 @@ class OpenRefExpr : public Expr
Expr* innerExpr = nullptr;
};
+class DetachExpr : public Expr
+{
+ SLANG_AST_CLASS(DetachExpr)
+
+ Expr* inner = nullptr;
+};
+
/// Base class for higher-order function application
/// Eg: foo(fn) where fn is a function expression.
///
@@ -563,6 +570,7 @@ class DifferentiateExpr : public HigherOrderInvokeExpr
{
SLANG_ABSTRACT_AST_CLASS(DifferentiateExpr)
};
+
/// An expression of the form `__fwd_diff(fn)` to access the
/// forward-mode derivative version of the function `fn`
///
diff --git a/source/slang/slang-check-conversion.cpp b/source/slang/slang-check-conversion.cpp
index c135ada8d..85040dc55 100644
--- a/source/slang/slang-check-conversion.cpp
+++ b/source/slang/slang-check-conversion.cpp
@@ -498,6 +498,27 @@ namespace Slang
toInitializerListExpr->type = QualType(toType);
toInitializerListExpr->args = coercedArgs;
+ // Wrap initalizer list args if we're creating a non-differentiable struct within a
+ // differentiable function.
+ //
+ if (auto func = getParentFuncOfVisitor())
+ {
+ if (func->findModifier<DifferentiableAttribute>() &&
+ !isTypeDifferentiable(toType))
+ {
+ for (auto &arg : toInitializerListExpr->args)
+ {
+ if (isTypeDifferentiable(arg->type.type))
+ {
+ auto detachedArg = m_astBuilder->create<DetachExpr>();
+ detachedArg->inner = arg;
+ detachedArg->type = arg->type;
+ arg = detachedArg;
+ }
+ }
+ }
+ }
+
*outToExpr = toInitializerListExpr;
}
diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp
index 4d36299bb..3072c3257 100644
--- a/source/slang/slang-check-expr.cpp
+++ b/source/slang/slang-check-expr.cpp
@@ -3242,6 +3242,13 @@ namespace Slang
return expr;
}
+ Expr* SemanticsExprVisitor::visitDetachExpr(DetachExpr* expr)
+ {
+ expr->inner = CheckTerm(expr->inner);
+ expr->type = expr->inner->type;
+ return expr;
+ }
+
static bool _isSizeOfType(Type* type)
{
diff --git a/source/slang/slang-check-impl.h b/source/slang/slang-check-impl.h
index 82ac8de59..f997abc57 100644
--- a/source/slang/slang-check-impl.h
+++ b/source/slang/slang-check-impl.h
@@ -2806,6 +2806,8 @@ namespace Slang
Expr* visitDefaultConstructExpr(DefaultConstructExpr* expr);
+ Expr* visitDetachExpr(DetachExpr* expr);
+
Expr* visitSPIRVAsmExpr(SPIRVAsmExpr*);
/// Perform semantic checking on a `modifier` that is being applied to the given `type`
diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp
index 02c4fae68..b9d7a898f 100644
--- a/source/slang/slang-lower-to-ir.cpp
+++ b/source/slang/slang-lower-to-ir.cpp
@@ -3929,6 +3929,16 @@ struct ExprLoweringVisitorBase : public ExprVisitor<Derived, LoweredValInfo>
baseVal.val));
}
+ LoweredValInfo visitDetachExpr(DetachExpr* expr)
+ {
+ auto baseVal = lowerRValueExpr(context, expr->inner);
+
+ return LoweredValInfo::simple(
+ getBuilder()->emitDetachDerivative(
+ lowerType(context, expr->type),
+ getSimpleVal(context, baseVal)));
+ }
+
LoweredValInfo visitPrimalSubstituteExpr(PrimalSubstituteExpr* expr)
{
auto baseVal = lowerSubExpr(expr->baseFunction);