summaryrefslogtreecommitdiffstats
path: root/source/slang
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2023-10-02 15:39:34 -0700
committerGitHub <noreply@github.com>2023-10-02 15:39:34 -0700
commitd87493a46c00be37b820a473c0827bbb865eb222 (patch)
tree33155e6be017238e07314f7793423dd50b748150 /source/slang
parentcea230bc686ef87db4cff47e367bbf824b90377d (diff)
More direct-SPIRV fixes. (#3257)
* More direct-SPIRV fixes. * Fix array-reg-to-mem. --------- Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source/slang')
-rw-r--r--source/slang/slang-check-decl.cpp32
-rw-r--r--source/slang/slang-emit-glsl.cpp12
-rw-r--r--source/slang/slang-emit-spirv.cpp14
-rw-r--r--source/slang/slang-ir-array-reg-to-mem.cpp5
-rw-r--r--source/slang/slang-ir-glsl-legalize.cpp12
-rw-r--r--source/slang/slang-ir-spirv-legalize.cpp79
-rw-r--r--source/slang/slang-ir-util.cpp18
-rw-r--r--source/slang/slang-ir-util.h2
-rw-r--r--source/slang/slang-lower-to-ir.cpp6
9 files changed, 153 insertions, 27 deletions
diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp
index 351d5a9cc..1e3c6a361 100644
--- a/source/slang/slang-check-decl.cpp
+++ b/source/slang/slang-check-decl.cpp
@@ -1836,18 +1836,20 @@ namespace Slang
RefPtr<WitnessTable> witnessTable)
{
if(satisfyingMemberDeclRef.getDecl()->hasModifier<MutatingAttribute>()
- && !requiredMemberDeclRef.getDecl()->hasModifier<MutatingAttribute>())
+ != requiredMemberDeclRef.getDecl()->hasModifier<MutatingAttribute>())
{
- // A `[mutating]` method can't satisfy a non-`[mutating]` requirement,
- // but vice-versa is okay.
+ // A `[mutating]` method can't satisfy a non-`[mutating]` requirement.
+ // The opposite direction is okay, but we will need to synthesize a wrapper
+ // to ensure type matches, so we will return false here either way.
return false;
}
if (satisfyingMemberDeclRef.getDecl()->hasModifier<ConstRefAttribute>()
- && !requiredMemberDeclRef.getDecl()->hasModifier<ConstRefAttribute>())
+ != requiredMemberDeclRef.getDecl()->hasModifier<ConstRefAttribute>())
{
- // A `[constref]` method can't satisfy a non-`[constref]` requirement,
- // but vice-versa is okay.
+ // A `[constref]` method can't satisfy a non-`[constref]` requirement.
+ // The opposite direction is okay, but we will need to synthesize a wrapper
+ // to ensure type matches, so we will return false here either way.
return false;
}
@@ -2677,11 +2679,21 @@ namespace Slang
synArg->type = paramType;
synArgs.add(synArg);
- if (paramDeclRef.getDecl()->findModifier<NoDiffModifier>())
+ // Add modifiers
+ for (auto modifier : paramDeclRef.getDecl()->modifiers)
{
- auto noDiffModifier = m_astBuilder->create<NoDiffModifier>();
- noDiffModifier->keywordName = getSession()->getNameObj("no_diff");
- addModifier(synParamDecl, noDiffModifier);
+ if (as<NoDiffModifier>(modifier))
+ {
+ auto noDiffModifier = m_astBuilder->create<NoDiffModifier>();
+ noDiffModifier->keywordName = getSession()->getNameObj("no_diff");
+ addModifier(synParamDecl, noDiffModifier);
+ }
+ else if (as<InOutModifier>(modifier) || as<OutModifier>(modifier) || as<ConstRefModifier>(modifier) || as<RefModifier>(modifier))
+ {
+ auto clonedModifier = (Modifier*)m_astBuilder->createByNodeType(modifier->astNodeType);
+ clonedModifier->keywordName = modifier->keywordName;
+ addModifier(synParamDecl, clonedModifier);
+ }
}
}
diff --git a/source/slang/slang-emit-glsl.cpp b/source/slang/slang-emit-glsl.cpp
index e9472bc96..ccb296e84 100644
--- a/source/slang/slang-emit-glsl.cpp
+++ b/source/slang/slang-emit-glsl.cpp
@@ -1893,6 +1893,18 @@ bool GLSLSourceEmitter::tryEmitInstExprImpl(IRInst* inst, const EmitOpInfo& inOu
return true;
}
+ case kIROp_GetVulkanRayTracingPayloadLocation:
+ {
+ auto payloadVar = inst->getOperand(0);
+ IRInst* location = getVulkanPayloadLocation(payloadVar);
+ if (!location)
+ {
+ SLANG_DIAGNOSE_UNEXPECTED(getSink(), inst, "no payload location assigned.");
+ m_writer->emit("0");
+ }
+ m_writer->emit(getIntVal(location));
+ return true;
+ }
case kIROp_ImageLoad:
{
m_writer->emit("imageLoad(");
diff --git a/source/slang/slang-emit-spirv.cpp b/source/slang/slang-emit-spirv.cpp
index 104fab339..62e7d428b 100644
--- a/source/slang/slang-emit-spirv.cpp
+++ b/source/slang/slang-emit-spirv.cpp
@@ -2274,6 +2274,20 @@ struct SPIRVEmitContext
registerInst(inst, inner);
return inner;
}
+ case kIROp_GetVulkanRayTracingPayloadLocation:
+ {
+ IRInst* location = getVulkanPayloadLocation(inst->getOperand(0));
+ if (!location)
+ {
+ SLANG_DIAGNOSE_UNEXPECTED(m_sink, inst, "no payload location assigned.");
+ IRBuilder builder(inst);
+ builder.setInsertBefore(inst);
+ location = builder.getIntValue(builder.getIntType(), 0);
+ }
+ auto inner = ensureInst(location);
+ registerInst(inst, inner);
+ return inner;
+ }
case kIROp_Return:
if (as<IRReturn>(inst)->getVal()->getOp() == kIROp_VoidLit)
return emitOpReturn(parent, inst);
diff --git a/source/slang/slang-ir-array-reg-to-mem.cpp b/source/slang/slang-ir-array-reg-to-mem.cpp
index 6f749f242..34bd5b148 100644
--- a/source/slang/slang-ir-array-reg-to-mem.cpp
+++ b/source/slang/slang-ir-array-reg-to-mem.cpp
@@ -33,9 +33,9 @@ namespace Slang
if (auto arrayType = as<IRArrayTypeBase>(param->getFullType()))
{
changed = true;
- builder.setInsertBefore(param);
auto ptrArrayType = builder.getPtrType(arrayType);
- auto newParam = builder.emitParam(ptrArrayType);
+ auto newParam = builder.createParam(ptrArrayType);
+ newParam->insertBefore(param);
setInsertAfterOrdinaryInst(&builder, param);
auto regVal = builder.emitLoad(newParam);
param->replaceUsesWith(regVal);
@@ -62,6 +62,7 @@ namespace Slang
for (auto paramId : arrayParamIds)
{
auto arg = call->getArg(paramId);
+ SLANG_ASSERT(as<IRPtrTypeBase>(paramTypes[paramId]));
auto var = builder.emitVar(as<IRPtrTypeBase>(paramTypes[paramId])->getValueType());
builder.emitStore(var, arg);
call->setArg(paramId, var);
diff --git a/source/slang/slang-ir-glsl-legalize.cpp b/source/slang/slang-ir-glsl-legalize.cpp
index c4a71a3e9..956685d85 100644
--- a/source/slang/slang-ir-glsl-legalize.cpp
+++ b/source/slang/slang-ir-glsl-legalize.cpp
@@ -2693,18 +2693,6 @@ void assignRayPayloadHitObjectAttributeLocations(IRModule* module)
}
}
end:;
- if (location)
- {
- traverseUses(globalVar, [&](IRUse* use)
- {
- auto user = use->getUser();
- if (user->getOp() == kIROp_GetVulkanRayTracingPayloadLocation)
- {
- user->replaceUsesWith(location);
- user->removeAndDeallocate();
- }
- });
- }
}
}
diff --git a/source/slang/slang-ir-spirv-legalize.cpp b/source/slang/slang-ir-spirv-legalize.cpp
index 5d4673981..b92f5e910 100644
--- a/source/slang/slang-ir-spirv-legalize.cpp
+++ b/source/slang/slang-ir-spirv-legalize.cpp
@@ -265,7 +265,38 @@ struct SPIRVLegalizationContext : public SourceEmitterBase
return false;
}
- void inferTextureFormat(IRInst* textureInst, IRTextureTypeBase* textureType)
+ static IRType* replaceImageElementType(IRInst* originalType, IRInst* newElementType)
+ {
+ switch(originalType->getOp())
+ {
+ case kIROp_ArrayType:
+ case kIROp_UnsizedArrayType:
+ case kIROp_PtrType:
+ case kIROp_OutType:
+ case kIROp_RefType:
+ case kIROp_ConstRefType:
+ case kIROp_InOutType:
+ {
+ auto newInnerType = replaceImageElementType(originalType->getOperand(0), newElementType);
+ if (newInnerType != originalType->getOperand(0))
+ {
+ IRBuilder builder(originalType);
+ builder.setInsertBefore(originalType);
+ IRCloneEnv cloneEnv;
+ cloneEnv.mapOldValToNew.add(originalType->getOperand(0), newInnerType);
+ return (IRType*)cloneInst(&cloneEnv, &builder, originalType);
+ }
+ return (IRType*)originalType;
+ }
+
+ default:
+ if (as<IRResourceTypeBase>(originalType))
+ return (IRType*)newElementType;
+ return (IRType*)originalType;
+ }
+ }
+
+ static void inferTextureFormat(IRInst* textureInst, IRTextureTypeBase* textureType)
{
ImageFormat format = ImageFormat::unknown;
if (auto decor = textureInst->findDecoration<IRFormatDecoration>())
@@ -368,14 +399,52 @@ struct SPIRVLegalizationContext : public SourceEmitterBase
args.add(builder.getIntValue(builder.getUIntType(), IRIntegerValue(format)));
auto newType = (IRType*)builder.emitIntrinsicInst(builder.getTypeKind(), textureType->getOp(), 3, args.getBuffer());
- textureInst->setFullType(newType);
+ if (textureInst->getFullType() == textureType)
+ {
+ // Simple texture typed global param.
+ textureInst->setFullType(newType);
+ }
+ else
+ {
+ // Array typed global param. We need to replace the type and the types of all getElement insts.
+ auto newInstType = (IRType*)replaceImageElementType(textureInst->getFullType(), newType);
+ textureInst->setFullType(newInstType);
+ List<IRUse*> typeReplacementWorkList;
+ HashSet<IRUse*> typeReplacementWorkListSet;
+ for (auto use = textureInst->firstUse; use; use = use->nextUse)
+ {
+ if (typeReplacementWorkListSet.add(use))
+ typeReplacementWorkList.add(use);
+ }
+ for (Index i = 0; i < typeReplacementWorkList.getCount(); i++)
+ {
+ auto use = typeReplacementWorkList[i];
+ auto user = use->getUser();
+ switch (user->getOp())
+ {
+ case kIROp_GetElementPtr:
+ case kIROp_GetElement:
+ case kIROp_Load:
+ {
+ auto newUserType = (IRType*)replaceImageElementType(user->getFullType(), newType);
+ user->setFullType(newUserType);
+ for (auto u = user->firstUse; u; u = u->nextUse)
+ {
+ if (typeReplacementWorkListSet.add(u))
+ typeReplacementWorkList.add(u);
+ }
+ break;
+ };
+ }
+ }
+ }
}
}
void processGlobalParam(IRGlobalParam* inst)
{
// If the param is a texture, infer its format.
- if (auto textureType = as<IRTextureTypeBase>(inst->getDataType()))
+ if (auto textureType = as<IRTextureTypeBase>(unwrapArray(inst->getDataType())))
{
inferTextureFormat(inst, textureType);
}
@@ -1447,6 +1516,10 @@ struct SPIRVLegalizationContext : public SourceEmitterBase
void processModule()
{
+#if 0
+ eliminateArrayTypeSSARegisters(m_module);
+#endif
+
// Process global params before anything else, so we don't generate inefficient
// array marhalling code for array-typed global params.
for (auto globalInst : m_module->getGlobalInsts())
diff --git a/source/slang/slang-ir-util.cpp b/source/slang/slang-ir-util.cpp
index b4a41f8a5..5ecbc8121 100644
--- a/source/slang/slang-ir-util.cpp
+++ b/source/slang/slang-ir-util.cpp
@@ -814,6 +814,24 @@ IRInst* findWitnessTableEntry(IRWitnessTable* table, IRInst* key)
return nullptr;
}
+IRInst* getVulkanPayloadLocation(IRInst* payloadGlobalVar)
+{
+ IRInst* location = nullptr;
+ for (auto decor : payloadGlobalVar->getDecorations())
+ {
+ switch (decor->getOp())
+ {
+ case kIROp_VulkanRayPayloadDecoration:
+ case kIROp_VulkanCallablePayloadDecoration:
+ case kIROp_VulkanHitObjectAttributesDecoration:
+ return decor->getOperand(0);
+ default:
+ continue;
+ }
+ }
+ return location;
+}
+
void moveParams(IRBlock* dest, IRBlock* src)
{
for (auto param = src->getFirstChild(); param;)
diff --git a/source/slang/slang-ir-util.h b/source/slang/slang-ir-util.h
index 82ce1344c..c13ce1931 100644
--- a/source/slang/slang-ir-util.h
+++ b/source/slang/slang-ir-util.h
@@ -209,6 +209,8 @@ IRInst* findInterfaceRequirement(IRInterfaceType* type, IRInst* key);
IRInst* findWitnessTableEntry(IRWitnessTable* table, IRInst* key);
+IRInst* getVulkanPayloadLocation(IRInst* payloadGlobalVar);
+
void moveParams(IRBlock* dest, IRBlock* src);
void removePhiArgs(IRInst* phiParam);
diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp
index 065263a59..f62d4b24a 100644
--- a/source/slang/slang-lower-to-ir.cpp
+++ b/source/slang/slang-lower-to-ir.cpp
@@ -9903,7 +9903,13 @@ static void lowerFrontEndEntryPointToIR(
instToDecorate = findGenericReturnVal(irGeneric);
}
+ // If the entry-point decorations has already been created (because the user
+ // specified duplicate entries in the entry point list), we can stop now.
+ if (instToDecorate->findDecoration<IREntryPointDecoration>())
+ return;
+
{
+
Name* entryPointName = entryPoint->getFuncDecl()->getName();
builder->addEntryPointDecoration(instToDecorate, entryPoint->getProfile(), entryPointName->text.getUnownedSlice(), moduleName.getUnownedSlice());
}