summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--source/slang/slang-check-expr.cpp5
-rw-r--r--tests/language-server/high-order-expr.slang21
-rw-r--r--tests/language-server/high-order-expr.slang.expected.txt3
3 files changed, 29 insertions, 0 deletions
diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp
index c7d69262d..fe37f5099 100644
--- a/source/slang/slang-check-expr.cpp
+++ b/source/slang/slang-check-expr.cpp
@@ -1977,6 +1977,11 @@ namespace Slang
// Check/Resolve inner function declaration.
expr->baseFunction = CheckTerm(expr->baseFunction);
+
+ // For now we only support using higher order expr as callee in an invoke expr.
+ // The actual type of the higher order function will be derived during resolve invoke.
+ expr->type = m_astBuilder->getBottomType();
+
return expr;
}
diff --git a/tests/language-server/high-order-expr.slang b/tests/language-server/high-order-expr.slang
new file mode 100644
index 000000000..f7f502e8f
--- /dev/null
+++ b/tests/language-server/high-order-expr.slang
@@ -0,0 +1,21 @@
+//TEST:LANG_SERVER:
+RWStructuredBuffer<float> outputBuffer;
+
+[ForwardDifferentiable]
+float f(float x)
+{
+ return x * x;
+}
+
+[ForwardDifferentiable]
+float df(float x)
+{
+//HOVER:14,17
+ return __fwd_diff();
+}
+
+[numthreads(1, 1, 1)]
+void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
+{
+ outputBuffer[0] = __fwd_diff(df)(__DifferentialPair<float>(x, 1.0)).d(); // Expect: 2.0
+}
diff --git a/tests/language-server/high-order-expr.slang.expected.txt b/tests/language-server/high-order-expr.slang.expected.txt
new file mode 100644
index 000000000..dd2c29b31
--- /dev/null
+++ b/tests/language-server/high-order-expr.slang.expected.txt
@@ -0,0 +1,3 @@
+--------
+null
+