summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2024-12-09 04:48:03 -0800
committerGitHub <noreply@github.com>2024-12-09 20:48:03 +0800
commit09a9d673322ebf4ca2fcb7d48f13a44e015ea33f (patch)
tree8bae8fa5718669dcbff98b8bcb29784483905f34
parent051ae8acec0a641bcaf86e7eeff35eff29e8922d (diff)
Allow pointers to existential values. (#5793)
* Fix pointer offset logic and add executable tests. * Fix. * Fix test. * Add existential ptr test. * Allow pointers to existential values. * Fix. * Fix. --------- Co-authored-by: Ellie Hermaszewska <ellieh@nvidia.com>
-rw-r--r--source/slang/slang-check-decl.cpp11
-rw-r--r--source/slang/slang-check-expr.cpp47
-rw-r--r--source/slang/slang-check-impl.h19
-rw-r--r--source/slang/slang-check-type.cpp5
-rw-r--r--source/slang/slang-check.h1
-rw-r--r--source/slang/slang-ir-lower-buffer-element-type.cpp22
-rw-r--r--tests/bugs/gh-3825.slang1
-rw-r--r--tests/spirv/existential-ptr.slang37
-rw-r--r--tests/spirv/ptr-member-func.slang29
-rw-r--r--tests/spirv/ptr-unsized-array-3.slang29
-rw-r--r--tests/spirv/ptr-unsized-array-4.slang25
-rw-r--r--tools/render-test/render-test-main.cpp32
12 files changed, 215 insertions, 43 deletions
diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp
index 4a4ade047..eeb75e3fd 100644
--- a/source/slang/slang-check-decl.cpp
+++ b/source/slang/slang-check-decl.cpp
@@ -3105,6 +3105,17 @@ Type* unwrapArrayType(Type* type)
}
}
+Type* unwrapModifiedType(Type* type)
+{
+ for (;;)
+ {
+ if (auto modType = as<ModifiedType>(type))
+ type = modType->getBase();
+ else
+ return type;
+ }
+}
+
void discoverExtensionDecls(List<ExtensionDecl*>& decls, Decl* parent)
{
if (auto extDecl = as<ExtensionDecl>(parent))
diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp
index 2840cdd39..1f2776ba0 100644
--- a/source/slang/slang-check-expr.cpp
+++ b/source/slang/slang-check-expr.cpp
@@ -2307,7 +2307,10 @@ Expr* SemanticsVisitor::CheckSimpleSubscriptExpr(IndexExpr* subscriptExpr, Type*
Expr* SemanticsExprVisitor::visitIndexExpr(IndexExpr* subscriptExpr)
{
bool needDeref = false;
- auto baseExpr = checkBaseForMemberExpr(subscriptExpr->baseExpression, needDeref);
+ auto baseExpr = checkBaseForMemberExpr(
+ subscriptExpr->baseExpression,
+ CheckBaseContext::Subscript,
+ needDeref);
// If the base expression is a type, it means that this is an array declaration,
// then we should disable short-circuit in case there is logical expression in
@@ -2951,7 +2954,10 @@ Expr* SemanticsExprVisitor::visitInvokeExpr(InvokeExpr* expr)
auto operatorName = getName("()");
bool needDeref = false;
- expr->functionExpr = maybeInsertImplicitOpForMemberBase(expr->functionExpr, needDeref);
+ expr->functionExpr = maybeInsertImplicitOpForMemberBase(
+ expr->functionExpr,
+ CheckBaseContext::Member,
+ needDeref);
LookupResult lookupResult = lookUpMember(
m_astBuilder,
@@ -4060,19 +4066,29 @@ void SemanticsExprVisitor::maybeCheckKnownBuiltinInvocation(Expr* invokeExpr)
}
}
-Expr* SemanticsVisitor::MaybeDereference(Expr* inExpr)
+Expr* SemanticsVisitor::maybeDereference(Expr* inExpr, CheckBaseContext checkBaseContext)
{
Expr* expr = inExpr;
for (;;)
{
auto baseType = expr->type;
+ QualType elementType;
if (auto pointerLikeType = as<PointerLikeType>(baseType))
{
- auto elementType = QualType(pointerLikeType->getElementType());
+ elementType = QualType(pointerLikeType->getElementType());
elementType.isLeftValue = baseType.isLeftValue;
elementType.hasReadOnlyOnTarget = baseType.hasReadOnlyOnTarget;
elementType.isWriteOnly = baseType.isWriteOnly;
-
+ }
+ else if (auto ptrType = as<PtrType>(baseType))
+ {
+ if (checkBaseContext == CheckBaseContext::Subscript)
+ return expr;
+ elementType = QualType(ptrType->getValueType());
+ elementType.isLeftValue = true;
+ }
+ if (elementType.type)
+ {
auto derefExpr = m_astBuilder->create<DerefExpr>();
derefExpr->base = expr;
derefExpr->type = elementType;
@@ -4080,7 +4096,6 @@ Expr* SemanticsVisitor::MaybeDereference(Expr* inExpr)
expr = derefExpr;
continue;
}
-
// Default case: just use the expression as-is
return expr;
}
@@ -4751,7 +4766,7 @@ Expr* SemanticsExprVisitor::visitStaticMemberExpr(StaticMemberExpr* expr)
expr->baseExpression = CheckTerm(expr->baseExpression);
// Not sure this is needed -> but guess someone could do
- expr->baseExpression = MaybeDereference(expr->baseExpression);
+ expr->baseExpression = maybeDereference(expr->baseExpression, CheckBaseContext::Member);
// If the base of the member lookup has an interface type
// *without* a suitable this-type substitution, then we are
@@ -4779,9 +4794,12 @@ Expr* SemanticsVisitor::lookupMemberResultFailure(
return expr;
}
-Expr* SemanticsVisitor::maybeInsertImplicitOpForMemberBase(Expr* baseExpr, bool& outNeedDeref)
+Expr* SemanticsVisitor::maybeInsertImplicitOpForMemberBase(
+ Expr* baseExpr,
+ CheckBaseContext checkBaseContext,
+ bool& outNeedDeref)
{
- auto derefExpr = MaybeDereference(baseExpr);
+ auto derefExpr = maybeDereference(baseExpr, checkBaseContext);
if (derefExpr != baseExpr)
outNeedDeref = true;
@@ -4834,11 +4852,15 @@ Expr* SemanticsVisitor::maybeInsertImplicitOpForMemberBase(Expr* baseExpr, bool&
return baseExpr;
}
-Expr* SemanticsVisitor::checkBaseForMemberExpr(Expr* inBaseExpr, bool& outNeedDeref)
+Expr* SemanticsVisitor::checkBaseForMemberExpr(
+ Expr* inBaseExpr,
+ CheckBaseContext checkBaseContext,
+ bool& outNeedDeref)
{
auto baseExpr = inBaseExpr;
baseExpr = CheckTerm(baseExpr);
- return maybeInsertImplicitOpForMemberBase(baseExpr, outNeedDeref);
+
+ return maybeInsertImplicitOpForMemberBase(baseExpr, checkBaseContext, outNeedDeref);
}
Expr* SemanticsVisitor::checkGeneralMemberLookupExpr(MemberExpr* expr, Type* baseType)
@@ -4861,7 +4883,8 @@ Expr* SemanticsVisitor::checkGeneralMemberLookupExpr(MemberExpr* expr, Type* bas
Expr* SemanticsExprVisitor::visitMemberExpr(MemberExpr* expr)
{
bool needDeref = false;
- expr->baseExpression = checkBaseForMemberExpr(expr->baseExpression, needDeref);
+ expr->baseExpression =
+ checkBaseForMemberExpr(expr->baseExpression, CheckBaseContext::Member, needDeref);
if (!needDeref && as<DerefMemberExpr>(expr) && !as<PtrType>(expr->baseExpression->type))
{
diff --git a/source/slang/slang-check-impl.h b/source/slang/slang-check-impl.h
index 95ec872a5..460e87cb9 100644
--- a/source/slang/slang-check-impl.h
+++ b/source/slang/slang-check-impl.h
@@ -2654,8 +2654,6 @@ public:
//
//
- Expr* MaybeDereference(Expr* inExpr);
-
Expr* CheckMatrixSwizzleExpr(
MemberExpr* memberRefExpr,
Type* baseElementType,
@@ -2696,11 +2694,24 @@ public:
/// Perform checking operations required for the "base" expression of a member-reference like
/// `base.someField`
- Expr* checkBaseForMemberExpr(Expr* baseExpr, bool& outNeedDeref);
+ enum class CheckBaseContext
+ {
+ Member,
+ Subscript,
+ };
+ Expr* checkBaseForMemberExpr(
+ Expr* baseExpr,
+ CheckBaseContext checkBaseContext,
+ bool& outNeedDeref);
+
+ Expr* maybeDereference(Expr* inExpr, CheckBaseContext checkBaseContext);
/// Prepare baseExpr for use as the base of a member expr.
/// This include inserting implicit open-existential operations as needed.
- Expr* maybeInsertImplicitOpForMemberBase(Expr* baseExpr, bool& outNeedDeref);
+ Expr* maybeInsertImplicitOpForMemberBase(
+ Expr* baseExpr,
+ CheckBaseContext checkBaseContext,
+ bool& outNeedDeref);
Expr* lookupMemberResultFailure(
DeclRefExpr* expr,
diff --git a/source/slang/slang-check-type.cpp b/source/slang/slang-check-type.cpp
index d9691a828..2c8f3d0c0 100644
--- a/source/slang/slang-check-type.cpp
+++ b/source/slang/slang-check-type.cpp
@@ -216,9 +216,10 @@ bool isManagedType(Type* type)
{
if (auto declRefValueType = as<DeclRefType>(type))
{
- if (as<ClassDecl>(declRefValueType->getDeclRef().getDecl()))
+ auto decl = declRefValueType->getDeclRef().getDecl();
+ if (as<ClassDecl>(decl))
return true;
- if (as<InterfaceDecl>(declRefValueType->getDeclRef().getDecl()))
+ if (as<InterfaceDecl>(decl) && decl->findModifier<ComInterfaceAttribute>())
return true;
}
return false;
diff --git a/source/slang/slang-check.h b/source/slang/slang-check.h
index bd2bdce41..f1392e9ce 100644
--- a/source/slang/slang-check.h
+++ b/source/slang/slang-check.h
@@ -24,6 +24,7 @@ bool isFromCoreModule(Decl* decl);
void registerBuiltinDecls(Session* session, Decl* decl);
Type* unwrapArrayType(Type* type);
+Type* unwrapModifiedType(Type* type);
OrderedDictionary<GenericTypeParamDeclBase*, List<Type*>> getCanonicalGenericConstraints(
ASTBuilder* builder,
diff --git a/source/slang/slang-ir-lower-buffer-element-type.cpp b/source/slang/slang-ir-lower-buffer-element-type.cpp
index bd3e350bc..dd62ca02c 100644
--- a/source/slang/slang-ir-lower-buffer-element-type.cpp
+++ b/source/slang/slang-ir-lower-buffer-element-type.cpp
@@ -901,28 +901,6 @@ struct LoweredElementTypeContext
{
builder.setInsertBefore(ptrVal);
auto newArrayPtrVal = fieldAddr->getBase();
- // Is base a pointer to an empty struct? If so, don't offset it.
- // For example, if the user has written:
- // ```
- // struct S {int arr[]};
- // uniform S* p;
- // void test() { p->arr[1]; }
- // ```
- // Then `S` will become an empty struct after we remove `arr[]`.
- // And `p` will be come a `void*`.
- // We don't want to offset `p` to `p+1` to get the starting address of
- // the array in this case.
- IRSizeAndAlignment parentStructSize = {};
- getNaturalSizeAndAlignment(
- target->getOptionSet(),
- tryGetPointedToType(&builder, fieldAddr->getBase()->getDataType()),
- &parentStructSize);
- if (parentStructSize.size != 0)
- {
- newArrayPtrVal = builder.emitGetOffsetPtr(
- fieldAddr->getBase(),
- builder.getIntValue(builder.getIntType(), 1));
- }
auto loweredInnerType =
getLoweredTypeInfo(unsizedArrayType->getElementType(), layoutRules);
diff --git a/tests/bugs/gh-3825.slang b/tests/bugs/gh-3825.slang
index c7c325864..5953a858b 100644
--- a/tests/bugs/gh-3825.slang
+++ b/tests/bugs/gh-3825.slang
@@ -21,7 +21,6 @@ float4 fragment(): SV_Target
}
// CHECK: OpDecorate %_ptr_PhysicalStorageBuffer_Descriptors_natural ArrayStride 4
-// CHECK: %{{.*}} = OpPtrAccessChain %_ptr_PhysicalStorageBuffer_Descriptors_natural %{{.*}} %int_1
// CHECK: OpBitcast %ulong
// CHECK: OpIAdd %ulong %{{.*}} %ulong_4
// CHECK: OpBitcast %_ptr_PhysicalStorageBuffer \ No newline at end of file
diff --git a/tests/spirv/existential-ptr.slang b/tests/spirv/existential-ptr.slang
new file mode 100644
index 000000000..66f1c64a2
--- /dev/null
+++ b/tests/spirv/existential-ptr.slang
@@ -0,0 +1,37 @@
+//TEST:COMPARE_COMPUTE(filecheck-buffer=CHECK): -vk -emit-spirv-directly -output-using-type
+//DISABLED_TEST:COMPARE_COMPUTE(filecheck-buffer=CHECK): -wgpu
+//DISABLED_TEST:COMPARE_COMPUTE(filecheck-buffer=CHECK): -d3d12
+//DISABLED_TEST:COMPARE_COMPUTE(filecheck-buffer=CHECK): -d3d11
+//DISABLED_TEST:COMPARE_COMPUTE(filecheck-buffer=CHECK): -metal
+
+interface IFoo
+{
+ int getVal();
+}
+
+struct Foo : IFoo
+{
+ int val;
+ int getVal() { return val; }
+}
+
+struct Bar : IFoo
+{
+ float val;
+ int getVal() { return (int)val + 1; }
+}
+
+//TEST_INPUT: set pFoo = ubuffer(data=[0 0 2 0 2.0f], stride=4);
+//TEST_INPUT: type_conformance Foo:IFoo = 1;
+//TEST_INPUT: type_conformance Bar:IFoo = 2;
+uniform IFoo* pFoo;
+
+//TEST_INPUT: set outputBuffer = out ubuffer(data=[0 0 0 0], stride=4);
+RWStructuredBuffer<float> outputBuffer;
+
+[numthreads(1,1,1)]
+void computeMain()
+{
+ // CHECK: 3.0
+ outputBuffer[0] = pFoo->getVal();
+} \ No newline at end of file
diff --git a/tests/spirv/ptr-member-func.slang b/tests/spirv/ptr-member-func.slang
new file mode 100644
index 000000000..0dcf572ee
--- /dev/null
+++ b/tests/spirv/ptr-member-func.slang
@@ -0,0 +1,29 @@
+//TEST:COMPARE_COMPUTE(filecheck-buffer=CHECK): -vk -emit-spirv-directly
+//DISABLED_TEST: COMPARE_COMPUTE(filecheck-buffer=CHECK): -wgpu
+//DISABLED_TEST: COMPARE_COMPUTE(filecheck-buffer=CHECK): -d3d11
+//DISABLED_TEST: COMPARE_COMPUTE(filecheck-buffer=CHECK): -d3d12
+//DISABLED_TEST: COMPARE_COMPUTE(filecheck-buffer=CHECK): -metal
+
+struct Obj
+{
+ int val;
+
+ [mutating]
+ void addOne() { val++; }
+
+ int getValPlusOne() { return val + 1; }
+}
+
+//TEST_INPUT: set pObj = ubuffer(data=[2 0 0 0], stride=4);
+uniform Obj* pObj;
+
+//TEST_INPUT: set outputBuffer = out ubuffer(data=[0 0 0 0],stride=4);
+uniform RWStructuredBuffer<uint> outputBuffer;
+
+[numthreads(1,1,1)]
+void computeMain()
+{
+ pObj->addOne();
+ // CHECK: 4
+ outputBuffer[0] = pObj->getValPlusOne();
+} \ No newline at end of file
diff --git a/tests/spirv/ptr-unsized-array-3.slang b/tests/spirv/ptr-unsized-array-3.slang
new file mode 100644
index 000000000..ffd1345ea
--- /dev/null
+++ b/tests/spirv/ptr-unsized-array-3.slang
@@ -0,0 +1,29 @@
+//TEST:COMPARE_COMPUTE(filecheck-buffer=CHECK): -vk -emit-spirv-directly
+//DISABLED_TEST: COMPARE_COMPUTE(filecheck-buffer=CHECK): -wgpu
+//DISABLED_TEST: COMPARE_COMPUTE(filecheck-buffer=CHECK): -d3d11
+//DISABLED_TEST: COMPARE_COMPUTE(filecheck-buffer=CHECK): -d3d12
+//DISABLED_TEST: COMPARE_COMPUTE(filecheck-buffer=CHECK): -metal
+
+// Test a pointer to a struct with a trailing unsized array.
+
+struct MeshStorage {
+ int foo;
+ uint64_t QuadData[];
+};
+
+//TEST_INPUT: set pStorage = ubuffer(data=[1 2 3 4 5 6 7 8],stride=4);
+uniform MeshStorage* pStorage;
+
+//TEST_INPUT: set outputBuffer = out ubuffer(data=[0 0 0 0],stride=4);
+uniform RWStructuredBuffer<uint> outputBuffer;
+
+[numthreads(1,1,1)]
+void computeMain()
+{
+ // CHECK: 5
+ // CHECK: 6
+ // CHECK: 1
+ outputBuffer[0] = (int)(pStorage.QuadData[1]&0xFFFFFFFF);
+ outputBuffer[1] = (int)(pStorage.QuadData[1]>>32);
+ outputBuffer[2] = pStorage.foo;
+} \ No newline at end of file
diff --git a/tests/spirv/ptr-unsized-array-4.slang b/tests/spirv/ptr-unsized-array-4.slang
new file mode 100644
index 000000000..561dfab22
--- /dev/null
+++ b/tests/spirv/ptr-unsized-array-4.slang
@@ -0,0 +1,25 @@
+//TEST:COMPARE_COMPUTE(filecheck-buffer=CHECK): -vk -emit-spirv-directly
+//DISABLED_TEST: COMPARE_COMPUTE(filecheck-buffer=CHECK): -wgpu
+//DISABLED_TEST: COMPARE_COMPUTE(filecheck-buffer=CHECK): -d3d11
+//DISABLED_TEST: COMPARE_COMPUTE(filecheck-buffer=CHECK): -d3d12
+//DISABLED_TEST: COMPARE_COMPUTE(filecheck-buffer=CHECK): -metal
+
+// Test a pointer to a struct that has only one field and is an unsized array.
+struct MeshStorage {
+ uint64_t QuadData[];
+};
+
+//TEST_INPUT: set pStorage = ubuffer(data=[1 2 3 4 5 6 7 8],stride=4);
+uniform MeshStorage* pStorage;
+
+//TEST_INPUT: set outputBuffer = out ubuffer(data=[0 0 0 0],stride=4);
+uniform RWStructuredBuffer<uint> outputBuffer;
+
+[numthreads(1,1,1)]
+void computeMain()
+{
+ // CHECK: 3
+ // CHECK: 4
+ outputBuffer[0] = (int)(pStorage.QuadData[1]&0xFFFFFFFF);
+ outputBuffer[1] = (int)(pStorage.QuadData[1]>>32);
+} \ No newline at end of file
diff --git a/tools/render-test/render-test-main.cpp b/tools/render-test/render-test-main.cpp
index 2e07a7689..5907be66d 100644
--- a/tools/render-test/render-test-main.cpp
+++ b/tools/render-test/render-test-main.cpp
@@ -76,6 +76,12 @@ struct ShaderOutputPlan
List<Item> items;
};
+// A context for hodling resources allocated for a test.
+struct TestResourceContext
+{
+ List<ComPtr<IResource>> resources;
+};
+
class RenderTestApp
{
public:
@@ -134,6 +140,7 @@ protected:
Options m_options;
ShaderOutputPlan m_outputPlan;
+ TestResourceContext m_resourceContext;
};
struct AssignValsFromLayoutContext
@@ -141,6 +148,7 @@ struct AssignValsFromLayoutContext
IDevice* device;
slang::ISession* slangSession;
ShaderOutputPlan& outputPlan;
+ TestResourceContext& resourceContext;
slang::ProgramLayout* slangReflection;
IAccelerationStructure* accelerationStructure;
@@ -148,11 +156,13 @@ struct AssignValsFromLayoutContext
IDevice* device,
slang::ISession* slangSession,
ShaderOutputPlan& outputPlan,
+ TestResourceContext& resourceContext,
slang::ProgramLayout* slangReflection,
IAccelerationStructure* accelerationStructure)
: device(device)
, slangSession(slangSession)
, outputPlan(outputPlan)
+ , resourceContext(resourceContext)
, slangReflection(slangReflection)
, accelerationStructure(accelerationStructure)
{
@@ -204,6 +214,7 @@ struct AssignValsFromLayoutContext
bufferData.add(0);
ComPtr<IBuffer> bufferResource;
+
SLANG_RETURN_ON_FAIL(ShaderRendererUtil::createBuffer(
srcBuffer,
/*entry.isOutput,*/ bufferSize,
@@ -211,6 +222,16 @@ struct AssignValsFromLayoutContext
device,
bufferResource));
+ if (dstCursor.getTypeLayout()->getType()->getKind() == slang::TypeReflection::Kind::Pointer)
+ {
+ // dstCursor is pointer to an ordinary uniform data field,
+ // we should write bufferResource as a pointer.
+ uint64_t addr = bufferResource->getDeviceAddress();
+ dstCursor.setData(&addr, sizeof(addr));
+ resourceContext.resources.add(ComPtr<IResource>(bufferResource.get()));
+ return SLANG_OK;
+ }
+
ComPtr<IBuffer> counterResource;
const auto explicitCounterCursor = dstCursor.getExplicitCounter();
if (srcBuffer.counter != ~0u)
@@ -488,11 +509,17 @@ SlangResult _assignVarsFromLayout(
IShaderObject* shaderObject,
ShaderInputLayout const& layout,
ShaderOutputPlan& ioOutputPlan,
+ TestResourceContext& ioResourceContext,
slang::ProgramLayout* slangReflection,
IAccelerationStructure* accelerationStructure)
{
- AssignValsFromLayoutContext
- context(device, slangSession, ioOutputPlan, slangReflection, accelerationStructure);
+ AssignValsFromLayoutContext context(
+ device,
+ slangSession,
+ ioOutputPlan,
+ ioResourceContext,
+ slangReflection,
+ accelerationStructure);
ShaderCursor rootCursor = ShaderCursor(shaderObject);
return context.assign(rootCursor, layout.rootVal);
}
@@ -510,6 +537,7 @@ Result RenderTestApp::applyBinding(IShaderObject* rootObject)
rootObject,
m_compilationOutput.layout,
m_outputPlan,
+ m_resourceContext,
slangReflection,
m_topLevelAccelerationStructure);
}