summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2024-11-04 17:37:50 -0800
committerGitHub <noreply@github.com>2024-11-04 17:37:50 -0800
commit7c2ff54758d26b73074fd14143ecd843ba685e0d (patch)
tree0abe5c4f11de2bdb1e960a3fef441c36d420966e
parent2c8dacfa471903a802a252905ec108420ee25d63 (diff)
Various WGSL fixes. (#5490)
* [WGSL] make sure switch has a default label. * Various WGSL fixes. * Update rhi submodule commit * format code * Remove unnecessary DISABLE_TEST directive on not applicable test. * Matrix comp mul + `select`. * Legalize binary ops for wgsl. --------- Co-authored-by: slangbot <186143334+slangbot@users.noreply.github.com>
m---------external/slang-rhi0
-rw-r--r--source/slang/slang-emit-wgsl.cpp126
-rw-r--r--source/slang/slang-emit-wgsl.h2
-rw-r--r--source/slang/slang-ir-insts.h1
-rw-r--r--source/slang/slang-ir-wgsl-legalize.cpp90
-rw-r--r--source/slang/slang-ir.cpp11
-rw-r--r--tests/autodiff-dstdlib/dstdlib-abs.slang2
-rw-r--r--tests/autodiff/matrix-arithmetic-fwd.slang2
-rw-r--r--tests/autodiff/reverse-loop-checkpoint-test.slang1
-rw-r--r--tests/bugs/nested-switch.slang2
-rw-r--r--tests/hlsl-intrinsic/sampler-feedback/compute-sampler-feedback.slang1
-rw-r--r--tests/ir/string-literal-hash.slang2
-rw-r--r--tests/language-feature/constants/constexpr-loop.slang2
-rw-r--r--tests/library/linked.spirvbin816 -> 0 bytes
14 files changed, 230 insertions, 12 deletions
diff --git a/external/slang-rhi b/external/slang-rhi
-Subproject 93c2ba8f68edee6732372ce4505bfc2a8640a1b
+Subproject 10ab9c69fb0f1e3f476c7fd66ca7f3bedffebe5
diff --git a/source/slang/slang-emit-wgsl.cpp b/source/slang/slang-emit-wgsl.cpp
index 4aca03a61..d8ec01776 100644
--- a/source/slang/slang-emit-wgsl.cpp
+++ b/source/slang/slang-emit-wgsl.cpp
@@ -497,6 +497,34 @@ void WGSLSourceEmitter::emitLayoutQualifiersImpl(IRVarLayout* layout)
}
}
+static bool isStaticConst(IRInst* inst)
+{
+ if (inst->getParent()->getOp() == kIROp_Module)
+ {
+ return true;
+ }
+ switch (inst->getOp())
+ {
+ case kIROp_MakeVector:
+ case kIROp_swizzle:
+ case kIROp_swizzleSet:
+ case kIROp_IntCast:
+ case kIROp_FloatCast:
+ case kIROp_CastFloatToInt:
+ case kIROp_CastIntToFloat:
+ case kIROp_BitCast:
+ {
+ for (UInt i = 0; i < inst->getOperandCount(); i++)
+ {
+ if (!isStaticConst(inst->getOperand(i)))
+ return false;
+ }
+ return true;
+ }
+ }
+ return false;
+}
+
void WGSLSourceEmitter::emitVarKeywordImpl(IRType* type, IRInst* varDecl)
{
switch (varDecl->getOp())
@@ -505,14 +533,10 @@ void WGSLSourceEmitter::emitVarKeywordImpl(IRType* type, IRInst* varDecl)
case kIROp_GlobalVar:
case kIROp_Var: m_writer->emit("var"); break;
default:
- if (as<IRModuleInst>(varDecl->getParent()))
- {
+ if (isStaticConst(varDecl))
m_writer->emit("const");
- }
else
- {
m_writer->emit("var");
- }
break;
}
@@ -977,6 +1001,33 @@ void WGSLSourceEmitter::emitCallArg(IRInst* inst)
}
}
+bool WGSLSourceEmitter::shouldFoldInstIntoUseSites(IRInst* inst)
+{
+ bool result = CLikeSourceEmitter::shouldFoldInstIntoUseSites(inst);
+ if (result)
+ {
+ // If inst is a matrix, and is used in a component-wise multiply,
+ // we need to not fold it.
+ if (as<IRMatrixType>(inst->getDataType()))
+ {
+ for (auto use = inst->firstUse; use; use = use->nextUse)
+ {
+ auto user = use->getUser();
+ if (user->getOp() == kIROp_Mul)
+ {
+ if (as<IRMatrixType>(user->getOperand(0)->getDataType()) &&
+ as<IRMatrixType>(user->getOperand(1)->getDataType()))
+ {
+ return false;
+ }
+ }
+ }
+ }
+ }
+ return result;
+}
+
+
bool WGSLSourceEmitter::tryEmitInstExprImpl(IRInst* inst, const EmitOpInfo& inOuterPrec)
{
EmitOpInfo outerPrec = inOuterPrec;
@@ -1126,6 +1177,71 @@ bool WGSLSourceEmitter::tryEmitInstExprImpl(IRInst* inst, const EmitOpInfo& inOu
return true;
}
break;
+
+ case kIROp_GetStringHash:
+ {
+ auto getStringHashInst = as<IRGetStringHash>(inst);
+ auto stringLit = getStringHashInst->getStringLit();
+
+ if (stringLit)
+ {
+ auto slice = stringLit->getStringSlice();
+ emitType(inst->getDataType());
+ m_writer->emit("(");
+ m_writer->emit((int)getStableHashCode32(slice.begin(), slice.getLength()).hash);
+ m_writer->emit(")");
+ }
+ else
+ {
+ // Couldn't handle
+ diagnoseUnhandledInst(inst);
+ }
+ return true;
+ }
+
+ case kIROp_Mul:
+ {
+ if (!as<IRMatrixType>(inst->getOperand(0)->getDataType()) ||
+ !as<IRMatrixType>(inst->getOperand(1)->getDataType()))
+ {
+ return false;
+ }
+ // Mul(m1, m2) should be translated to component-wise multiplication in WGSL.
+ auto matrixType = as<IRMatrixType>(inst->getDataType());
+ auto rowCount = getIntVal(matrixType->getRowCount());
+ emitType(inst->getDataType());
+ m_writer->emit("(");
+ for (IRIntegerValue i = 0; i < rowCount; i++)
+ {
+ if (i != 0)
+ {
+ m_writer->emit(", ");
+ }
+ emitOperand(inst->getOperand(0), getInfo(EmitOp::Postfix));
+ m_writer->emit("[");
+ m_writer->emit(i);
+ m_writer->emit("] * ");
+ emitOperand(inst->getOperand(1), getInfo(EmitOp::Postfix));
+ m_writer->emit("[");
+ m_writer->emit(i);
+ m_writer->emit("]");
+ }
+ m_writer->emit(")");
+
+ return true;
+ }
+
+ case kIROp_Select:
+ {
+ m_writer->emit("select(");
+ emitOperand(inst->getOperand(2), getInfo(EmitOp::General));
+ m_writer->emit(", ");
+ emitOperand(inst->getOperand(1), getInfo(EmitOp::General));
+ m_writer->emit(", ");
+ emitOperand(inst->getOperand(0), getInfo(EmitOp::General));
+ m_writer->emit(")");
+ return true;
+ }
}
return false;
diff --git a/source/slang/slang-emit-wgsl.h b/source/slang/slang-emit-wgsl.h
index 70df65933..1a8ec2fd5 100644
--- a/source/slang/slang-emit-wgsl.h
+++ b/source/slang/slang-emit-wgsl.h
@@ -50,6 +50,8 @@ public:
void emit(const AddressSpace addressSpace);
+ virtual bool shouldFoldInstIntoUseSites(IRInst* inst) SLANG_OVERRIDE;
+
private:
// Emit the matrix type with 'rowCountWGSL' WGSL-rows and 'colCountWGSL' WGSL-columns
void emitMatrixType(
diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h
index c44211c1c..9a081f9de 100644
--- a/source/slang/slang-ir-insts.h
+++ b/source/slang/slang-ir-insts.h
@@ -4021,6 +4021,7 @@ public:
IRInst* emitDifferentialPairGetPrimalUserCode(IRInst* diffPair);
IRInst* emitMakeVector(IRType* type, UInt argCount, IRInst* const* args);
IRInst* emitMakeVectorFromScalar(IRType* type, IRInst* scalarValue);
+ IRInst* emitMakeCompositeFromScalar(IRType* type, IRInst* scalarValue);
IRInst* emitMakeVector(IRType* type, List<IRInst*> const& args)
{
diff --git a/source/slang/slang-ir-wgsl-legalize.cpp b/source/slang/slang-ir-wgsl-legalize.cpp
index c97a8a89f..96eb13be4 100644
--- a/source/slang/slang-ir-wgsl-legalize.cpp
+++ b/source/slang/slang-ir-wgsl-legalize.cpp
@@ -51,6 +51,8 @@ struct LegalizeWGSLEntryPointContext
String* optionalSemanticIndex,
IRInst* parentVar);
void legalizeCall(IRCall* call);
+ void legalizeSwitch(IRSwitch* switchInst);
+ void legalizeBinaryOp(IRInst* inst);
void processInst(IRInst* inst);
};
@@ -349,11 +351,97 @@ void LegalizeWGSLEntryPointContext::legalizeCall(IRCall* call)
}
}
+void LegalizeWGSLEntryPointContext::legalizeSwitch(IRSwitch* switchInst)
+{
+ // WGSL Requires all switch statements to contain a default case.
+ // If the switch statement does not contain a default case, we will add one.
+ if (switchInst->getDefaultLabel() != switchInst->getBreakLabel())
+ return;
+ IRBuilder builder(switchInst);
+ auto defaultBlock = builder.createBlock();
+ builder.setInsertInto(defaultBlock);
+ builder.emitBranch(switchInst->getBreakLabel());
+ defaultBlock->insertBefore(switchInst->getBreakLabel());
+ List<IRInst*> cases;
+ for (UInt i = 0; i < switchInst->getCaseCount(); i++)
+ {
+ cases.add(switchInst->getCaseValue(i));
+ cases.add(switchInst->getCaseLabel(i));
+ }
+ builder.setInsertBefore(switchInst);
+ auto newSwitch = builder.emitSwitch(
+ switchInst->getCondition(),
+ switchInst->getBreakLabel(),
+ defaultBlock,
+ (UInt)cases.getCount(),
+ cases.getBuffer());
+ switchInst->transferDecorationsTo(newSwitch);
+ switchInst->removeAndDeallocate();
+}
+
+void LegalizeWGSLEntryPointContext::legalizeBinaryOp(IRInst* inst)
+{
+ auto isVectorOrMatrix = [](IRType* type)
+ {
+ switch (type->getOp())
+ {
+ case kIROp_VectorType:
+ case kIROp_MatrixType: return true;
+ default: return false;
+ }
+ };
+ if (isVectorOrMatrix(inst->getOperand(0)->getDataType()) &&
+ as<IRBasicType>(inst->getOperand(1)->getDataType()))
+ {
+ IRBuilder builder(inst);
+ builder.setInsertBefore(inst);
+ auto newRhs = builder.emitMakeCompositeFromScalar(
+ inst->getOperand(0)->getDataType(),
+ inst->getOperand(1));
+ builder.replaceOperand(inst->getOperands() + 1, newRhs);
+ }
+ else if (
+ as<IRBasicType>(inst->getOperand(0)->getDataType()) &&
+ isVectorOrMatrix(inst->getOperand(1)->getDataType()))
+ {
+ IRBuilder builder(inst);
+ builder.setInsertBefore(inst);
+ auto newLhs = builder.emitMakeCompositeFromScalar(
+ inst->getOperand(1)->getDataType(),
+ inst->getOperand(0));
+ builder.replaceOperand(inst->getOperands(), newLhs);
+ }
+}
+
void LegalizeWGSLEntryPointContext::processInst(IRInst* inst)
{
switch (inst->getOp())
{
- case kIROp_Call: legalizeCall(static_cast<IRCall*>(inst)); break;
+ case kIROp_Call: legalizeCall(static_cast<IRCall*>(inst)); break;
+ case kIROp_Switch: legalizeSwitch(as<IRSwitch>(inst)); break;
+
+ // For all binary operators, make sure both side of the operator have the same type
+ // (vector-ness and matrix-ness).
+ case kIROp_Add:
+ case kIROp_Sub:
+ case kIROp_Mul:
+ case kIROp_Div:
+ case kIROp_FRem:
+ case kIROp_IRem:
+ case kIROp_And:
+ case kIROp_Or:
+ case kIROp_BitAnd:
+ case kIROp_BitOr:
+ case kIROp_BitXor:
+ case kIROp_Lsh:
+ case kIROp_Rsh:
+ case kIROp_Eql:
+ case kIROp_Neq:
+ case kIROp_Greater:
+ case kIROp_Less:
+ case kIROp_Geq:
+ case kIROp_Leq: legalizeBinaryOp(inst); break;
+
default:
for (auto child : inst->getModifiableChildren())
processInst(child);
diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp
index 49273163e..3bd31d6e9 100644
--- a/source/slang/slang-ir.cpp
+++ b/source/slang/slang-ir.cpp
@@ -4162,6 +4162,17 @@ IRInst* IRBuilder::emitMakeVectorFromScalar(IRType* type, IRInst* scalarValue)
return emitIntrinsicInst(type, kIROp_MakeVectorFromScalar, 1, &scalarValue);
}
+IRInst* IRBuilder::emitMakeCompositeFromScalar(IRType* type, IRInst* scalarValue)
+{
+ switch (type->getOp())
+ {
+ case kIROp_VectorType: return emitMakeVectorFromScalar(type, scalarValue);
+ case kIROp_MatrixType: return emitMakeMatrixFromScalar(type, scalarValue);
+ case kIROp_ArrayType: return emitMakeArrayFromElement(type, scalarValue);
+ default: SLANG_UNEXPECTED("unhandled composite type"); UNREACHABLE_RETURN(nullptr);
+ }
+}
+
IRInst* IRBuilder::emitMatrixReshape(IRType* type, IRInst* inst)
{
return emitIntrinsicInst(type, kIROp_MatrixReshape, 1, &inst);
diff --git a/tests/autodiff-dstdlib/dstdlib-abs.slang b/tests/autodiff-dstdlib/dstdlib-abs.slang
index c0878bfb4..d11f06b31 100644
--- a/tests/autodiff-dstdlib/dstdlib-abs.slang
+++ b/tests/autodiff-dstdlib/dstdlib-abs.slang
@@ -1,6 +1,6 @@
//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type
//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type
-//DISABLE_TEST(compute):COMPARE_COMPUTE_EX:-wgpu
+//TEST(compute):COMPARE_COMPUTE_EX:-wgpu -compute -output-using-type
//TEST_INPUT:ubuffer(data=[0 0 0 0 0 0 0 0 0 0 0 0], stride=4):out,name=outputBuffer
RWStructuredBuffer<float> outputBuffer;
diff --git a/tests/autodiff/matrix-arithmetic-fwd.slang b/tests/autodiff/matrix-arithmetic-fwd.slang
index 0dd1936af..0c2db76e9 100644
--- a/tests/autodiff/matrix-arithmetic-fwd.slang
+++ b/tests/autodiff/matrix-arithmetic-fwd.slang
@@ -1,6 +1,6 @@
+//TEST(compute):COMPARE_COMPUTE_EX:-wgpu -compute -output-using-type
//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type
//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type
-//DISABLE_TEST(compute):COMPARE_COMPUTE_EX:-wgpu
//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBuffer
RWStructuredBuffer<float> outputBuffer;
diff --git a/tests/autodiff/reverse-loop-checkpoint-test.slang b/tests/autodiff/reverse-loop-checkpoint-test.slang
index 19316a786..8191608fd 100644
--- a/tests/autodiff/reverse-loop-checkpoint-test.slang
+++ b/tests/autodiff/reverse-loop-checkpoint-test.slang
@@ -1,6 +1,7 @@
//TEST(compute):COMPARE_COMPUTE_EX:-dx12 -compute -shaderobj -output-using-type
//TEST(compute):COMPARE_COMPUTE_EX:-cuda -compute -shaderobj -output-using-type
//TEST(compute):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type
+//TEST(compute):COMPARE_COMPUTE_EX:-wgpu -compute -shaderobj -output-using-type
//TEST:SIMPLE(filecheck=CHECK): -target hlsl -profile cs_5_0 -entry computeMain -line-directive-mode none
//DISABLE_TEST:SIMPLE(filecheck=CHK):-target glsl -stage compute -entry computeMain -report-checkpoint-intermediates
diff --git a/tests/bugs/nested-switch.slang b/tests/bugs/nested-switch.slang
index 485a83e1f..90abe70d5 100644
--- a/tests/bugs/nested-switch.slang
+++ b/tests/bugs/nested-switch.slang
@@ -3,7 +3,7 @@
//TEST(compute):COMPARE_COMPUTE: -shaderobj
//TEST(compute):COMPARE_COMPUTE:-vk -shaderobj
//TEST(compute):COMPARE_COMPUTE:-cpu -shaderobj
-//DISABLE_TEST(compute):COMPARE_COMPUTE_EX:-wgpu
+//TEST(compute):COMPARE_COMPUTE:-wgpu
int test(int t, int r)
{
diff --git a/tests/hlsl-intrinsic/sampler-feedback/compute-sampler-feedback.slang b/tests/hlsl-intrinsic/sampler-feedback/compute-sampler-feedback.slang
index a7fc8731c..77e7c2050 100644
--- a/tests/hlsl-intrinsic/sampler-feedback/compute-sampler-feedback.slang
+++ b/tests/hlsl-intrinsic/sampler-feedback/compute-sampler-feedback.slang
@@ -1,5 +1,4 @@
//TEST:COMPILE: -entry computeMain -stage compute -target callable tests/hlsl-intrinsic/sampler-feedback/compute-sampler-feedback.slang
-//DISABLE_TEST(compute):COMPARE_COMPUTE_EX:-wgpu
// Not available on non PS shader
// dx.op.writeSamplerFeedback WriteSamplerFeedback
diff --git a/tests/ir/string-literal-hash.slang b/tests/ir/string-literal-hash.slang
index 678a8d9c7..2d61a84c1 100644
--- a/tests/ir/string-literal-hash.slang
+++ b/tests/ir/string-literal-hash.slang
@@ -1,6 +1,6 @@
//TEST(compute):COMPARE_COMPUTE: -shaderobj
//TEST(compute):COMPARE_COMPUTE: -vk -shaderobj
-//DISABLE_TEST(compute):COMPARE_COMPUTE_EX:-wgpu
+//TEST(compute):COMPARE_COMPUTE:-wgpu
// Note: disabled on CPU target until we can fill
// in a more correct/complete `String` and `getStringHash`
diff --git a/tests/language-feature/constants/constexpr-loop.slang b/tests/language-feature/constants/constexpr-loop.slang
index 81b0a5c17..7af9c60b2 100644
--- a/tests/language-feature/constants/constexpr-loop.slang
+++ b/tests/language-feature/constants/constexpr-loop.slang
@@ -1,6 +1,6 @@
//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type
//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type
-//DISABLE_TEST(compute):COMPARE_COMPUTE_EX:-wgpu
+//TEST(compute):COMPARE_COMPUTE_EX: -wgpu -compute -output-using-type
//TEST_INPUT: set g_texture = Texture2D(size=8, content = one)
//TEST_INPUT: set g_sampler = Sampler
diff --git a/tests/library/linked.spirv b/tests/library/linked.spirv
deleted file mode 100644
index 7ea385e71..000000000
--- a/tests/library/linked.spirv
+++ /dev/null
Binary files differ