summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2024-06-11 23:58:25 -0700
committerGitHub <noreply@github.com>2024-06-11 23:58:25 -0700
commit3fe4a77287345c303aeb985e24ee237f272e8eca (patch)
tree59d5096a8f0e42286f8db2fb72d04f3db82f166f
parent5da06d43bb0997455211ca56597c4302b09909ab (diff)
Fix crash when using optional type in a generic. (#4341)
-rw-r--r--source/slang/slang-ir-lower-optional-type.cpp63
-rw-r--r--tests/bugs/optional-generic.slang22
-rw-r--r--tests/bugs/optional.slang42
3 files changed, 102 insertions, 25 deletions
diff --git a/source/slang/slang-ir-lower-optional-type.cpp b/source/slang/slang-ir-lower-optional-type.cpp
index ba2584976..272f04545 100644
--- a/source/slang/slang-ir-lower-optional-type.cpp
+++ b/source/slang/slang-ir-lower-optional-type.cpp
@@ -15,6 +15,10 @@ namespace Slang
InstWorkList workList;
InstHashSet workListSet;
+ IRGeneric* genericOptionalStructType = nullptr;
+ IRStructKey* valueKey = nullptr;
+ IRStructKey* hasValueKey = nullptr;
+
OptionalTypeLoweringContext(IRModule* inModule)
:module(inModule), workList(inModule), workListSet(inModule)
{}
@@ -24,8 +28,6 @@ namespace Slang
IRType* optionalType = nullptr;
IRType* valueType = nullptr;
IRType* loweredType = nullptr;
- IRStructField* valueField = nullptr;
- IRStructField* hasValueField = nullptr;
};
Dictionary<IRInst*, RefPtr<LoweredOptionalTypeInfo>> mapLoweredTypeToOptionalTypeInfo;
Dictionary<IRInst*, RefPtr<LoweredOptionalTypeInfo>> loweredOptionalTypes;
@@ -38,6 +40,34 @@ namespace Slang
return type;
}
+ IRInst* getOrCreateGenericOptionalStruct()
+ {
+ if (genericOptionalStructType)
+ return genericOptionalStructType;
+ IRBuilder builder(module);
+ builder.setInsertInto(module->getModuleInst());
+
+ valueKey = builder.createStructKey();
+ builder.addNameHintDecoration(valueKey, UnownedStringSlice("value"));
+ hasValueKey = builder.createStructKey();
+ builder.addNameHintDecoration(hasValueKey, UnownedStringSlice("hasValue"));
+
+ genericOptionalStructType = builder.emitGeneric();
+ builder.addNameHintDecoration(genericOptionalStructType, UnownedStringSlice("_slang_Optional"));
+
+ builder.setInsertInto(genericOptionalStructType);
+ auto block = builder.emitBlock();
+ auto typeParam = builder.emitParam(builder.getTypeKind());
+ auto structType = builder.createStructType();
+ builder.addNameHintDecoration(structType, UnownedStringSlice("_slang_Optional"));
+ builder.createStructField(structType, valueKey, (IRType*)typeParam);
+ builder.createStructField(structType, hasValueKey, builder.getBoolType());
+ builder.setInsertInto(block);
+ builder.emitReturn(structType);
+ genericOptionalStructType->setFullType(builder.getTypeKind());
+ return genericOptionalStructType;
+ }
+
bool typeHasNullValue(IRInst* type)
{
switch (type->getOp())
@@ -78,19 +108,10 @@ namespace Slang
}
else
{
- auto structType = builder->createStructType();
- info->loweredType = structType;
- builder->addNameHintDecoration(structType, UnownedStringSlice("OptionalType"));
-
- info->valueType = valueType;
- auto valueKey = builder->createStructKey();
- builder->addNameHintDecoration(valueKey, UnownedStringSlice("value"));
- info->valueField = builder->createStructField(structType, valueKey, (IRType*)valueType);
-
- auto boolType = builder->getBoolType();
- auto hasValueKey = builder->createStructKey();
- builder->addNameHintDecoration(hasValueKey, UnownedStringSlice("hasValue"));
- info->hasValueField = builder->createStructField(structType, hasValueKey, (IRType*)boolType);
+ auto genericType = getOrCreateGenericOptionalStruct();
+ IRInst* args[] = { valueType };
+ auto specializedType = builder->emitSpecializeInst(builder->getTypeKind(), genericType, 1, args);
+ info->loweredType = (IRType*)specializedType;
}
mapLoweredTypeToOptionalTypeInfo[info->loweredType] = info;
loweredOptionalTypes[type] = info;
@@ -100,12 +121,6 @@ namespace Slang
void addToWorkList(
IRInst* inst)
{
- for (auto ii = inst->getParent(); ii; ii = ii->getParent())
- {
- if (as<IRGeneric>(ii))
- return;
- }
-
if (workListSet.contains(inst))
return;
@@ -169,7 +184,7 @@ namespace Slang
result = builder->emitFieldExtract(
builder->getBoolType(),
optionalInst,
- loweredOptionalTypeInfo->hasValueField->getKey());
+ hasValueKey);
}
else
{
@@ -201,11 +216,10 @@ namespace Slang
if (loweredOptionalTypeInfo->loweredType != loweredOptionalTypeInfo->valueType)
{
SLANG_ASSERT(loweredOptionalTypeInfo);
- SLANG_ASSERT(loweredOptionalTypeInfo->valueField);
auto getElement = builder->emitFieldExtract(
loweredOptionalTypeInfo->valueType,
base,
- loweredOptionalTypeInfo->valueField->getKey());
+ valueKey);
inst->replaceUsesWith(getElement);
}
else
@@ -257,7 +271,6 @@ namespace Slang
while (workList.getCount() != 0)
{
IRInst* inst = workList.getLast();
-
workList.removeLast();
workListSet.remove(inst);
diff --git a/tests/bugs/optional-generic.slang b/tests/bugs/optional-generic.slang
new file mode 100644
index 000000000..16b466273
--- /dev/null
+++ b/tests/bugs/optional-generic.slang
@@ -0,0 +1,22 @@
+//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=BUF):-slang -compute
+//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=BUF):-slang -compute -vk
+
+
+Optional<T> genFunc<T : IArithmetic>(T v)
+{
+ if (v is int)
+ return v;
+ return none;
+}
+
+//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name buffer
+
+RWStructuredBuffer<int> buffer;
+
+[numthreads(1,1,1)]
+void computeMain()
+{
+ // BUF: 2
+ buffer[0] = genFunc(2).value;
+}
+
diff --git a/tests/bugs/optional.slang b/tests/bugs/optional.slang
new file mode 100644
index 000000000..3512ba29f
--- /dev/null
+++ b/tests/bugs/optional.slang
@@ -0,0 +1,42 @@
+//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=BUF):-slang -compute
+//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=BUF):-slang -compute -vk
+
+interface IFoo
+{
+ void foo();
+}
+
+struct S : IFoo { int x; void foo(); }
+
+struct P
+{
+ IFoo f;
+}
+struct Tr
+{
+ int test<T:IArithmetic>(T t, inout P p)
+ {
+ const IFoo hit = p.f;
+ let castResult = hit as S;
+ if (!castResult.hasValue)
+ return 0;
+ return castResult.value.x;
+ }
+}
+
+//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name buffer
+
+RWStructuredBuffer<int> buffer;
+
+[numthreads(1,1,1)]
+void computeMain()
+{
+ P p;
+ S s;
+ s.x = 2;
+ p.f = s;
+ Tr tt;
+ // BUF: 2
+ buffer[0] = tt.test(0, p);
+}
+