summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--source/slang/slang-ast-expr.h8
-rw-r--r--source/slang/slang-check-conversion.cpp21
-rw-r--r--source/slang/slang-check-expr.cpp7
-rw-r--r--source/slang/slang-check-impl.h2
-rw-r--r--source/slang/slang-lower-to-ir.cpp10
-rw-r--r--tests/autodiff/field-extract-of-make-struct.slang43
-rw-r--r--tests/autodiff/field-extract-of-make-struct.slang.expected.txt6
7 files changed, 97 insertions, 0 deletions
diff --git a/source/slang/slang-ast-expr.h b/source/slang/slang-ast-expr.h
index c07f7f5b9..f9b8831f1 100644
--- a/source/slang/slang-ast-expr.h
+++ b/source/slang/slang-ast-expr.h
@@ -544,6 +544,13 @@ class OpenRefExpr : public Expr
Expr* innerExpr = nullptr;
};
+class DetachExpr : public Expr
+{
+ SLANG_AST_CLASS(DetachExpr)
+
+ Expr* inner = nullptr;
+};
+
/// Base class for higher-order function application
/// Eg: foo(fn) where fn is a function expression.
///
@@ -563,6 +570,7 @@ class DifferentiateExpr : public HigherOrderInvokeExpr
{
SLANG_ABSTRACT_AST_CLASS(DifferentiateExpr)
};
+
/// An expression of the form `__fwd_diff(fn)` to access the
/// forward-mode derivative version of the function `fn`
///
diff --git a/source/slang/slang-check-conversion.cpp b/source/slang/slang-check-conversion.cpp
index c135ada8d..85040dc55 100644
--- a/source/slang/slang-check-conversion.cpp
+++ b/source/slang/slang-check-conversion.cpp
@@ -498,6 +498,27 @@ namespace Slang
toInitializerListExpr->type = QualType(toType);
toInitializerListExpr->args = coercedArgs;
+ // Wrap initalizer list args if we're creating a non-differentiable struct within a
+ // differentiable function.
+ //
+ if (auto func = getParentFuncOfVisitor())
+ {
+ if (func->findModifier<DifferentiableAttribute>() &&
+ !isTypeDifferentiable(toType))
+ {
+ for (auto &arg : toInitializerListExpr->args)
+ {
+ if (isTypeDifferentiable(arg->type.type))
+ {
+ auto detachedArg = m_astBuilder->create<DetachExpr>();
+ detachedArg->inner = arg;
+ detachedArg->type = arg->type;
+ arg = detachedArg;
+ }
+ }
+ }
+ }
+
*outToExpr = toInitializerListExpr;
}
diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp
index 4d36299bb..3072c3257 100644
--- a/source/slang/slang-check-expr.cpp
+++ b/source/slang/slang-check-expr.cpp
@@ -3242,6 +3242,13 @@ namespace Slang
return expr;
}
+ Expr* SemanticsExprVisitor::visitDetachExpr(DetachExpr* expr)
+ {
+ expr->inner = CheckTerm(expr->inner);
+ expr->type = expr->inner->type;
+ return expr;
+ }
+
static bool _isSizeOfType(Type* type)
{
diff --git a/source/slang/slang-check-impl.h b/source/slang/slang-check-impl.h
index 82ac8de59..f997abc57 100644
--- a/source/slang/slang-check-impl.h
+++ b/source/slang/slang-check-impl.h
@@ -2806,6 +2806,8 @@ namespace Slang
Expr* visitDefaultConstructExpr(DefaultConstructExpr* expr);
+ Expr* visitDetachExpr(DetachExpr* expr);
+
Expr* visitSPIRVAsmExpr(SPIRVAsmExpr*);
/// Perform semantic checking on a `modifier` that is being applied to the given `type`
diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp
index 02c4fae68..b9d7a898f 100644
--- a/source/slang/slang-lower-to-ir.cpp
+++ b/source/slang/slang-lower-to-ir.cpp
@@ -3929,6 +3929,16 @@ struct ExprLoweringVisitorBase : public ExprVisitor<Derived, LoweredValInfo>
baseVal.val));
}
+ LoweredValInfo visitDetachExpr(DetachExpr* expr)
+ {
+ auto baseVal = lowerRValueExpr(context, expr->inner);
+
+ return LoweredValInfo::simple(
+ getBuilder()->emitDetachDerivative(
+ lowerType(context, expr->type),
+ getSimpleVal(context, baseVal)));
+ }
+
LoweredValInfo visitPrimalSubstituteExpr(PrimalSubstituteExpr* expr)
{
auto baseVal = lowerSubExpr(expr->baseFunction);
diff --git a/tests/autodiff/field-extract-of-make-struct.slang b/tests/autodiff/field-extract-of-make-struct.slang
new file mode 100644
index 000000000..1cf2b40c4
--- /dev/null
+++ b/tests/autodiff/field-extract-of-make-struct.slang
@@ -0,0 +1,43 @@
+//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type
+//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type
+//TEST(compute):COMPARE_COMPUTE_EX:-cpu -compute -output-using-type -shaderobj
+
+//TEST_INPUT:ubuffer(data=[0 0 0 0 0], stride=4):out,name=outputBuffer
+RWStructuredBuffer<float> outputBuffer;
+
+// This test checks to ensure that expressions of the form FieldExtract(MakeStruct(...))
+// are handled correctly in the presence of differentiation.
+// Such expressions are often optimized out, and the non-differentiability of the struct
+// type can be accidentally overlooked.
+//
+
+typedef DifferentialPair<float> dpfloat;
+typedef float.Differential dfloat;
+
+struct Data
+{
+ [PreferRecompute]
+ __init(float tin)
+ {
+ this.t = tin;
+ }
+ float t;
+};
+
+[BackwardDifferentiable]
+float test_make_struct(float y)
+{
+ Data d = { y };
+ return d.t;
+}
+
+[numthreads(1, 1, 1)]
+void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
+{
+ {
+ dpfloat dpa = dpfloat(1.0, 0.0);
+
+ bwd_diff(test_make_struct)(dpa, 1.0f);
+ outputBuffer[0] = dpa.d; // Expect: 0.0
+ }
+} \ No newline at end of file
diff --git a/tests/autodiff/field-extract-of-make-struct.slang.expected.txt b/tests/autodiff/field-extract-of-make-struct.slang.expected.txt
new file mode 100644
index 000000000..e070cf84d
--- /dev/null
+++ b/tests/autodiff/field-extract-of-make-struct.slang.expected.txt
@@ -0,0 +1,6 @@
+type: float
+0.000000
+0.000000
+0.000000
+0.000000
+0.000000