summaryrefslogtreecommitdiffstats
path: root/source
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2025-10-09 18:30:24 -0700
committerGitHub <noreply@github.com>2025-10-10 01:30:24 +0000
commite420f2f980813559b186a6a6bcd5540f74310d02 (patch)
tree8f7a833ed86e8ce2a7b40bd1e9e7da5cb95d66ab /source
parent3cf1f5a616917480c63b76aae906dc36b29e46ce (diff)
Defer `IRCastStorageToLogicalDeref` in lowerBufferElementType pass. (#8668)
Fix a regression on metal test. In `lowerBufferElementTypeToStorageType` pass, not only we want to defer an argument that is `CastStorageToLogical` to the callee, but also apply the same defer logic to `CastStorageToLogicalDeref` as well. Because `CastStorageToLogicalDeref` will appear as argumnet if `lowerBufferElementTypeToStorageType` is run before we apply the `in->borrow` transformation pass, which is the case for metal parameter block legalization.
Diffstat (limited to 'source')
-rw-r--r--source/slang/slang-emit-metal.cpp3
-rw-r--r--source/slang/slang-emit-wgsl.cpp4
-rw-r--r--source/slang/slang-emit.cpp4
-rw-r--r--source/slang/slang-ir-insts.h5
-rw-r--r--source/slang/slang-ir-lower-buffer-element-type.cpp47
-rw-r--r--source/slang/slang-ir-specialize-arrays.cpp15
-rw-r--r--source/slang/slang-ir-wgsl-legalize.cpp89
-rw-r--r--source/slang/slang-ir-wgsl-legalize.h3
-rw-r--r--source/slang/slang-ir.cpp10
9 files changed, 153 insertions, 27 deletions
diff --git a/source/slang/slang-emit-metal.cpp b/source/slang/slang-emit-metal.cpp
index e992d17f5..d614bd4b1 100644
--- a/source/slang/slang-emit-metal.cpp
+++ b/source/slang/slang-emit-metal.cpp
@@ -1278,7 +1278,8 @@ void MetalSourceEmitter::emitSimpleTypeImpl(IRType* type)
m_writer->emit("uint32_t device*");
break;
case kIROp_RaytracingAccelerationStructureType:
- m_writer->emit("acceleration_structure<instancing>");
+ m_writer->emit(
+ "metal::raytracing::acceleration_structure<metal::raytracing::instancing>");
break;
default:
SLANG_DIAGNOSE_UNEXPECTED(getSink(), SourceLoc(), "unhandled buffer type");
diff --git a/source/slang/slang-emit-wgsl.cpp b/source/slang/slang-emit-wgsl.cpp
index 3cebae97c..7b7c4ffa3 100644
--- a/source/slang/slang-emit-wgsl.cpp
+++ b/source/slang/slang-emit-wgsl.cpp
@@ -336,10 +336,12 @@ void WGSLSourceEmitter::emit(const AddressSpace addressSpace)
break;
case AddressSpace::StorageBuffer:
+ case AddressSpace::Global:
m_writer->emit("storage");
break;
case AddressSpace::Generic:
+ case AddressSpace::Function:
m_writer->emit("function");
break;
@@ -1311,7 +1313,7 @@ bool WGSLSourceEmitter::tryEmitInstStmtImpl(IRInst* inst)
void WGSLSourceEmitter::emitCallArg(IRInst* inst)
{
- if (as<IRPtrTypeBase>(inst->getDataType()))
+ if (as<IRPointerLikeType>(inst->getDataType()) || as<IRPtrTypeBase>(inst->getDataType()))
{
// If we are calling a function with a pointer-typed argument, we need to
// explicitly prefix the argument with `&` to pass a pointer.
diff --git a/source/slang/slang-emit.cpp b/source/slang/slang-emit.cpp
index 7d72fb77f..c57851bc2 100644
--- a/source/slang/slang-emit.cpp
+++ b/source/slang/slang-emit.cpp
@@ -1881,6 +1881,10 @@ Result linkAndOptimizeIR(
{
specializeAddressSpaceForMetal(irModule);
}
+ else if (isWGPUTarget(targetRequest))
+ {
+ specializeAddressSpaceForWGSL(irModule);
+ }
performForceInlining(irModule);
diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h
index 42db9cb44..5c27d5e25 100644
--- a/source/slang/slang-ir-insts.h
+++ b/source/slang/slang-ir-insts.h
@@ -4584,10 +4584,7 @@ public:
IRInst* emitCastIntToPtr(IRType* ptrType, IRInst* val);
IRInst* emitCastStorageToLogical(IRType* type, IRInst* val, IRInst* bufferType);
- IRCastStorageToLogicalDeref* emitCastStorageToLogicalDeref(
- IRType* type,
- IRInst* val,
- IRInst* bufferType);
+ IRInst* emitCastStorageToLogicalDeref(IRType* type, IRInst* val, IRInst* bufferType);
IRGlobalConstant* emitGlobalConstant(IRType* type);
diff --git a/source/slang/slang-ir-lower-buffer-element-type.cpp b/source/slang/slang-ir-lower-buffer-element-type.cpp
index c69592939..4bf36259b 100644
--- a/source/slang/slang-ir-lower-buffer-element-type.cpp
+++ b/source/slang/slang-ir-lower-buffer-element-type.cpp
@@ -12,13 +12,16 @@
///
/// Many of our targets have special restrictions on what is allowed to be used as a
/// buffer element. Examples are:
-/// - In HLSL and SPIRV, if you have ConstantBuffer<T>, T must be a struct.
/// - In SPIRV, `bool` is considered a logical type, meaning it cannot appear inside
/// buffers. bool vectors and matrices needs to be lowered into arrays.
/// - In SPIRV, if `T` is used to declare a buffer, then every member in `T` must have
/// explicit offset. But if it is used to declare a local variable, then it cannot
/// have explicit member offset. This means that we cannot use the same `Foo` struct
/// inside a `StructuredBuffer<Foo>` and also use it to declare a local variable.
+/// - In Metal, if we have a `struct Foo {Texture2D member; }` and
+/// `ParameterBlock<Foo>`, then we should translate it to
+/// `struct Foo_pb { Texture2D.Handle member; }` and `ParameterBlock<Foo_pb>`, so that
+/// the resource legalization pass won't hoist the texture out of the parameter block.
///
/// We use the terms "physical", "storage", or "lowered" types to refer to types that
/// are legal to use as buffer elements. In contrast, the terms "original" or "logical"
@@ -187,6 +190,11 @@
/// m_ptr = FieldAddr(ptr, member)
/// call f_1, m_ptr
/// ```
+/// Note that it is only correct to defer a load/CastStorageToLogicalDeref if the location
+/// being loaded from is immutable. Otherwise, we might be changing the order of memory
+/// operations and result in a change in application behavior. So this pass will also make sure
+/// that we only create `CastStorageToLogicalDeref(x)` such that `x` is an immutable location,
+/// such as an immutable temporary variable.
///
/// # Trailing Pointer Rewrite
///
@@ -1192,10 +1200,7 @@ struct LoweredElementTypeContext
// and push the cast to inside the callee.
// We will process calls after other gep insts, so for now just add
// it into a separate worklist.
- if (castInst->getOp() == kIROp_CastStorageToLogical)
- {
- callWorkListSet.add((IRCall*)user);
- }
+ callWorkListSet.add((IRCall*)user);
break;
}
case kIROp_Load:
@@ -1261,7 +1266,8 @@ struct LoweredElementTypeContext
castInst->getBufferType());
user->replaceUsesWith(newCast);
user->removeAndDeallocate();
- castInstWorkList.add(newCast);
+ if (auto newCastStorage = as<IRCastStorageToLogicalBase>(newCast))
+ castInstWorkList.add(newCastStorage);
break;
}
case kIROp_FieldExtract:
@@ -1353,6 +1359,16 @@ struct LoweredElementTypeContext
paramTypes.add(storagePtrType);
newArgs.add(castArg->getOperand(0));
}
+ else if (auto castArgDeref = as<IRCastStorageToLogicalDeref>(arg))
+ {
+ auto storageValueType = tryGetPointedToOrBufferElementType(
+ &builder,
+ castArgDeref->getOperand(0)->getDataType());
+ auto storagePtrType =
+ builder.getBorrowInParamType(storageValueType, AddressSpace::Generic);
+ paramTypes.add(storagePtrType);
+ newArgs.add(castArgDeref->getOperand(0));
+ }
else
{
paramTypes.add(arg->getDataType());
@@ -1422,7 +1438,7 @@ struct LoweredElementTypeContext
auto param = params[i];
SLANG_RELEASE_ASSERT(i < call->getArgCount());
auto arg = call->getArg(i);
- auto cast = as<IRCastStorageToLogical>(arg);
+ auto cast = as<IRCastStorageToLogicalBase>(arg);
if (!cast)
continue;
auto logicalParamType = param->getFullType();
@@ -1434,8 +1450,21 @@ struct LoweredElementTypeContext
uses.clear();
for (auto use = param->firstUse; use; use = use->nextUse)
uses.add(use);
- auto castedParam =
- builder.emitCastStorageToLogical(logicalParamType, param, cast->getBufferType());
+ IRInst* castedParam = nullptr;
+ if (arg->getOp() == kIROp_CastStorageToLogical)
+ {
+ castedParam = builder.emitCastStorageToLogical(
+ logicalParamType,
+ param,
+ cast->getBufferType());
+ }
+ else
+ {
+ castedParam = builder.emitCastStorageToLogicalDeref(
+ logicalParamType,
+ param,
+ cast->getBufferType());
+ }
if (auto castStorage = as<IRCastStorageToLogicalBase>(castedParam))
outNewCasts.add(castStorage);
diff --git a/source/slang/slang-ir-specialize-arrays.cpp b/source/slang/slang-ir-specialize-arrays.cpp
index c2bc4d14e..1f50a7579 100644
--- a/source/slang/slang-ir-specialize-arrays.cpp
+++ b/source/slang/slang-ir-specialize-arrays.cpp
@@ -29,20 +29,23 @@ struct ArrayParameterSpecializationCondition : FunctionCallSpecializeCondition
if (auto outTypeBase = as<IROutParamTypeBase>(paramType))
{
paramType = outTypeBase->getValueType();
- SLANG_ASSERT(as<IRPtrTypeBase>(argType));
- argType = as<IRPtrTypeBase>(argType)->getValueType();
+ IRBuilder builder(paramType);
+ argType = tryGetPointedToType(&builder, argType);
+ SLANG_ASSERT(argType);
}
else if (auto refType = as<IRRefParamType>(paramType))
{
paramType = refType->getValueType();
- SLANG_ASSERT(as<IRPtrTypeBase>(argType));
- argType = as<IRPtrTypeBase>(argType)->getValueType();
+ IRBuilder builder(paramType);
+ argType = tryGetPointedToType(&builder, argType);
+ SLANG_ASSERT(argType);
}
else if (auto constRefType = as<IRBorrowInParamType>(paramType))
{
paramType = constRefType->getValueType();
- SLANG_ASSERT(as<IRPtrTypeBase>(argType));
- argType = as<IRPtrTypeBase>(argType)->getValueType();
+ IRBuilder builder(paramType);
+ argType = tryGetPointedToType(&builder, argType);
+ SLANG_ASSERT(argType);
}
auto arrayType = as<IRUnsizedArrayType>(paramType);
if (!arrayType)
diff --git a/source/slang/slang-ir-wgsl-legalize.cpp b/source/slang/slang-ir-wgsl-legalize.cpp
index 51f16e603..571e6311f 100644
--- a/source/slang/slang-ir-wgsl-legalize.cpp
+++ b/source/slang/slang-ir-wgsl-legalize.cpp
@@ -4,6 +4,8 @@
#include "slang-ir-legalize-binary-operator.h"
#include "slang-ir-legalize-global-values.h"
#include "slang-ir-legalize-varying-params.h"
+#include "slang-ir-specialize-address-space.h"
+#include "slang-ir-util.h"
#include "slang-ir.h"
namespace Slang
@@ -225,4 +227,91 @@ void legalizeIRForWGSL(IRModule* module, DiagnosticSink* sink)
GlobalInstInliningContext().inlineGlobalValuesAndRemoveIfUnused(module);
}
+struct WGSLAddressSpaceAssigner : InitialAddressSpaceAssigner
+{
+ virtual bool tryAssignAddressSpace(IRInst* inst, AddressSpace& outAddressSpace) override
+ {
+ switch (inst->getOp())
+ {
+ case kIROp_Var:
+ if (as<IRBlock>(inst->getParent()))
+ outAddressSpace = AddressSpace::Function;
+ else
+ outAddressSpace = AddressSpace::ThreadLocal;
+ return true;
+ case kIROp_RWStructuredBufferGetElementPtr:
+ outAddressSpace = AddressSpace::Global;
+ return true;
+ case kIROp_Load:
+ {
+ auto addrSpace = getAddressSpaceFromVarType(inst->getDataType());
+ if (addrSpace != AddressSpace::Generic)
+ {
+ outAddressSpace = addrSpace;
+ return true;
+ }
+ }
+ return false;
+ default:
+ return false;
+ }
+ }
+
+ virtual AddressSpace getAddressSpaceFromVarType(IRInst* type) override
+ {
+ if (as<IRUniformParameterGroupType>(type))
+ {
+ return AddressSpace::Uniform;
+ }
+ if (as<IRByteAddressBufferTypeBase>(type))
+ {
+ return AddressSpace::Global;
+ }
+ if (as<IRHLSLStructuredBufferTypeBase>(type))
+ {
+ return AddressSpace::Global;
+ }
+ if (as<IRGLSLShaderStorageBufferType>(type))
+ {
+ return AddressSpace::Global;
+ }
+ if (auto ptrType = as<IRPtrTypeBase>(type))
+ {
+ if (ptrType->hasAddressSpace())
+ return ptrType->getAddressSpace();
+ return AddressSpace::Generic;
+ }
+ return AddressSpace::Generic;
+ }
+
+ virtual AddressSpace getLeafInstAddressSpace(IRInst* inst) override
+ {
+ if (as<IRGroupSharedRate>(inst->getRate()))
+ return AddressSpace::GroupShared;
+ switch (inst->getOp())
+ {
+ case kIROp_RWStructuredBufferGetElementPtr:
+ return AddressSpace::Global;
+ case kIROp_Var:
+ if (as<IRBlock>(inst->getParent()))
+ return AddressSpace::Function;
+ else
+ return AddressSpace::ThreadLocal;
+ break;
+ default:
+ break;
+ }
+ auto type = unwrapAttributedType(inst->getDataType());
+ if (!type)
+ return AddressSpace::Generic;
+ return getAddressSpaceFromVarType(type);
+ }
+};
+
+void specializeAddressSpaceForWGSL(IRModule* module)
+{
+ WGSLAddressSpaceAssigner wgslAddressSpaceAssigner;
+ specializeAddressSpace(module, &wgslAddressSpaceAssigner);
+}
+
} // namespace Slang
diff --git a/source/slang/slang-ir-wgsl-legalize.h b/source/slang/slang-ir-wgsl-legalize.h
index 11e25ea88..b9b270ba3 100644
--- a/source/slang/slang-ir-wgsl-legalize.h
+++ b/source/slang/slang-ir-wgsl-legalize.h
@@ -7,4 +7,7 @@ namespace Slang
class DiagnosticSink;
void legalizeIRForWGSL(IRModule* module, DiagnosticSink* sink);
+
+void specializeAddressSpaceForWGSL(IRModule* module);
+
} // namespace Slang
diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp
index ebaebcc8a..ba8864684 100644
--- a/source/slang/slang-ir.cpp
+++ b/source/slang/slang-ir.cpp
@@ -6187,14 +6187,12 @@ IRInst* IRBuilder::emitCastStorageToLogical(IRType* type, IRInst* val, IRInst* b
return (IRCastStorageToLogical*)emitIntrinsicInst(type, kIROp_CastStorageToLogical, 2, args);
}
-IRCastStorageToLogicalDeref* IRBuilder::emitCastStorageToLogicalDeref(
- IRType* type,
- IRInst* val,
- IRInst* bufferType)
+IRInst* IRBuilder::emitCastStorageToLogicalDeref(IRType* type, IRInst* val, IRInst* bufferType)
{
IRInst* args[] = {val, bufferType};
- return (IRCastStorageToLogicalDeref*)
- emitIntrinsicInst(type, kIROp_CastStorageToLogicalDeref, 2, args);
+ if (type == tryGetPointedToType(this, val->getDataType()))
+ return emitLoad(type, val);
+ return emitIntrinsicInst(type, kIROp_CastStorageToLogicalDeref, 2, args);
}
IRGlobalConstant* IRBuilder::emitGlobalConstant(IRType* type)