summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2025-02-28 22:46:56 -0800
committerGitHub <noreply@github.com>2025-02-28 22:46:56 -0800
commitdd9d24d29c4a9e05a4510eb9959fafa0ed36618b (patch)
tree240e2e4ecd8fc15fa835db4377670ec7fdf90e71
parent700c38ae7c16a49de7f720ae3b1940df5b2b4b33 (diff)
Allow partial specialization of existential arguments. (#6487)
* Allow partial specialization of existential arguments. * Fix. * Add test case for improved diagnostics. * Fix compile error. * Fix tests. * Fix. * Fix test. * Fix compile issue. * Fix typo. * Address comment.
-rw-r--r--source/slang/slang-diagnostic-defs.h6
-rw-r--r--source/slang/slang-ir-generics-lowering-context.cpp13
-rw-r--r--source/slang/slang-ir-generics-lowering-context.h3
-rw-r--r--source/slang/slang-ir-layout.cpp1
-rw-r--r--source/slang/slang-ir-legalize-types.cpp8
-rw-r--r--source/slang/slang-ir-lower-reinterpret.cpp2
-rw-r--r--source/slang/slang-ir-specialize.cpp111
-rw-r--r--source/slang/slang-ir-util.h1
-rw-r--r--source/slang/slang-ir-witness-table-wrapper.cpp40
-rw-r--r--source/slang/slang-legalize-types.cpp30
-rw-r--r--source/slang/slang-legalize-types.h2
-rw-r--r--tests/bugs/gh-6482-interface-method-existential-specialize.slang93
-rw-r--r--tests/diagnostics/interfaces/anyvalue-size-validation.slang.expected2
-rw-r--r--tests/diagnostics/resource-type-in-dynamic-dispatch.slang28
14 files changed, 270 insertions, 70 deletions
diff --git a/source/slang/slang-diagnostic-defs.h b/source/slang/slang-diagnostic-defs.h
index 761dc4768..85ee545a4 100644
--- a/source/slang/slang-diagnostic-defs.h
+++ b/source/slang/slang-diagnostic-defs.h
@@ -2203,12 +2203,12 @@ DIAGNOSTIC(
Error,
typeDoesNotFitAnyValueSize,
"type '$0' does not fit in the size required by its conforming interface.")
-DIAGNOSTIC(41012, Note, typeAndLimit, "sizeof($0) is $1, limit is $2")
+DIAGNOSTIC(-1, Note, typeAndLimit, "sizeof($0) is $1, limit is $2")
DIAGNOSTIC(
- 41012,
+ 41014,
Error,
typeCannotBePackedIntoAnyValue,
- "type '$0' contains fields that cannot be packed into an AnyValue.")
+ "type '$0' contains fields that cannot be packed into ordinary bytes for dynamic dispatch.")
DIAGNOSTIC(
41020,
Error,
diff --git a/source/slang/slang-ir-generics-lowering-context.cpp b/source/slang/slang-ir-generics-lowering-context.cpp
index 097fa58b8..199eae2fc 100644
--- a/source/slang/slang-ir-generics-lowering-context.cpp
+++ b/source/slang/slang-ir-generics-lowering-context.cpp
@@ -405,12 +405,23 @@ bool SharedGenericsLoweringContext::doesTypeFitInAnyValue(
IRType* concreteType,
IRInterfaceType* interfaceType,
IRIntegerValue* outTypeSize,
- IRIntegerValue* outLimit)
+ IRIntegerValue* outLimit,
+ bool* outIsTypeOpaque)
{
auto anyValueSize = getInterfaceAnyValueSize(interfaceType, interfaceType->sourceLoc);
if (outLimit)
*outLimit = anyValueSize;
+ if (!areResourceTypesBindlessOnTarget(targetProgram->getTargetReq()))
+ {
+ IRType* opaqueType = nullptr;
+ if (isOpaqueType(concreteType, &opaqueType))
+ {
+ if (outIsTypeOpaque)
+ *outIsTypeOpaque = true;
+ return false;
+ }
+ }
IRSizeAndAlignment sizeAndAlignment;
Result result =
getNaturalSizeAndAlignment(targetProgram->getOptionSet(), concreteType, &sizeAndAlignment);
diff --git a/source/slang/slang-ir-generics-lowering-context.h b/source/slang/slang-ir-generics-lowering-context.h
index baa1f7757..62848c4b7 100644
--- a/source/slang/slang-ir-generics-lowering-context.h
+++ b/source/slang/slang-ir-generics-lowering-context.h
@@ -98,7 +98,8 @@ struct SharedGenericsLoweringContext
IRType* concreteType,
IRInterfaceType* interfaceType,
IRIntegerValue* outTypeSize = nullptr,
- IRIntegerValue* outLimit = nullptr);
+ IRIntegerValue* outLimit = nullptr,
+ bool* outIsTypeOpaque = nullptr);
};
List<IRWitnessTable*> getWitnessTablesFromInterfaceType(IRModule* module, IRInst* interfaceType);
diff --git a/source/slang/slang-ir-layout.cpp b/source/slang/slang-ir-layout.cpp
index 7ce19bf67..558877aaf 100644
--- a/source/slang/slang-ir-layout.cpp
+++ b/source/slang/slang-ir-layout.cpp
@@ -341,6 +341,7 @@ static Result _calcSizeAndAlignment(
case kIROp_ComPtrType:
case kIROp_NativeStringType:
case kIROp_HLSLConstBufferPointerType:
+ case kIROp_RaytracingAccelerationStructureType:
{
*outSizeAndAlignment = IRSizeAndAlignment(kPointerSize, kPointerSize);
return SLANG_OK;
diff --git a/source/slang/slang-ir-legalize-types.cpp b/source/slang/slang-ir-legalize-types.cpp
index 9b857f899..3d6f18569 100644
--- a/source/slang/slang-ir-legalize-types.cpp
+++ b/source/slang/slang-ir-legalize-types.cpp
@@ -2053,15 +2053,13 @@ static LegalVal coerceToLegalType(IRTypeLegalizationContext* context, LegalType
static LegalVal legalizeUndefined(IRTypeLegalizationContext* context, IRInst* inst)
{
- List<IRType*> opaqueTypes;
- if (isOpaqueType(inst->getFullType(), opaqueTypes))
+ IRType* opaqueType = nullptr;
+ if (isOpaqueType(inst->getFullType(), &opaqueType))
{
- auto opaqueType = opaqueTypes[0];
- auto containerType = opaqueTypes.getCount() > 1 ? opaqueTypes[1] : opaqueType;
SourceLoc loc = findBestSourceLocFromUses(inst);
if (!loc.isValid())
- loc = getDiagnosticPos(containerType);
+ loc = getDiagnosticPos(opaqueType);
context->m_sink->diagnose(loc, Diagnostics::useOfUninitializedOpaqueHandle, opaqueType);
}
diff --git a/source/slang/slang-ir-lower-reinterpret.cpp b/source/slang/slang-ir-lower-reinterpret.cpp
index 1b733d832..88c288f28 100644
--- a/source/slang/slang-ir-lower-reinterpret.cpp
+++ b/source/slang/slang-ir-lower-reinterpret.cpp
@@ -79,7 +79,7 @@ struct ReinterpretLoweringContext
Slang::Diagnostics::typeCannotBePackedIntoAnyValue,
toType);
}
- if (fromTypeSize != toTypeSize && cantPack == false)
+ if (fromTypeSize != toTypeSize && !cantPack && !as<IRExtractExistentialType>(fromType))
{
sink->diagnose(
inst->sourceLoc,
diff --git a/source/slang/slang-ir-specialize.cpp b/source/slang/slang-ir-specialize.cpp
index a9b0d4412..a200a907b 100644
--- a/source/slang/slang-ir-specialize.cpp
+++ b/source/slang/slang-ir-specialize.cpp
@@ -1373,11 +1373,16 @@ struct SpecializationContext
if (!isExistentialType(param->getDataType()))
continue;
+ // Is arg in the most simplified form for specialization? If not we are
+ // not ready to consider specialization yet.
+ if (!isSimplifiedExistentialArg(arg))
+ return false;
+
// We *cannot* specialize unless the argument value corresponding
// to such a parameter is one we can specialize.
//
if (!canSpecializeExistentialArg(arg))
- return false;
+ continue;
argumentNeedSpecialization = true;
}
@@ -1416,7 +1421,6 @@ struct SpecializationContext
auto arg = inst->getArg(argCounter++);
if (!isExistentialType(param->getDataType()))
continue;
-
if (auto makeExistential = as<IRMakeExistential>(arg))
{
// Note that we use the *type* stored in the
@@ -1426,25 +1430,32 @@ struct SpecializationContext
// call sites that pass in the exact same argument).
//
auto val = makeExistential->getWrappedValue();
- auto valType = val->getFullType();
- key.vals.add(valType);
-
- // We are also including the witness table in the key.
- // This isn't required with our current language model,
- // since a given type can only conform to a given interface
- // in one way (so there can be only one witness table).
- // That means that the `valType` and the existential
- // type of `param` above should uniquely determine
- // the witness table we see.
- //
- // There are forward-looking cases where supporting
- // "overlapping conformances" could be required, and
- // there is low incremental cost to future-proofing
- // this code, so we go ahead and add the witness
- // table even if it is redundant.
- //
- auto witnessTable = makeExistential->getWitnessTable();
- key.vals.add(witnessTable);
+ auto valType = val->getDataType();
+ if (isCompileTimeConstantType(valType))
+ {
+ key.vals.add(valType);
+
+ // We are also including the witness table in the key.
+ // This isn't required with our current language model,
+ // since a given type can only conform to a given interface
+ // in one way (so there can be only one witness table).
+ // That means that the `valType` and the existential
+ // type of `param` above should uniquely determine
+ // the witness table we see.
+ //
+ // There are forward-looking cases where supporting
+ // "overlapping conformances" could be required, and
+ // there is low incremental cost to future-proofing
+ // this code, so we go ahead and add the witness
+ // table even if it is redundant.
+ //
+ auto witnessTable = makeExistential->getWitnessTable();
+ key.vals.add(witnessTable);
+ }
+ else
+ {
+ key.vals.add(param->getDataType());
+ }
}
else if (auto wrapExistential = as<IRWrapExistential>(arg))
{
@@ -1508,7 +1519,11 @@ struct SpecializationContext
if (auto makeExistential = as<IRMakeExistential>(arg))
{
auto val = makeExistential->getWrappedValue();
- newArgs.add(val);
+ auto valType = val->getDataType();
+ if (isCompileTimeConstantType(valType))
+ newArgs.add(val);
+ else
+ newArgs.add(arg);
}
else if (auto wrapExistential = as<IRWrapExistential>(arg))
{
@@ -1634,6 +1649,18 @@ struct SpecializationContext
return true;
}
+
+ // Returns true if `inst` is a simplified existential argument ready for specialization.
+ bool isSimplifiedExistentialArg(IRInst* inst)
+ {
+ if (as<IRMakeExistential>(inst))
+ return true;
+ if (as<IRWrapExistential>(inst))
+ return true;
+ return false;
+ }
+
+
// Similarly, we want to be able to test whether an instruction
// used as an argument for an existential-type parameter is
// suitable for use in specialization.
@@ -1760,21 +1787,33 @@ struct SpecializationContext
// created.
//
auto valType = val->getFullType();
- auto newParam = builder->createParam(valType);
- newParams.add(newParam);
+ if (auto extractExistentialType = as<IRExtractExistentialType>(valType))
+ {
+ valType = extractExistentialType->getOperand(0)->getDataType();
+ auto newParam = builder->createParam(valType);
+ newParams.add(newParam);
+ replacementVal = newParam;
+ }
+ else
+ {
+ auto newParam = builder->createParam(valType);
+ newParams.add(newParam);
- // Within the body of the function we cannot just use `val`
- // directly, because the existing code expects an existential
- // value, including its witness table.
- //
- // Therefore we will create a `makeExistential(newParam, witnessTable)`
- // in the body of the new function and use *that* as the replacement
- // value for the original parameter (since it will have the
- // correct existential type, and stores the right witness table).
- //
- auto newMakeExistential =
- builder->emitMakeExistential(oldParam->getFullType(), newParam, witnessTable);
- replacementVal = newMakeExistential;
+ // Within the body of the function we cannot just use `val`
+ // directly, because the existing code expects an existential
+ // value, including its witness table.
+ //
+ // Therefore we will create a `makeExistential(newParam, witnessTable)`
+ // in the body of the new function and use *that* as the replacement
+ // value for the original parameter (since it will have the
+ // correct existential type, and stores the right witness table).
+ //
+ auto newMakeExistential = builder->emitMakeExistential(
+ oldParam->getFullType(),
+ newParam,
+ witnessTable);
+ replacementVal = newMakeExistential;
+ }
}
else if (auto oldWrapExistential = as<IRWrapExistential>(arg))
{
diff --git a/source/slang/slang-ir-util.h b/source/slang/slang-ir-util.h
index aed63da47..549981f58 100644
--- a/source/slang/slang-ir-util.h
+++ b/source/slang/slang-ir-util.h
@@ -83,6 +83,7 @@ IRType* getMatrixElementType(IRType* type);
// True if type is a resource backing memory
bool isResourceType(IRType* type);
+bool isOpaqueType(IRType* type, IRType** outLeafOpaqueHandleType);
// True if type is a pointer to a resource
bool isPointerToResourceType(IRType* type);
diff --git a/source/slang/slang-ir-witness-table-wrapper.cpp b/source/slang/slang-ir-witness-table-wrapper.cpp
index f835b68c3..fabfd1611 100644
--- a/source/slang/slang-ir-witness-table-wrapper.cpp
+++ b/source/slang/slang-ir-witness-table-wrapper.cpp
@@ -190,19 +190,35 @@ struct GenerateWitnessTableWrapperContext
//
auto concreteType = witnessTable->getConcreteType();
IRIntegerValue typeSize, sizeLimit;
- if (!sharedContext
- ->doesTypeFitInAnyValue(concreteType, interfaceType, &typeSize, &sizeLimit))
- {
- sharedContext->sink->diagnose(
- concreteType,
- Diagnostics::typeDoesNotFitAnyValueSize,
- concreteType);
- sharedContext->sink->diagnoseWithoutSourceView(
- concreteType,
- Diagnostics::typeAndLimit,
+ bool isTypeOpaque = false;
+ if (!sharedContext->doesTypeFitInAnyValue(
concreteType,
- typeSize,
- sizeLimit);
+ interfaceType,
+ &typeSize,
+ &sizeLimit,
+ &isTypeOpaque))
+ {
+ HashSet<IRType*> visited;
+ if (isTypeOpaque)
+ {
+ sharedContext->sink->diagnose(
+ concreteType,
+ Diagnostics::typeCannotBePackedIntoAnyValue,
+ concreteType);
+ }
+ else
+ {
+ sharedContext->sink->diagnose(
+ concreteType,
+ Diagnostics::typeDoesNotFitAnyValueSize,
+ concreteType);
+ sharedContext->sink->diagnoseWithoutSourceView(
+ concreteType,
+ Diagnostics::typeAndLimit,
+ concreteType,
+ typeSize,
+ sizeLimit);
+ }
return;
}
diff --git a/source/slang/slang-legalize-types.cpp b/source/slang/slang-legalize-types.cpp
index 7d107bbce..e26475522 100644
--- a/source/slang/slang-legalize-types.cpp
+++ b/source/slang/slang-legalize-types.cpp
@@ -198,31 +198,40 @@ bool isResourceType(IRType* type)
return false;
}
-bool isOpaqueType(IRType* type, List<IRType*>& opaqueTypes)
+
+bool isOpaqueTypeImpl(IRType* type, HashSet<IRType*>& visited, IRType** outLeafOpaqueHandleType)
{
+ if (visited.contains(type))
+ {
+ if (outLeafOpaqueHandleType)
+ *outLeafOpaqueHandleType = type;
+ return true;
+ }
+
if (isResourceType(type))
{
- opaqueTypes.add(type);
+ if (outLeafOpaqueHandleType)
+ *outLeafOpaqueHandleType = type;
return true;
}
if (auto structType = as<IRStructType>(type))
{
+ visited.add(type);
for (auto field : structType->getFields())
{
- if (isOpaqueType(field->getFieldType(), opaqueTypes))
+ if (isOpaqueTypeImpl(field->getFieldType(), visited, outLeafOpaqueHandleType))
{
- opaqueTypes.add(type);
return true;
}
}
+ visited.remove(type);
}
if (auto arrayType = as<IRArrayTypeBase>(type))
{
- if (isOpaqueType(arrayType->getElementType(), opaqueTypes))
+ if (isOpaqueTypeImpl(arrayType->getElementType(), visited, outLeafOpaqueHandleType))
{
- opaqueTypes.add(type);
return true;
}
}
@@ -233,9 +242,8 @@ bool isOpaqueType(IRType* type, List<IRType*>& opaqueTypes)
{
if (auto elementType = as<IRType>(tupleType->getOperand(i)))
{
- if (isOpaqueType(elementType, opaqueTypes))
+ if (isOpaqueTypeImpl(elementType, visited, outLeafOpaqueHandleType))
{
- opaqueTypes.add(type);
return true;
}
}
@@ -245,6 +253,12 @@ bool isOpaqueType(IRType* type, List<IRType*>& opaqueTypes)
return false;
}
+bool isOpaqueType(IRType* type, IRType** outLeafOpaqueHandleType)
+{
+ HashSet<IRType*> visited;
+ return isOpaqueTypeImpl(type, visited, outLeafOpaqueHandleType);
+}
+
SourceLoc findBestSourceLocFromUses(IRInst* inst)
{
for (auto use = inst->firstUse; use; use = use->nextUse)
diff --git a/source/slang/slang-legalize-types.h b/source/slang/slang-legalize-types.h
index 17498ce08..ae76cbd39 100644
--- a/source/slang/slang-legalize-types.h
+++ b/source/slang/slang-legalize-types.h
@@ -703,8 +703,6 @@ void legalizeEmptyTypes(TargetProgram* target, IRModule* module, DiagnosticSink*
bool isResourceType(IRType* type);
-bool isOpaqueType(IRType* type, List<IRType*>& opaqueTypes);
-
SourceLoc findBestSourceLocFromUses(IRInst* inst);
} // namespace Slang
diff --git a/tests/bugs/gh-6482-interface-method-existential-specialize.slang b/tests/bugs/gh-6482-interface-method-existential-specialize.slang
new file mode 100644
index 000000000..d01e5b7ff
--- /dev/null
+++ b/tests/bugs/gh-6482-interface-method-existential-specialize.slang
@@ -0,0 +1,93 @@
+//TEST:SIMPLE(filecheck=CHECK): -target spirv
+
+// This is a test that checks that we can apply partial specialization to a function
+// that takes existential parameters.
+
+// CHECK: OpRayQueryProceedKHR
+// CHECK: OpImageWrite
+
+public interface IRandom {
+ [mutating] uint32_t next_uint();
+ [mutating] float next_float();
+};
+
+public struct TEA : IRandom {
+ int val;
+ public __init(uint32_t v0, uint32_t v1) {}
+ [mutating] public uint32_t next_uint() {
+ return val++;
+ }
+ [mutating] public float next_float() {
+ return val++;
+ }
+};
+
+public interface IScene {
+ property RaytracingAccelerationStructure as;
+};
+
+// The Scene type contains a resource field, if dynamic dispatch code were generated
+// for this type, we will get a compile error.
+struct Scene : IScene {
+ RaytracingAccelerationStructure as;
+};
+
+public interface IIntegrator {
+
+ // This function takes two existential parameters, `scene` and `rng`.
+ // if we call this function with `rng` being dynamic, and `scene` being static,
+ // we should still be able to specialize the `sample` function with the statically known
+ // type of `scene`.
+ public float3 sample(IScene scene, RayDesc ray, IRandom rng);
+};
+namespace integrator {
+ public struct NoShading : IIntegrator {
+ public float3 sample(IScene scene, RayDesc _ray, IRandom rng) {
+ return float3( 0.0f, 0.0f, 0.0f );
+ }
+ };
+
+ struct Test {
+
+ float4 sample(IScene scene, RayDesc _ray, IRandom rng, int pixel_aabb_uv, uint3 id) {
+ float4 grad = { 0.0f, 0.0f, 0.0f, 1.0f };
+
+ RayDesc ray = _ray; uint32_t depth = 0;
+ RayQuery<0> rayQuery;
+ while (rayQuery.Proceed()) {
+ float rand = rng.next_float();
+
+ IIntegrator integrator = integrator::NoShading();
+
+ // Here `rng` is mutating in the loop, so its type is dynamic and we need
+ // to generate dynamic dispatch code around it.
+ // But this shouldn't result in `scene` being dynamic as well.
+ // We should still be able to specialize `integrator.sample` with the statically
+ // known type of `scene`.
+ // If this doesn't happen, then the compiler will try to synthesize dynamic dispatch
+ // logic for `scene` and fail to compile.
+ let color_in = integrator.sample(scene, {}, rng);
+ let color_out = integrator.sample(scene, {}, rng);
+ grad.xyz += color_out;
+ }
+ return grad;
+ }
+ };
+};
+
+[[vk::binding(0, 0)]] RWTexture2D<float4> output;
+
+[vk::constant_id(0)] const int WGS_X = 1;
+[vk::constant_id(1)] const int WGS_Y = 1;
+[shader("compute"), numthreads(WGS_X, WGS_Y, 1)]
+void main(
+ uint3 id : SV_DispatchThreadID
+)
+{
+ IRandom rng = TEA(id.y * id.x, 1);
+ Scene scene = { RaytracingAccelerationStructure(0) };
+
+ integrator::Test integrator = integrator::Test();
+ float4 grad = integrator.sample(scene, {}, rng, {}, id);
+ output[id.xy] += float4(grad.xyz, grad.w);
+} \ No newline at end of file
diff --git a/tests/diagnostics/interfaces/anyvalue-size-validation.slang.expected b/tests/diagnostics/interfaces/anyvalue-size-validation.slang.expected
index f2000909b..e94651671 100644
--- a/tests/diagnostics/interfaces/anyvalue-size-validation.slang.expected
+++ b/tests/diagnostics/interfaces/anyvalue-size-validation.slang.expected
@@ -3,7 +3,7 @@ standard error = {
tests/diagnostics/interfaces/anyvalue-size-validation.slang(11): error 41011: type 'S' does not fit in the size required by its conforming interface.
struct S : IInterface
^
-tests/diagnostics/interfaces/anyvalue-size-validation.slang(11): note 41012: sizeof(S) is 12, limit is 8
+tests/diagnostics/interfaces/anyvalue-size-validation.slang(11): note: sizeof(S) is 12, limit is 8
}
standard output = {
}
diff --git a/tests/diagnostics/resource-type-in-dynamic-dispatch.slang b/tests/diagnostics/resource-type-in-dynamic-dispatch.slang
new file mode 100644
index 000000000..114eb800b
--- /dev/null
+++ b/tests/diagnostics/resource-type-in-dynamic-dispatch.slang
@@ -0,0 +1,28 @@
+//TEST:SIMPLE(filecheck=CHECK): -target spirv
+// CHECK: type 'FooImpl' contains fields that cannot be packed into ordinary bytes for dynamic dispatch.
+interface IFoo
+{
+ float get();
+}
+
+export struct FooImpl : IFoo
+{
+ Texture2D t;
+ float get() { return 1.0; }
+}
+
+export struct FooImpl2 : IFoo
+{
+ float v;
+ float get() { return v; }
+}
+
+
+RWStructuredBuffer<float> outputBuffer;
+
+[numthreads(1,1,1)]
+void computeMain()
+{
+ IFoo f = createDynamicObject<IFoo>(0, 5);
+ outputBuffer[0] = f.get();
+} \ No newline at end of file