summaryrefslogtreecommitdiffstats
path: root/source
diff options
context:
space:
mode:
Diffstat (limited to 'source')
-rw-r--r--source/slang/slang-emit-wgsl.cpp126
-rw-r--r--source/slang/slang-emit-wgsl.h2
-rw-r--r--source/slang/slang-ir-insts.h1
-rw-r--r--source/slang/slang-ir-wgsl-legalize.cpp90
-rw-r--r--source/slang/slang-ir.cpp11
5 files changed, 224 insertions, 6 deletions
diff --git a/source/slang/slang-emit-wgsl.cpp b/source/slang/slang-emit-wgsl.cpp
index 4aca03a61..d8ec01776 100644
--- a/source/slang/slang-emit-wgsl.cpp
+++ b/source/slang/slang-emit-wgsl.cpp
@@ -497,6 +497,34 @@ void WGSLSourceEmitter::emitLayoutQualifiersImpl(IRVarLayout* layout)
}
}
+static bool isStaticConst(IRInst* inst)
+{
+ if (inst->getParent()->getOp() == kIROp_Module)
+ {
+ return true;
+ }
+ switch (inst->getOp())
+ {
+ case kIROp_MakeVector:
+ case kIROp_swizzle:
+ case kIROp_swizzleSet:
+ case kIROp_IntCast:
+ case kIROp_FloatCast:
+ case kIROp_CastFloatToInt:
+ case kIROp_CastIntToFloat:
+ case kIROp_BitCast:
+ {
+ for (UInt i = 0; i < inst->getOperandCount(); i++)
+ {
+ if (!isStaticConst(inst->getOperand(i)))
+ return false;
+ }
+ return true;
+ }
+ }
+ return false;
+}
+
void WGSLSourceEmitter::emitVarKeywordImpl(IRType* type, IRInst* varDecl)
{
switch (varDecl->getOp())
@@ -505,14 +533,10 @@ void WGSLSourceEmitter::emitVarKeywordImpl(IRType* type, IRInst* varDecl)
case kIROp_GlobalVar:
case kIROp_Var: m_writer->emit("var"); break;
default:
- if (as<IRModuleInst>(varDecl->getParent()))
- {
+ if (isStaticConst(varDecl))
m_writer->emit("const");
- }
else
- {
m_writer->emit("var");
- }
break;
}
@@ -977,6 +1001,33 @@ void WGSLSourceEmitter::emitCallArg(IRInst* inst)
}
}
+bool WGSLSourceEmitter::shouldFoldInstIntoUseSites(IRInst* inst)
+{
+ bool result = CLikeSourceEmitter::shouldFoldInstIntoUseSites(inst);
+ if (result)
+ {
+ // If inst is a matrix, and is used in a component-wise multiply,
+ // we need to not fold it.
+ if (as<IRMatrixType>(inst->getDataType()))
+ {
+ for (auto use = inst->firstUse; use; use = use->nextUse)
+ {
+ auto user = use->getUser();
+ if (user->getOp() == kIROp_Mul)
+ {
+ if (as<IRMatrixType>(user->getOperand(0)->getDataType()) &&
+ as<IRMatrixType>(user->getOperand(1)->getDataType()))
+ {
+ return false;
+ }
+ }
+ }
+ }
+ }
+ return result;
+}
+
+
bool WGSLSourceEmitter::tryEmitInstExprImpl(IRInst* inst, const EmitOpInfo& inOuterPrec)
{
EmitOpInfo outerPrec = inOuterPrec;
@@ -1126,6 +1177,71 @@ bool WGSLSourceEmitter::tryEmitInstExprImpl(IRInst* inst, const EmitOpInfo& inOu
return true;
}
break;
+
+ case kIROp_GetStringHash:
+ {
+ auto getStringHashInst = as<IRGetStringHash>(inst);
+ auto stringLit = getStringHashInst->getStringLit();
+
+ if (stringLit)
+ {
+ auto slice = stringLit->getStringSlice();
+ emitType(inst->getDataType());
+ m_writer->emit("(");
+ m_writer->emit((int)getStableHashCode32(slice.begin(), slice.getLength()).hash);
+ m_writer->emit(")");
+ }
+ else
+ {
+ // Couldn't handle
+ diagnoseUnhandledInst(inst);
+ }
+ return true;
+ }
+
+ case kIROp_Mul:
+ {
+ if (!as<IRMatrixType>(inst->getOperand(0)->getDataType()) ||
+ !as<IRMatrixType>(inst->getOperand(1)->getDataType()))
+ {
+ return false;
+ }
+ // Mul(m1, m2) should be translated to component-wise multiplication in WGSL.
+ auto matrixType = as<IRMatrixType>(inst->getDataType());
+ auto rowCount = getIntVal(matrixType->getRowCount());
+ emitType(inst->getDataType());
+ m_writer->emit("(");
+ for (IRIntegerValue i = 0; i < rowCount; i++)
+ {
+ if (i != 0)
+ {
+ m_writer->emit(", ");
+ }
+ emitOperand(inst->getOperand(0), getInfo(EmitOp::Postfix));
+ m_writer->emit("[");
+ m_writer->emit(i);
+ m_writer->emit("] * ");
+ emitOperand(inst->getOperand(1), getInfo(EmitOp::Postfix));
+ m_writer->emit("[");
+ m_writer->emit(i);
+ m_writer->emit("]");
+ }
+ m_writer->emit(")");
+
+ return true;
+ }
+
+ case kIROp_Select:
+ {
+ m_writer->emit("select(");
+ emitOperand(inst->getOperand(2), getInfo(EmitOp::General));
+ m_writer->emit(", ");
+ emitOperand(inst->getOperand(1), getInfo(EmitOp::General));
+ m_writer->emit(", ");
+ emitOperand(inst->getOperand(0), getInfo(EmitOp::General));
+ m_writer->emit(")");
+ return true;
+ }
}
return false;
diff --git a/source/slang/slang-emit-wgsl.h b/source/slang/slang-emit-wgsl.h
index 70df65933..1a8ec2fd5 100644
--- a/source/slang/slang-emit-wgsl.h
+++ b/source/slang/slang-emit-wgsl.h
@@ -50,6 +50,8 @@ public:
void emit(const AddressSpace addressSpace);
+ virtual bool shouldFoldInstIntoUseSites(IRInst* inst) SLANG_OVERRIDE;
+
private:
// Emit the matrix type with 'rowCountWGSL' WGSL-rows and 'colCountWGSL' WGSL-columns
void emitMatrixType(
diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h
index c44211c1c..9a081f9de 100644
--- a/source/slang/slang-ir-insts.h
+++ b/source/slang/slang-ir-insts.h
@@ -4021,6 +4021,7 @@ public:
IRInst* emitDifferentialPairGetPrimalUserCode(IRInst* diffPair);
IRInst* emitMakeVector(IRType* type, UInt argCount, IRInst* const* args);
IRInst* emitMakeVectorFromScalar(IRType* type, IRInst* scalarValue);
+ IRInst* emitMakeCompositeFromScalar(IRType* type, IRInst* scalarValue);
IRInst* emitMakeVector(IRType* type, List<IRInst*> const& args)
{
diff --git a/source/slang/slang-ir-wgsl-legalize.cpp b/source/slang/slang-ir-wgsl-legalize.cpp
index c97a8a89f..96eb13be4 100644
--- a/source/slang/slang-ir-wgsl-legalize.cpp
+++ b/source/slang/slang-ir-wgsl-legalize.cpp
@@ -51,6 +51,8 @@ struct LegalizeWGSLEntryPointContext
String* optionalSemanticIndex,
IRInst* parentVar);
void legalizeCall(IRCall* call);
+ void legalizeSwitch(IRSwitch* switchInst);
+ void legalizeBinaryOp(IRInst* inst);
void processInst(IRInst* inst);
};
@@ -349,11 +351,97 @@ void LegalizeWGSLEntryPointContext::legalizeCall(IRCall* call)
}
}
+void LegalizeWGSLEntryPointContext::legalizeSwitch(IRSwitch* switchInst)
+{
+ // WGSL Requires all switch statements to contain a default case.
+ // If the switch statement does not contain a default case, we will add one.
+ if (switchInst->getDefaultLabel() != switchInst->getBreakLabel())
+ return;
+ IRBuilder builder(switchInst);
+ auto defaultBlock = builder.createBlock();
+ builder.setInsertInto(defaultBlock);
+ builder.emitBranch(switchInst->getBreakLabel());
+ defaultBlock->insertBefore(switchInst->getBreakLabel());
+ List<IRInst*> cases;
+ for (UInt i = 0; i < switchInst->getCaseCount(); i++)
+ {
+ cases.add(switchInst->getCaseValue(i));
+ cases.add(switchInst->getCaseLabel(i));
+ }
+ builder.setInsertBefore(switchInst);
+ auto newSwitch = builder.emitSwitch(
+ switchInst->getCondition(),
+ switchInst->getBreakLabel(),
+ defaultBlock,
+ (UInt)cases.getCount(),
+ cases.getBuffer());
+ switchInst->transferDecorationsTo(newSwitch);
+ switchInst->removeAndDeallocate();
+}
+
+void LegalizeWGSLEntryPointContext::legalizeBinaryOp(IRInst* inst)
+{
+ auto isVectorOrMatrix = [](IRType* type)
+ {
+ switch (type->getOp())
+ {
+ case kIROp_VectorType:
+ case kIROp_MatrixType: return true;
+ default: return false;
+ }
+ };
+ if (isVectorOrMatrix(inst->getOperand(0)->getDataType()) &&
+ as<IRBasicType>(inst->getOperand(1)->getDataType()))
+ {
+ IRBuilder builder(inst);
+ builder.setInsertBefore(inst);
+ auto newRhs = builder.emitMakeCompositeFromScalar(
+ inst->getOperand(0)->getDataType(),
+ inst->getOperand(1));
+ builder.replaceOperand(inst->getOperands() + 1, newRhs);
+ }
+ else if (
+ as<IRBasicType>(inst->getOperand(0)->getDataType()) &&
+ isVectorOrMatrix(inst->getOperand(1)->getDataType()))
+ {
+ IRBuilder builder(inst);
+ builder.setInsertBefore(inst);
+ auto newLhs = builder.emitMakeCompositeFromScalar(
+ inst->getOperand(1)->getDataType(),
+ inst->getOperand(0));
+ builder.replaceOperand(inst->getOperands(), newLhs);
+ }
+}
+
void LegalizeWGSLEntryPointContext::processInst(IRInst* inst)
{
switch (inst->getOp())
{
- case kIROp_Call: legalizeCall(static_cast<IRCall*>(inst)); break;
+ case kIROp_Call: legalizeCall(static_cast<IRCall*>(inst)); break;
+ case kIROp_Switch: legalizeSwitch(as<IRSwitch>(inst)); break;
+
+ // For all binary operators, make sure both side of the operator have the same type
+ // (vector-ness and matrix-ness).
+ case kIROp_Add:
+ case kIROp_Sub:
+ case kIROp_Mul:
+ case kIROp_Div:
+ case kIROp_FRem:
+ case kIROp_IRem:
+ case kIROp_And:
+ case kIROp_Or:
+ case kIROp_BitAnd:
+ case kIROp_BitOr:
+ case kIROp_BitXor:
+ case kIROp_Lsh:
+ case kIROp_Rsh:
+ case kIROp_Eql:
+ case kIROp_Neq:
+ case kIROp_Greater:
+ case kIROp_Less:
+ case kIROp_Geq:
+ case kIROp_Leq: legalizeBinaryOp(inst); break;
+
default:
for (auto child : inst->getModifiableChildren())
processInst(child);
diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp
index 49273163e..3bd31d6e9 100644
--- a/source/slang/slang-ir.cpp
+++ b/source/slang/slang-ir.cpp
@@ -4162,6 +4162,17 @@ IRInst* IRBuilder::emitMakeVectorFromScalar(IRType* type, IRInst* scalarValue)
return emitIntrinsicInst(type, kIROp_MakeVectorFromScalar, 1, &scalarValue);
}
+IRInst* IRBuilder::emitMakeCompositeFromScalar(IRType* type, IRInst* scalarValue)
+{
+ switch (type->getOp())
+ {
+ case kIROp_VectorType: return emitMakeVectorFromScalar(type, scalarValue);
+ case kIROp_MatrixType: return emitMakeMatrixFromScalar(type, scalarValue);
+ case kIROp_ArrayType: return emitMakeArrayFromElement(type, scalarValue);
+ default: SLANG_UNEXPECTED("unhandled composite type"); UNREACHABLE_RETURN(nullptr);
+ }
+}
+
IRInst* IRBuilder::emitMatrixReshape(IRType* type, IRInst* inst)
{
return emitIntrinsicInst(type, kIROp_MatrixReshape, 1, &inst);