summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--examples/heterogeneous-hello-world/main.cpp2
-rw-r--r--examples/heterogeneous-hello-world/shader.slang6
-rw-r--r--source/slang/slang-ast-stmt.h11
-rw-r--r--source/slang/slang-check-impl.h2
-rw-r--r--source/slang/slang-check-stmt.cpp9
-rw-r--r--source/slang/slang-lower-to-ir.cpp6
-rw-r--r--source/slang/slang-parser.cpp52
7 files changed, 85 insertions, 3 deletions
diff --git a/examples/heterogeneous-hello-world/main.cpp b/examples/heterogeneous-hello-world/main.cpp
index 47df20dc5..6bb1bc071 100644
--- a/examples/heterogeneous-hello-world/main.cpp
+++ b/examples/heterogeneous-hello-world/main.cpp
@@ -35,7 +35,7 @@
#include "gfx/render.h"
#include "gfx/d3d11/render-d3d11.h"
#include "gfx/window.h"
-#include "../../prelude/slang-cpp-types.h";
+#include "../../prelude/slang-cpp-types.h"
using namespace gfx;
// We create global ref pointers to avoid dereferencing values
diff --git a/examples/heterogeneous-hello-world/shader.slang b/examples/heterogeneous-hello-world/shader.slang
index a9ad66cc7..6b56c8700 100644
--- a/examples/heterogeneous-hello-world/shader.slang
+++ b/examples/heterogeneous-hello-world/shader.slang
@@ -1,10 +1,10 @@
// shader.slang
//TEST_INPUT:ubuffer(random(float, 4096, -1.0, 1.0), stride=4):name=ioBuffer
-RWStructuredBuffer<float> ioBuffer;
+RWStructuredBuffer<float> convertBuffer(Ptr<gfx::BufferResource> x);
[numthreads(4, 1, 1)]
-void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
+void computeMain(uniform RWStructuredBuffer<float> ioBuffer, uint3 dispatchThreadID : SV_DispatchThreadID)
{
uint tid = dispatchThreadID.x;
@@ -69,6 +69,8 @@ public bool executeComputation() {
let window = createWindow(windowWidth, windowHeight);
let renderer = createRenderer(windowWidth, windowHeight, window);
let structuredBuffer = createStructuredBuffer(renderer, initialArray);
+ __GPU_FOREACH(renderer, uint3(4, 1, 1), LAMBDA(uint3 dispatchThreadID)
+ { computeMain(convertBuffer(structuredBuffer), dispatchThreadID) ; });
let shaderProgram = loadShaderProgram(renderer);
let descriptorSetLayout = buildDescriptorSetLayout(renderer);
let pipelineLayout = buildPipeline(renderer, descriptorSetLayout);
diff --git a/source/slang/slang-ast-stmt.h b/source/slang/slang-ast-stmt.h
index e9e5ea4f3..2fb10db1f 100644
--- a/source/slang/slang-ast-stmt.h
+++ b/source/slang/slang-ast-stmt.h
@@ -117,6 +117,17 @@ class DefaultStmt : public CaseStmtBase
SLANG_CLASS(DefaultStmt)
};
+// a `default` statement inside a `switch`
+class GpuForeachStmt : public ScopeStmt
+{
+ SLANG_CLASS(GpuForeachStmt)
+
+ Expr* renderer = nullptr;
+ Expr* gridDims = nullptr;
+ VarDecl* dispatchThreadID = nullptr;
+ Expr* kernelCall = nullptr;
+};
+
// A statement that represents a loop, and can thus be escaped with a `continue`
class LoopStmt : public BreakableStmt
{
diff --git a/source/slang/slang-check-impl.h b/source/slang/slang-check-impl.h
index 97e77ec3e..19f5553c7 100644
--- a/source/slang/slang-check-impl.h
+++ b/source/slang/slang-check-impl.h
@@ -1596,6 +1596,8 @@ namespace Slang
void visitReturnStmt(ReturnStmt *stmt);
void visitWhileStmt(WhileStmt *stmt);
+
+ void visitGpuForeachStmt(GpuForeachStmt *stmt);
void visitExpressionStmt(ExpressionStmt *stmt);
};
diff --git a/source/slang/slang-check-stmt.cpp b/source/slang/slang-check-stmt.cpp
index 2d01086f1..9a5aee15c 100644
--- a/source/slang/slang-check-stmt.cpp
+++ b/source/slang/slang-check-stmt.cpp
@@ -288,4 +288,13 @@ namespace Slang
stmt->expression = CheckExpr(stmt->expression);
}
+ void SemanticsStmtVisitor::visitGpuForeachStmt(GpuForeachStmt*stmt)
+ {
+ stmt->renderer = CheckExpr(stmt->renderer);
+ stmt->gridDims = CheckExpr(stmt->gridDims);
+ ensureDeclBase(stmt->dispatchThreadID, DeclCheckState::Checked);
+ WithOuterStmt subContext(this, stmt);
+ stmt->kernelCall = subContext.CheckExpr(stmt->kernelCall);
+ return;
+ }
}
diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp
index a4542647a..9c4808f31 100644
--- a/source/slang/slang-lower-to-ir.cpp
+++ b/source/slang/slang-lower-to-ir.cpp
@@ -4114,6 +4114,12 @@ struct StmtLoweringVisitor : StmtVisitor<StmtLoweringVisitor>
insertBlock(breakLabel);
}
+ void visitGpuForeachStmt(GpuForeachStmt* stmt)
+ {
+ startBlockIfNeeded(stmt);
+ return;
+ }
+
void visitExpressionStmt(ExpressionStmt* stmt)
{
startBlockIfNeeded(stmt);
diff --git a/source/slang/slang-parser.cpp b/source/slang/slang-parser.cpp
index 460f781fe..179587550 100644
--- a/source/slang/slang-parser.cpp
+++ b/source/slang/slang-parser.cpp
@@ -3380,6 +3380,56 @@ namespace Slang
return stmt;
}
+ GpuForeachStmt* ParseGpuForeachStmt(Parser* parser)
+ {
+ // Hard-coding parsing of the following:
+ // __GPU_FOREACH(renderer, gridDims, LAMBDA(uint3 dispatchThreadID) {
+ // kernelCall(args, ...); });
+
+ // Setup the scope so that dispatchThreadID is in scope for kernelCall
+ ScopeDecl* scopeDecl = parser->astBuilder->create<ScopeDecl>();
+ GpuForeachStmt* stmt = parser->astBuilder->create<GpuForeachStmt>();
+ stmt->scopeDecl = scopeDecl;
+
+ parser->FillPosition(stmt);
+ parser->ReadToken("__GPU_FOREACH");
+ parser->ReadToken(TokenType::LParent);
+ stmt->renderer = parser->ParseArgExpr();
+ parser->ReadToken(TokenType::Comma);
+ stmt->gridDims = parser->ParseArgExpr();
+
+ parser->ReadToken(TokenType::Comma);
+ parser->ReadToken("LAMBDA");
+ parser->ReadToken(TokenType::LParent);
+
+ auto idType = parser->ParseTypeExp();
+ NameLoc varNameAndLoc = expectIdentifier(parser);
+ VarDecl* varDecl = parser->astBuilder->create<VarDecl>();
+ varDecl->nameAndLoc = varNameAndLoc;
+ varDecl->loc = varNameAndLoc.loc;
+ varDecl->type = idType;
+ stmt->dispatchThreadID = varDecl;
+
+ parser->ReadToken(TokenType::RParent);
+ parser->ReadToken(TokenType::LBrace);
+
+ parser->pushScopeAndSetParent(scopeDecl);
+ AddMember(parser->currentScope, varDecl);
+
+ stmt->kernelCall = parser->ParseExpression();
+
+ parser->PopScope();
+
+ parser->ReadToken(TokenType::Semicolon);
+ parser->ReadToken(TokenType::RBrace);
+
+ parser->ReadToken(TokenType::RParent);
+
+ parser->ReadToken(TokenType::Semicolon);
+
+ return stmt;
+ }
+
static bool _isType(Decl* decl)
{
return decl && (as<AggTypeDecl>(decl) || as<SimpleTypeDecl>(decl));
@@ -3552,6 +3602,8 @@ namespace Slang
statement = ParseCaseStmt(this);
else if (LookAheadToken("default"))
statement = ParseDefaultStmt(this);
+ else if (LookAheadToken("__GPU_FOREACH"))
+ statement = ParseGpuForeachStmt(this);
else if (LookAheadToken(TokenType::Dollar))
{
statement = parseCompileTimeStmt(this);